[ray_client] Integrate with test_basic, test_basic_2 and test_actor (#12964)

This commit is contained in:
Barak Michener 2020-12-20 14:54:18 -08:00 committed by GitHub
parent bf6577c8f4
commit 7ab9164f1b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 375 additions and 174 deletions

View file

@ -46,6 +46,7 @@ matrix:
script:
# bazel python tests for medium size tests. Used for parallelization.
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_a_to_j python/ray/tests/...; fi
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,client_tests --test_env=RAY_TEST_CLIENT_MODE=1 python/ray/tests/...; fi
- os: linux
env:

View file

@ -1,12 +1,14 @@
# py_test_module_list creates a py_test target for each
# Python file in `files`
def py_test_module_list(files, size, deps, extra_srcs, **kwargs):
def py_test_module_list(files, size, deps, extra_srcs, name_suffix="", **kwargs):
for file in files:
# remove .py
name = file[:-3]
name = file[:-3] + name_suffix
main = file
native.py_test(
name = name,
size = size,
main = file,
srcs = extra_srcs + [file],
**kwargs
)

View file

@ -4,6 +4,7 @@ from typing import Optional, List, Tuple
from contextlib import contextmanager
import logging
import os
logger = logging.getLogger(__name__)
@ -43,9 +44,11 @@ def stash_api_for_tests(in_test: bool):
is_server = _is_server
if in_test:
_is_server = True
yield _server_api
if in_test:
_is_server = is_server
try:
yield _server_api
finally:
if in_test:
_is_server = is_server
def _set_client_api(val: Optional[APIImpl]):
@ -77,18 +80,7 @@ def reset_api():
def _get_client_api() -> APIImpl:
global _client_api
global _server_api
global _is_server
api = None
if _is_server:
api = _server_api
else:
api = _client_api
if api is None:
# We're inside a raylet worker
from ray.experimental.client.server.core_ray_api import CoreRayAPI
return CoreRayAPI()
return api
return _client_api
def _get_server_instance():
@ -124,9 +116,33 @@ class RayAPIStub:
global _client_api
return _client_api is not None
def init(self, *args, **kwargs):
if _is_client_test_env():
global _test_server
import ray.experimental.client.server.server as ray_client_server
_test_server, address_info = ray_client_server.init_and_serve(
"localhost:50051", test_mode=True, *args, **kwargs)
self.connect("localhost:50051")
return address_info
else:
raise NotImplementedError(
"Please call ray.connect() in client mode")
ray = RayAPIStub()
_test_server = None
def _stop_test_server(*args):
global _test_server
_test_server.stop(*args)
def _is_client_test_env() -> bool:
return os.environ.get("RAY_TEST_CLIENT_MODE") == "1"
# Someday we might add methods in this module so that someone who
# tries to `import ray_client as ray` -- as a module, instead of
# `from ray_client import ray` -- as the API stub

View file

@ -28,11 +28,15 @@ import sys
from typing import NamedTuple
from typing import Any
from typing import Optional
from ray.experimental.client import RayAPIStub
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientActorRef
from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.common import ClientRemoteMethod
from ray.experimental.client.common import SelfReferenceSentinel
import ray.core.generated.ray_client_pb2 as ray_client_pb2
@ -44,8 +48,11 @@ if sys.version_info < (3, 8):
else:
import pickle # noqa: F401
PickleStub = NamedTuple("PickleStub", [("type", str), ("client_id", str),
("ref_id", bytes)])
# NOTE(barakmich): These PickleStubs are really close to
# the data for an exectuion, with no arguments. Combine the two?
PickleStub = NamedTuple("PickleStub",
[("type", str), ("client_id", str), ("ref_id", bytes),
("name", Optional[str])])
class ClientPickler(cloudpickle.CloudPickler):
@ -54,17 +61,26 @@ class ClientPickler(cloudpickle.CloudPickler):
self.client_id = client_id
def persistent_id(self, obj):
if isinstance(obj, ClientObjectRef):
if isinstance(obj, RayAPIStub):
return PickleStub(
type="Ray",
client_id=self.client_id,
ref_id=b"",
name=None,
)
elif isinstance(obj, ClientObjectRef):
return PickleStub(
type="Object",
client_id=self.client_id,
ref_id=obj.id,
name=None,
)
elif isinstance(obj, ClientActorHandle):
return PickleStub(
type="Actor",
client_id=self.client_id,
ref_id=obj._actor_id,
name=None,
)
elif isinstance(obj, ClientRemoteFunc):
# TODO(barakmich): This is going to have trouble with mutually
@ -77,11 +93,39 @@ class ClientPickler(cloudpickle.CloudPickler):
return PickleStub(
type="RemoteFuncSelfReference",
client_id=self.client_id,
ref_id=b"")
ref_id=b"",
name=None,
)
return PickleStub(
type="RemoteFunc",
client_id=self.client_id,
ref_id=obj._ref.id)
ref_id=obj._ref.id,
name=None,
)
elif isinstance(obj, ClientActorClass):
# TODO(barakmich): Mutual recursion, as above.
if obj._ref is None:
obj._ensure_ref()
if type(obj._ref) == SelfReferenceSentinel:
return PickleStub(
type="RemoteActorSelfReference",
client_id=self.client_id,
ref_id=b"",
name=None,
)
return PickleStub(
type="RemoteActor",
client_id=self.client_id,
ref_id=obj._ref.id,
name=None,
)
elif isinstance(obj, ClientRemoteMethod):
return PickleStub(
type="RemoteMethod",
client_id=self.client_id,
ref_id=obj.actor_handle.actor_ref.id,
name=obj.method_name,
)
return None

