[ray_client] Passing actors to actors (#12585)

* start building tests around passing handles to handles

Change-Id: Ie8c3de5c8ce789c3ec8d29f0702df80ba598279f

* clean up the switch statements by moving to a method, implement state tranfer, extend test

Change-Id: Ie7b6493db3a6c203d3a0b262b8fbacb90e5cdbc5

* passing

Change-Id: Id88dc0a41da1c9d5ba68f754c5b57141aae47beb

* flush out tests

Change-Id: If77c0f586e9e99449d494be4e85f854e4a7a4952

* formatting

Change-Id: I497c07cee70b52453b221ed4393f04f6f560061e

* fix python3.6 and other attributes

Change-Id: I5a2c5231e8a021184d9dfc3e346df7f71fc93257

* address documentation

Change-Id: I049d841ed1f85b7350c17c05da4a4d81d5cb03df

* formatting

Change-Id: I6a2b32a2466ffc9f03fc91ac17901b9c1a49505c

* use the pickled handle as the id bytes for actors

Change-Id: I9ddcb41d614de65d42d6f0382fe0faa7ad2c2ade

* pydoc

Change-Id: I9b32a0f383d5ff5ac052e61929b7ae3e42a89fc5

* format

Change-Id: Iac0010bb990a4025a98139ab88700030b2e9e7f5

* todos

Change-Id: I7b550800cf7499403e8a17b77484bc46f20f0afc

* tests

Change-Id: If8ebf6a335baeb113c1332acc930c41a6b4f5384

* fix lint

Change-Id: I019f41e0ec341d39bbbbd39aa43d9fb5f8b57cf0

* nits

Change-Id: I2e6813d8db34f4ce008326faa095d414c10eee95

* add some tricky, python3.6-troublesome type checking

Change-Id: Ib887fc943a6e7084002bc13dfbe113b69b4d9317
This commit is contained in:
Barak Michener 2020-12-08 21:54:55 -08:00 committed by GitHub
parent d534719af6
commit dc4b5c7aa3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 491 additions and 130 deletions

View file

@ -7,34 +7,88 @@ import logging
logger = logging.getLogger(__name__)
# _client_api has to be external to the API stub, below.
# Otherwise, ray.remote() that contains ray.remote()
# contains a reference to the RayAPIStub, therefore a
# reference to the _client_api, and then tries to pickle
# the thing.
# About these global variables: Ray 1.0 uses exported module functions to
# provide its API, and we need to match that. However, we want different
# behaviors depending on where, exactly, in the client stack this is running.
#
# The reason for these differences depends on what's being pickled and passed
# to functions, or functions inside functions. So there are three cases to care
# about
#
# (Python Client)-->(Python ClientServer)-->(Internal Raylet Process)
#
# * _client_api should be set if we're inside the client
# * _server_api should be set if we're inside the clientserver
# * Both will be set if we're running both (as in a test)
# * Neither should be set if we're inside the raylet (but we still need to shim
# from the client API surface to the Ray API)
#
# The job of RayAPIStub (below) delegates to the appropriate one of these
# depending on what's set or not. Then, all users importing the ray object
# from this package get the stub which routes them to the appropriate APIImpl.
_client_api: Optional[APIImpl] = None
_server_api: Optional[APIImpl] = None
# The reason for _is_server is a hack around the above comment while running
# tests. If we have both a client and a server trying to control these static
# variables then we need a way to decide which to use. In this case, both
# _client_api and _server_api are set.
# This boolean flips between the two
_is_server: bool = False
@contextmanager
def stash_api_for_tests(in_test: bool):
api = None
global _is_server
is_server = _is_server
if in_test:
api = stash_api()
yield api
_is_server = True
yield _server_api
if in_test:
restore_api(api)
_is_server = is_server
def stash_api() -> Optional[APIImpl]:
def _set_client_api(val: Optional[APIImpl]):
global _client_api
a = _client_api
global _is_server
if _client_api is not None:
raise Exception("Trying to set more than one client API")
_client_api = val
_is_server = False
def _set_server_api(val: Optional[APIImpl]):
global _server_api
global _is_server
if _server_api is not None:
raise Exception("Trying to set more than one server API")
_server_api = val
_is_server = True
def reset_api():
global _client_api
global _server_api
global _is_server
_client_api = None
return a
_server_api = None
_is_server = False
def restore_api(api: Optional[APIImpl]):
def _get_client_api() -> APIImpl:
global _client_api
_client_api = api
global _server_api
global _is_server
api = None
if _is_server:
api = _server_api
else:
api = _client_api
if api is None:
# We're inside a raylet worker
from ray.experimental.client.server.core_ray_api import CoreRayAPI
return CoreRayAPI()
return api
class RayAPIStub:
@ -43,11 +97,10 @@ class RayAPIStub:
secure: bool = False,
metadata: List[Tuple[str, str]] = None,
stub=None):
global _client_api
from ray.experimental.client.worker import Worker
_client_worker = Worker(
conn_str, secure=secure, metadata=metadata, stub=stub)
_client_api = ClientAPI(_client_worker)
_set_client_api(ClientAPI(_client_worker))
def disconnect(self):
global _client_api
@ -56,15 +109,9 @@ class RayAPIStub:
_client_api = None
def __getattr__(self, key: str):
global _client_api
self.__check_client_api()
return getattr(_client_api, key)
def __check_client_api(self):
global _client_api
if _client_api is None:
from ray.experimental.client.server.core_ray_api import CoreRayAPI
_client_api = CoreRayAPI()
global _get_client_api
api = _get_client_api()
return getattr(api, key)
ray = RayAPIStub()

