[ray_client] Integrate with test_basic, test_basic_2 and test_actor (#12964)

This commit is contained in:
Barak Michener 2020-12-20 14:54:18 -08:00 committed by GitHub
parent bf6577c8f4
commit 7ab9164f1b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 375 additions and 174 deletions

View file

@ -46,6 +46,7 @@ matrix:
script: script:
# bazel python tests for medium size tests. Used for parallelization. # bazel python tests for medium size tests. Used for parallelization.
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_a_to_j python/ray/tests/...; fi - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_a_to_j python/ray/tests/...; fi
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,client_tests --test_env=RAY_TEST_CLIENT_MODE=1 python/ray/tests/...; fi
- os: linux - os: linux
env: env:

View file

@ -1,12 +1,14 @@
# py_test_module_list creates a py_test target for each # py_test_module_list creates a py_test target for each
# Python file in `files` # Python file in `files`
def py_test_module_list(files, size, deps, extra_srcs, **kwargs): def py_test_module_list(files, size, deps, extra_srcs, name_suffix="", **kwargs):
for file in files: for file in files:
# remove .py # remove .py
name = file[:-3] name = file[:-3] + name_suffix
main = file
native.py_test( native.py_test(
name = name, name = name,
size = size, size = size,
main = file,
srcs = extra_srcs + [file], srcs = extra_srcs + [file],
**kwargs **kwargs
) )

View file

@ -4,6 +4,7 @@ from typing import Optional, List, Tuple
from contextlib import contextmanager from contextlib import contextmanager
import logging import logging
import os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,7 +44,9 @@ def stash_api_for_tests(in_test: bool):
is_server = _is_server is_server = _is_server
if in_test: if in_test:
_is_server = True _is_server = True
try:
yield _server_api yield _server_api
finally:
if in_test: if in_test:
_is_server = is_server _is_server = is_server
@ -77,18 +80,7 @@ def reset_api():
def _get_client_api() -> APIImpl: def _get_client_api() -> APIImpl:
global _client_api global _client_api
global _server_api return _client_api
global _is_server
api = None
if _is_server:
api = _server_api
else:
api = _client_api
if api is None:
# We're inside a raylet worker
from ray.experimental.client.server.core_ray_api import CoreRayAPI
return CoreRayAPI()
return api
def _get_server_instance(): def _get_server_instance():
@ -124,9 +116,33 @@ class RayAPIStub:
global _client_api global _client_api
return _client_api is not None return _client_api is not None
def init(self, *args, **kwargs):
if _is_client_test_env():
global _test_server
import ray.experimental.client.server.server as ray_client_server
_test_server, address_info = ray_client_server.init_and_serve(
"localhost:50051", test_mode=True, *args, **kwargs)
self.connect("localhost:50051")
return address_info
else:
raise NotImplementedError(
"Please call ray.connect() in client mode")
ray = RayAPIStub() ray = RayAPIStub()
_test_server = None
def _stop_test_server(*args):
global _test_server
_test_server.stop(*args)
def _is_client_test_env() -> bool:
return os.environ.get("RAY_TEST_CLIENT_MODE") == "1"
# Someday we might add methods in this module so that someone who # Someday we might add methods in this module so that someone who
# tries to `import ray_client as ray` -- as a module, instead of # tries to `import ray_client as ray` -- as a module, instead of
# `from ray_client import ray` -- as the API stub # `from ray_client import ray` -- as the API stub

View file

@ -28,11 +28,15 @@ import sys
from typing import NamedTuple from typing import NamedTuple
from typing import Any from typing import Any
from typing import Optional
from ray.experimental.client import RayAPIStub
from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientActorHandle from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientActorRef from ray.experimental.client.common import ClientActorRef
from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.common import ClientRemoteFunc from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.common import ClientRemoteMethod
from ray.experimental.client.common import SelfReferenceSentinel from ray.experimental.client.common import SelfReferenceSentinel
import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2 as ray_client_pb2
@ -44,8 +48,11 @@ if sys.version_info < (3, 8):
else: else:
import pickle # noqa: F401 import pickle # noqa: F401
PickleStub = NamedTuple("PickleStub", [("type", str), ("client_id", str), # NOTE(barakmich): These PickleStubs are really close to
("ref_id", bytes)]) # the data for an exectuion, with no arguments. Combine the two?
PickleStub = NamedTuple("PickleStub",
[("type", str), ("client_id", str), ("ref_id", bytes),
("name", Optional[str])])
class ClientPickler(cloudpickle.CloudPickler): class ClientPickler(cloudpickle.CloudPickler):
@ -54,17 +61,26 @@ class ClientPickler(cloudpickle.CloudPickler):
self.client_id = client_id self.client_id = client_id
def persistent_id(self, obj): def persistent_id(self, obj):
if isinstance(obj, ClientObjectRef): if isinstance(obj, RayAPIStub):
return PickleStub(
type="Ray",
client_id=self.client_id,
ref_id=b"",
name=None,
)
elif isinstance(obj, ClientObjectRef):
return PickleStub( return PickleStub(
type="Object", type="Object",
client_id=self.client_id, client_id=self.client_id,
ref_id=obj.id, ref_id=obj.id,
name=None,
) )
elif isinstance(obj, ClientActorHandle): elif isinstance(obj, ClientActorHandle):
return PickleStub( return PickleStub(
type="Actor", type="Actor",
client_id=self.client_id, client_id=self.client_id,
ref_id=obj._actor_id, ref_id=obj._actor_id,
name=None,
) )
elif isinstance(obj, ClientRemoteFunc): elif isinstance(obj, ClientRemoteFunc):
# TODO(barakmich): This is going to have trouble with mutually # TODO(barakmich): This is going to have trouble with mutually
@ -77,11 +93,39 @@ class ClientPickler(cloudpickle.CloudPickler):
return PickleStub( return PickleStub(
type="RemoteFuncSelfReference", type="RemoteFuncSelfReference",
client_id=self.client_id, client_id=self.client_id,
ref_id=b"") ref_id=b"",
name=None,
)
return PickleStub( return PickleStub(
type="RemoteFunc", type="RemoteFunc",
client_id=self.client_id, client_id=self.client_id,
ref_id=obj._ref.id) ref_id=obj._ref.id,
name=None,
)
elif isinstance(obj, ClientActorClass):
# TODO(barakmich): Mutual recursion, as above.
if obj._ref is None:
obj._ensure_ref()
if type(obj._ref) == SelfReferenceSentinel:
return PickleStub(
type="RemoteActorSelfReference",
client_id=self.client_id,
ref_id=b"",
name=None,
)
return PickleStub(
type="RemoteActor",
client_id=self.client_id,
ref_id=obj._ref.id,
name=None,
)
elif isinstance(obj, ClientRemoteMethod):
return PickleStub(
type="RemoteMethod",
client_id=self.client_id,
ref_id=obj.actor_handle.actor_ref.id,
name=obj.method_name,
)
return None return None

View file

@ -1,6 +1,5 @@
import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray.experimental.client import ray from ray.experimental.client import ray
from typing import Dict
class ClientBaseRef: class ClientBaseRef:
@ -8,17 +7,20 @@ class ClientBaseRef:
self.id: bytes = id self.id: bytes = id
ray.call_retain(id) ray.call_retain(id)
def binary(self):
return self.id
def __eq__(self, other):
return self.id == other.id
def __repr__(self): def __repr__(self):
return "%s(%s)" % ( return "%s(%s)" % (
type(self).__name__, type(self).__name__,
self.id.hex(), self.id.hex(),
) )
def __eq__(self, other): def __hash__(self):
return self.id == other.id return hash(self.id)
def binary(self):
return self.id
def __del__(self): def __del__(self):
if ray.is_connected(): if ray.is_connected():
@ -107,18 +109,13 @@ class ClientActorClass(ClientStub):
raise TypeError(f"Remote actor cannot be instantiated directly. " raise TypeError(f"Remote actor cannot be instantiated directly. "
"Use {self._name}.remote() instead") "Use {self._name}.remote() instead")
def __getstate__(self) -> Dict: def _ensure_ref(self):
state = { if self._ref is None:
"actor_cls": self.actor_cls, # As before, set the state of the reference to be an
"_name": self._name, # in-progress self reference value, which
"_ref": self._ref, # the encoding can detect and handle correctly.
} self._ref = SelfReferenceSentinel()
return state self._ref = ray.put(self.actor_cls)
def __setstate__(self, state: Dict) -> None:
self.actor_cls = state["actor_cls"]
self._name = state["_name"]
self._ref = state["_ref"]
def remote(self, *args, **kwargs) -> "ClientActorHandle": def remote(self, *args, **kwargs) -> "ClientActorHandle":
# Actually instantiate the actor # Actually instantiate the actor
@ -126,7 +123,7 @@ class ClientActorClass(ClientStub):
return ClientActorHandle(ClientActorRef(ref_id), self) return ClientActorHandle(ClientActorRef(ref_id), self)
def __repr__(self): def __repr__(self):
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref) return "ClientActorClass(%s, %s)" % (self._name, self._ref)
def __getattr__(self, key): def __getattr__(self, key):
if key not in self.__dict__: if key not in self.__dict__:
@ -134,8 +131,7 @@ class ClientActorClass(ClientStub):
raise NotImplementedError("static methods") raise NotImplementedError("static methods")
def _prepare_client_task(self) -> ray_client_pb2.ClientTask: def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
if self._ref is None: self._ensure_ref()
self._ref = ray.put(self.actor_cls)
task = ray_client_pb2.ClientTask() task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.ACTOR task.type = ray_client_pb2.ClientTask.ACTOR
task.name = self._name task.name = self._name