View file

@ -1,6 +1,5 @@
import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray.experimental.client import ray
from typing import Dict
class ClientBaseRef:
@ -8,17 +7,20 @@ class ClientBaseRef:
self.id: bytes = id
ray.call_retain(id)
def binary(self):
return self.id
def __eq__(self, other):
return self.id == other.id
def __repr__(self):
return "%s(%s)" % (
type(self).__name__,
self.id.hex(),
)
def __eq__(self, other):
return self.id == other.id
def binary(self):
return self.id
def __hash__(self):
return hash(self.id)
def __del__(self):
if ray.is_connected():
@ -107,18 +109,13 @@ class ClientActorClass(ClientStub):
raise TypeError(f"Remote actor cannot be instantiated directly. "
"Use {self._name}.remote() instead")
def __getstate__(self) -> Dict:
state = {
"actor_cls": self.actor_cls,
"_name": self._name,
"_ref": self._ref,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_cls = state["actor_cls"]
self._name = state["_name"]
self._ref = state["_ref"]
def _ensure_ref(self):
if self._ref is None:
# As before, set the state of the reference to be an
# in-progress self reference value, which
# the encoding can detect and handle correctly.
self._ref = SelfReferenceSentinel()
self._ref = ray.put(self.actor_cls)
def remote(self, *args, **kwargs) -> "ClientActorHandle":
# Actually instantiate the actor
@ -126,7 +123,7 @@ class ClientActorClass(ClientStub):
return ClientActorHandle(ClientActorRef(ref_id), self)
def __repr__(self):
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
return "ClientActorClass(%s, %s)" % (self._name, self._ref)
def __getattr__(self, key):
if key not in self.__dict__:
@ -134,8 +131,7 @@ class ClientActorClass(ClientStub):
raise NotImplementedError("static methods")
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
if self._ref is None:
self._ref = ray.put(self.actor_cls)
self._ensure_ref()
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.ACTOR
task.name = self._name

View file

@ -53,24 +53,33 @@ class DataClient:
resp_stream = stub.Datapath(
iter(self.request_queue.get, None),
metadata=(("client_id", self._client_id), ))
for response in resp_stream:
if response.req_id == 0:
# This is not being waited for.
logger.debug(f"Got unawaited response {response}")
continue
with self.cv:
self.ready_data[response.req_id] = response
self.cv.notify_all()
try:
for response in resp_stream:
if response.req_id == 0:
# This is not being waited for.
logger.debug(f"Got unawaited response {response}")
continue
with self.cv:
self.ready_data[response.req_id] = response
self.cv.notify_all()
except grpc.RpcError as e:
if grpc.StatusCode.CANCELLED == e.code():
# Gracefully shutting down
logger.info("Cancelling data channel")
else:
logger.error(
f"Got Error from rpc channel -- shutting down: {e}")
raise e
def close(self, close_channel: bool = False) -> None:
if self.request_queue is not None:
self.request_queue.put(None)
self.request_queue = None
if close_channel:
self.channel.close()
if self.data_thread is not None:
self.data_thread.join()
self.data_thread = None
if close_channel:
self.channel.close()
def _blocking_send(self, req: ray_client_pb2.DataRequest
) -> ray_client_pb2.DataResponse:

View file

@ -79,8 +79,3 @@ class RayServerAPI(CoreRayAPI):
def __init__(self, server_instance):
self.server = server_instance
def call_remote(self, instance: ClientStub, *args, **kwargs) -> bytes:
task = instance._prepare_client_task()
ticket = self.server.Schedule(task, prepared_args=args)
return ticket.return_id

View file

@ -21,7 +21,7 @@ from ray.experimental.client.server.server_pickler import dumps_from_server
from ray.experimental.client.server.server_pickler import loads_from_client
from ray.experimental.client.server.core_ray_api import RayServerAPI
from ray.experimental.client.server.dataservicer import DataServicer
from ray.experimental.client.server.server_stubs import current_func
from ray.experimental.client.server.server_stubs import current_remote
logger = logging.getLogger(__name__)
@ -205,82 +205,75 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
ready_object_ids=ready_object_ids,
remaining_object_ids=remaining_object_ids)
def Schedule(self, task, context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
logger.info("schedule: %s %s" %
(task.name,
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
with stash_api_for_tests(self._test_mode):
if task.type == ray_client_pb2.ClientTask.FUNCTION:
return self._schedule_function(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.ACTOR:
return self._schedule_actor(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.METHOD:
return self._schedule_method(task, context, prepared_args)
else:
raise NotImplementedError(
"Unimplemented Schedule task type: %s" %
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
try:
if task.type == ray_client_pb2.ClientTask.FUNCTION:
result = self._schedule_function(task, context)
elif task.type == ray_client_pb2.ClientTask.ACTOR:
result = self._schedule_actor(task, context)
elif task.type == ray_client_pb2.ClientTask.METHOD:
result = self._schedule_method(task, context)
else:
raise NotImplementedError(
"Unimplemented Schedule task type: %s" %
ray_client_pb2.ClientTask.RemoteExecType.Name(
task.type))
result.valid = True
return result
except Exception as e:
logger.error(f"Caught schedule exception {e}")
return ray_client_pb2.ClientTaskTicket(
valid=False, error=cloudpickle.dumps(e))
def _schedule_method(
self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
def _schedule_method(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
actor_handle = self.actor_refs.get(task.payload_id)
if actor_handle is None:
raise Exception(
"Can't run an actor the server doesn't have a handle for")
arglist = self._convert_args(task.args, prepared_args)
output = getattr(actor_handle, task.name).remote(*arglist)
arglist, kwargs = self._convert_args(task.args, task.kwargs)
output = getattr(actor_handle, task.name).remote(*arglist, **kwargs)
self.object_refs[task.client_id][output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
def _schedule_actor(self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
if task.payload_id not in self.registered_actor_classes:
actor_class_ref = \
self.object_refs[task.client_id][task.payload_id]
actor_class = ray.get(actor_class_ref)
if not inspect.isclass(actor_class):
raise Exception("Attempting to schedule actor that "
"isn't a class.")
reg_class = ray.remote(actor_class)
self.registered_actor_classes[task.payload_id] = reg_class
remote_class = self.registered_actor_classes[task.payload_id]
arglist = self._convert_args(task.args, prepared_args)
actor = remote_class.remote(*arglist)
def _schedule_actor(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
remote_class = self.lookup_or_register_actor(task.payload_id,
task.client_id)
arglist, kwargs = self._convert_args(task.args, task.kwargs)
with current_remote(remote_class):
actor = remote_class.remote(*arglist, **kwargs)
self.actor_refs[actor._actor_id.binary()] = actor
self.actor_owners[task.client_id].add(actor._actor_id.binary())
return ray_client_pb2.ClientTaskTicket(
return_id=actor._actor_id.binary())
def _schedule_function(
self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
def _schedule_function(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
remote_func = self.lookup_or_register_func(task.payload_id,
task.client_id)
arglist = self._convert_args(task.args, prepared_args)
# Prepare call if we're in a test
with current_func(remote_func):
output = remote_func.remote(*arglist)
arglist, kwargs = self._convert_args(task.args, task.kwargs)
with current_remote(remote_func):
output = remote_func.remote(*arglist, **kwargs)
if output.binary() in self.object_refs[task.client_id]:
raise Exception("already found it")
self.object_refs[task.client_id][output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
def _convert_args(self, arg_list, prepared_args=None):
if prepared_args is not None:
return prepared_args
out = []
def _convert_args(self, arg_list, kwarg_map):
argout = []
for arg in arg_list:
t = convert_from_arg(arg, self)
out.append(t)
return out
argout.append(t)
kwargout = {}
for k in kwarg_map:
kwargout[k] = convert_from_arg(kwarg_map[k], self)
return argout, kwargout
def lookup_or_register_func(self, id: bytes, client_id: str
) -> ray.remote_function.RemoteFunction:
@ -293,6 +286,17 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
self.function_refs[id] = ray.remote(func)
return self.function_refs[id]
def lookup_or_register_actor(self, id: bytes, client_id: str):
if id not in self.registered_actor_classes:
actor_class_ref = self.object_refs[client_id][id]
actor_class = ray.get(actor_class_ref)
if not inspect.isclass(actor_class):
raise Exception("Attempting to schedule actor that "
"isn't a class.")
reg_class = ray.remote(actor_class)
self.registered_actor_classes[id] = reg_class
return self.registered_actor_classes[id]
def return_exception_in_context(err, context):
if context is not None:
@ -319,6 +323,12 @@ def serve(connection_str, test_mode=False):
return server
def init_and_serve(connection_str, test_mode=False, *args, **kwargs):
info = ray.init(*args, **kwargs)
server = serve(connection_str, test_mode)
return (server, info)
if __name__ == "__main__":
logging.basicConfig(level="INFO")
# TODO(barakmich): Perhaps wrap ray init

View file

@ -21,7 +21,8 @@ from typing import Any
from typing import TYPE_CHECKING
from ray.experimental.client.client_pickler import PickleStub
from ray.experimental.client.server.server_stubs import ServerFunctionSentinel
from ray.experimental.client.server.server_stubs import (
ServerSelfReferenceSentinel)
if TYPE_CHECKING:
from ray.experimental.client.server.server import RayletServicer
@ -54,6 +55,7 @@ class ServerPickler(cloudpickle.CloudPickler):
type="Object",
client_id=self.client_id,
ref_id=obj_id,
name=None,
)
elif isinstance(obj, ray.actor.ActorHandle):
actor_id = obj._actor_id.binary()
@ -66,6 +68,7 @@ class ServerPickler(cloudpickle.CloudPickler):
type="Actor",
client_id=self.client_id,
ref_id=obj._actor_id.binary(),
name=None,
)
return None
@ -77,15 +80,25 @@ class ClientUnpickler(pickle.Unpickler):
def persistent_load(self, pid):
assert isinstance(pid, PickleStub)
if pid.type == "Object":
if pid.type == "Ray":
return ray
elif pid.type == "Object":
return self.server.object_refs[pid.client_id][pid.ref_id]
elif pid.type == "Actor":
return self.server.actor_refs[pid.ref_id]
elif pid.type == "RemoteFuncSelfReference":
return ServerFunctionSentinel()
return ServerSelfReferenceSentinel()
elif pid.type == "RemoteFunc":
return self.server.lookup_or_register_func(pid.ref_id,
pid.client_id)
elif pid.type == "RemoteActorSelfReference":
return ServerSelfReferenceSentinel()
elif pid.type == "RemoteActor":
return self.server.lookup_or_register_actor(
pid.ref_id, pid.client_id)
elif pid.type == "RemoteMethod":
actor = self.server.actor_refs[pid.ref_id]
return getattr(actor, pid.name)
else:
raise NotImplementedError("Uncovered client data type")

View file

@ -1,28 +1,28 @@
from contextlib import contextmanager
_current_remote_func = None
_current_remote_obj = None
@contextmanager
def current_func(f):
global _current_remote_func
remote_func = _current_remote_func
_current_remote_func = f
def current_remote(r):
global _current_remote_obj
remote = _current_remote_obj
_current_remote_obj = r
try:
yield
finally:
_current_remote_func = remote_func
_current_remote_obj = remote
class ServerFunctionSentinel:
class ServerSelfReferenceSentinel:
def __init__(self):
pass
def __reduce__(self):
global _current_remote_func
if _current_remote_func is None:
return (ServerFunctionSentinel, tuple())
return (identity, (_current_remote_func, ))
global _current_remote_obj
if _current_remote_obj is None:
return (ServerSelfReferenceSentinel, tuple())
return (identity, (_current_remote_obj, ))
def identity(x):

View file

@ -109,9 +109,13 @@ class Worker:
num_returns: int = 1,
timeout: float = None
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
assert isinstance(object_refs, list)
if not isinstance(object_refs, list):
raise TypeError("wait() expected a list of ClientObjectRef, "
f"got {type(object_refs)}")
for ref in object_refs:
assert isinstance(ref, ClientObjectRef)
if not isinstance(ref, ClientObjectRef):
raise TypeError("wait() expected a list of ClientObjectRef, "
f"got list containing {type(ref)}")
data = {
"object_ids": [object_ref.id for object_ref in object_refs],
"num_returns": num_returns,
@ -149,9 +153,16 @@ class Worker:
for arg in args:
pb_arg = convert_to_arg(arg, self._client_id)
task.args.append(pb_arg)
for k, v in kwargs.items():
task.kwargs[k].CopyFrom(convert_to_arg(v, self._client_id))
task.client_id = self._client_id
logger.debug("Scheduling %s" % task)
ticket = self.server.Schedule(task, metadata=self.metadata)
try:
ticket = self.server.Schedule(task, metadata=self.metadata)
except grpc.RpcError as e:
raise e.details()
if not ticket.valid:
raise cloudpickle.loads(ticket.error)
return ticket.return_id
def call_release(self, id: bytes) -> None:
@ -171,10 +182,9 @@ class Worker:
self.reference_count[id] += 1
def close(self):
self.data_client.close()
self.data_client.close(close_channel=True)
self.server = None
if self.channel:
self.channel.close()
self.channel = None
def terminate_actor(self, actor: ClientActorHandle,

View file

@ -443,3 +443,7 @@ def format_web_url(url):
def new_scheduler_enabled():
return os.environ.get("RAY_ENABLE_NEW_SCHEDULER", "1") == "1"
def client_test_enabled() -> bool:
return os.environ.get("RAY_TEST_CLIENT_MODE") == "1"

View file

@ -153,3 +153,20 @@ py_test(
tags = ["exclusive"],
deps = ["//:ray_lib"],
)
py_test_module_list(
files = [
"test_actor.py",
"test_basic.py",
"test_basic_2.py",
],
size = "medium",
extra_srcs = SRCS,
name_suffix = "_client_mode",
# TODO(barakmich): py_test will support env in Bazel 4.0.0...
# Until then, we can use tags.
#env = {"RAY_TEST_CLIENT_MODE": "true"},
tags = ["exclusive", "client_tests"],
deps = ["//:ray_lib"],
)

View file

@ -0,0 +1,20 @@
import asyncio
def create_remote_signal_actor(ray):
# TODO(barakmich): num_cpus=0
@ray.remote
class SignalActor:
def __init__(self):
self.ready_event = asyncio.Event()
def send(self, clear=False):
self.ready_event.set()
if clear:
self.ready_event.clear()
async def wait(self, should_wait=True):
if should_wait:
await self.ready_event.wait()
return SignalActor

View file

@ -9,12 +9,18 @@ import subprocess
import ray
from ray.cluster_utils import Cluster
from ray.test_utils import init_error_pubsub
from ray.test_utils import client_test_enabled
import ray.experimental.client as ray_client
@pytest.fixture
def shutdown_only():
yield None
# The code after the yield will run as teardown code.
if client_test_enabled():
ray_client.ray.disconnect()
ray_client._stop_test_server(1)
ray_client.reset_api()
ray.shutdown()
@ -43,9 +49,17 @@ def _ray_start(**kwargs):
init_kwargs = get_default_fixture_ray_kwargs()
init_kwargs.update(kwargs)
# Start the Ray processes.
address_info = ray.init(**init_kwargs)
if client_test_enabled():
address_info = ray_client.ray.init(**init_kwargs)
else:
address_info = ray.init(**init_kwargs)
yield address_info
# The code after the yield will run as teardown code.
if client_test_enabled():
ray_client.ray.disconnect()
ray_client._stop_test_server(1)
ray_client.reset_api()
ray.shutdown()
@ -130,9 +144,16 @@ def _ray_start_cluster(**kwargs):
# We assume driver will connect to the head (first node),
# so ray init will be invoked if do_init is true
if len(remote_nodes) == 1 and do_init:
ray.init(address=cluster.address)
if client_test_enabled():
ray_client.ray.init(address=cluster.address)
else:
ray.init(address=cluster.address)
yield cluster
# The code after the yield will run as teardown code.
if client_test_enabled():
ray_client.ray.disconnect()
ray_client._stop_test_server(1)
ray_client.reset_api()
ray.shutdown()
cluster.shutdown()

View file

@ -11,15 +11,21 @@ import sys
import tempfile
import datetime
import ray
import ray.test_utils
import ray.cluster_utils
from ray.test_utils import client_test_enabled
from ray.test_utils import wait_for_condition
from ray.test_utils import wait_for_pid_to_exit
from ray.tests.client_test_utils import create_remote_signal_actor
if client_test_enabled():
from ray.experimental.client import ray
else:
import ray
# NOTE: We have to import setproctitle after ray because we bundle setproctitle
# with ray.
import setproctitle
import setproctitle # noqa
@pytest.mark.skipif(client_test_enabled(), reason="test setup order")
def test_caching_actors(shutdown_only):
# Test defining actors before ray.init() has been called.
@ -238,6 +244,7 @@ def test_actor_import_counter(ray_start_10_cpus):
assert ray.get(g.remote()) == num_remote_functions - 1
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_actor_method_metadata_cache(ray_start_regular):
class Actor(object):
pass
@ -257,6 +264,7 @@ def test_actor_method_metadata_cache(ray_start_regular):
assert [id(x) for x in list(cache.items())[0]] == cached_data_id
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_actor_class_name(ray_start_regular):
@ray.remote
class Foo:
@ -556,6 +564,7 @@ def test_actor_static_attributes(ray_start_regular_shared):
assert ray.get(t.g.remote()) == 3
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_decorator_args(ray_start_regular_shared):
# This is an invalid way of using the actor decorator.
with pytest.raises(Exception):
@ -618,6 +627,8 @@ def test_random_id_generation(ray_start_regular_shared):
assert f1._actor_id != f2._actor_id
@pytest.mark.skipif(
client_test_enabled(), reason="differing inheritence structure")
def test_actor_inheritance(ray_start_regular_shared):
class NonActorBase:
def __init__(self):
@ -630,8 +641,7 @@ def test_actor_inheritance(ray_start_regular_shared):
pass
# Test that you can't instantiate an actor class directly.
with pytest.raises(
Exception, match="Actors cannot be instantiated directly."):
with pytest.raises(Exception, match="cannot be instantiated directly"):
ActorBase()
# Test that you can't inherit from an actor class.
@ -645,6 +655,7 @@ def test_actor_inheritance(ray_start_regular_shared):
pass
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_multiple_return_values(ray_start_regular_shared):
@ray.remote
class Foo:
@ -678,6 +689,7 @@ def test_multiple_return_values(ray_start_regular_shared):
assert ray.get([id3a, id3b, id3c]) == [1, 2, 3]
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_options_num_returns(ray_start_regular_shared):
@ray.remote
class Foo:
@ -693,6 +705,7 @@ def test_options_num_returns(ray_start_regular_shared):
assert ray.get([obj1, obj2]) == [1, 2]
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_options_name(ray_start_regular_shared):
@ray.remote
class Foo:
@ -734,13 +747,13 @@ def test_actor_deletion(ray_start_regular_shared):
a = Actor.remote()
pid = ray.get(a.getpid.remote())
a = None
ray.test_utils.wait_for_pid_to_exit(pid)
wait_for_pid_to_exit(pid)
actors = [Actor.remote() for _ in range(10)]
pids = ray.get([a.getpid.remote() for a in actors])
a = None
actors = None
[ray.test_utils.wait_for_pid_to_exit(pid) for pid in pids]
[wait_for_pid_to_exit(pid) for pid in pids]
def test_actor_method_deletion(ray_start_regular_shared):
@ -769,7 +782,8 @@ def test_distributed_actor_handle_deletion(ray_start_regular_shared):
ray.get(signal.wait.remote())
return ray.get(actor.method.remote())
signal = ray.test_utils.SignalActor.remote()
SignalActor = create_remote_signal_actor(ray)
signal = SignalActor.remote()
a = Actor.remote()
pid = ray.get(a.getpid.remote())
# Pass the handle to another task that cannot run yet.
@ -780,7 +794,7 @@ def test_distributed_actor_handle_deletion(ray_start_regular_shared):
# Once the task finishes, the actor process should get killed.
ray.get(signal.send.remote())
assert ray.get(x_id) == 1
ray.test_utils.wait_for_pid_to_exit(pid)
wait_for_pid_to_exit(pid)
def test_multiple_actors(ray_start_regular_shared):
@ -921,7 +935,7 @@ def test_atexit_handler(ray_start_regular_shared, exit_condition):
if exit_condition == "ray.kill":
assert not check_file_written()
else:
ray.test_utils.wait_for_condition(check_file_written)
wait_for_condition(check_file_written)
if __name__ == "__main__":

View file

@ -8,14 +8,23 @@ import time
import numpy as np
import pytest
import ray
import ray.cluster_utils
import ray.test_utils
from ray.test_utils import (
client_test_enabled,
dicts_equal,
wait_for_pid_to_exit,
)
if client_test_enabled():
from ray.experimental.client import ray
else:
import ray
logger = logging.getLogger(__name__)
# https://github.com/ray-project/ray/issues/6662
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_ignore_http_proxy(shutdown_only):
ray.init(num_cpus=1)
os.environ["http_proxy"] = "http://example.com"
@ -29,6 +38,7 @@ def test_ignore_http_proxy(shutdown_only):
# https://github.com/ray-project/ray/issues/7263
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_grpc_message_size(shutdown_only):
ray.init(num_cpus=1)
@ -45,12 +55,14 @@ def test_grpc_message_size(shutdown_only):
# https://github.com/ray-project/ray/issues/7287
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_omp_threads_set(shutdown_only):
ray.init(num_cpus=1)
# Should have been auto set by ray init.
assert os.environ["OMP_NUM_THREADS"] == "1"
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_submit_api(shutdown_only):
ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1})
@ -109,6 +121,7 @@ def test_submit_api(shutdown_only):
assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2]
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_invalid_arguments(shutdown_only):
ray.init(num_cpus=2)
@ -163,6 +176,7 @@ def test_invalid_arguments(shutdown_only):
x = 1
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_many_fractional_resources(shutdown_only):
ray.init(num_cpus=2, num_gpus=2, resources={"Custom": 2})
@ -178,7 +192,7 @@ def test_many_fractional_resources(shutdown_only):
}
if block:
ray.get(g.remote())
return ray.test_utils.dicts_equal(true_resources, accepted_resources)
return dicts_equal(true_resources, accepted_resources)
# Check that the resource are assigned correctly.
result_ids = []
@ -230,6 +244,7 @@ def test_many_fractional_resources(shutdown_only):
assert False, "Did not get correct available resources."
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_background_tasks_with_max_calls(shutdown_only):
ray.init(num_cpus=2)
@ -257,7 +272,7 @@ def test_background_tasks_with_max_calls(shutdown_only):
pid, g_id = nested.pop(0)
ray.get(g_id)
del g_id
ray.test_utils.wait_for_pid_to_exit(pid)
wait_for_pid_to_exit(pid)
def test_fair_queueing(shutdown_only):
@ -327,6 +342,7 @@ def test_wait_timing(shutdown_only):
assert len(not_ready) == 1
@pytest.mark.skipif(client_test_enabled(), reason="internal _raylet")
def test_function_descriptor():
python_descriptor = ray._raylet.PythonFunctionDescriptor(
"module_name", "function_name", "class_name", "function_hash")
@ -344,6 +360,7 @@ def test_function_descriptor():
assert d.get(python_descriptor2) == 123
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_ray_options(shutdown_only):
@ray.remote(
num_cpus=2, num_gpus=3, memory=150 * 2**20, resources={"custom1": 1})
@ -371,6 +388,7 @@ def test_ray_options(shutdown_only):
assert without_options != with_options
@pytest.mark.skipif(client_test_enabled(), reason="message size")
@pytest.mark.parametrize(
"ray_start_cluster_head", [{
"num_cpus": 0,
@ -438,8 +456,11 @@ def test_nested_functions(ray_start_shared_local_modes):
assert ray.get(factorial.remote(4)) == 24
assert ray.get(factorial.remote(5)) == 120
# Test remote functions that recursively call each other.
@pytest.mark.skipif(
client_test_enabled(), reason="mutual recursion is a known issue")
def test_mutually_recursive_functions(ray_start_shared_local_modes):
# Test remote functions that recursively call each other.
@ray.remote
def factorial_even(n):
assert n % 2 == 0
@ -710,6 +731,7 @@ def test_args_stars_after(ray_start_shared_local_modes):
ray.get(remote_test_function.remote(local_method, actor_method))
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_object_id_backward_compatibility(ray_start_shared_local_modes):
# We've renamed Python's `ObjectID` to `ObjectRef`, and added a type
# alias for backward compatibility.

View file

@ -9,10 +9,16 @@ import pytest
from unittest.mock import MagicMock, patch
import ray
import ray.cluster_utils
import ray.test_utils
from ray.test_utils import client_test_enabled
from ray.tests.client_test_utils import create_remote_signal_actor
from ray.exceptions import GetTimeoutError
from ray.exceptions import RayTaskError
if client_test_enabled():
from ray.experimental.client import ray
else:
import ray
logger = logging.getLogger(__name__)
@ -25,6 +31,8 @@ logger = logging.getLogger(__name__)
}],
indirect=True)
def test_variable_number_of_args(shutdown_only):
ray.init(num_cpus=1)
@ray.remote
def varargs_fct1(*a):
return " ".join(map(str, a))
@ -33,8 +41,6 @@ def test_variable_number_of_args(shutdown_only):
def varargs_fct2(a, *b):
return " ".join(map(str, b))
ray.init(num_cpus=1)
x = varargs_fct1.remote(0, 1, 2)
assert ray.get(x) == "0 1 2"
x = varargs_fct2.remote(0, 1, 2)
@ -160,7 +166,7 @@ def test_redefining_remote_functions(shutdown_only):
def g():
return nonexistent()
with pytest.raises(ray.exceptions.RayTaskError, match="nonexistent"):
with pytest.raises(RayTaskError, match="nonexistent"):
ray.get(g.remote())
def nonexistent():
@ -187,6 +193,7 @@ def test_redefining_remote_functions(shutdown_only):
assert ray.get(ray.get(h.remote(i))) == i
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_call_matrix(shutdown_only):
ray.init(object_store_memory=1000 * 1024 * 1024)
@ -312,6 +319,7 @@ def test_actor_pass_by_ref_order_optimization(shutdown_only):
assert delta < 10, "did not skip slow value"
@pytest.mark.skipif(client_test_enabled(), reason="message size")
@pytest.mark.parametrize(
"ray_start_cluster", [{
"num_cpus": 1,
@ -332,6 +340,7 @@ def test_call_chain(ray_start_cluster):
assert ray.get(x) == 100
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_system_config_when_connecting(ray_start_cluster):
config = {"object_pinning_enabled": 0, "object_timeout_milliseconds": 200}
cluster = ray.cluster_utils.Cluster()
@ -368,7 +377,8 @@ def test_get_multiple(ray_start_regular_shared):
def test_get_with_timeout(ray_start_regular_shared):
signal = ray.test_utils.SignalActor.remote()
SignalActor = create_remote_signal_actor(ray)
signal = SignalActor.remote()
# Check that get() returns early if object is ready.
start = time.time()
@ -438,6 +448,7 @@ def test_inline_arg_memory_corruption(ray_start_regular_shared):
ray.get(a.add.remote(f.remote()))
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_skip_plasma(ray_start_regular_shared):
@ray.remote
class Actor:
@ -454,6 +465,8 @@ def test_skip_plasma(ray_start_regular_shared):
assert ray.get(obj_ref) == 2
@pytest.mark.skipif(
client_test_enabled(), reason="internal api and message size")
def test_actor_large_objects(ray_start_regular_shared):
@ray.remote
class Actor:
@ -524,6 +537,7 @@ def test_actor_recursive(ray_start_regular_shared):
assert result == [x * 2 for x in range(100)]
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_actor_concurrent(ray_start_regular_shared):
@ray.remote
class Batcher:
@ -626,6 +640,7 @@ def test_duplicate_args(ray_start_regular_shared):
arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1))
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_get_correct_node_ip():
with patch("ray.worker") as worker_mock:
node_mock = MagicMock()

View file

@ -81,11 +81,11 @@ def test_wait(ray_start_regular_shared):
with pytest.raises(Exception):
# Reference not in the object store.
ray.wait([ClientObjectRef("blabla")])
with pytest.raises(AssertionError):
with pytest.raises(TypeError):
ray.wait("blabla")
with pytest.raises(AssertionError):
with pytest.raises(TypeError):
ray.wait(ClientObjectRef("blabla"))
with pytest.raises(AssertionError):
with pytest.raises(TypeError):
ray.wait(["blabla"])

View file

@ -1,6 +1,6 @@
import pytest
import asyncio
from ray.tests.test_experimental_client import ray_start_client_server
from ray.tests.client_test_utils import create_remote_signal_actor
from ray.test_utils import wait_for_condition
from ray.exceptions import TaskCancelledError
from ray.exceptions import RayTaskError
@ -45,21 +45,7 @@ def test_kill_actor_immediately_after_creation(ray_start_regular):
@pytest.mark.parametrize("use_force", [True, False])
def test_cancel_chain(ray_start_regular, use_force):
with ray_start_client_server() as ray:
@ray.remote
class SignalActor:
def __init__(self):
self.ready_event = asyncio.Event()
def send(self, clear=False):
self.ready_event.set()
if clear:
self.ready_event.clear()
async def wait(self, should_wait=True):
if should_wait:
await self.ready_event.wait()
SignalActor = create_remote_signal_actor(ray)
signaler = SignalActor.remote()
@ray.remote

View file

@ -50,16 +50,22 @@ message ClientTask {
string name = 2;
// A reference to the payload.
bytes payload_id = 3;
// The parameters to pass to this call.
// Positional parameters to pass to this call.
repeated Arg args = 4;
// Keyword parameters to pass to this call.
map<string, Arg> kwargs = 5;
// The ID of the client namespace associated with the Datapath stream making this
// request.
string client_id = 5;
string client_id = 6;
}
message ClientTaskTicket {
// Was the task successful?
bool valid = 1;
// A reference to the returned value from the execution.
bytes return_id = 1;
bytes return_id = 2;
// If unsuccessful, an encoding of the error.
bytes error = 3;
}
// Delivers data to the server