[ray_client] Passing actors to actors (#12585)

* start building tests around passing handles to handles

Change-Id: Ie8c3de5c8ce789c3ec8d29f0702df80ba598279f

* clean up the switch statements by moving to a method, implement state tranfer, extend test

Change-Id: Ie7b6493db3a6c203d3a0b262b8fbacb90e5cdbc5

* passing

Change-Id: Id88dc0a41da1c9d5ba68f754c5b57141aae47beb

* flush out tests

Change-Id: If77c0f586e9e99449d494be4e85f854e4a7a4952

* formatting

Change-Id: I497c07cee70b52453b221ed4393f04f6f560061e

* fix python3.6 and other attributes

Change-Id: I5a2c5231e8a021184d9dfc3e346df7f71fc93257

* address documentation

Change-Id: I049d841ed1f85b7350c17c05da4a4d81d5cb03df

* formatting

Change-Id: I6a2b32a2466ffc9f03fc91ac17901b9c1a49505c

* use the pickled handle as the id bytes for actors

Change-Id: I9ddcb41d614de65d42d6f0382fe0faa7ad2c2ade

* pydoc

Change-Id: I9b32a0f383d5ff5ac052e61929b7ae3e42a89fc5

* format

Change-Id: Iac0010bb990a4025a98139ab88700030b2e9e7f5

* todos

Change-Id: I7b550800cf7499403e8a17b77484bc46f20f0afc

* tests

Change-Id: If8ebf6a335baeb113c1332acc930c41a6b4f5384

* fix lint

Change-Id: I019f41e0ec341d39bbbbd39aa43d9fb5f8b57cf0

* nits

Change-Id: I2e6813d8db34f4ce008326faa095d414c10eee95

* add some tricky, python3.6-troublesome type checking

Change-Id: Ib887fc943a6e7084002bc13dfbe113b69b4d9317
This commit is contained in:
Barak Michener 2020-12-08 21:54:55 -08:00 committed by GitHub
parent d534719af6
commit dc4b5c7aa3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 491 additions and 130 deletions

View file

@ -7,34 +7,88 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# _client_api has to be external to the API stub, below. # About these global variables: Ray 1.0 uses exported module functions to
# Otherwise, ray.remote() that contains ray.remote() # provide its API, and we need to match that. However, we want different
# contains a reference to the RayAPIStub, therefore a # behaviors depending on where, exactly, in the client stack this is running.
# reference to the _client_api, and then tries to pickle #
# the thing. # The reason for these differences depends on what's being pickled and passed
# to functions, or functions inside functions. So there are three cases to care
# about
#
# (Python Client)-->(Python ClientServer)-->(Internal Raylet Process)
#
# * _client_api should be set if we're inside the client
# * _server_api should be set if we're inside the clientserver
# * Both will be set if we're running both (as in a test)
# * Neither should be set if we're inside the raylet (but we still need to shim
# from the client API surface to the Ray API)
#
# The job of RayAPIStub (below) delegates to the appropriate one of these
# depending on what's set or not. Then, all users importing the ray object
# from this package get the stub which routes them to the appropriate APIImpl.
_client_api: Optional[APIImpl] = None _client_api: Optional[APIImpl] = None
_server_api: Optional[APIImpl] = None
# The reason for _is_server is a hack around the above comment while running
# tests. If we have both a client and a server trying to control these static
# variables then we need a way to decide which to use. In this case, both
# _client_api and _server_api are set.
# This boolean flips between the two
_is_server: bool = False
@contextmanager @contextmanager
def stash_api_for_tests(in_test: bool): def stash_api_for_tests(in_test: bool):
api = None global _is_server
is_server = _is_server
if in_test: if in_test:
api = stash_api() _is_server = True
yield api yield _server_api
if in_test: if in_test:
restore_api(api) _is_server = is_server
def stash_api() -> Optional[APIImpl]: def _set_client_api(val: Optional[APIImpl]):
global _client_api global _client_api
a = _client_api global _is_server
if _client_api is not None:
raise Exception("Trying to set more than one client API")
_client_api = val
_is_server = False
def _set_server_api(val: Optional[APIImpl]):
global _server_api
global _is_server
if _server_api is not None:
raise Exception("Trying to set more than one server API")
_server_api = val
_is_server = True
def reset_api():
global _client_api
global _server_api
global _is_server
_client_api = None _client_api = None
return a _server_api = None
_is_server = False
def restore_api(api: Optional[APIImpl]): def _get_client_api() -> APIImpl:
global _client_api global _client_api
_client_api = api global _server_api
global _is_server
api = None
if _is_server:
api = _server_api
else:
api = _client_api
if api is None:
# We're inside a raylet worker
from ray.experimental.client.server.core_ray_api import CoreRayAPI
return CoreRayAPI()
return api
class RayAPIStub: class RayAPIStub:
@ -43,11 +97,10 @@ class RayAPIStub:
secure: bool = False, secure: bool = False,
metadata: List[Tuple[str, str]] = None, metadata: List[Tuple[str, str]] = None,
stub=None): stub=None):
global _client_api
from ray.experimental.client.worker import Worker from ray.experimental.client.worker import Worker
_client_worker = Worker( _client_worker = Worker(
conn_str, secure=secure, metadata=metadata, stub=stub) conn_str, secure=secure, metadata=metadata, stub=stub)
_client_api = ClientAPI(_client_worker) _set_client_api(ClientAPI(_client_worker))
def disconnect(self): def disconnect(self):
global _client_api global _client_api
@ -56,15 +109,9 @@ class RayAPIStub:
_client_api = None _client_api = None
def __getattr__(self, key: str): def __getattr__(self, key: str):
global _client_api global _get_client_api
self.__check_client_api() api = _get_client_api()
return getattr(_client_api, key) return getattr(api, key)
def __check_client_api(self):
global _client_api
if _client_api is None:
from ray.experimental.client.server.core_ray_api import CoreRayAPI
_client_api = CoreRayAPI()
ray = RayAPIStub() ray = RayAPIStub()