View file

@ -53,6 +53,7 @@ class DataClient:
resp_stream = stub.Datapath( resp_stream = stub.Datapath(
iter(self.request_queue.get, None), iter(self.request_queue.get, None),
metadata=(("client_id", self._client_id), )) metadata=(("client_id", self._client_id), ))
try:
for response in resp_stream: for response in resp_stream:
if response.req_id == 0: if response.req_id == 0:
# This is not being waited for. # This is not being waited for.
@ -61,16 +62,24 @@ class DataClient:
with self.cv: with self.cv:
self.ready_data[response.req_id] = response self.ready_data[response.req_id] = response
self.cv.notify_all() self.cv.notify_all()
except grpc.RpcError as e:
if grpc.StatusCode.CANCELLED == e.code():
# Gracefully shutting down
logger.info("Cancelling data channel")
else:
logger.error(
f"Got Error from rpc channel -- shutting down: {e}")
raise e
def close(self, close_channel: bool = False) -> None: def close(self, close_channel: bool = False) -> None:
if self.request_queue is not None: if self.request_queue is not None:
self.request_queue.put(None) self.request_queue.put(None)
self.request_queue = None self.request_queue = None
if close_channel:
self.channel.close()
if self.data_thread is not None: if self.data_thread is not None:
self.data_thread.join() self.data_thread.join()
self.data_thread = None self.data_thread = None
if close_channel:
self.channel.close()
def _blocking_send(self, req: ray_client_pb2.DataRequest def _blocking_send(self, req: ray_client_pb2.DataRequest
) -> ray_client_pb2.DataResponse: ) -> ray_client_pb2.DataResponse:

View file

@ -79,8 +79,3 @@ class RayServerAPI(CoreRayAPI):
def __init__(self, server_instance): def __init__(self, server_instance):
self.server = server_instance self.server = server_instance
def call_remote(self, instance: ClientStub, *args, **kwargs) -> bytes:
task = instance._prepare_client_task()
ticket = self.server.Schedule(task, prepared_args=args)
return ticket.return_id

View file