View file

@ -11,35 +11,105 @@
from abc import ABC
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Union
if TYPE_CHECKING:
from ray.experimental.client.common import ClientStub
from ray.experimental.client.common import ClientObjectRef
from ray._raylet import ObjectRef
# Use the imports for type checking. This is a python 3.6 limitation.
# See https://www.python.org/dev/peps/pep-0563/
PutType = Union[ClientObjectRef, ObjectRef]
class APIImpl(ABC):
"""
APIImpl is the interface to implement for whichever version of the core
Ray API that needs abstracting when run in client mode.
"""
@abstractmethod
def get(self, *args, **kwargs):
def get(self, *args, **kwargs) -> Any:
"""
get is the hook stub passed on to replace `ray.get`
Args:
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass
@abstractmethod
def put(self, *args, **kwargs):
def put(self, vals: Any, *args,
**kwargs) -> Union["ClientObjectRef", "ObjectRef"]:
"""
put is the hook stub passed on to replace `ray.put`
Args:
vals: The value or list of values to `put`.
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass
@abstractmethod
def wait(self, *args, **kwargs):
"""
wait is the hook stub passed on to replace `ray.wait`
Args:
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass
@abstractmethod
def remote(self, *args, **kwargs):
"""
remote is the hook stub passed on to replace `ray.remote`.
This sets up remote functions or actors, as the decorator,
but does not execute them.
Args:
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass
@abstractmethod
def call_remote(self, f, kind, *args, **kwargs):
def call_remote(self, instance: "ClientStub", *args, **kwargs):
"""
call_remote is called by stub objects to execute them remotely.
This is used by stub objects in situations where they're called
with .remote, eg, `f.remote()` or `actor_cls.remote()`.
This allows the client stub objects to delegate execution to be
implemented in the most effective way whether it's in the client,
clientserver, or raylet worker.
Args:
instance: The Client-side stub reference to a remote object
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass
@abstractmethod
def close(self, *args, **kwargs):
def close(self) -> None:
"""
close cleans up an API connection by closing any channels or
shutting down any servers gracefully.
"""
pass
class ClientAPI(APIImpl):
"""
The Client-side methods corresponding to the ray API. Delegates
to the Client Worker that contains the connection to the ClientServer.
"""
def __init__(self, worker):
self.worker = worker
@ -55,10 +125,10 @@ class ClientAPI(APIImpl):
def remote(self, *args, **kwargs):
return self.worker.remote(*args, **kwargs)
def call_remote(self, f, kind, *args, **kwargs):
return self.worker.call_remote(f, kind, *args, **kwargs)
def call_remote(self, instance: "ClientStub", *args, **kwargs):
return self.worker.call_remote(instance, *args, **kwargs)
def close(self, *args, **kwargs):
def close(self) -> None:
return self.worker.close()
def __getattr__(self, key: str):

View file

@ -1,6 +1,7 @@
import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray.experimental.client import ray
from typing import Any
from typing import Dict
from ray import cloudpickle
@ -17,6 +18,9 @@ class ClientBaseRef:
def __eq__(self, other):
return self.id == other.id
def binary(self):
return self.id
class ClientObjectRef(ClientBaseRef):
pass
@ -26,74 +30,222 @@ class ClientActorRef(ClientBaseRef):
pass
class ClientRemoteFunc:
class ClientStub:
pass
class ClientRemoteFunc(ClientStub):
"""
A stub created on the Ray Client to represent a remote
function that can be exectued on the cluster.
This class is allowed to be passed around between remote functions.
Args:
_func: The actual function to execute remotely
_name: The original name of the function
_ref: The ClientObjectRef of the pickled code of the function, _func
_raylet_remote: The Raylet-side ray.remote_function.RemoteFunction
for this object
"""
def __init__(self, f):
self._func = f
self._name = f.__name__
self.id = None
self._raylet_remote_func = None
# self._ref can be lazily instantiated. Rather than eagerly creating
# function data objects in the server we can put them just before we
# execute the function, especially in cases where many @ray.remote
# functions exist in a library and only a handful are ever executed by
# a user of the library.
#
# TODO(barakmich): This ref might actually be better as a serialized
# ObjectRef. This requires being able to serialize the ref without
# pinning it (as the lifetime of the ref is tied with the server, not
# the client)
self._ref = None
self._raylet_remote = None
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote function cannot be called directly. "
"Use {self._name}.remote method instead")
def remote(self, *args, **kwargs):
return ray.call_remote(self, ray_client_pb2.ClientTask.FUNCTION, *args,
**kwargs)
return ray.call_remote(self, *args, **kwargs)
def _get_ray_remote_impl(self):
if self._raylet_remote is None:
self._raylet_remote = ray.remote(self._func)
return self._raylet_remote
def __repr__(self):
return "ClientRemoteFunc(%s, %s)" % (self._name, self.id)
return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
if self._ref is None:
self._ref = ray.put(self._func)
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.FUNCTION
task.name = self._name
task.payload_id = self._ref.id
return task
class ClientActorClass:
class ClientActorClass(ClientStub):
""" A stub created on the Ray Client to represent an actor class.
It is wrapped by ray.remote and can be executed on the cluster.
Args:
actor_cls: The actual class to execute remotely
_name: The original name of the class
_ref: The ClientObjectRef of the pickled `actor_cls`
_raylet_remote: The Raylet-side ray.ActorClass for this object
"""
def __init__(self, actor_cls):
self.actor_cls = actor_cls
self._name = actor_cls.__name__
self._ref = None
self._raylet_remote = None
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote actor cannot be instantiated directly. "
"Use {self._name}.remote() instead")
def __getstate__(self) -> Dict:
state = {
"actor_cls": self.actor_cls,
"_name": self._name,
"_ref": self._ref,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_cls = state["actor_cls"]
self._name = state["_name"]
self._ref = state["_ref"]
def remote(self, *args, **kwargs):
# Actually instantiate the actor
ref = ray.call_remote(self, ray_client_pb2.ClientTask.ACTOR, *args,
**kwargs)
return ClientActorHandle(ref, self)
ref = ray.call_remote(self, *args, **kwargs)
return ClientActorHandle(ClientActorRef(ref.id), self)
def __repr__(self):
return "ClientRemoteActor(%s, %s)" % (self._name, self.id)
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
def __getattr__(self, key):
if key not in self.__dict__:
raise AttributeError("Not a class attribute")
raise NotImplementedError("static methods")
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
if self._ref is None:
self._ref = ray.put(self.actor_cls)
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.ACTOR
task.name = self._name
task.payload_id = self._ref.id
return task
class ClientActorHandle:
def __init__(self, actor_id: ClientActorRef,
class ClientActorHandle(ClientStub):
"""Client-side stub for instantiated actor.
A stub created on the Ray Client to represent a remote actor that
has been started on the cluster. This class is allowed to be passed
around between remote functions.
Args:
actor_ref: A reference to the running actor given to the client. This
is a serialized version of the actual handle as an opaque token.
actor_class: A reference to the ClientActorClass that this actor was
instantiated from.
_real_actor_handle: Cached copy of the Raylet-side
ray.actor.ActorHandle contained in the actor_id ref.
"""
def __init__(self, actor_ref: ClientActorRef,
actor_class: ClientActorClass):
self.actor_id = actor_id
self.actor_ref = actor_ref
self.actor_class = actor_class
self._real_actor_handle = None
def _get_ray_remote_impl(self):
if self._real_actor_handle is None:
self._real_actor_handle = cloudpickle.loads(self.actor_ref.id)
return self._real_actor_handle
def __getstate__(self) -> Dict:
state = {
"actor_ref": self.actor_ref,
"actor_class": self.actor_class,
"_real_actor_handle": self._real_actor_handle,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_ref = state["actor_ref"]
self.actor_class = state["actor_class"]
self._real_actor_handle = state["_real_actor_handle"]
def __getattr__(self, key):
return ClientRemoteMethod(self, key)
def __repr__(self):
return "ClientActorHandle(%s)" % (self.actor_ref.id.hex())
class ClientRemoteMethod(ClientStub):
"""A stub for a method on a remote actor.
Can be annotated with exection options.
Args:
actor_handle: A reference to the ClientActorHandle that generated
this method and will have this method called upon it.
method_name: The name of this method
"""
class ClientRemoteMethod:
def __init__(self, actor_handle: ClientActorHandle, method_name: str):
self.actor_handle = actor_handle
self.method_name = method_name
self._name = "%s.%s" % (self.actor_handle.actor_class._name,
self.method_name)
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote method cannot be called directly. "
"Use {self._name}.remote() instead")
def _get_ray_remote_impl(self):
return getattr(self.actor_handle._get_ray_remote_impl(),
self.method_name)
def __getstate__(self) -> Dict:
state = {
"actor_handle": self.actor_handle,
"method_name": self.method_name,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_handle = state["actor_handle"]
self.method_name = state["method_name"]
def remote(self, *args, **kwargs):
return ray.call_remote(self, ray_client_pb2.ClientTask.METHOD, *args,
**kwargs)
return ray.call_remote(self, *args, **kwargs)
def __repr__(self):
return "ClientRemoteMethod(%s, %s)" % (self._name, self.actor_id)
name = "%s.%s" % (self.actor_handle.actor_class._name,
self.method_name)
return "ClientRemoteMethod(%s, %s)" % (name,
self.actor_handle.actor_id)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.METHOD
task.name = self.method_name
task.payload_id = self.actor_handle.actor_ref.id
return task
def convert_from_arg(pb) -> Any:

View file

@ -7,18 +7,29 @@
# While the stub is trivial, it allows us to check that the calls we're
# making into the core-ray module are contained and well-defined.
from typing import Any
from typing import Union
import ray
from ray.experimental.client.api import APIImpl
from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientStub
class CoreRayAPI(APIImpl):
"""
Implements the equivalent client-side Ray API by simply passing along to
the Core Ray API. Primarily used inside of Ray Workers as a trampoline back
to core ray when passed client stubs.
"""
def get(self, *args, **kwargs):
return ray.get(*args, **kwargs)
def put(self, *args, **kwargs):
return ray.put(*args, **kwargs)
def put(self, vals: Any, *args,
**kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]:
return ray.put(vals, *args, **kwargs)
def wait(self, *args, **kwargs):
return ray.wait(*args, **kwargs)
@ -26,12 +37,10 @@ class CoreRayAPI(APIImpl):
def remote(self, *args, **kwargs):
return ray.remote(*args, **kwargs)
def call_remote(self, f: ClientRemoteFunc, kind: int, *args, **kwargs):
if f._raylet_remote_func is None:
f._raylet_remote_func = ray.remote(f._func)
return f._raylet_remote_func.remote(*args, **kwargs)
def call_remote(self, instance: ClientStub, *args, **kwargs):
return instance._get_ray_remote_impl().remote(*args, **kwargs)
def close(self, *args, **kwargs):
def close(self) -> None:
return None
# Allow for generic fallback to ray.* in remote methods. This allows calls
@ -39,3 +48,38 @@ class CoreRayAPI(APIImpl):
# doesn't currently support them.
def __getattr__(self, key: str):
return getattr(ray, key)
class RayServerAPI(CoreRayAPI):
"""
Ray Client server-side API shim. By default, simply calls the default Core
Ray API calls, but also accepts scheduling calls from functions running
inside of other remote functions that need to create more work.
"""
def __init__(self, server_instance):
self.server = server_instance
# Wrap single item into list if needed before calling server put.
def put(self, vals: Any, *args, **kwargs) -> ClientObjectRef:
to_put = []
single = False
if isinstance(vals, list):
to_put = vals
else:
single = True
to_put.append(vals)
out = [self._put(x) for x in to_put]
if single:
out = out[0]
return out
def _put(self, val: Any):
resp = self.server._put_and_retain_obj(val)
return ClientObjectRef(resp.id)
def call_remote(self, instance: ClientStub, *args, **kwargs):
task = instance._prepare_client_task()
ticket = self.server.Schedule(task, prepared_args=args)
return ClientObjectRef(ticket.return_id)

View file

@ -7,10 +7,10 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import time
import inspect
from ray.experimental.client import stash_api_for_tests
from ray.experimental.client import stash_api_for_tests, _set_server_api
from ray.experimental.client.common import convert_from_arg
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.server.core_ray_api import RayServerAPI
logger = logging.getLogger(__name__)
@ -32,12 +32,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
item_ser = cloudpickle.dumps(item)
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
def PutObject(self, request, context=None):
def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse:
obj = cloudpickle.loads(request.data)
objectref = self._put_and_retain_obj(obj)
return ray_client_pb2.PutResponse(id=objectref.binary())
def _put_and_retain_obj(self, obj) -> ray.ObjectRef:
objectref = ray.put(obj)
self.object_refs[objectref.binary()] = objectref
logger.info("put: %s" % objectref)
return ray_client_pb2.PutResponse(id=objectref.binary())
return objectref
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
object_refs = [cloudpickle.loads(o) for o in request.object_refs]
@ -70,70 +74,83 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
ready_object_ids=ready_object_ids,
remaining_object_ids=remaining_object_ids)
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
def Schedule(self, task, context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
logger.info("schedule: %s %s" %
(task.name,
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
if task.type == ray_client_pb2.ClientTask.FUNCTION:
return self._schedule_function(task, context)
return self._schedule_function(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.ACTOR:
return self._schedule_actor(task, context)
return self._schedule_actor(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.METHOD:
return self._schedule_method(task, context)
return self._schedule_method(task, context, prepared_args)
else:
raise NotImplementedError(
"Unimplemented Schedule task type: %s" %
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
def _schedule_method(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
def _schedule_method(
self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
actor_handle = self.actor_refs.get(task.payload_id)
if actor_handle is None:
raise Exception(
"Can't run an actor the server doesn't have a handle for")
arglist = _convert_args(task.args)
arglist = _convert_args(task.args, prepared_args)
with stash_api_for_tests(self._test_mode):
output = getattr(actor_handle, task.name).remote(*arglist)
self.object_refs[output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
def _schedule_actor(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
def _schedule_actor(self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
with stash_api_for_tests(self._test_mode):
if task.payload_id not in self.registered_actor_classes:
actor_class_ref = self.object_refs[task.payload_id]
actor_class = ray.get(actor_class_ref)
if not inspect.isclass(actor_class):
raise Exception("Attempting to schedule actor that "
"isn't a ClientActorClass.")
"isn't a class.")
reg_class = ray.remote(actor_class)
self.registered_actor_classes[task.payload_id] = reg_class
remote_class = self.registered_actor_classes[task.payload_id]
arglist = _convert_args(task.args)
arglist = _convert_args(task.args, prepared_args)
actor = remote_class.remote(*arglist)
actor_ref = actor._actor_id
self.actor_refs[actor_ref.binary()] = actor
return ray_client_pb2.ClientTaskTicket(return_id=actor_ref.binary())
actorhandle = cloudpickle.dumps(actor)
self.actor_refs[actorhandle] = actor
return ray_client_pb2.ClientTaskTicket(return_id=actorhandle)
def _schedule_function(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
def _schedule_function(
self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
if task.payload_id not in self.function_refs:
funcref = self.object_refs[task.payload_id]
func = ray.get(funcref)
if not isinstance(func, ClientRemoteFunc):
if not inspect.isfunction(func):
raise Exception("Attempting to schedule function that "
"isn't a ClientRemoteFunc.")
self.function_refs[task.payload_id] = func
"isn't a function.")
self.function_refs[task.payload_id] = ray.remote(func)
remote_func = self.function_refs[task.payload_id]
arglist = _convert_args(task.args)
arglist = _convert_args(task.args, prepared_args)
# Prepare call if we're in a test
with stash_api_for_tests(self._test_mode):
output = remote_func.remote(*arglist)
if output.binary() in self.object_refs:
raise Exception("already found it")
self.object_refs[output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
def _convert_args(arg_list):
def _convert_args(arg_list, prepared_args=None):
if prepared_args is not None:
return prepared_args
out = []
for arg in arg_list:
t = convert_from_arg(arg)
@ -147,6 +164,7 @@ def _convert_args(arg_list):
def serve(connection_str, test_mode=False):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
task_servicer = RayletServicer(test_mode=test_mode)
_set_server_api(RayServerAPI(task_servicer))
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
task_servicer, server)
server.add_insecure_port(connection_str)

View file

@ -3,6 +3,7 @@ It implements the Ray API functions that are forwarded through grpc calls
to the server.
"""
import inspect
import logging
from typing import List
from typing import Tuple
@ -14,11 +15,11 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.experimental.client.common import convert_to_arg
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientActorRef
from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.common import ClientRemoteMethod
from ray.experimental.client.common import ClientRemoteFunc
logger = logging.getLogger(__name__)
class Worker:
def __init__(self,
@ -130,50 +131,14 @@ class Worker:
raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.")
def call_remote(self, instance, kind, *args, **kwargs):
ticket = None
if kind == ray_client_pb2.ClientTask.FUNCTION:
ticket = self._put_and_schedule(instance, kind, *args, **kwargs)
elif kind == ray_client_pb2.ClientTask.ACTOR:
ticket = self._put_and_schedule(instance, kind, *args, **kwargs)
return ClientActorRef(ticket.return_id)
elif kind == ray_client_pb2.ClientTask.METHOD:
ticket = self._call_method(instance, *args, **kwargs)
if ticket is None:
raise Exception(
"Couldn't call_remote on %s for type %s" % (instance, kind))
def call_remote(self, instance, *args, **kwargs):
task = instance._prepare_client_task()
for arg in args:
pb_arg = convert_to_arg(arg)
task.args.append(pb_arg)
logging.debug("Scheduling %s" % task)
ticket = self.server.Schedule(task, metadata=self.metadata)
return ClientObjectRef(ticket.return_id)
def _call_method(self, instance: ClientRemoteMethod, *args, **kwargs):
if not isinstance(instance, ClientRemoteMethod):
raise TypeError("Client not passing a ClientRemoteMethod stub")
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.METHOD
task.name = instance.method_name
task.payload_id = instance.actor_handle.actor_id.id
for arg in args:
pb_arg = convert_to_arg(arg)
task.args.append(pb_arg)
ticket = self.server.Schedule(task, metadata=self.metadata)
return ticket
def _put_and_schedule(self, item, task_type, *args, **kwargs):
if isinstance(item, ClientRemoteFunc):
ref = self._put(item)
elif isinstance(item, ClientActorClass):
ref = self._put(item.actor_cls)
else:
raise TypeError("Client not passing a ClientRemoteFunc stub")
task = ray_client_pb2.ClientTask()
task.type = task_type
task.name = item._name
task.payload_id = ref.id
for arg in args:
pb_arg = convert_to_arg(arg)
task.args.append(pb_arg)
ticket = self.server.Schedule(task, metadata=self.metadata)
return ticket
def close(self):
self.channel.close()

View file

@ -2,7 +2,7 @@ import pytest
from contextlib import contextmanager
import ray.experimental.client.server.server as ray_client_server
from ray.experimental.client import ray
from ray.experimental.client import ray, reset_api
from ray.experimental.client.common import ClientObjectRef
@ -13,6 +13,7 @@ def ray_start_client_server():
yield ray
ray.disconnect()
server.stop(0)
reset_api()
def test_real_ray_fallback(ray_start_regular_shared):
@ -170,6 +171,70 @@ def test_basic_actor(ray_start_regular_shared):
assert count == 2
def test_pass_handles(ray_start_regular_shared):
"""
Test that passing client handles to actors and functions to remote actors
in functions (on the server or raylet side) works transparently to the
caller.
"""
with ray_start_client_server() as ray:
@ray.remote
class ExecActor:
def exec(self, f, x):
return ray.get(f.remote(x))
def exec_exec(self, actor, f, x):
return ray.get(actor.exec.remote(f, x))
@ray.remote
def fact(x):
out = 1
while x > 0:
out = out * x
x -= 1
return out
@ray.remote
def func_exec(f, x):
return ray.get(f.remote(x))
@ray.remote
def func_actor_exec(actor, f, x):
return ray.get(actor.exec.remote(f, x))
@ray.remote
def sneaky_func_exec(obj, x):
return ray.get(obj["f"].remote(x))
@ray.remote
def sneaky_actor_exec(obj, x):
return ray.get(obj["actor"].exec.remote(obj["f"], x))
def local_fact(x):
if x <= 0:
return 1
return x * local_fact(x - 1)
assert ray.get(fact.remote(7)) == local_fact(7)
assert ray.get(func_exec.remote(fact, 8)) == local_fact(8)
test_obj = {}
test_obj["f"] = fact
assert ray.get(sneaky_func_exec.remote(test_obj, 5)) == local_fact(5)
actor_handle = ExecActor.remote()
assert ray.get(actor_handle.exec.remote(fact, 7)) == local_fact(7)
assert ray.get(func_actor_exec.remote(actor_handle, fact,
10)) == local_fact(10)
second_actor = ExecActor.remote()
assert ray.get(actor_handle.exec_exec.remote(second_actor, fact,
9)) == local_fact(9)
test_actor_obj = {}
test_actor_obj["actor"] = second_actor
test_actor_obj["f"] = fact
assert ray.get(sneaky_actor_exec.remote(test_actor_obj,
4)) == local_fact(4)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))