From e715ade2d1026d1ca381b463fbb482c30acd3d93 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Sun, 20 Dec 2020 16:34:50 -0800 Subject: [PATCH] Support retrieval of named actor handles (#13000) Change-Id: I05d31c9c67943d2a0230782cbdaa98341584cbc7 --- python/ray/experimental/client/api.py | 3 ++ python/ray/experimental/client/common.py | 7 ++--- .../ray/experimental/client/server/server.py | 12 ++++++++ python/ray/experimental/client/worker.py | 15 +++++++++- python/ray/tests/test_experimental_client.py | 29 +++++++++++++++++++ src/ray/protobuf/ray_client.proto | 1 + 6 files changed, 62 insertions(+), 5 deletions(-) diff --git a/python/ray/experimental/client/api.py b/python/ray/experimental/client/api.py index 5167e5988..93da6382f 100644 --- a/python/ray/experimental/client/api.py +++ b/python/ray/experimental/client/api.py @@ -197,6 +197,9 @@ class ClientAPI(APIImpl): def close(self) -> None: return self.worker.close() + def get_actor(self, name: str) -> "ClientActorHandle": + return self.worker.get_actor(name) + def kill(self, actor: "ClientActorHandle", *, no_restart=True): return self.worker.terminate_actor(actor, no_restart) diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index 49eee05d6..f68b26e2c 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -144,7 +144,7 @@ class ClientActorClass(ClientStub): # Actually instantiate the actor ref_ids = ray.call_remote(self, *args, **kwargs) assert len(ref_ids) == 1 - return ClientActorHandle(ClientActorRef(ref_ids[0]), self) + return ClientActorHandle(ClientActorRef(ref_ids[0])) def options(self, **kwargs): return ActorOptionWrapper(self, kwargs) @@ -186,8 +186,7 @@ class ClientActorHandle(ClientStub): ray.actor.ActorHandle contained in the actor_id ref. """ - def __init__(self, actor_ref: ClientActorRef, - actor_class: ClientActorClass): + def __init__(self, actor_ref: ClientActorRef): self.actor_ref = actor_ref def __del__(self) -> None: @@ -266,7 +265,7 @@ class ActorOptionWrapper(OptionWrapper): def remote(self, *args, **kwargs): ref_ids = ray.call_remote(self, *args, **kwargs) assert len(ref_ids) == 1 - return ClientActorHandle(ClientActorRef(ref_ids[0]), self) + return ClientActorHandle(ClientActorRef(ref_ids[0])) def set_task_options(task: ray_client_pb2.ClientTask, diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index 5f86ddee2..442cf1afa 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -222,6 +222,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): result = self._schedule_actor(task, context) elif task.type == ray_client_pb2.ClientTask.METHOD: result = self._schedule_method(task, context) + elif task.type == ray_client_pb2.ClientTask.NAMED_ACTOR: + result = self._schedule_named_actor(task, context) else: raise NotImplementedError( "Unimplemented Schedule task type: %s" % @@ -281,6 +283,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): ids = self.unify_and_track_outputs(output, task.client_id) return ray_client_pb2.ClientTaskTicket(return_ids=ids) + def _schedule_named_actor(self, + task: ray_client_pb2.ClientTask, + context=None) -> ray_client_pb2.ClientTaskTicket: + assert len(task.payload_id) == 0 + actor = ray.get_actor(task.name) + self.actor_refs[actor._actor_id.binary()] = actor + self.actor_owners[task.client_id].add(actor._actor_id.binary()) + return ray_client_pb2.ClientTaskTicket( + return_ids=[actor._actor_id.binary()]) + def _convert_args(self, arg_list, kwarg_map): argout = [] for arg in arg_list: diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index d2ba52d62..bba23584b 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -25,6 +25,7 @@ from ray.experimental.client.client_pickler import dumps_from_client from ray.experimental.client.client_pickler import loads_from_server from ray.experimental.client.common import ClientActorClass from ray.experimental.client.common import ClientActorHandle +from ray.experimental.client.common import ClientActorRef from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientRemoteFunc from ray.experimental.client.common import ClientStub @@ -169,8 +170,12 @@ class Worker: task.args.append(pb_arg) for k, v in kwargs.items(): task.kwargs[k].CopyFrom(convert_to_arg(v, self._client_id)) - task.client_id = self._client_id + return self._call_schedule_for_task(task) + + def _call_schedule_for_task( + self, task: ray_client_pb2.ClientTask) -> List[bytes]: logger.debug("Scheduling %s" % task) + task.client_id = self._client_id try: ticket = self.server.Schedule(task, metadata=self.metadata) except grpc.RpcError as e: @@ -201,6 +206,14 @@ class Worker: if self.channel: self.channel = None + def get_actor(self, name: str) -> ClientActorHandle: + task = ray_client_pb2.ClientTask() + task.type = ray_client_pb2.ClientTask.NAMED_ACTOR + task.name = name + ids = self._call_schedule_for_task(task) + assert len(ids) == 1 + return ClientActorHandle(ClientActorRef(ids[0])) + def terminate_actor(self, actor: ClientActorHandle, no_restart: bool) -> None: if not isinstance(actor, ClientActorHandle): diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index e68abb366..cc15e7272 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -234,6 +234,35 @@ def test_pass_handles(ray_start_regular_shared): 4)) == local_fact(4) +def test_basic_named_actor(ray_start_regular_shared): + """ + Test that ray.get_actor() can create and return a detached actor. + """ + with ray_start_client_server() as ray: + + @ray.remote + class Accumulator: + def __init__(self): + self.x = 0 + + def inc(self): + self.x += 1 + + def get(self): + return self.x + + # Create the actor + actor = Accumulator.options(name="test_acc").remote() + + actor.inc.remote() + actor.inc.remote() + del actor + + new_actor = ray.get_actor("test_acc") + new_actor.inc.remote() + assert ray.get(new_actor.get.remote()) == 3 + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index cbd6679dd..a566f8031 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -54,6 +54,7 @@ message ClientTask { ACTOR = 1; METHOD = 2; STATIC_METHOD = 3; + NAMED_ACTOR = 4; } // Which type of work this request represents. RemoteExecType type = 1;