@ -21,7 +21,7 @@ from ray.experimental.client.server.server_pickler import dumps_from_server
from ray.experimental.client.server.server_pickler import loads_from_client from ray.experimental.client.server.server_pickler import loads_from_client
from ray.experimental.client.server.core_ray_api import RayServerAPI from ray.experimental.client.server.core_ray_api import RayServerAPI
from ray.experimental.client.server.dataservicer import DataServicer from ray.experimental.client.server.dataservicer import DataServicer
from ray.experimental.client.server.server_stubs import current_func from ray.experimental.client.server.server_stubs import current_remote
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -205,82 +205,75 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
ready_object_ids=ready_object_ids, ready_object_ids=ready_object_ids,
remaining_object_ids=remaining_object_ids) remaining_object_ids=remaining_object_ids)
def Schedule(self, task, context=None, def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
logger.info("schedule: %s %s" % logger.info("schedule: %s %s" %
(task.name, (task.name,
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))) ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
with stash_api_for_tests(self._test_mode): with stash_api_for_tests(self._test_mode):
try:
if task.type == ray_client_pb2.ClientTask.FUNCTION: if task.type == ray_client_pb2.ClientTask.FUNCTION:
return self._schedule_function(task, context, prepared_args) result = self._schedule_function(task, context)
elif task.type == ray_client_pb2.ClientTask.ACTOR: elif task.type == ray_client_pb2.ClientTask.ACTOR:
return self._schedule_actor(task, context, prepared_args) result = self._schedule_actor(task, context)
elif task.type == ray_client_pb2.ClientTask.METHOD: elif task.type == ray_client_pb2.ClientTask.METHOD:
return self._schedule_method(task, context, prepared_args) result = self._schedule_method(task, context)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Unimplemented Schedule task type: %s" % "Unimplemented Schedule task type: %s" %
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)) ray_client_pb2.ClientTask.RemoteExecType.Name(
task.type))
result.valid = True
return result
except Exception as e:
logger.error(f"Caught schedule exception {e}")
return ray_client_pb2.ClientTaskTicket(
valid=False, error=cloudpickle.dumps(e))
def _schedule_method( def _schedule_method(self, task: ray_client_pb2.ClientTask,
self, context=None) -> ray_client_pb2.ClientTaskTicket:
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
actor_handle = self.actor_refs.get(task.payload_id) actor_handle = self.actor_refs.get(task.payload_id)
if actor_handle is None: if actor_handle is None:
raise Exception( raise Exception(
"Can't run an actor the server doesn't have a handle for") "Can't run an actor the server doesn't have a handle for")
arglist = self._convert_args(task.args, prepared_args) arglist, kwargs = self._convert_args(task.args, task.kwargs)
output = getattr(actor_handle, task.name).remote(*arglist) output = getattr(actor_handle, task.name).remote(*arglist, **kwargs)
self.object_refs[task.client_id][output.binary()] = output self.object_refs[task.client_id][output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
def _schedule_actor(self, def _schedule_actor(self, task: ray_client_pb2.ClientTask,
task: ray_client_pb2.ClientTask, context=None) -> ray_client_pb2.ClientTaskTicket:
context=None, remote_class = self.lookup_or_register_actor(task.payload_id,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket: task.client_id)
if task.payload_id not in self.registered_actor_classes:
actor_class_ref = \ arglist, kwargs = self._convert_args(task.args, task.kwargs)
self.object_refs[task.client_id][task.payload_id] with current_remote(remote_class):
actor_class = ray.get(actor_class_ref) actor = remote_class.remote(*arglist, **kwargs)
if not inspect.isclass(actor_class):
raise Exception("Attempting to schedule actor that "
"isn't a class.")
reg_class = ray.remote(actor_class)
self.registered_actor_classes[task.payload_id] = reg_class
remote_class = self.registered_actor_classes[task.payload_id]
arglist = self._convert_args(task.args, prepared_args)
actor = remote_class.remote(*arglist)
self.actor_refs[actor._actor_id.binary()] = actor self.actor_refs[actor._actor_id.binary()] = actor
self.actor_owners[task.client_id].add(actor._actor_id.binary()) self.actor_owners[task.client_id].add(actor._actor_id.binary())
return ray_client_pb2.ClientTaskTicket( return ray_client_pb2.ClientTaskTicket(
return_id=actor._actor_id.binary()) return_id=actor._actor_id.binary())
def _schedule_function( def _schedule_function(self, task: ray_client_pb2.ClientTask,
self, context=None) -> ray_client_pb2.ClientTaskTicket:
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
remote_func = self.lookup_or_register_func(task.payload_id, remote_func = self.lookup_or_register_func(task.payload_id,
task.client_id) task.client_id)
arglist = self._convert_args(task.args, prepared_args) arglist, kwargs = self._convert_args(task.args, task.kwargs)
# Prepare call if we're in a test with current_remote(remote_func):
with current_func(remote_func): output = remote_func.remote(*arglist, **kwargs)
output = remote_func.remote(*arglist)
if output.binary() in self.object_refs[task.client_id]: if output.binary() in self.object_refs[task.client_id]:
raise Exception("already found it") raise Exception("already found it")
self.object_refs[task.client_id][output.binary()] = output self.object_refs[task.client_id][output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
def _convert_args(self, arg_list, prepared_args=None): def _convert_args(self, arg_list, kwarg_map):
if prepared_args is not None: argout = []
return prepared_args
out = []
for arg in arg_list: for arg in arg_list:
t = convert_from_arg(arg, self) t = convert_from_arg(arg, self)
out.append(t) argout.append(t)
return out kwargout = {}
for k in kwarg_map:
kwargout[k] = convert_from_arg(kwarg_map[k], self)
return argout, kwargout
def lookup_or_register_func(self, id: bytes, client_id: str def lookup_or_register_func(self, id: bytes, client_id: str
) -> ray.remote_function.RemoteFunction: ) -> ray.remote_function.RemoteFunction:
@ -293,6 +286,17 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
self.function_refs[id] = ray.remote(func) self.function_refs[id] = ray.remote(func)
return self.function_refs[id] return self.function_refs[id]
def lookup_or_register_actor(self, id: bytes, client_id: str):
if id not in self.registered_actor_classes:
actor_class_ref = self.object_refs[client_id][id]
actor_class = ray.get(actor_class_ref)
if not inspect.isclass(actor_class):
raise Exception("Attempting to schedule actor that "
"isn't a class.")
reg_class = ray.remote(actor_class)
self.registered_actor_classes[id] = reg_class
return self.registered_actor_classes[id]
def return_exception_in_context(err, context): def return_exception_in_context(err, context):
if context is not None: if context is not None:
@ -319,6 +323,12 @@ def serve(connection_str, test_mode=False):
return server return server
def init_and_serve(connection_str, test_mode=False, *args, **kwargs):
info = ray.init(*args, **kwargs)
server = serve(connection_str, test_mode)
return (server, info)
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level="INFO") logging.basicConfig(level="INFO")
# TODO(barakmich): Perhaps wrap ray init # TODO(barakmich): Perhaps wrap ray init

View file

@ -21,7 +21,8 @@ from typing import Any
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ray.experimental.client.client_pickler import PickleStub from ray.experimental.client.client_pickler import PickleStub
from ray.experimental.client.server.server_stubs import ServerFunctionSentinel from ray.experimental.client.server.server_stubs import (
ServerSelfReferenceSentinel)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.experimental.client.server.server import RayletServicer from ray.experimental.client.server.server import RayletServicer
@ -54,6 +55,7 @@ class ServerPickler(cloudpickle.CloudPickler):
type="Object", type="Object",
client_id=self.client_id, client_id=self.client_id,
ref_id=obj_id, ref_id=obj_id,
name=None,
) )
elif isinstance(obj, ray.actor.ActorHandle): elif isinstance(obj, ray.actor.ActorHandle):
actor_id = obj._actor_id.binary() actor_id = obj._actor_id.binary()
@ -66,6 +68,7 @@ class ServerPickler(cloudpickle.CloudPickler):
type="Actor", type="Actor",
client_id=self.client_id, client_id=self.client_id,
ref_id=obj._actor_id.binary(), ref_id=obj._actor_id.binary(),
name=None,
) )
return None return None
@ -77,15 +80,25 @@ class ClientUnpickler(pickle.Unpickler):
def persistent_load(self, pid): def persistent_load(self, pid):
assert isinstance(pid, PickleStub) assert isinstance(pid, PickleStub)
if pid.type == "Object": if pid.type == "Ray":
return ray
elif pid.type == "Object":
return self.server.object_refs[pid.client_id][pid.ref_id] return self.server.object_refs[pid.client_id][pid.ref_id]
elif pid.type == "Actor": elif pid.type == "Actor":
return self.server.actor_refs[pid.ref_id] return self.server.actor_refs[pid.ref_id]
elif pid.type == "RemoteFuncSelfReference": elif pid.type == "RemoteFuncSelfReference":
return ServerFunctionSentinel() return ServerSelfReferenceSentinel()
elif pid.type == "RemoteFunc": elif pid.type == "RemoteFunc":
return self.server.lookup_or_register_func(pid.ref_id, return self.server.lookup_or_register_func(pid.ref_id,
pid.client_id) pid.client_id)
elif pid.type == "RemoteActorSelfReference":
return ServerSelfReferenceSentinel()
elif pid.type == "RemoteActor":
return self.server.lookup_or_register_actor(
pid.ref_id, pid.client_id)
elif pid.type == "RemoteMethod":
actor = self.server.actor_refs[pid.ref_id]
return getattr(actor, pid.name)
else: else:
raise NotImplementedError("Uncovered client data type") raise NotImplementedError("Uncovered client data type")

