mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Support retrieval of named actor handles (#13000)
Change-Id: I05d31c9c67943d2a0230782cbdaa98341584cbc7
This commit is contained in:
parent
80f6dd16b2
commit
e715ade2d1
6 changed files with 62 additions and 5 deletions
|
@ -197,6 +197,9 @@ class ClientAPI(APIImpl):
|
|||
def close(self) -> None:
|
||||
return self.worker.close()
|
||||
|
||||
def get_actor(self, name: str) -> "ClientActorHandle":
|
||||
return self.worker.get_actor(name)
|
||||
|
||||
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
|
||||
return self.worker.terminate_actor(actor, no_restart)
|
||||
|
||||
|
|
|
@ -144,7 +144,7 @@ class ClientActorClass(ClientStub):
|
|||
# Actually instantiate the actor
|
||||
ref_ids = ray.call_remote(self, *args, **kwargs)
|
||||
assert len(ref_ids) == 1
|
||||
return ClientActorHandle(ClientActorRef(ref_ids[0]), self)
|
||||
return ClientActorHandle(ClientActorRef(ref_ids[0]))
|
||||
|
||||
def options(self, **kwargs):
|
||||
return ActorOptionWrapper(self, kwargs)
|
||||
|
@ -186,8 +186,7 @@ class ClientActorHandle(ClientStub):
|
|||
ray.actor.ActorHandle contained in the actor_id ref.
|
||||
"""
|
||||
|
||||
def __init__(self, actor_ref: ClientActorRef,
|
||||
actor_class: ClientActorClass):
|
||||
def __init__(self, actor_ref: ClientActorRef):
|
||||
self.actor_ref = actor_ref
|
||||
|
||||
def __del__(self) -> None:
|
||||
|
@ -266,7 +265,7 @@ class ActorOptionWrapper(OptionWrapper):
|
|||
def remote(self, *args, **kwargs):
|
||||
ref_ids = ray.call_remote(self, *args, **kwargs)
|
||||
assert len(ref_ids) == 1
|
||||
return ClientActorHandle(ClientActorRef(ref_ids[0]), self)
|
||||
return ClientActorHandle(ClientActorRef(ref_ids[0]))
|
||||
|
||||
|
||||
def set_task_options(task: ray_client_pb2.ClientTask,
|
||||
|
|
|
@ -222,6 +222,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
result = self._schedule_actor(task, context)
|
||||
elif task.type == ray_client_pb2.ClientTask.METHOD:
|
||||
result = self._schedule_method(task, context)
|
||||
elif task.type == ray_client_pb2.ClientTask.NAMED_ACTOR:
|
||||
result = self._schedule_named_actor(task, context)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented Schedule task type: %s" %
|
||||
|
@ -281,6 +283,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
ids = self.unify_and_track_outputs(output, task.client_id)
|
||||
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
|
||||
|
||||
def _schedule_named_actor(self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
assert len(task.payload_id) == 0
|
||||
actor = ray.get_actor(task.name)
|
||||
self.actor_refs[actor._actor_id.binary()] = actor
|
||||
self.actor_owners[task.client_id].add(actor._actor_id.binary())
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ids=[actor._actor_id.binary()])
|
||||
|
||||
def _convert_args(self, arg_list, kwarg_map):
|
||||
argout = []
|
||||
for arg in arg_list:
|
||||
|
|
|
@ -25,6 +25,7 @@ from ray.experimental.client.client_pickler import dumps_from_client
|
|||
from ray.experimental.client.client_pickler import loads_from_server
|
||||
from ray.experimental.client.common import ClientActorClass
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientActorRef
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.common import ClientStub
|
||||
|
@ -169,8 +170,12 @@ class Worker:
|
|||
task.args.append(pb_arg)
|
||||
for k, v in kwargs.items():
|
||||
task.kwargs[k].CopyFrom(convert_to_arg(v, self._client_id))
|
||||
task.client_id = self._client_id
|
||||
return self._call_schedule_for_task(task)
|
||||
|
||||
def _call_schedule_for_task(
|
||||
self, task: ray_client_pb2.ClientTask) -> List[bytes]:
|
||||
logger.debug("Scheduling %s" % task)
|
||||
task.client_id = self._client_id
|
||||
try:
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
except grpc.RpcError as e:
|
||||
|
@ -201,6 +206,14 @@ class Worker:
|
|||
if self.channel:
|
||||
self.channel = None
|
||||
|
||||
def get_actor(self, name: str) -> ClientActorHandle:
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
|
||||
task.name = name
|
||||
ids = self._call_schedule_for_task(task)
|
||||
assert len(ids) == 1
|
||||
return ClientActorHandle(ClientActorRef(ids[0]))
|
||||
|
||||
def terminate_actor(self, actor: ClientActorHandle,
|
||||
no_restart: bool) -> None:
|
||||
if not isinstance(actor, ClientActorHandle):
|
||||
|
|
|
@ -234,6 +234,35 @@ def test_pass_handles(ray_start_regular_shared):
|
|||
4)) == local_fact(4)
|
||||
|
||||
|
||||
def test_basic_named_actor(ray_start_regular_shared):
|
||||
"""
|
||||
Test that ray.get_actor() can create and return a detached actor.
|
||||
"""
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class Accumulator:
|
||||
def __init__(self):
|
||||
self.x = 0
|
||||
|
||||
def inc(self):
|
||||
self.x += 1
|
||||
|
||||
def get(self):
|
||||
return self.x
|
||||
|
||||
# Create the actor
|
||||
actor = Accumulator.options(name="test_acc").remote()
|
||||
|
||||
actor.inc.remote()
|
||||
actor.inc.remote()
|
||||
del actor
|
||||
|
||||
new_actor = ray.get_actor("test_acc")
|
||||
new_actor.inc.remote()
|
||||
assert ray.get(new_actor.get.remote()) == 3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -54,6 +54,7 @@ message ClientTask {
|
|||
ACTOR = 1;
|
||||
METHOD = 2;
|
||||
STATIC_METHOD = 3;
|
||||
NAMED_ACTOR = 4;
|
||||
}
|
||||
// Which type of work this request represents.
|
||||
RemoteExecType type = 1;
|
||||
|
|
Loading…
Add table
Reference in a new issue