View file

@ -11,35 +11,105 @@
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Union
if TYPE_CHECKING:
from ray.experimental.client.common import ClientStub
from ray.experimental.client.common import ClientObjectRef
from ray._raylet import ObjectRef
# Use the imports for type checking. This is a python 3.6 limitation.
# See https://www.python.org/dev/peps/pep-0563/
PutType = Union[ClientObjectRef, ObjectRef]
class APIImpl(ABC): class APIImpl(ABC):
"""
APIImpl is the interface to implement for whichever version of the core
Ray API that needs abstracting when run in client mode.
"""
@abstractmethod @abstractmethod
def get(self, *args, **kwargs): def get(self, *args, **kwargs) -> Any:
"""
get is the hook stub passed on to replace `ray.get`
Args:
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass pass
@abstractmethod @abstractmethod
def put(self, *args, **kwargs): def put(self, vals: Any, *args,
**kwargs) -> Union["ClientObjectRef", "ObjectRef"]:
"""
put is the hook stub passed on to replace `ray.put`
Args:
vals: The value or list of values to `put`.
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass pass
@abstractmethod @abstractmethod
def wait(self, *args, **kwargs): def wait(self, *args, **kwargs):
"""
wait is the hook stub passed on to replace `ray.wait`
Args:
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass pass
@abstractmethod @abstractmethod
def remote(self, *args, **kwargs): def remote(self, *args, **kwargs):
"""
remote is the hook stub passed on to replace `ray.remote`.
This sets up remote functions or actors, as the decorator,
but does not execute them.
Args:
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass pass
@abstractmethod @abstractmethod
def call_remote(self, f, kind, *args, **kwargs): def call_remote(self, instance: "ClientStub", *args, **kwargs):
"""
call_remote is called by stub objects to execute them remotely.
This is used by stub objects in situations where they're called
with .remote, eg, `f.remote()` or `actor_cls.remote()`.
This allows the client stub objects to delegate execution to be
implemented in the most effective way whether it's in the client,
clientserver, or raylet worker.
Args:
instance: The Client-side stub reference to a remote object
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass pass
@abstractmethod @abstractmethod
def close(self, *args, **kwargs): def close(self) -> None:
"""
close cleans up an API connection by closing any channels or
shutting down any servers gracefully.
"""
pass pass
class ClientAPI(APIImpl): class ClientAPI(APIImpl):
"""
The Client-side methods corresponding to the ray API. Delegates
to the Client Worker that contains the connection to the ClientServer.
"""
def __init__(self, worker): def __init__(self, worker):
self.worker = worker self.worker = worker
@ -55,10 +125,10 @@ class ClientAPI(APIImpl):
def remote(self, *args, **kwargs): def remote(self, *args, **kwargs):
return self.worker.remote(*args, **kwargs) return self.worker.remote(*args, **kwargs)
def call_remote(self, f, kind, *args, **kwargs): def call_remote(self, instance: "ClientStub", *args, **kwargs):
return self.worker.call_remote(f, kind, *args, **kwargs) return self.worker.call_remote(instance, *args, **kwargs)
def close(self, *args, **kwargs): def close(self) -> None:
return self.worker.close() return self.worker.close()
def __getattr__(self, key: str): def __getattr__(self, key: str):

View file

@ -1,6 +1,7 @@
import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray.experimental.client import ray from ray.experimental.client import ray
from typing import Any from typing import Any
from typing import Dict
from ray import cloudpickle from ray import cloudpickle
@ -17,6 +18,9 @@ class ClientBaseRef:
def __eq__(self, other): def __eq__(self, other):
return self.id == other.id return self.id == other.id
def binary(self):
return self.id
class ClientObjectRef(ClientBaseRef): class ClientObjectRef(ClientBaseRef):
pass pass
@ -26,74 +30,222 @@ class ClientActorRef(ClientBaseRef):
pass pass
class ClientRemoteFunc: class ClientStub:
pass
class ClientRemoteFunc(ClientStub):
"""
A stub created on the Ray Client to represent a remote
function that can be exectued on the cluster.
This class is allowed to be passed around between remote functions.
Args:
_func: The actual function to execute remotely
_name: The original name of the function
_ref: The ClientObjectRef of the pickled code of the function, _func
_raylet_remote: The Raylet-side ray.remote_function.RemoteFunction
for this object
"""
def __init__(self, f): def __init__(self, f):
self._func = f self._func = f
self._name = f.__name__ self._name = f.__name__
self.id = None self.id = None
self._raylet_remote_func = None
# self._ref can be lazily instantiated. Rather than eagerly creating
# function data objects in the server we can put them just before we
# execute the function, especially in cases where many @ray.remote
# functions exist in a library and only a handful are ever executed by
# a user of the library.
#
# TODO(barakmich): This ref might actually be better as a serialized
# ObjectRef. This requires being able to serialize the ref without
# pinning it (as the lifetime of the ref is tied with the server, not
# the client)
self._ref = None
self._raylet_remote = None
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
raise TypeError(f"Remote function cannot be called directly. " raise TypeError(f"Remote function cannot be called directly. "
"Use {self._name}.remote method instead") "Use {self._name}.remote method instead")
def remote(self, *args, **kwargs): def remote(self, *args, **kwargs):
return ray.call_remote(self, ray_client_pb2.ClientTask.FUNCTION, *args, return ray.call_remote(self, *args, **kwargs)
**kwargs)
def _get_ray_remote_impl(self):
if self._raylet_remote is None:
self._raylet_remote = ray.remote(self._func)
return self._raylet_remote
def __repr__(self): def __repr__(self):
return "ClientRemoteFunc(%s, %s)" % (self._name, self.id) return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
if self._ref is None:
self._ref = ray.put(self._func)
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.FUNCTION
task.name = self._name
task.payload_id = self._ref.id
return task
class ClientActorClass: class ClientActorClass(ClientStub):
""" A stub created on the Ray Client to represent an actor class.
It is wrapped by ray.remote and can be executed on the cluster.
Args:
actor_cls: The actual class to execute remotely
_name: The original name of the class
_ref: The ClientObjectRef of the pickled `actor_cls`
_raylet_remote: The Raylet-side ray.ActorClass for this object
"""
def __init__(self, actor_cls): def __init__(self, actor_cls):
self.actor_cls = actor_cls self.actor_cls = actor_cls
self._name = actor_cls.__name__ self._name = actor_cls.__name__
self._ref = None
self._raylet_remote = None
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
raise TypeError(f"Remote actor cannot be instantiated directly. " raise TypeError(f"Remote actor cannot be instantiated directly. "
"Use {self._name}.remote() instead") "Use {self._name}.remote() instead")
def __getstate__(self) -> Dict:
state = {
"actor_cls": self.actor_cls,
"_name": self._name,
"_ref": self._ref,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_cls = state["actor_cls"]
self._name = state["_name"]
self._ref = state["_ref"]
def remote(self, *args, **kwargs): def remote(self, *args, **kwargs):
# Actually instantiate the actor # Actually instantiate the actor
ref = ray.call_remote(self, ray_client_pb2.ClientTask.ACTOR, *args, ref = ray.call_remote(self, *args, **kwargs)
**kwargs) return ClientActorHandle(ClientActorRef(ref.id), self)
return ClientActorHandle(ref, self)
def __repr__(self): def __repr__(self):
return "ClientRemoteActor(%s, %s)" % (self._name, self.id) return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
def __getattr__(self, key): def __getattr__(self, key):
if key not in self.__dict__:
raise AttributeError("Not a class attribute")
raise NotImplementedError("static methods") raise NotImplementedError("static methods")
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
if self._ref is None:
self._ref = ray.put(self.actor_cls)
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.ACTOR
task.name = self._name
task.payload_id = self._ref.id
return task
class ClientActorHandle:
def __init__(self, actor_id: ClientActorRef, class ClientActorHandle(ClientStub):
"""Client-side stub for instantiated actor.
A stub created on the Ray Client to represent a remote actor that
has been started on the cluster. This class is allowed to be passed
around between remote functions.
Args:
actor_ref: A reference to the running actor given to the client. This
is a serialized version of the actual handle as an opaque token.
actor_class: A reference to the ClientActorClass that this actor was
instantiated from.
_real_actor_handle: Cached copy of the Raylet-side
ray.actor.ActorHandle contained in the actor_id ref.
"""
def __init__(self, actor_ref: ClientActorRef,
actor_class: ClientActorClass): actor_class: ClientActorClass):
self.actor_id = actor_id self.actor_ref = actor_ref
self.actor_class = actor_class self.actor_class = actor_class
self._real_actor_handle = None
def _get_ray_remote_impl(self):
if self._real_actor_handle is None:
self._real_actor_handle = cloudpickle.loads(self.actor_ref.id)
return self._real_actor_handle
def __getstate__(self) -> Dict:
state = {
"actor_ref": self.actor_ref,
"actor_class": self.actor_class,
"_real_actor_handle": self._real_actor_handle,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_ref = state["actor_ref"]
self.actor_class = state["actor_class"]
self._real_actor_handle = state["_real_actor_handle"]
def __getattr__(self, key): def __getattr__(self, key):
return ClientRemoteMethod(self, key) return ClientRemoteMethod(self, key)
def __repr__(self):
return "ClientActorHandle(%s)" % (self.actor_ref.id.hex())
class ClientRemoteMethod(ClientStub):
"""A stub for a method on a remote actor.
Can be annotated with exection options.
Args:
actor_handle: A reference to the ClientActorHandle that generated
this method and will have this method called upon it.
method_name: The name of this method
"""
class ClientRemoteMethod:
def __init__(self, actor_handle: ClientActorHandle, method_name: str): def __init__(self, actor_handle: ClientActorHandle, method_name: str):
self.actor_handle = actor_handle self.actor_handle = actor_handle
self.method_name = method_name self.method_name = method_name
self._name = "%s.%s" % (self.actor_handle.actor_class._name,
self.method_name)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
raise TypeError(f"Remote method cannot be called directly. " raise TypeError(f"Remote method cannot be called directly. "
"Use {self._name}.remote() instead") "Use {self._name}.remote() instead")
def _get_ray_remote_impl(self):
return getattr(self.actor_handle._get_ray_remote_impl(),
self.method_name)
def __getstate__(self) -> Dict:
state = {
"actor_handle": self.actor_handle,
"method_name": self.method_name,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_handle = state["actor_handle"]
self.method_name = state["method_name"]
def remote(self, *args, **kwargs): def remote(self, *args, **kwargs):
return ray.call_remote(self, ray_client_pb2.ClientTask.METHOD, *args, return ray.call_remote(self, *args, **kwargs)
**kwargs)
def __repr__(self): def __repr__(self):
return "ClientRemoteMethod(%s, %s)" % (self._name, self.actor_id) name = "%s.%s" % (self.actor_handle.actor_class._name,
self.method_name)
return "ClientRemoteMethod(%s, %s)" % (name,
self.actor_handle.actor_id)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.METHOD
task.name = self.method_name
task.payload_id = self.actor_handle.actor_ref.id
return task
def convert_from_arg(pb) -> Any: def convert_from_arg(pb) -> Any:

View file

@ -7,18 +7,29 @@
# While the stub is trivial, it allows us to check that the calls we're # While the stub is trivial, it allows us to check that the calls we're
# making into the core-ray module are contained and well-defined. # making into the core-ray module are contained and well-defined.
from typing import Any
from typing import Union
import ray import ray
from ray.experimental.client.api import APIImpl from ray.experimental.client.api import APIImpl
from ray.experimental.client.common import ClientRemoteFunc from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientStub
class CoreRayAPI(APIImpl): class CoreRayAPI(APIImpl):
"""
Implements the equivalent client-side Ray API by simply passing along to
the Core Ray API. Primarily used inside of Ray Workers as a trampoline back
to core ray when passed client stubs.
"""
def get(self, *args, **kwargs): def get(self, *args, **kwargs):
return ray.get(*args, **kwargs) return ray.get(*args, **kwargs)
def put(self, *args, **kwargs): def put(self, vals: Any, *args,
return ray.put(*args, **kwargs) **kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]:
return ray.put(vals, *args, **kwargs)
def wait(self, *args, **kwargs): def wait(self, *args, **kwargs):
return ray.wait(*args, **kwargs) return ray.wait(*args, **kwargs)
@ -26,12 +37,10 @@ class CoreRayAPI(APIImpl):
def remote(self, *args, **kwargs): def remote(self, *args, **kwargs):
return ray.remote(*args, **kwargs) return ray.remote(*args, **kwargs)
def call_remote(self, f: ClientRemoteFunc, kind: int, *args, **kwargs): def call_remote(self, instance: ClientStub, *args, **kwargs):
if f._raylet_remote_func is None: return instance._get_ray_remote_impl().remote(*args, **kwargs)
f._raylet_remote_func = ray.remote(f._func)
return f._raylet_remote_func.remote(*args, **kwargs)
def close(self, *args, **kwargs): def close(self) -> None:
return None return None
# Allow for generic fallback to ray.* in remote methods. This allows calls # Allow for generic fallback to ray.* in remote methods. This allows calls
@ -39,3 +48,38 @@ class CoreRayAPI(APIImpl):
# doesn't currently support them. # doesn't currently support them.
def __getattr__(self, key: str): def __getattr__(self, key: str):
return getattr(ray, key) return getattr(ray, key)
class RayServerAPI(CoreRayAPI):
"""
Ray Client server-side API shim. By default, simply calls the default Core
Ray API calls, but also accepts scheduling calls from functions running
inside of other remote functions that need to create more work.
"""
def __init__(self, server_instance):
self.server = server_instance
# Wrap single item into list if needed before calling server put.
def put(self, vals: Any, *args, **kwargs) -> ClientObjectRef:
to_put = []
single = False
if isinstance(vals, list):
to_put = vals
else:
single = True
to_put.append(vals)
out = [self._put(x) for x in to_put]
if single:
out = out[0]
return out
def _put(self, val: Any):
resp = self.server._put_and_retain_obj(val)
return ClientObjectRef(resp.id)
def call_remote(self, instance: ClientStub, *args, **kwargs):
task = instance._prepare_client_task()
ticket = self.server.Schedule(task, prepared_args=args)
return ClientObjectRef(ticket.return_id)

View file

@ -7,10 +7,10 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import time import time
import inspect import inspect
from ray.experimental.client import stash_api_for_tests from ray.experimental.client import stash_api_for_tests, _set_server_api
from ray.experimental.client.common import convert_from_arg from ray.experimental.client.common import convert_from_arg
from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientRemoteFunc from ray.experimental.client.server.core_ray_api import RayServerAPI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,12 +32,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
item_ser = cloudpickle.dumps(item) item_ser = cloudpickle.dumps(item)
return ray_client_pb2.GetResponse(valid=True, data=item_ser) return ray_client_pb2.GetResponse(valid=True, data=item_ser)
def PutObject(self, request, context=None): def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse:
obj = cloudpickle.loads(request.data) obj = cloudpickle.loads(request.data)
objectref = self._put_and_retain_obj(obj)
return ray_client_pb2.PutResponse(id=objectref.binary())
def _put_and_retain_obj(self, obj) -> ray.ObjectRef:
objectref = ray.put(obj) objectref = ray.put(obj)
self.object_refs[objectref.binary()] = objectref self.object_refs[objectref.binary()] = objectref
logger.info("put: %s" % objectref) logger.info("put: %s" % objectref)
return ray_client_pb2.PutResponse(id=objectref.binary()) return objectref
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse: def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
object_refs = [cloudpickle.loads(o) for o in request.object_refs] object_refs = [cloudpickle.loads(o) for o in request.object_refs]
@ -70,70 +74,83 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
ready_object_ids=ready_object_ids, ready_object_ids=ready_object_ids,
remaining_object_ids=remaining_object_ids) remaining_object_ids=remaining_object_ids)
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket: def Schedule(self, task, context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
logger.info("schedule: %s %s" % logger.info("schedule: %s %s" %
(task.name, (task.name,
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))) ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
if task.type == ray_client_pb2.ClientTask.FUNCTION: if task.type == ray_client_pb2.ClientTask.FUNCTION:
return self._schedule_function(task, context) return self._schedule_function(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.ACTOR: elif task.type == ray_client_pb2.ClientTask.ACTOR:
return self._schedule_actor(task, context) return self._schedule_actor(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.METHOD: elif task.type == ray_client_pb2.ClientTask.METHOD:
return self._schedule_method(task, context) return self._schedule_method(task, context, prepared_args)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Unimplemented Schedule task type: %s" % "Unimplemented Schedule task type: %s" %
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)) ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
def _schedule_method(self, task: ray_client_pb2.ClientTask, def _schedule_method(
context=None) -> ray_client_pb2.ClientTaskTicket: self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
actor_handle = self.actor_refs.get(task.payload_id) actor_handle = self.actor_refs.get(task.payload_id)
if actor_handle is None: if actor_handle is None:
raise Exception( raise Exception(
"Can't run an actor the server doesn't have a handle for") "Can't run an actor the server doesn't have a handle for")
arglist = _convert_args(task.args) arglist = _convert_args(task.args, prepared_args)
with stash_api_for_tests(self._test_mode): with stash_api_for_tests(self._test_mode):
output = getattr(actor_handle, task.name).remote(*arglist) output = getattr(actor_handle, task.name).remote(*arglist)
self.object_refs[output.binary()] = output self.object_refs[output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
def _schedule_actor(self, task: ray_client_pb2.ClientTask, def _schedule_actor(self,
context=None) -> ray_client_pb2.ClientTaskTicket: task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
with stash_api_for_tests(self._test_mode): with stash_api_for_tests(self._test_mode):
if task.payload_id not in self.registered_actor_classes: if task.payload_id not in self.registered_actor_classes:
actor_class_ref = self.object_refs[task.payload_id] actor_class_ref = self.object_refs[task.payload_id]
actor_class = ray.get(actor_class_ref) actor_class = ray.get(actor_class_ref)
if not inspect.isclass(actor_class): if not inspect.isclass(actor_class):
raise Exception("Attempting to schedule actor that " raise Exception("Attempting to schedule actor that "
"isn't a ClientActorClass.") "isn't a class.")
reg_class = ray.remote(actor_class) reg_class = ray.remote(actor_class)
self.registered_actor_classes[task.payload_id] = reg_class self.registered_actor_classes[task.payload_id] = reg_class
remote_class = self.registered_actor_classes[task.payload_id] remote_class = self.registered_actor_classes[task.payload_id]
arglist = _convert_args(task.args) arglist = _convert_args(task.args, prepared_args)
actor = remote_class.remote(*arglist) actor = remote_class.remote(*arglist)
actor_ref = actor._actor_id actorhandle = cloudpickle.dumps(actor)
self.actor_refs[actor_ref.binary()] = actor self.actor_refs[actorhandle] = actor
return ray_client_pb2.ClientTaskTicket(return_id=actor_ref.binary()) return ray_client_pb2.ClientTaskTicket(return_id=actorhandle)
def _schedule_function(self, task: ray_client_pb2.ClientTask, def _schedule_function(
context=None) -> ray_client_pb2.ClientTaskTicket: self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
if task.payload_id not in self.function_refs: if task.payload_id not in self.function_refs:
funcref = self.object_refs[task.payload_id] funcref = self.object_refs[task.payload_id]
func = ray.get(funcref) func = ray.get(funcref)
if not isinstance(func, ClientRemoteFunc): if not inspect.isfunction(func):
raise Exception("Attempting to schedule function that " raise Exception("Attempting to schedule function that "
"isn't a ClientRemoteFunc.") "isn't a function.")
self.function_refs[task.payload_id] = func self.function_refs[task.payload_id] = ray.remote(func)
remote_func = self.function_refs[task.payload_id] remote_func = self.function_refs[task.payload_id]
arglist = _convert_args(task.args) arglist = _convert_args(task.args, prepared_args)
# Prepare call if we're in a test # Prepare call if we're in a test
with stash_api_for_tests(self._test_mode): with stash_api_for_tests(self._test_mode):
output = remote_func.remote(*arglist) output = remote_func.remote(*arglist)
if output.binary() in self.object_refs:
raise Exception("already found it")
self.object_refs[output.binary()] = output self.object_refs[output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
def _convert_args(arg_list): def _convert_args(arg_list, prepared_args=None):
if prepared_args is not None:
return prepared_args
out = [] out = []
for arg in arg_list: for arg in arg_list:
t = convert_from_arg(arg) t = convert_from_arg(arg)
@ -147,6 +164,7 @@ def _convert_args(arg_list):
def serve(connection_str, test_mode=False): def serve(connection_str, test_mode=False):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
task_servicer = RayletServicer(test_mode=test_mode) task_servicer = RayletServicer(test_mode=test_mode)
_set_server_api(RayServerAPI(task_servicer))
ray_client_pb2_grpc.add_RayletDriverServicer_to_server( ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
task_servicer, server) task_servicer, server)
server.add_insecure_port(connection_str) server.add_insecure_port(connection_str)

View file

@ -3,6 +3,7 @@ It implements the Ray API functions that are forwarded through grpc calls
to the server. to the server.
""" """
import inspect import inspect
import logging
from typing import List from typing import List
from typing import Tuple from typing import Tuple
@ -14,11 +15,11 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.experimental.client.common import convert_to_arg from ray.experimental.client.common import convert_to_arg
from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientActorRef
from ray.experimental.client.common import ClientActorClass from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.common import ClientRemoteMethod
from ray.experimental.client.common import ClientRemoteFunc from ray.experimental.client.common import ClientRemoteFunc
logger = logging.getLogger(__name__)
class Worker: class Worker:
def __init__(self, def __init__(self,
@ -130,50 +131,14 @@ class Worker:
raise TypeError("The @ray.remote decorator must be applied to " raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.") "either a function or to a class.")
def call_remote(self, instance, kind, *args, **kwargs): def call_remote(self, instance, *args, **kwargs):
ticket = None task = instance._prepare_client_task()
if kind == ray_client_pb2.ClientTask.FUNCTION: for arg in args:
ticket = self._put_and_schedule(instance, kind, *args, **kwargs) pb_arg = convert_to_arg(arg)
elif kind == ray_client_pb2.ClientTask.ACTOR: task.args.append(pb_arg)
ticket = self._put_and_schedule(instance, kind, *args, **kwargs) logging.debug("Scheduling %s" % task)
return ClientActorRef(ticket.return_id) ticket = self.server.Schedule(task, metadata=self.metadata)
elif kind == ray_client_pb2.ClientTask.METHOD:
ticket = self._call_method(instance, *args, **kwargs)
if ticket is None:
raise Exception(
"Couldn't call_remote on %s for type %s" % (instance, kind))
return ClientObjectRef(ticket.return_id) return ClientObjectRef(ticket.return_id)
def _call_method(self, instance: ClientRemoteMethod, *args, **kwargs):
if not isinstance(instance, ClientRemoteMethod):
raise TypeError("Client not passing a ClientRemoteMethod stub")
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.METHOD
task.name = instance.method_name
task.payload_id = instance.actor_handle.actor_id.id
for arg in args:
pb_arg = convert_to_arg(arg)
task.args.append(pb_arg)
ticket = self.server.Schedule(task, metadata=self.metadata)
return ticket
def _put_and_schedule(self, item, task_type, *args, **kwargs):
if isinstance(item, ClientRemoteFunc):
ref = self._put(item)
elif isinstance(item, ClientActorClass):
ref = self._put(item.actor_cls)
else:
raise TypeError("Client not passing a ClientRemoteFunc stub")
task = ray_client_pb2.ClientTask()
task.type = task_type
task.name = item._name
task.payload_id = ref.id
for arg in args:
pb_arg = convert_to_arg(arg)
task.args.append(pb_arg)
ticket = self.server.Schedule(task, metadata=self.metadata)
return ticket
def close(self): def close(self):
self.channel.close() self.channel.close()

View file

@ -2,7 +2,7 @@ import pytest
from contextlib import contextmanager from contextlib import contextmanager
import ray.experimental.client.server.server as ray_client_server import ray.experimental.client.server.server as ray_client_server
from ray.experimental.client import ray from ray.experimental.client import ray, reset_api
from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientObjectRef
@ -13,6 +13,7 @@ def ray_start_client_server():
yield ray yield ray
ray.disconnect() ray.disconnect()
server.stop(0) server.stop(0)
reset_api()
def test_real_ray_fallback(ray_start_regular_shared): def test_real_ray_fallback(ray_start_regular_shared):
@ -170,6 +171,70 @@ def test_basic_actor(ray_start_regular_shared):
assert count == 2 assert count == 2
def test_pass_handles(ray_start_regular_shared):
"""
Test that passing client handles to actors and functions to remote actors
in functions (on the server or raylet side) works transparently to the
caller.
"""
with ray_start_client_server() as ray:
@ray.remote
class ExecActor:
def exec(self, f, x):
return ray.get(f.remote(x))
def exec_exec(self, actor, f, x):
return ray.get(actor.exec.remote(f, x))
@ray.remote
def fact(x):
out = 1
while x > 0:
out = out * x
x -= 1
return out
@ray.remote
def func_exec(f, x):
return ray.get(f.remote(x))
@ray.remote
def func_actor_exec(actor, f, x):
return ray.get(actor.exec.remote(f, x))
@ray.remote
def sneaky_func_exec(obj, x):
return ray.get(obj["f"].remote(x))
@ray.remote
def sneaky_actor_exec(obj, x):
return ray.get(obj["actor"].exec.remote(obj["f"], x))
def local_fact(x):
if x <= 0:
return 1
return x * local_fact(x - 1)
assert ray.get(fact.remote(7)) == local_fact(7)
assert ray.get(func_exec.remote(fact, 8)) == local_fact(8)
test_obj = {}
test_obj["f"] = fact
assert ray.get(sneaky_func_exec.remote(test_obj, 5)) == local_fact(5)
actor_handle = ExecActor.remote()
assert ray.get(actor_handle.exec.remote(fact, 7)) == local_fact(7)
assert ray.get(func_actor_exec.remote(actor_handle, fact,
10)) == local_fact(10)
second_actor = ExecActor.remote()
assert ray.get(actor_handle.exec_exec.remote(second_actor, fact,
9)) == local_fact(9)
test_actor_obj = {}
test_actor_obj["actor"] = second_actor
test_actor_obj["f"] = fact
assert ray.get(sneaky_actor_exec.remote(test_actor_obj,
4)) == local_fact(4)
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
sys.exit(pytest.main(["-v", __file__])) sys.exit(pytest.main(["-v", __file__]))