View file

@ -1,28 +1,28 @@
from contextlib import contextmanager from contextlib import contextmanager
_current_remote_func = None _current_remote_obj = None
@contextmanager @contextmanager
def current_func(f): def current_remote(r):
global _current_remote_func global _current_remote_obj
remote_func = _current_remote_func remote = _current_remote_obj
_current_remote_func = f _current_remote_obj = r
try: try:
yield yield
finally: finally:
_current_remote_func = remote_func _current_remote_obj = remote
class ServerFunctionSentinel: class ServerSelfReferenceSentinel:
def __init__(self): def __init__(self):
pass pass
def __reduce__(self): def __reduce__(self):
global _current_remote_func global _current_remote_obj
if _current_remote_func is None: if _current_remote_obj is None:
return (ServerFunctionSentinel, tuple()) return (ServerSelfReferenceSentinel, tuple())
return (identity, (_current_remote_func, )) return (identity, (_current_remote_obj, ))
def identity(x): def identity(x):

View file

@ -109,9 +109,13 @@ class Worker:
num_returns: int = 1, num_returns: int = 1,
timeout: float = None timeout: float = None
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]: ) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
assert isinstance(object_refs, list) if not isinstance(object_refs, list):
raise TypeError("wait() expected a list of ClientObjectRef, "
f"got {type(object_refs)}")
for ref in object_refs: for ref in object_refs:
assert isinstance(ref, ClientObjectRef) if not isinstance(ref, ClientObjectRef):
raise TypeError("wait() expected a list of ClientObjectRef, "
f"got list containing {type(ref)}")
data = { data = {
"object_ids": [object_ref.id for object_ref in object_refs], "object_ids": [object_ref.id for object_ref in object_refs],
"num_returns": num_returns, "num_returns": num_returns,
@ -149,9 +153,16 @@ class Worker:
for arg in args: for arg in args:
pb_arg = convert_to_arg(arg, self._client_id) pb_arg = convert_to_arg(arg, self._client_id)
task.args.append(pb_arg) task.args.append(pb_arg)
for k, v in kwargs.items():
task.kwargs[k].CopyFrom(convert_to_arg(v, self._client_id))
task.client_id = self._client_id task.client_id = self._client_id
logger.debug("Scheduling %s" % task) logger.debug("Scheduling %s" % task)
try:
ticket = self.server.Schedule(task, metadata=self.metadata) ticket = self.server.Schedule(task, metadata=self.metadata)
except grpc.RpcError as e:
raise e.details()
if not ticket.valid:
raise cloudpickle.loads(ticket.error)
return ticket.return_id return ticket.return_id
def call_release(self, id: bytes) -> None: def call_release(self, id: bytes) -> None:
@ -171,10 +182,9 @@ class Worker:
self.reference_count[id] += 1 self.reference_count[id] += 1
def close(self): def close(self):
self.data_client.close() self.data_client.close(close_channel=True)
self.server = None self.server = None
if self.channel: if self.channel:
self.channel.close()
self.channel = None self.channel = None
def terminate_actor(self, actor: ClientActorHandle, def terminate_actor(self, actor: ClientActorHandle,

View file

@ -443,3 +443,7 @@ def format_web_url(url):
def new_scheduler_enabled(): def new_scheduler_enabled():
return os.environ.get("RAY_ENABLE_NEW_SCHEDULER", "1") == "1" return os.environ.get("RAY_ENABLE_NEW_SCHEDULER", "1") == "1"
def client_test_enabled() -> bool:
return os.environ.get("RAY_TEST_CLIENT_MODE") == "1"

View file

@ -153,3 +153,20 @@ py_test(
tags = ["exclusive"], tags = ["exclusive"],
deps = ["//:ray_lib"], deps = ["//:ray_lib"],
) )
py_test_module_list(
files = [
"test_actor.py",
"test_basic.py",
"test_basic_2.py",
],
size = "medium",
extra_srcs = SRCS,
name_suffix = "_client_mode",
# TODO(barakmich): py_test will support env in Bazel 4.0.0...
# Until then, we can use tags.
#env = {"RAY_TEST_CLIENT_MODE": "true"},
tags = ["exclusive", "client_tests"],
deps = ["//:ray_lib"],
)

View file

@ -0,0 +1,20 @@
import asyncio
def create_remote_signal_actor(ray):
# TODO(barakmich): num_cpus=0
@ray.remote
class SignalActor:
def __init__(self):
self.ready_event = asyncio.Event()
def send(self, clear=False):
self.ready_event.set()
if clear:
self.ready_event.clear()
async def wait(self, should_wait=True):
if should_wait:
await self.ready_event.wait()
return SignalActor

View file

@ -9,12 +9,18 @@ import subprocess
import ray import ray
from ray.cluster_utils import Cluster from ray.cluster_utils import Cluster
from ray.test_utils import init_error_pubsub from ray.test_utils import init_error_pubsub
from ray.test_utils import client_test_enabled
import ray.experimental.client as ray_client
@pytest.fixture @pytest.fixture
def shutdown_only(): def shutdown_only():
yield None yield None
# The code after the yield will run as teardown code. # The code after the yield will run as teardown code.
if client_test_enabled():
ray_client.ray.disconnect()
ray_client._stop_test_server(1)
ray_client.reset_api()
ray.shutdown() ray.shutdown()
@ -43,9 +49,17 @@ def _ray_start(**kwargs):
init_kwargs = get_default_fixture_ray_kwargs() init_kwargs = get_default_fixture_ray_kwargs()
init_kwargs.update(kwargs) init_kwargs.update(kwargs)
# Start the Ray processes. # Start the Ray processes.
if client_test_enabled():
address_info = ray_client.ray.init(**init_kwargs)
else:
address_info = ray.init(**init_kwargs) address_info = ray.init(**init_kwargs)
yield address_info yield address_info
# The code after the yield will run as teardown code. # The code after the yield will run as teardown code.
if client_test_enabled():
ray_client.ray.disconnect()
ray_client._stop_test_server(1)
ray_client.reset_api()
ray.shutdown() ray.shutdown()
@ -130,9 +144,16 @@ def _ray_start_cluster(**kwargs):
# We assume driver will connect to the head (first node), # We assume driver will connect to the head (first node),
# so ray init will be invoked if do_init is true # so ray init will be invoked if do_init is true
if len(remote_nodes) == 1 and do_init: if len(remote_nodes) == 1 and do_init:
if client_test_enabled():
ray_client.ray.init(address=cluster.address)
else:
ray.init(address=cluster.address) ray.init(address=cluster.address)
yield cluster yield cluster
# The code after the yield will run as teardown code. # The code after the yield will run as teardown code.
if client_test_enabled():
ray_client.ray.disconnect()
ray_client._stop_test_server(1)
ray_client.reset_api()
ray.shutdown() ray.shutdown()
cluster.shutdown() cluster.shutdown()

View file

@ -11,15 +11,21 @@ import sys
import tempfile import tempfile
import datetime import datetime
import ray from ray.test_utils import client_test_enabled
import ray.test_utils from ray.test_utils import wait_for_condition
import ray.cluster_utils from ray.test_utils import wait_for_pid_to_exit
from ray.tests.client_test_utils import create_remote_signal_actor
if client_test_enabled():
from ray.experimental.client import ray
else:
import ray
# NOTE: We have to import setproctitle after ray because we bundle setproctitle # NOTE: We have to import setproctitle after ray because we bundle setproctitle
# with ray. # with ray.
import setproctitle import setproctitle # noqa
@pytest.mark.skipif(client_test_enabled(), reason="test setup order")
def test_caching_actors(shutdown_only): def test_caching_actors(shutdown_only):
# Test defining actors before ray.init() has been called. # Test defining actors before ray.init() has been called.
@ -238,6 +244,7 @@ def test_actor_import_counter(ray_start_10_cpus):
assert ray.get(g.remote()) == num_remote_functions - 1 assert ray.get(g.remote()) == num_remote_functions - 1
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_actor_method_metadata_cache(ray_start_regular): def test_actor_method_metadata_cache(ray_start_regular):
class Actor(object): class Actor(object):
pass pass
@ -257,6 +264,7 @@ def test_actor_method_metadata_cache(ray_start_regular):
assert [id(x) for x in list(cache.items())[0]] == cached_data_id assert [id(x) for x in list(cache.items())[0]] == cached_data_id
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_actor_class_name(ray_start_regular): def test_actor_class_name(ray_start_regular):
@ray.remote @ray.remote
class Foo: class Foo:
@ -556,6 +564,7 @@ def test_actor_static_attributes(ray_start_regular_shared):
assert ray.get(t.g.remote()) == 3 assert ray.get(t.g.remote()) == 3
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_decorator_args(ray_start_regular_shared): def test_decorator_args(ray_start_regular_shared):
# This is an invalid way of using the actor decorator. # This is an invalid way of using the actor decorator.
with pytest.raises(Exception): with pytest.raises(Exception):
@ -618,6 +627,8 @@ def test_random_id_generation(ray_start_regular_shared):
assert f1._actor_id != f2._actor_id assert f1._actor_id != f2._actor_id
@pytest.mark.skipif(
client_test_enabled(), reason="differing inheritence structure")
def test_actor_inheritance(ray_start_regular_shared): def test_actor_inheritance(ray_start_regular_shared):
class NonActorBase: class NonActorBase:
def __init__(self): def __init__(self):
@ -630,8 +641,7 @@ def test_actor_inheritance(ray_start_regular_shared):
pass pass
# Test that you can't instantiate an actor class directly. # Test that you can't instantiate an actor class directly.
with pytest.raises( with pytest.raises(Exception, match="cannot be instantiated directly"):
Exception, match="Actors cannot be instantiated directly."):
ActorBase() ActorBase()
# Test that you can't inherit from an actor class. # Test that you can't inherit from an actor class.
@ -645,6 +655,7 @@ def test_actor_inheritance(ray_start_regular_shared):
pass pass
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_multiple_return_values(ray_start_regular_shared): def test_multiple_return_values(ray_start_regular_shared):
@ray.remote @ray.remote
class Foo: class Foo:
@ -678,6 +689,7 @@ def test_multiple_return_values(ray_start_regular_shared):
assert ray.get([id3a, id3b, id3c]) == [1, 2, 3] assert ray.get([id3a, id3b, id3c]) == [1, 2, 3]
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_options_num_returns(ray_start_regular_shared): def test_options_num_returns(ray_start_regular_shared):
@ray.remote @ray.remote
class Foo: class Foo:
@ -693,6 +705,7 @@ def test_options_num_returns(ray_start_regular_shared):
assert ray.get([obj1, obj2]) == [1, 2] assert ray.get([obj1, obj2]) == [1, 2]
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_options_name(ray_start_regular_shared): def test_options_name(ray_start_regular_shared):
@ray.remote @ray.remote
class Foo: class Foo:
@ -734,13 +747,13 @@ def test_actor_deletion(ray_start_regular_shared):
a = Actor.remote() a = Actor.remote()
pid = ray.get(a.getpid.remote()) pid = ray.get(a.getpid.remote())
a = None a = None
ray.test_utils.wait_for_pid_to_exit(pid) wait_for_pid_to_exit(pid)
actors = [Actor.remote() for _ in range(10)] actors = [Actor.remote() for _ in range(10)]
pids = ray.get([a.getpid.remote() for a in actors]) pids = ray.get([a.getpid.remote() for a in actors])
a = None a = None
actors = None actors = None
[ray.test_utils.wait_for_pid_to_exit(pid) for pid in pids] [wait_for_pid_to_exit(pid) for pid in pids]
def test_actor_method_deletion(ray_start_regular_shared): def test_actor_method_deletion(ray_start_regular_shared):
@ -769,7 +782,8 @@ def test_distributed_actor_handle_deletion(ray_start_regular_shared):
ray.get(signal.wait.remote()) ray.get(signal.wait.remote())
return ray.get(actor.method.remote()) return ray.get(actor.method.remote())
signal = ray.test_utils.SignalActor.remote() SignalActor = create_remote_signal_actor(ray)
signal = SignalActor.remote()
a = Actor.remote() a = Actor.remote()
pid = ray.get(a.getpid.remote()) pid = ray.get(a.getpid.remote())
# Pass the handle to another task that cannot run yet. # Pass the handle to another task that cannot run yet.
@ -780,7 +794,7 @@ def test_distributed_actor_handle_deletion(ray_start_regular_shared):
# Once the task finishes, the actor process should get killed. # Once the task finishes, the actor process should get killed.
ray.get(signal.send.remote()) ray.get(signal.send.remote())
assert ray.get(x_id) == 1 assert ray.get(x_id) == 1
ray.test_utils.wait_for_pid_to_exit(pid) wait_for_pid_to_exit(pid)
def test_multiple_actors(ray_start_regular_shared): def test_multiple_actors(ray_start_regular_shared):
@ -921,7 +935,7 @@ def test_atexit_handler(ray_start_regular_shared, exit_condition):
if exit_condition == "ray.kill": if exit_condition == "ray.kill":
assert not check_file_written() assert not check_file_written()
else: else:
ray.test_utils.wait_for_condition(check_file_written) wait_for_condition(check_file_written)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -8,14 +8,23 @@ import time
import numpy as np import numpy as np
import pytest import pytest
import ray
import ray.cluster_utils import ray.cluster_utils
import ray.test_utils from ray.test_utils import (
client_test_enabled,
dicts_equal,
wait_for_pid_to_exit,
)
if client_test_enabled():
from ray.experimental.client import ray
else:
import ray
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# https://github.com/ray-project/ray/issues/6662 # https://github.com/ray-project/ray/issues/6662
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_ignore_http_proxy(shutdown_only): def test_ignore_http_proxy(shutdown_only):
ray.init(num_cpus=1) ray.init(num_cpus=1)
os.environ["http_proxy"] = "http://example.com" os.environ["http_proxy"] = "http://example.com"
@ -29,6 +38,7 @@ def test_ignore_http_proxy(shutdown_only):
# https://github.com/ray-project/ray/issues/7263 # https://github.com/ray-project/ray/issues/7263
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_grpc_message_size(shutdown_only): def test_grpc_message_size(shutdown_only):
ray.init(num_cpus=1) ray.init(num_cpus=1)
@ -45,12 +55,14 @@ def test_grpc_message_size(shutdown_only):
# https://github.com/ray-project/ray/issues/7287 # https://github.com/ray-project/ray/issues/7287
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_omp_threads_set(shutdown_only): def test_omp_threads_set(shutdown_only):
ray.init(num_cpus=1) ray.init(num_cpus=1)
# Should have been auto set by ray init. # Should have been auto set by ray init.
assert os.environ["OMP_NUM_THREADS"] == "1" assert os.environ["OMP_NUM_THREADS"] == "1"
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_submit_api(shutdown_only): def test_submit_api(shutdown_only):
ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1})
@ -109,6 +121,7 @@ def test_submit_api(shutdown_only):
assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2]
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_invalid_arguments(shutdown_only): def test_invalid_arguments(shutdown_only):
ray.init(num_cpus=2) ray.init(num_cpus=2)
@ -163,6 +176,7 @@ def test_invalid_arguments(shutdown_only):
x = 1 x = 1
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_many_fractional_resources(shutdown_only): def test_many_fractional_resources(shutdown_only):
ray.init(num_cpus=2, num_gpus=2, resources={"Custom": 2}) ray.init(num_cpus=2, num_gpus=2, resources={"Custom": 2})
@ -178,7 +192,7 @@ def test_many_fractional_resources(shutdown_only):
} }
if block: if block:
ray.get(g.remote()) ray.get(g.remote())
return ray.test_utils.dicts_equal(true_resources, accepted_resources) return dicts_equal(true_resources, accepted_resources)
# Check that the resource are assigned correctly. # Check that the resource are assigned correctly.
result_ids = [] result_ids = []
@ -230,6 +244,7 @@ def test_many_fractional_resources(shutdown_only):
assert False, "Did not get correct available resources." assert False, "Did not get correct available resources."
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_background_tasks_with_max_calls(shutdown_only): def test_background_tasks_with_max_calls(shutdown_only):
ray.init(num_cpus=2) ray.init(num_cpus=2)
@ -257,7 +272,7 @@ def test_background_tasks_with_max_calls(shutdown_only):
pid, g_id = nested.pop(0) pid, g_id = nested.pop(0)
ray.get(g_id) ray.get(g_id)
del g_id del g_id
ray.test_utils.wait_for_pid_to_exit(pid) wait_for_pid_to_exit(pid)
def test_fair_queueing(shutdown_only): def test_fair_queueing(shutdown_only):
@ -327,6 +342,7 @@ def test_wait_timing(shutdown_only):
assert len(not_ready) == 1 assert len(not_ready) == 1
@pytest.mark.skipif(client_test_enabled(), reason="internal _raylet")
def test_function_descriptor(): def test_function_descriptor():
python_descriptor = ray._raylet.PythonFunctionDescriptor( python_descriptor = ray._raylet.PythonFunctionDescriptor(
"module_name", "function_name", "class_name", "function_hash") "module_name", "function_name", "class_name", "function_hash")
@ -344,6 +360,7 @@ def test_function_descriptor():
assert d.get(python_descriptor2) == 123 assert d.get(python_descriptor2) == 123
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_ray_options(shutdown_only): def test_ray_options(shutdown_only):
@ray.remote( @ray.remote(
num_cpus=2, num_gpus=3, memory=150 * 2**20, resources={"custom1": 1}) num_cpus=2, num_gpus=3, memory=150 * 2**20, resources={"custom1": 1})
@ -371,6 +388,7 @@ def test_ray_options(shutdown_only):
assert without_options != with_options assert without_options != with_options
@pytest.mark.skipif(client_test_enabled(), reason="message size")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ray_start_cluster_head", [{ "ray_start_cluster_head", [{
"num_cpus": 0, "num_cpus": 0,
@ -438,8 +456,11 @@ def test_nested_functions(ray_start_shared_local_modes):
assert ray.get(factorial.remote(4)) == 24 assert ray.get(factorial.remote(4)) == 24
assert ray.get(factorial.remote(5)) == 120 assert ray.get(factorial.remote(5)) == 120
# Test remote functions that recursively call each other.
@pytest.mark.skipif(
client_test_enabled(), reason="mutual recursion is a known issue")
def test_mutually_recursive_functions(ray_start_shared_local_modes):
# Test remote functions that recursively call each other.
@ray.remote @ray.remote
def factorial_even(n): def factorial_even(n):
assert n % 2 == 0 assert n % 2 == 0
@ -710,6 +731,7 @@ def test_args_stars_after(ray_start_shared_local_modes):
ray.get(remote_test_function.remote(local_method, actor_method)) ray.get(remote_test_function.remote(local_method, actor_method))
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_object_id_backward_compatibility(ray_start_shared_local_modes): def test_object_id_backward_compatibility(ray_start_shared_local_modes):
# We've renamed Python's `ObjectID` to `ObjectRef`, and added a type # We've renamed Python's `ObjectID` to `ObjectRef`, and added a type
# alias for backward compatibility. # alias for backward compatibility.

