mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[ray_client] Passing actors to actors (#12585)
* start building tests around passing handles to handles Change-Id: Ie8c3de5c8ce789c3ec8d29f0702df80ba598279f * clean up the switch statements by moving to a method, implement state tranfer, extend test Change-Id: Ie7b6493db3a6c203d3a0b262b8fbacb90e5cdbc5 * passing Change-Id: Id88dc0a41da1c9d5ba68f754c5b57141aae47beb * flush out tests Change-Id: If77c0f586e9e99449d494be4e85f854e4a7a4952 * formatting Change-Id: I497c07cee70b52453b221ed4393f04f6f560061e * fix python3.6 and other attributes Change-Id: I5a2c5231e8a021184d9dfc3e346df7f71fc93257 * address documentation Change-Id: I049d841ed1f85b7350c17c05da4a4d81d5cb03df * formatting Change-Id: I6a2b32a2466ffc9f03fc91ac17901b9c1a49505c * use the pickled handle as the id bytes for actors Change-Id: I9ddcb41d614de65d42d6f0382fe0faa7ad2c2ade * pydoc Change-Id: I9b32a0f383d5ff5ac052e61929b7ae3e42a89fc5 * format Change-Id: Iac0010bb990a4025a98139ab88700030b2e9e7f5 * todos Change-Id: I7b550800cf7499403e8a17b77484bc46f20f0afc * tests Change-Id: If8ebf6a335baeb113c1332acc930c41a6b4f5384 * fix lint Change-Id: I019f41e0ec341d39bbbbd39aa43d9fb5f8b57cf0 * nits Change-Id: I2e6813d8db34f4ce008326faa095d414c10eee95 * add some tricky, python3.6-troublesome type checking Change-Id: Ib887fc943a6e7084002bc13dfbe113b69b4d9317
This commit is contained in:
parent
d534719af6
commit
dc4b5c7aa3
7 changed files with 491 additions and 130 deletions
|
@ -7,34 +7,88 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# _client_api has to be external to the API stub, below.
|
||||
# Otherwise, ray.remote() that contains ray.remote()
|
||||
# contains a reference to the RayAPIStub, therefore a
|
||||
# reference to the _client_api, and then tries to pickle
|
||||
# the thing.
|
||||
# About these global variables: Ray 1.0 uses exported module functions to
|
||||
# provide its API, and we need to match that. However, we want different
|
||||
# behaviors depending on where, exactly, in the client stack this is running.
|
||||
#
|
||||
# The reason for these differences depends on what's being pickled and passed
|
||||
# to functions, or functions inside functions. So there are three cases to care
|
||||
# about
|
||||
#
|
||||
# (Python Client)-->(Python ClientServer)-->(Internal Raylet Process)
|
||||
#
|
||||
# * _client_api should be set if we're inside the client
|
||||
# * _server_api should be set if we're inside the clientserver
|
||||
# * Both will be set if we're running both (as in a test)
|
||||
# * Neither should be set if we're inside the raylet (but we still need to shim
|
||||
# from the client API surface to the Ray API)
|
||||
#
|
||||
# The job of RayAPIStub (below) delegates to the appropriate one of these
|
||||
# depending on what's set or not. Then, all users importing the ray object
|
||||
# from this package get the stub which routes them to the appropriate APIImpl.
|
||||
_client_api: Optional[APIImpl] = None
|
||||
_server_api: Optional[APIImpl] = None
|
||||
|
||||
# The reason for _is_server is a hack around the above comment while running
|
||||
# tests. If we have both a client and a server trying to control these static
|
||||
# variables then we need a way to decide which to use. In this case, both
|
||||
# _client_api and _server_api are set.
|
||||
# This boolean flips between the two
|
||||
_is_server: bool = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def stash_api_for_tests(in_test: bool):
|
||||
api = None
|
||||
global _is_server
|
||||
is_server = _is_server
|
||||
if in_test:
|
||||
api = stash_api()
|
||||
yield api
|
||||
_is_server = True
|
||||
yield _server_api
|
||||
if in_test:
|
||||
restore_api(api)
|
||||
_is_server = is_server
|
||||
|
||||
|
||||
def stash_api() -> Optional[APIImpl]:
|
||||
def _set_client_api(val: Optional[APIImpl]):
|
||||
global _client_api
|
||||
a = _client_api
|
||||
global _is_server
|
||||
if _client_api is not None:
|
||||
raise Exception("Trying to set more than one client API")
|
||||
_client_api = val
|
||||
_is_server = False
|
||||
|
||||
|
||||
def _set_server_api(val: Optional[APIImpl]):
|
||||
global _server_api
|
||||
global _is_server
|
||||
if _server_api is not None:
|
||||
raise Exception("Trying to set more than one server API")
|
||||
_server_api = val
|
||||
_is_server = True
|
||||
|
||||
|
||||
def reset_api():
|
||||
global _client_api
|
||||
global _server_api
|
||||
global _is_server
|
||||
_client_api = None
|
||||
return a
|
||||
_server_api = None
|
||||
_is_server = False
|
||||
|
||||
|
||||
def restore_api(api: Optional[APIImpl]):
|
||||
def _get_client_api() -> APIImpl:
|
||||
global _client_api
|
||||
_client_api = api
|
||||
global _server_api
|
||||
global _is_server
|
||||
api = None
|
||||
if _is_server:
|
||||
api = _server_api
|
||||
else:
|
||||
api = _client_api
|
||||
if api is None:
|
||||
# We're inside a raylet worker
|
||||
from ray.experimental.client.server.core_ray_api import CoreRayAPI
|
||||
return CoreRayAPI()
|
||||
return api
|
||||
|
||||
|
||||
class RayAPIStub:
|
||||
|
@ -43,11 +97,10 @@ class RayAPIStub:
|
|||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
stub=None):
|
||||
global _client_api
|
||||
from ray.experimental.client.worker import Worker
|
||||
_client_worker = Worker(
|
||||
conn_str, secure=secure, metadata=metadata, stub=stub)
|
||||
_client_api = ClientAPI(_client_worker)
|
||||
_set_client_api(ClientAPI(_client_worker))
|
||||
|
||||
def disconnect(self):
|
||||
global _client_api
|
||||
|
@ -56,15 +109,9 @@ class RayAPIStub:
|
|||
_client_api = None
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
global _client_api
|
||||
self.__check_client_api()
|
||||
return getattr(_client_api, key)
|
||||
|
||||
def __check_client_api(self):
|
||||
global _client_api
|
||||
if _client_api is None:
|
||||
from ray.experimental.client.server.core_ray_api import CoreRayAPI
|
||||
_client_api = CoreRayAPI()
|
||||
global _get_client_api
|
||||
api = _get_client_api()
|
||||
return getattr(api, key)
|
||||
|
||||
|
||||
ray = RayAPIStub()
|
||||
|
|
|
@ -11,35 +11,105 @@
|
|||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
if TYPE_CHECKING:
|
||||
from ray.experimental.client.common import ClientStub
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray._raylet import ObjectRef
|
||||
|
||||
# Use the imports for type checking. This is a python 3.6 limitation.
|
||||
# See https://www.python.org/dev/peps/pep-0563/
|
||||
PutType = Union[ClientObjectRef, ObjectRef]
|
||||
|
||||
|
||||
class APIImpl(ABC):
|
||||
"""
|
||||
APIImpl is the interface to implement for whichever version of the core
|
||||
Ray API that needs abstracting when run in client mode.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, *args, **kwargs):
|
||||
def get(self, *args, **kwargs) -> Any:
|
||||
"""
|
||||
get is the hook stub passed on to replace `ray.get`
|
||||
|
||||
Args:
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, *args, **kwargs):
|
||||
def put(self, vals: Any, *args,
|
||||
**kwargs) -> Union["ClientObjectRef", "ObjectRef"]:
|
||||
"""
|
||||
put is the hook stub passed on to replace `ray.put`
|
||||
|
||||
Args:
|
||||
vals: The value or list of values to `put`.
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait(self, *args, **kwargs):
|
||||
"""
|
||||
wait is the hook stub passed on to replace `ray.wait`
|
||||
|
||||
Args:
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remote(self, *args, **kwargs):
|
||||
"""
|
||||
remote is the hook stub passed on to replace `ray.remote`.
|
||||
|
||||
This sets up remote functions or actors, as the decorator,
|
||||
but does not execute them.
|
||||
|
||||
Args:
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def call_remote(self, f, kind, *args, **kwargs):
|
||||
def call_remote(self, instance: "ClientStub", *args, **kwargs):
|
||||
"""
|
||||
call_remote is called by stub objects to execute them remotely.
|
||||
|
||||
This is used by stub objects in situations where they're called
|
||||
with .remote, eg, `f.remote()` or `actor_cls.remote()`.
|
||||
This allows the client stub objects to delegate execution to be
|
||||
implemented in the most effective way whether it's in the client,
|
||||
clientserver, or raylet worker.
|
||||
|
||||
Args:
|
||||
instance: The Client-side stub reference to a remote object
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self, *args, **kwargs):
|
||||
def close(self) -> None:
|
||||
"""
|
||||
close cleans up an API connection by closing any channels or
|
||||
shutting down any servers gracefully.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ClientAPI(APIImpl):
|
||||
"""
|
||||
The Client-side methods corresponding to the ray API. Delegates
|
||||
to the Client Worker that contains the connection to the ClientServer.
|
||||
"""
|
||||
|
||||
def __init__(self, worker):
|
||||
self.worker = worker
|
||||
|
||||
|
@ -55,10 +125,10 @@ class ClientAPI(APIImpl):
|
|||
def remote(self, *args, **kwargs):
|
||||
return self.worker.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, f, kind, *args, **kwargs):
|
||||
return self.worker.call_remote(f, kind, *args, **kwargs)
|
||||
def call_remote(self, instance: "ClientStub", *args, **kwargs):
|
||||
return self.worker.call_remote(instance, *args, **kwargs)
|
||||
|
||||
def close(self, *args, **kwargs):
|
||||
def close(self) -> None:
|
||||
return self.worker.close()
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
from ray.experimental.client import ray
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from ray import cloudpickle
|
||||
|
||||
|
||||
|
@ -17,6 +18,9 @@ class ClientBaseRef:
|
|||
def __eq__(self, other):
|
||||
return self.id == other.id
|
||||
|
||||
def binary(self):
|
||||
return self.id
|
||||
|
||||
|
||||
class ClientObjectRef(ClientBaseRef):
|
||||
pass
|
||||
|
@ -26,74 +30,222 @@ class ClientActorRef(ClientBaseRef):
|
|||
pass
|
||||
|
||||
|
||||
class ClientRemoteFunc:
|
||||
class ClientStub:
|
||||
pass
|
||||
|
||||
|
||||
class ClientRemoteFunc(ClientStub):
|
||||
"""
|
||||
A stub created on the Ray Client to represent a remote
|
||||
function that can be exectued on the cluster.
|
||||
|
||||
This class is allowed to be passed around between remote functions.
|
||||
|
||||
Args:
|
||||
_func: The actual function to execute remotely
|
||||
_name: The original name of the function
|
||||
_ref: The ClientObjectRef of the pickled code of the function, _func
|
||||
_raylet_remote: The Raylet-side ray.remote_function.RemoteFunction
|
||||
for this object
|
||||
"""
|
||||
|
||||
def __init__(self, f):
|
||||
self._func = f
|
||||
self._name = f.__name__
|
||||
self.id = None
|
||||
self._raylet_remote_func = None
|
||||
|
||||
# self._ref can be lazily instantiated. Rather than eagerly creating
|
||||
# function data objects in the server we can put them just before we
|
||||
# execute the function, especially in cases where many @ray.remote
|
||||
# functions exist in a library and only a handful are ever executed by
|
||||
# a user of the library.
|
||||
#
|
||||
# TODO(barakmich): This ref might actually be better as a serialized
|
||||
# ObjectRef. This requires being able to serialize the ref without
|
||||
# pinning it (as the lifetime of the ref is tied with the server, not
|
||||
# the client)
|
||||
self._ref = None
|
||||
self._raylet_remote = None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote function cannot be called directly. "
|
||||
"Use {self._name}.remote method instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.call_remote(self, ray_client_pb2.ClientTask.FUNCTION, *args,
|
||||
**kwargs)
|
||||
return ray.call_remote(self, *args, **kwargs)
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
if self._raylet_remote is None:
|
||||
self._raylet_remote = ray.remote(self._func)
|
||||
return self._raylet_remote
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteFunc(%s, %s)" % (self._name, self.id)
|
||||
return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
if self._ref is None:
|
||||
self._ref = ray.put(self._func)
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.FUNCTION
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.id
|
||||
return task
|
||||
|
||||
|
||||
class ClientActorClass:
|
||||
class ClientActorClass(ClientStub):
|
||||
""" A stub created on the Ray Client to represent an actor class.
|
||||
|
||||
It is wrapped by ray.remote and can be executed on the cluster.
|
||||
|
||||
Args:
|
||||
actor_cls: The actual class to execute remotely
|
||||
_name: The original name of the class
|
||||
_ref: The ClientObjectRef of the pickled `actor_cls`
|
||||
_raylet_remote: The Raylet-side ray.ActorClass for this object
|
||||
"""
|
||||
|
||||
def __init__(self, actor_cls):
|
||||
self.actor_cls = actor_cls
|
||||
self._name = actor_cls.__name__
|
||||
self._ref = None
|
||||
self._raylet_remote = None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote actor cannot be instantiated directly. "
|
||||
"Use {self._name}.remote() instead")
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_cls": self.actor_cls,
|
||||
"_name": self._name,
|
||||
"_ref": self._ref,
|
||||
}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_cls = state["actor_cls"]
|
||||
self._name = state["_name"]
|
||||
self._ref = state["_ref"]
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
# Actually instantiate the actor
|
||||
ref = ray.call_remote(self, ray_client_pb2.ClientTask.ACTOR, *args,
|
||||
**kwargs)
|
||||
return ClientActorHandle(ref, self)
|
||||
ref = ray.call_remote(self, *args, **kwargs)
|
||||
return ClientActorHandle(ClientActorRef(ref.id), self)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteActor(%s, %s)" % (self._name, self.id)
|
||||
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key not in self.__dict__:
|
||||
raise AttributeError("Not a class attribute")
|
||||
raise NotImplementedError("static methods")
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
if self._ref is None:
|
||||
self._ref = ray.put(self.actor_cls)
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.ACTOR
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.id
|
||||
return task
|
||||
|
||||
class ClientActorHandle:
|
||||
def __init__(self, actor_id: ClientActorRef,
|
||||
|
||||
class ClientActorHandle(ClientStub):
|
||||
"""Client-side stub for instantiated actor.
|
||||
|
||||
A stub created on the Ray Client to represent a remote actor that
|
||||
has been started on the cluster. This class is allowed to be passed
|
||||
around between remote functions.
|
||||
|
||||
Args:
|
||||
actor_ref: A reference to the running actor given to the client. This
|
||||
is a serialized version of the actual handle as an opaque token.
|
||||
actor_class: A reference to the ClientActorClass that this actor was
|
||||
instantiated from.
|
||||
_real_actor_handle: Cached copy of the Raylet-side
|
||||
ray.actor.ActorHandle contained in the actor_id ref.
|
||||
"""
|
||||
|
||||
def __init__(self, actor_ref: ClientActorRef,
|
||||
actor_class: ClientActorClass):
|
||||
self.actor_id = actor_id
|
||||
self.actor_ref = actor_ref
|
||||
self.actor_class = actor_class
|
||||
self._real_actor_handle = None
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
if self._real_actor_handle is None:
|
||||
self._real_actor_handle = cloudpickle.loads(self.actor_ref.id)
|
||||
return self._real_actor_handle
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_ref": self.actor_ref,
|
||||
"actor_class": self.actor_class,
|
||||
"_real_actor_handle": self._real_actor_handle,
|
||||
}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_ref = state["actor_ref"]
|
||||
self.actor_class = state["actor_class"]
|
||||
self._real_actor_handle = state["_real_actor_handle"]
|
||||
|
||||
def __getattr__(self, key):
|
||||
return ClientRemoteMethod(self, key)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientActorHandle(%s)" % (self.actor_ref.id.hex())
|
||||
|
||||
|
||||
class ClientRemoteMethod(ClientStub):
|
||||
"""A stub for a method on a remote actor.
|
||||
|
||||
Can be annotated with exection options.
|
||||
|
||||
Args:
|
||||
actor_handle: A reference to the ClientActorHandle that generated
|
||||
this method and will have this method called upon it.
|
||||
method_name: The name of this method
|
||||
"""
|
||||
|
||||
class ClientRemoteMethod:
|
||||
def __init__(self, actor_handle: ClientActorHandle, method_name: str):
|
||||
self.actor_handle = actor_handle
|
||||
self.method_name = method_name
|
||||
self._name = "%s.%s" % (self.actor_handle.actor_class._name,
|
||||
self.method_name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote method cannot be called directly. "
|
||||
"Use {self._name}.remote() instead")
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
return getattr(self.actor_handle._get_ray_remote_impl(),
|
||||
self.method_name)
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_handle": self.actor_handle,
|
||||
"method_name": self.method_name,
|
||||
}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_handle = state["actor_handle"]
|
||||
self.method_name = state["method_name"]
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.call_remote(self, ray_client_pb2.ClientTask.METHOD, *args,
|
||||
**kwargs)
|
||||
return ray.call_remote(self, *args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteMethod(%s, %s)" % (self._name, self.actor_id)
|
||||
name = "%s.%s" % (self.actor_handle.actor_class._name,
|
||||
self.method_name)
|
||||
return "ClientRemoteMethod(%s, %s)" % (name,
|
||||
self.actor_handle.actor_id)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.METHOD
|
||||
task.name = self.method_name
|
||||
task.payload_id = self.actor_handle.actor_ref.id
|
||||
return task
|
||||
|
||||
|
||||
def convert_from_arg(pb) -> Any:
|
||||
|
|
|
@ -7,18 +7,29 @@
|
|||
# While the stub is trivial, it allows us to check that the calls we're
|
||||
# making into the core-ray module are contained and well-defined.
|
||||
|
||||
from typing import Any
|
||||
from typing import Union
|
||||
|
||||
import ray
|
||||
|
||||
from ray.experimental.client.api import APIImpl
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientStub
|
||||
|
||||
|
||||
class CoreRayAPI(APIImpl):
|
||||
"""
|
||||
Implements the equivalent client-side Ray API by simply passing along to
|
||||
the Core Ray API. Primarily used inside of Ray Workers as a trampoline back
|
||||
to core ray when passed client stubs.
|
||||
"""
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
return ray.get(*args, **kwargs)
|
||||
|
||||
def put(self, *args, **kwargs):
|
||||
return ray.put(*args, **kwargs)
|
||||
def put(self, vals: Any, *args,
|
||||
**kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]:
|
||||
return ray.put(vals, *args, **kwargs)
|
||||
|
||||
def wait(self, *args, **kwargs):
|
||||
return ray.wait(*args, **kwargs)
|
||||
|
@ -26,12 +37,10 @@ class CoreRayAPI(APIImpl):
|
|||
def remote(self, *args, **kwargs):
|
||||
return ray.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, f: ClientRemoteFunc, kind: int, *args, **kwargs):
|
||||
if f._raylet_remote_func is None:
|
||||
f._raylet_remote_func = ray.remote(f._func)
|
||||
return f._raylet_remote_func.remote(*args, **kwargs)
|
||||
def call_remote(self, instance: ClientStub, *args, **kwargs):
|
||||
return instance._get_ray_remote_impl().remote(*args, **kwargs)
|
||||
|
||||
def close(self, *args, **kwargs):
|
||||
def close(self) -> None:
|
||||
return None
|
||||
|
||||
# Allow for generic fallback to ray.* in remote methods. This allows calls
|
||||
|
@ -39,3 +48,38 @@ class CoreRayAPI(APIImpl):
|
|||
# doesn't currently support them.
|
||||
def __getattr__(self, key: str):
|
||||
return getattr(ray, key)
|
||||
|
||||
|
||||
class RayServerAPI(CoreRayAPI):
|
||||
"""
|
||||
Ray Client server-side API shim. By default, simply calls the default Core
|
||||
Ray API calls, but also accepts scheduling calls from functions running
|
||||
inside of other remote functions that need to create more work.
|
||||
"""
|
||||
|
||||
def __init__(self, server_instance):
|
||||
self.server = server_instance
|
||||
|
||||
# Wrap single item into list if needed before calling server put.
|
||||
def put(self, vals: Any, *args, **kwargs) -> ClientObjectRef:
|
||||
to_put = []
|
||||
single = False
|
||||
if isinstance(vals, list):
|
||||
to_put = vals
|
||||
else:
|
||||
single = True
|
||||
to_put.append(vals)
|
||||
|
||||
out = [self._put(x) for x in to_put]
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _put(self, val: Any):
|
||||
resp = self.server._put_and_retain_obj(val)
|
||||
return ClientObjectRef(resp.id)
|
||||
|
||||
def call_remote(self, instance: ClientStub, *args, **kwargs):
|
||||
task = instance._prepare_client_task()
|
||||
ticket = self.server.Schedule(task, prepared_args=args)
|
||||
return ClientObjectRef(ticket.return_id)
|
||||
|
|
|
@ -7,10 +7,10 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
|||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
import time
|
||||
import inspect
|
||||
from ray.experimental.client import stash_api_for_tests
|
||||
from ray.experimental.client import stash_api_for_tests, _set_server_api
|
||||
from ray.experimental.client.common import convert_from_arg
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.server.core_ray_api import RayServerAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -32,12 +32,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
item_ser = cloudpickle.dumps(item)
|
||||
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
|
||||
|
||||
def PutObject(self, request, context=None):
|
||||
def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse:
|
||||
obj = cloudpickle.loads(request.data)
|
||||
objectref = self._put_and_retain_obj(obj)
|
||||
return ray_client_pb2.PutResponse(id=objectref.binary())
|
||||
|
||||
def _put_and_retain_obj(self, obj) -> ray.ObjectRef:
|
||||
objectref = ray.put(obj)
|
||||
self.object_refs[objectref.binary()] = objectref
|
||||
logger.info("put: %s" % objectref)
|
||||
return ray_client_pb2.PutResponse(id=objectref.binary())
|
||||
return objectref
|
||||
|
||||
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
|
||||
object_refs = [cloudpickle.loads(o) for o in request.object_refs]
|
||||
|
@ -70,70 +74,83 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
ready_object_ids=ready_object_ids,
|
||||
remaining_object_ids=remaining_object_ids)
|
||||
|
||||
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
def Schedule(self, task, context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
logger.info("schedule: %s %s" %
|
||||
(task.name,
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
|
||||
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
||||
return self._schedule_function(task, context)
|
||||
return self._schedule_function(task, context, prepared_args)
|
||||
elif task.type == ray_client_pb2.ClientTask.ACTOR:
|
||||
return self._schedule_actor(task, context)
|
||||
return self._schedule_actor(task, context, prepared_args)
|
||||
elif task.type == ray_client_pb2.ClientTask.METHOD:
|
||||
return self._schedule_method(task, context)
|
||||
return self._schedule_method(task, context, prepared_args)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented Schedule task type: %s" %
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
|
||||
|
||||
def _schedule_method(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
def _schedule_method(
|
||||
self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
actor_handle = self.actor_refs.get(task.payload_id)
|
||||
if actor_handle is None:
|
||||
raise Exception(
|
||||
"Can't run an actor the server doesn't have a handle for")
|
||||
arglist = _convert_args(task.args)
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
output = getattr(actor_handle, task.name).remote(*arglist)
|
||||
self.object_refs[output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
|
||||
def _schedule_actor(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
def _schedule_actor(self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
if task.payload_id not in self.registered_actor_classes:
|
||||
actor_class_ref = self.object_refs[task.payload_id]
|
||||
actor_class = ray.get(actor_class_ref)
|
||||
if not inspect.isclass(actor_class):
|
||||
raise Exception("Attempting to schedule actor that "
|
||||
"isn't a ClientActorClass.")
|
||||
"isn't a class.")
|
||||
reg_class = ray.remote(actor_class)
|
||||
self.registered_actor_classes[task.payload_id] = reg_class
|
||||
remote_class = self.registered_actor_classes[task.payload_id]
|
||||
arglist = _convert_args(task.args)
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
actor = remote_class.remote(*arglist)
|
||||
actor_ref = actor._actor_id
|
||||
self.actor_refs[actor_ref.binary()] = actor
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=actor_ref.binary())
|
||||
actorhandle = cloudpickle.dumps(actor)
|
||||
self.actor_refs[actorhandle] = actor
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=actorhandle)
|
||||
|
||||
def _schedule_function(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
def _schedule_function(
|
||||
self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
if task.payload_id not in self.function_refs:
|
||||
funcref = self.object_refs[task.payload_id]
|
||||
func = ray.get(funcref)
|
||||
if not isinstance(func, ClientRemoteFunc):
|
||||
if not inspect.isfunction(func):
|
||||
raise Exception("Attempting to schedule function that "
|
||||
"isn't a ClientRemoteFunc.")
|
||||
self.function_refs[task.payload_id] = func
|
||||
"isn't a function.")
|
||||
self.function_refs[task.payload_id] = ray.remote(func)
|
||||
remote_func = self.function_refs[task.payload_id]
|
||||
arglist = _convert_args(task.args)
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
# Prepare call if we're in a test
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
output = remote_func.remote(*arglist)
|
||||
if output.binary() in self.object_refs:
|
||||
raise Exception("already found it")
|
||||
self.object_refs[output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
|
||||
|
||||
def _convert_args(arg_list):
|
||||
def _convert_args(arg_list, prepared_args=None):
|
||||
if prepared_args is not None:
|
||||
return prepared_args
|
||||
out = []
|
||||
for arg in arg_list:
|
||||
t = convert_from_arg(arg)
|
||||
|
@ -147,6 +164,7 @@ def _convert_args(arg_list):
|
|||
def serve(connection_str, test_mode=False):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
task_servicer = RayletServicer(test_mode=test_mode)
|
||||
_set_server_api(RayServerAPI(task_servicer))
|
||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||
task_servicer, server)
|
||||
server.add_insecure_port(connection_str)
|
||||
|
|
|
@ -3,6 +3,7 @@ It implements the Ray API functions that are forwarded through grpc calls
|
|||
to the server.
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
|
@ -14,11 +15,11 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
|||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.experimental.client.common import convert_to_arg
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientActorRef
|
||||
from ray.experimental.client.common import ClientActorClass
|
||||
from ray.experimental.client.common import ClientRemoteMethod
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(self,
|
||||
|
@ -130,50 +131,14 @@ class Worker:
|
|||
raise TypeError("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
def call_remote(self, instance, kind, *args, **kwargs):
|
||||
ticket = None
|
||||
if kind == ray_client_pb2.ClientTask.FUNCTION:
|
||||
ticket = self._put_and_schedule(instance, kind, *args, **kwargs)
|
||||
elif kind == ray_client_pb2.ClientTask.ACTOR:
|
||||
ticket = self._put_and_schedule(instance, kind, *args, **kwargs)
|
||||
return ClientActorRef(ticket.return_id)
|
||||
elif kind == ray_client_pb2.ClientTask.METHOD:
|
||||
ticket = self._call_method(instance, *args, **kwargs)
|
||||
|
||||
if ticket is None:
|
||||
raise Exception(
|
||||
"Couldn't call_remote on %s for type %s" % (instance, kind))
|
||||
def call_remote(self, instance, *args, **kwargs):
|
||||
task = instance._prepare_client_task()
|
||||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg)
|
||||
task.args.append(pb_arg)
|
||||
logging.debug("Scheduling %s" % task)
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
return ClientObjectRef(ticket.return_id)
|
||||
|
||||
def _call_method(self, instance: ClientRemoteMethod, *args, **kwargs):
|
||||
if not isinstance(instance, ClientRemoteMethod):
|
||||
raise TypeError("Client not passing a ClientRemoteMethod stub")
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.METHOD
|
||||
task.name = instance.method_name
|
||||
task.payload_id = instance.actor_handle.actor_id.id
|
||||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg)
|
||||
task.args.append(pb_arg)
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
return ticket
|
||||
|
||||
def _put_and_schedule(self, item, task_type, *args, **kwargs):
|
||||
if isinstance(item, ClientRemoteFunc):
|
||||
ref = self._put(item)
|
||||
elif isinstance(item, ClientActorClass):
|
||||
ref = self._put(item.actor_cls)
|
||||
else:
|
||||
raise TypeError("Client not passing a ClientRemoteFunc stub")
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = task_type
|
||||
task.name = item._name
|
||||
task.payload_id = ref.id
|
||||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg)
|
||||
task.args.append(pb_arg)
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
return ticket
|
||||
|
||||
def close(self):
|
||||
self.channel.close()
|
||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
from contextlib import contextmanager
|
||||
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
from ray.experimental.client import ray
|
||||
from ray.experimental.client import ray, reset_api
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
|
||||
|
||||
|
@ -13,6 +13,7 @@ def ray_start_client_server():
|
|||
yield ray
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
reset_api()
|
||||
|
||||
|
||||
def test_real_ray_fallback(ray_start_regular_shared):
|
||||
|
@ -170,6 +171,70 @@ def test_basic_actor(ray_start_regular_shared):
|
|||
assert count == 2
|
||||
|
||||
|
||||
def test_pass_handles(ray_start_regular_shared):
|
||||
"""
|
||||
Test that passing client handles to actors and functions to remote actors
|
||||
in functions (on the server or raylet side) works transparently to the
|
||||
caller.
|
||||
"""
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class ExecActor:
|
||||
def exec(self, f, x):
|
||||
return ray.get(f.remote(x))
|
||||
|
||||
def exec_exec(self, actor, f, x):
|
||||
return ray.get(actor.exec.remote(f, x))
|
||||
|
||||
@ray.remote
|
||||
def fact(x):
|
||||
out = 1
|
||||
while x > 0:
|
||||
out = out * x
|
||||
x -= 1
|
||||
return out
|
||||
|
||||
@ray.remote
|
||||
def func_exec(f, x):
|
||||
return ray.get(f.remote(x))
|
||||
|
||||
@ray.remote
|
||||
def func_actor_exec(actor, f, x):
|
||||
return ray.get(actor.exec.remote(f, x))
|
||||
|
||||
@ray.remote
|
||||
def sneaky_func_exec(obj, x):
|
||||
return ray.get(obj["f"].remote(x))
|
||||
|
||||
@ray.remote
|
||||
def sneaky_actor_exec(obj, x):
|
||||
return ray.get(obj["actor"].exec.remote(obj["f"], x))
|
||||
|
||||
def local_fact(x):
|
||||
if x <= 0:
|
||||
return 1
|
||||
return x * local_fact(x - 1)
|
||||
|
||||
assert ray.get(fact.remote(7)) == local_fact(7)
|
||||
assert ray.get(func_exec.remote(fact, 8)) == local_fact(8)
|
||||
test_obj = {}
|
||||
test_obj["f"] = fact
|
||||
assert ray.get(sneaky_func_exec.remote(test_obj, 5)) == local_fact(5)
|
||||
actor_handle = ExecActor.remote()
|
||||
assert ray.get(actor_handle.exec.remote(fact, 7)) == local_fact(7)
|
||||
assert ray.get(func_actor_exec.remote(actor_handle, fact,
|
||||
10)) == local_fact(10)
|
||||
second_actor = ExecActor.remote()
|
||||
assert ray.get(actor_handle.exec_exec.remote(second_actor, fact,
|
||||
9)) == local_fact(9)
|
||||
test_actor_obj = {}
|
||||
test_actor_obj["actor"] = second_actor
|
||||
test_actor_obj["f"] = fact
|
||||
assert ray.get(sneaky_actor_exec.remote(test_actor_obj,
|
||||
4)) == local_fact(4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
Loading…
Add table
Reference in a new issue