Support retrieval of named actor handles (#13000)

Change-Id: I05d31c9c67943d2a0230782cbdaa98341584cbc7
This commit is contained in:
Barak Michener 2020-12-20 16:34:50 -08:00 committed by GitHub
parent 80f6dd16b2
commit e715ade2d1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 62 additions and 5 deletions

View file

@ -197,6 +197,9 @@ class ClientAPI(APIImpl):
def close(self) -> None:
return self.worker.close()
def get_actor(self, name: str) -> "ClientActorHandle":
return self.worker.get_actor(name)
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
return self.worker.terminate_actor(actor, no_restart)

View file

@ -144,7 +144,7 @@ class ClientActorClass(ClientStub):
# Actually instantiate the actor
ref_ids = ray.call_remote(self, *args, **kwargs)
assert len(ref_ids) == 1
return ClientActorHandle(ClientActorRef(ref_ids[0]), self)
return ClientActorHandle(ClientActorRef(ref_ids[0]))
def options(self, **kwargs):
return ActorOptionWrapper(self, kwargs)
@ -186,8 +186,7 @@ class ClientActorHandle(ClientStub):
ray.actor.ActorHandle contained in the actor_id ref.
"""
def __init__(self, actor_ref: ClientActorRef,
actor_class: ClientActorClass):
def __init__(self, actor_ref: ClientActorRef):
self.actor_ref = actor_ref
def __del__(self) -> None:
@ -266,7 +265,7 @@ class ActorOptionWrapper(OptionWrapper):
def remote(self, *args, **kwargs):
ref_ids = ray.call_remote(self, *args, **kwargs)
assert len(ref_ids) == 1
return ClientActorHandle(ClientActorRef(ref_ids[0]), self)
return ClientActorHandle(ClientActorRef(ref_ids[0]))
def set_task_options(task: ray_client_pb2.ClientTask,

View file

@ -222,6 +222,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
result = self._schedule_actor(task, context)
elif task.type == ray_client_pb2.ClientTask.METHOD:
result = self._schedule_method(task, context)
elif task.type == ray_client_pb2.ClientTask.NAMED_ACTOR:
result = self._schedule_named_actor(task, context)
else:
raise NotImplementedError(
"Unimplemented Schedule task type: %s" %
@ -281,6 +283,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
ids = self.unify_and_track_outputs(output, task.client_id)
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
def _schedule_named_actor(self,
task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
assert len(task.payload_id) == 0
actor = ray.get_actor(task.name)
self.actor_refs[actor._actor_id.binary()] = actor
self.actor_owners[task.client_id].add(actor._actor_id.binary())
return ray_client_pb2.ClientTaskTicket(
return_ids=[actor._actor_id.binary()])
def _convert_args(self, arg_list, kwarg_map):
argout = []
for arg in arg_list:

View file

@ -25,6 +25,7 @@ from ray.experimental.client.client_pickler import dumps_from_client
from ray.experimental.client.client_pickler import loads_from_server
from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientActorRef
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.common import ClientStub
@ -169,8 +170,12 @@ class Worker:
task.args.append(pb_arg)
for k, v in kwargs.items():
task.kwargs[k].CopyFrom(convert_to_arg(v, self._client_id))
task.client_id = self._client_id
return self._call_schedule_for_task(task)
def _call_schedule_for_task(
self, task: ray_client_pb2.ClientTask) -> List[bytes]:
logger.debug("Scheduling %s" % task)
task.client_id = self._client_id
try:
ticket = self.server.Schedule(task, metadata=self.metadata)
except grpc.RpcError as e:
@ -201,6 +206,14 @@ class Worker:
if self.channel:
self.channel = None
def get_actor(self, name: str) -> ClientActorHandle:
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
task.name = name
ids = self._call_schedule_for_task(task)
assert len(ids) == 1
return ClientActorHandle(ClientActorRef(ids[0]))
def terminate_actor(self, actor: ClientActorHandle,
no_restart: bool) -> None:
if not isinstance(actor, ClientActorHandle):

View file

@ -234,6 +234,35 @@ def test_pass_handles(ray_start_regular_shared):
4)) == local_fact(4)
def test_basic_named_actor(ray_start_regular_shared):
"""
Test that ray.get_actor() can create and return a detached actor.
"""
with ray_start_client_server() as ray:
@ray.remote
class Accumulator:
def __init__(self):
self.x = 0
def inc(self):
self.x += 1
def get(self):
return self.x
# Create the actor
actor = Accumulator.options(name="test_acc").remote()
actor.inc.remote()
actor.inc.remote()
del actor
new_actor = ray.get_actor("test_acc")
new_actor.inc.remote()
assert ray.get(new_actor.get.remote()) == 3
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -54,6 +54,7 @@ message ClientTask {
ACTOR = 1;
METHOD = 2;
STATIC_METHOD = 3;
NAMED_ACTOR = 4;
}
// Which type of work this request represents.
RemoteExecType type = 1;