View file

@ -9,10 +9,16 @@ import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import ray
import ray.cluster_utils import ray.cluster_utils
import ray.test_utils from ray.test_utils import client_test_enabled
from ray.tests.client_test_utils import create_remote_signal_actor
from ray.exceptions import GetTimeoutError from ray.exceptions import GetTimeoutError
from ray.exceptions import RayTaskError
if client_test_enabled():
from ray.experimental.client import ray
else:
import ray
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,6 +31,8 @@ logger = logging.getLogger(__name__)
}], }],
indirect=True) indirect=True)
def test_variable_number_of_args(shutdown_only): def test_variable_number_of_args(shutdown_only):
ray.init(num_cpus=1)
@ray.remote @ray.remote
def varargs_fct1(*a): def varargs_fct1(*a):
return " ".join(map(str, a)) return " ".join(map(str, a))
@ -33,8 +41,6 @@ def test_variable_number_of_args(shutdown_only):
def varargs_fct2(a, *b): def varargs_fct2(a, *b):
return " ".join(map(str, b)) return " ".join(map(str, b))
ray.init(num_cpus=1)
x = varargs_fct1.remote(0, 1, 2) x = varargs_fct1.remote(0, 1, 2)
assert ray.get(x) == "0 1 2" assert ray.get(x) == "0 1 2"
x = varargs_fct2.remote(0, 1, 2) x = varargs_fct2.remote(0, 1, 2)
@ -160,7 +166,7 @@ def test_redefining_remote_functions(shutdown_only):
def g(): def g():
return nonexistent() return nonexistent()
with pytest.raises(ray.exceptions.RayTaskError, match="nonexistent"): with pytest.raises(RayTaskError, match="nonexistent"):
ray.get(g.remote()) ray.get(g.remote())
def nonexistent(): def nonexistent():
@ -187,6 +193,7 @@ def test_redefining_remote_functions(shutdown_only):
assert ray.get(ray.get(h.remote(i))) == i assert ray.get(ray.get(h.remote(i))) == i
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_call_matrix(shutdown_only): def test_call_matrix(shutdown_only):
ray.init(object_store_memory=1000 * 1024 * 1024) ray.init(object_store_memory=1000 * 1024 * 1024)
@ -312,6 +319,7 @@ def test_actor_pass_by_ref_order_optimization(shutdown_only):
assert delta < 10, "did not skip slow value" assert delta < 10, "did not skip slow value"
@pytest.mark.skipif(client_test_enabled(), reason="message size")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ray_start_cluster", [{ "ray_start_cluster", [{
"num_cpus": 1, "num_cpus": 1,
@ -332,6 +340,7 @@ def test_call_chain(ray_start_cluster):
assert ray.get(x) == 100 assert ray.get(x) == 100
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_system_config_when_connecting(ray_start_cluster): def test_system_config_when_connecting(ray_start_cluster):
config = {"object_pinning_enabled": 0, "object_timeout_milliseconds": 200} config = {"object_pinning_enabled": 0, "object_timeout_milliseconds": 200}
cluster = ray.cluster_utils.Cluster() cluster = ray.cluster_utils.Cluster()
@ -368,7 +377,8 @@ def test_get_multiple(ray_start_regular_shared):
def test_get_with_timeout(ray_start_regular_shared): def test_get_with_timeout(ray_start_regular_shared):
signal = ray.test_utils.SignalActor.remote() SignalActor = create_remote_signal_actor(ray)
signal = SignalActor.remote()
# Check that get() returns early if object is ready. # Check that get() returns early if object is ready.
start = time.time() start = time.time()
@ -438,6 +448,7 @@ def test_inline_arg_memory_corruption(ray_start_regular_shared):
ray.get(a.add.remote(f.remote())) ray.get(a.add.remote(f.remote()))
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_skip_plasma(ray_start_regular_shared): def test_skip_plasma(ray_start_regular_shared):
@ray.remote @ray.remote
class Actor: class Actor:
@ -454,6 +465,8 @@ def test_skip_plasma(ray_start_regular_shared):
assert ray.get(obj_ref) == 2 assert ray.get(obj_ref) == 2
@pytest.mark.skipif(
client_test_enabled(), reason="internal api and message size")
def test_actor_large_objects(ray_start_regular_shared): def test_actor_large_objects(ray_start_regular_shared):
@ray.remote @ray.remote
class Actor: class Actor:
@ -524,6 +537,7 @@ def test_actor_recursive(ray_start_regular_shared):
assert result == [x * 2 for x in range(100)] assert result == [x * 2 for x in range(100)]
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_actor_concurrent(ray_start_regular_shared): def test_actor_concurrent(ray_start_regular_shared):
@ray.remote @ray.remote
class Batcher: class Batcher:
@ -626,6 +640,7 @@ def test_duplicate_args(ray_start_regular_shared):
arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1)) arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1))
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_get_correct_node_ip(): def test_get_correct_node_ip():
with patch("ray.worker") as worker_mock: with patch("ray.worker") as worker_mock:
node_mock = MagicMock() node_mock = MagicMock()

View file

@ -81,11 +81,11 @@ def test_wait(ray_start_regular_shared):
with pytest.raises(Exception): with pytest.raises(Exception):
# Reference not in the object store. # Reference not in the object store.
ray.wait([ClientObjectRef("blabla")]) ray.wait([ClientObjectRef("blabla")])
with pytest.raises(AssertionError): with pytest.raises(TypeError):
ray.wait("blabla") ray.wait("blabla")
with pytest.raises(AssertionError): with pytest.raises(TypeError):
ray.wait(ClientObjectRef("blabla")) ray.wait(ClientObjectRef("blabla"))
with pytest.raises(AssertionError): with pytest.raises(TypeError):
ray.wait(["blabla"]) ray.wait(["blabla"])

View file

@ -1,6 +1,6 @@
import pytest import pytest
import asyncio
from ray.tests.test_experimental_client import ray_start_client_server from ray.tests.test_experimental_client import ray_start_client_server
from ray.tests.client_test_utils import create_remote_signal_actor
from ray.test_utils import wait_for_condition from ray.test_utils import wait_for_condition
from ray.exceptions import TaskCancelledError from ray.exceptions import TaskCancelledError
from ray.exceptions import RayTaskError from ray.exceptions import RayTaskError
@ -45,21 +45,7 @@ def test_kill_actor_immediately_after_creation(ray_start_regular):
@pytest.mark.parametrize("use_force", [True, False]) @pytest.mark.parametrize("use_force", [True, False])
def test_cancel_chain(ray_start_regular, use_force): def test_cancel_chain(ray_start_regular, use_force):
with ray_start_client_server() as ray: with ray_start_client_server() as ray:
SignalActor = create_remote_signal_actor(ray)
@ray.remote
class SignalActor:
def __init__(self):
self.ready_event = asyncio.Event()
def send(self, clear=False):
self.ready_event.set()
if clear:
self.ready_event.clear()
async def wait(self, should_wait=True):
if should_wait:
await self.ready_event.wait()
signaler = SignalActor.remote() signaler = SignalActor.remote()
@ray.remote @ray.remote

View file

@ -50,16 +50,22 @@ message ClientTask {
string name = 2; string name = 2;
// A reference to the payload. // A reference to the payload.
bytes payload_id = 3; bytes payload_id = 3;
// The parameters to pass to this call. // Positional parameters to pass to this call.
repeated Arg args = 4; repeated Arg args = 4;
// Keyword parameters to pass to this call.
map<string, Arg> kwargs = 5;
// The ID of the client namespace associated with the Datapath stream making this // The ID of the client namespace associated with the Datapath stream making this
// request. // request.
string client_id = 5; string client_id = 6;
} }
message ClientTaskTicket { message ClientTaskTicket {
// Was the task successful?
bool valid = 1;
// A reference to the returned value from the execution. // A reference to the returned value from the execution.
bytes return_id = 1; bytes return_id = 2;
// If unsuccessful, an encoding of the error.
bytes error = 3;
} }
// Delivers data to the server // Delivers data to the server