mirror of
https://github.com/vale981/ray
synced 2025-03-09 04:46:38 -04:00
[ray_client] Passing actors to actors (#12585)
* start building tests around passing handles to handles Change-Id: Ie8c3de5c8ce789c3ec8d29f0702df80ba598279f * clean up the switch statements by moving to a method, implement state tranfer, extend test Change-Id: Ie7b6493db3a6c203d3a0b262b8fbacb90e5cdbc5 * passing Change-Id: Id88dc0a41da1c9d5ba68f754c5b57141aae47beb * flush out tests Change-Id: If77c0f586e9e99449d494be4e85f854e4a7a4952 * formatting Change-Id: I497c07cee70b52453b221ed4393f04f6f560061e * fix python3.6 and other attributes Change-Id: I5a2c5231e8a021184d9dfc3e346df7f71fc93257 * address documentation Change-Id: I049d841ed1f85b7350c17c05da4a4d81d5cb03df * formatting Change-Id: I6a2b32a2466ffc9f03fc91ac17901b9c1a49505c * use the pickled handle as the id bytes for actors Change-Id: I9ddcb41d614de65d42d6f0382fe0faa7ad2c2ade * pydoc Change-Id: I9b32a0f383d5ff5ac052e61929b7ae3e42a89fc5 * format Change-Id: Iac0010bb990a4025a98139ab88700030b2e9e7f5 * todos Change-Id: I7b550800cf7499403e8a17b77484bc46f20f0afc * tests Change-Id: If8ebf6a335baeb113c1332acc930c41a6b4f5384 * fix lint Change-Id: I019f41e0ec341d39bbbbd39aa43d9fb5f8b57cf0 * nits Change-Id: I2e6813d8db34f4ce008326faa095d414c10eee95 * add some tricky, python3.6-troublesome type checking Change-Id: Ib887fc943a6e7084002bc13dfbe113b69b4d9317
This commit is contained in:
parent
d534719af6
commit
dc4b5c7aa3
7 changed files with 491 additions and 130 deletions
|
@ -7,34 +7,88 @@ import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# _client_api has to be external to the API stub, below.
|
# About these global variables: Ray 1.0 uses exported module functions to
|
||||||
# Otherwise, ray.remote() that contains ray.remote()
|
# provide its API, and we need to match that. However, we want different
|
||||||
# contains a reference to the RayAPIStub, therefore a
|
# behaviors depending on where, exactly, in the client stack this is running.
|
||||||
# reference to the _client_api, and then tries to pickle
|
#
|
||||||
# the thing.
|
# The reason for these differences depends on what's being pickled and passed
|
||||||
|
# to functions, or functions inside functions. So there are three cases to care
|
||||||
|
# about
|
||||||
|
#
|
||||||
|
# (Python Client)-->(Python ClientServer)-->(Internal Raylet Process)
|
||||||
|
#
|
||||||
|
# * _client_api should be set if we're inside the client
|
||||||
|
# * _server_api should be set if we're inside the clientserver
|
||||||
|
# * Both will be set if we're running both (as in a test)
|
||||||
|
# * Neither should be set if we're inside the raylet (but we still need to shim
|
||||||
|
# from the client API surface to the Ray API)
|
||||||
|
#
|
||||||
|
# The job of RayAPIStub (below) delegates to the appropriate one of these
|
||||||
|
# depending on what's set or not. Then, all users importing the ray object
|
||||||
|
# from this package get the stub which routes them to the appropriate APIImpl.
|
||||||
_client_api: Optional[APIImpl] = None
|
_client_api: Optional[APIImpl] = None
|
||||||
|
_server_api: Optional[APIImpl] = None
|
||||||
|
|
||||||
|
# The reason for _is_server is a hack around the above comment while running
|
||||||
|
# tests. If we have both a client and a server trying to control these static
|
||||||
|
# variables then we need a way to decide which to use. In this case, both
|
||||||
|
# _client_api and _server_api are set.
|
||||||
|
# This boolean flips between the two
|
||||||
|
_is_server: bool = False
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def stash_api_for_tests(in_test: bool):
|
def stash_api_for_tests(in_test: bool):
|
||||||
api = None
|
global _is_server
|
||||||
|
is_server = _is_server
|
||||||
if in_test:
|
if in_test:
|
||||||
api = stash_api()
|
_is_server = True
|
||||||
yield api
|
yield _server_api
|
||||||
if in_test:
|
if in_test:
|
||||||
restore_api(api)
|
_is_server = is_server
|
||||||
|
|
||||||
|
|
||||||
def stash_api() -> Optional[APIImpl]:
|
def _set_client_api(val: Optional[APIImpl]):
|
||||||
global _client_api
|
global _client_api
|
||||||
a = _client_api
|
global _is_server
|
||||||
|
if _client_api is not None:
|
||||||
|
raise Exception("Trying to set more than one client API")
|
||||||
|
_client_api = val
|
||||||
|
_is_server = False
|
||||||
|
|
||||||
|
|
||||||
|
def _set_server_api(val: Optional[APIImpl]):
|
||||||
|
global _server_api
|
||||||
|
global _is_server
|
||||||
|
if _server_api is not None:
|
||||||
|
raise Exception("Trying to set more than one server API")
|
||||||
|
_server_api = val
|
||||||
|
_is_server = True
|
||||||
|
|
||||||
|
|
||||||
|
def reset_api():
|
||||||
|
global _client_api
|
||||||
|
global _server_api
|
||||||
|
global _is_server
|
||||||
_client_api = None
|
_client_api = None
|
||||||
return a
|
_server_api = None
|
||||||
|
_is_server = False
|
||||||
|
|
||||||
|
|
||||||
def restore_api(api: Optional[APIImpl]):
|
def _get_client_api() -> APIImpl:
|
||||||
global _client_api
|
global _client_api
|
||||||
_client_api = api
|
global _server_api
|
||||||
|
global _is_server
|
||||||
|
api = None
|
||||||
|
if _is_server:
|
||||||
|
api = _server_api
|
||||||
|
else:
|
||||||
|
api = _client_api
|
||||||
|
if api is None:
|
||||||
|
# We're inside a raylet worker
|
||||||
|
from ray.experimental.client.server.core_ray_api import CoreRayAPI
|
||||||
|
return CoreRayAPI()
|
||||||
|
return api
|
||||||
|
|
||||||
|
|
||||||
class RayAPIStub:
|
class RayAPIStub:
|
||||||
|
@ -43,11 +97,10 @@ class RayAPIStub:
|
||||||
secure: bool = False,
|
secure: bool = False,
|
||||||
metadata: List[Tuple[str, str]] = None,
|
metadata: List[Tuple[str, str]] = None,
|
||||||
stub=None):
|
stub=None):
|
||||||
global _client_api
|
|
||||||
from ray.experimental.client.worker import Worker
|
from ray.experimental.client.worker import Worker
|
||||||
_client_worker = Worker(
|
_client_worker = Worker(
|
||||||
conn_str, secure=secure, metadata=metadata, stub=stub)
|
conn_str, secure=secure, metadata=metadata, stub=stub)
|
||||||
_client_api = ClientAPI(_client_worker)
|
_set_client_api(ClientAPI(_client_worker))
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
global _client_api
|
global _client_api
|
||||||
|
@ -56,15 +109,9 @@ class RayAPIStub:
|
||||||
_client_api = None
|
_client_api = None
|
||||||
|
|
||||||
def __getattr__(self, key: str):
|
def __getattr__(self, key: str):
|
||||||
global _client_api
|
global _get_client_api
|
||||||
self.__check_client_api()
|
api = _get_client_api()
|
||||||
return getattr(_client_api, key)
|
return getattr(api, key)
|
||||||
|
|
||||||
def __check_client_api(self):
|
|
||||||
global _client_api
|
|
||||||
if _client_api is None:
|
|
||||||
from ray.experimental.client.server.core_ray_api import CoreRayAPI
|
|
||||||
_client_api = CoreRayAPI()
|
|
||||||
|
|
||||||
|
|
||||||
ray = RayAPIStub()
|
ray = RayAPIStub()
|
||||||
|
|
|
@ -11,35 +11,105 @@
|
||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ray.experimental.client.common import ClientStub
|
||||||
|
from ray.experimental.client.common import ClientObjectRef
|
||||||
|
from ray._raylet import ObjectRef
|
||||||
|
|
||||||
|
# Use the imports for type checking. This is a python 3.6 limitation.
|
||||||
|
# See https://www.python.org/dev/peps/pep-0563/
|
||||||
|
PutType = Union[ClientObjectRef, ObjectRef]
|
||||||
|
|
||||||
|
|
||||||
class APIImpl(ABC):
|
class APIImpl(ABC):
|
||||||
|
"""
|
||||||
|
APIImpl is the interface to implement for whichever version of the core
|
||||||
|
Ray API that needs abstracting when run in client mode.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, *args, **kwargs):
|
def get(self, *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
get is the hook stub passed on to replace `ray.get`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: opaque arguments
|
||||||
|
kwargs: opaque keyword arguments
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def put(self, *args, **kwargs):
|
def put(self, vals: Any, *args,
|
||||||
|
**kwargs) -> Union["ClientObjectRef", "ObjectRef"]:
|
||||||
|
"""
|
||||||
|
put is the hook stub passed on to replace `ray.put`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vals: The value or list of values to `put`.
|
||||||
|
args: opaque arguments
|
||||||
|
kwargs: opaque keyword arguments
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def wait(self, *args, **kwargs):
|
def wait(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
wait is the hook stub passed on to replace `ray.wait`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: opaque arguments
|
||||||
|
kwargs: opaque keyword arguments
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remote(self, *args, **kwargs):
|
def remote(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
remote is the hook stub passed on to replace `ray.remote`.
|
||||||
|
|
||||||
|
This sets up remote functions or actors, as the decorator,
|
||||||
|
but does not execute them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: opaque arguments
|
||||||
|
kwargs: opaque keyword arguments
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def call_remote(self, f, kind, *args, **kwargs):
|
def call_remote(self, instance: "ClientStub", *args, **kwargs):
|
||||||
|
"""
|
||||||
|
call_remote is called by stub objects to execute them remotely.
|
||||||
|
|
||||||
|
This is used by stub objects in situations where they're called
|
||||||
|
with .remote, eg, `f.remote()` or `actor_cls.remote()`.
|
||||||
|
This allows the client stub objects to delegate execution to be
|
||||||
|
implemented in the most effective way whether it's in the client,
|
||||||
|
clientserver, or raylet worker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: The Client-side stub reference to a remote object
|
||||||
|
args: opaque arguments
|
||||||
|
kwargs: opaque keyword arguments
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def close(self, *args, **kwargs):
|
def close(self) -> None:
|
||||||
|
"""
|
||||||
|
close cleans up an API connection by closing any channels or
|
||||||
|
shutting down any servers gracefully.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ClientAPI(APIImpl):
|
class ClientAPI(APIImpl):
|
||||||
|
"""
|
||||||
|
The Client-side methods corresponding to the ray API. Delegates
|
||||||
|
to the Client Worker that contains the connection to the ClientServer.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, worker):
|
def __init__(self, worker):
|
||||||
self.worker = worker
|
self.worker = worker
|
||||||
|
|
||||||
|
@ -55,10 +125,10 @@ class ClientAPI(APIImpl):
|
||||||
def remote(self, *args, **kwargs):
|
def remote(self, *args, **kwargs):
|
||||||
return self.worker.remote(*args, **kwargs)
|
return self.worker.remote(*args, **kwargs)
|
||||||
|
|
||||||
def call_remote(self, f, kind, *args, **kwargs):
|
def call_remote(self, instance: "ClientStub", *args, **kwargs):
|
||||||
return self.worker.call_remote(f, kind, *args, **kwargs)
|
return self.worker.call_remote(instance, *args, **kwargs)
|
||||||
|
|
||||||
def close(self, *args, **kwargs):
|
def close(self) -> None:
|
||||||
return self.worker.close()
|
return self.worker.close()
|
||||||
|
|
||||||
def __getattr__(self, key: str):
|
def __getattr__(self, key: str):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||||
from ray.experimental.client import ray
|
from ray.experimental.client import ray
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
from ray import cloudpickle
|
from ray import cloudpickle
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +18,9 @@ class ClientBaseRef:
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.id == other.id
|
return self.id == other.id
|
||||||
|
|
||||||
|
def binary(self):
|
||||||
|
return self.id
|
||||||
|
|
||||||
|
|
||||||
class ClientObjectRef(ClientBaseRef):
|
class ClientObjectRef(ClientBaseRef):
|
||||||
pass
|
pass
|
||||||
|
@ -26,74 +30,222 @@ class ClientActorRef(ClientBaseRef):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ClientRemoteFunc:
|
class ClientStub:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ClientRemoteFunc(ClientStub):
|
||||||
|
"""
|
||||||
|
A stub created on the Ray Client to represent a remote
|
||||||
|
function that can be exectued on the cluster.
|
||||||
|
|
||||||
|
This class is allowed to be passed around between remote functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_func: The actual function to execute remotely
|
||||||
|
_name: The original name of the function
|
||||||
|
_ref: The ClientObjectRef of the pickled code of the function, _func
|
||||||
|
_raylet_remote: The Raylet-side ray.remote_function.RemoteFunction
|
||||||
|
for this object
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, f):
|
def __init__(self, f):
|
||||||
self._func = f
|
self._func = f
|
||||||
self._name = f.__name__
|
self._name = f.__name__
|
||||||
self.id = None
|
self.id = None
|
||||||
self._raylet_remote_func = None
|
|
||||||
|
# self._ref can be lazily instantiated. Rather than eagerly creating
|
||||||
|
# function data objects in the server we can put them just before we
|
||||||
|
# execute the function, especially in cases where many @ray.remote
|
||||||
|
# functions exist in a library and only a handful are ever executed by
|
||||||
|
# a user of the library.
|
||||||
|
#
|
||||||
|
# TODO(barakmich): This ref might actually be better as a serialized
|
||||||
|
# ObjectRef. This requires being able to serialize the ref without
|
||||||
|
# pinning it (as the lifetime of the ref is tied with the server, not
|
||||||
|
# the client)
|
||||||
|
self._ref = None
|
||||||
|
self._raylet_remote = None
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
raise TypeError(f"Remote function cannot be called directly. "
|
raise TypeError(f"Remote function cannot be called directly. "
|
||||||
"Use {self._name}.remote method instead")
|
"Use {self._name}.remote method instead")
|
||||||
|
|
||||||
def remote(self, *args, **kwargs):
|
def remote(self, *args, **kwargs):
|
||||||
return ray.call_remote(self, ray_client_pb2.ClientTask.FUNCTION, *args,
|
return ray.call_remote(self, *args, **kwargs)
|
||||||
**kwargs)
|
|
||||||
|
def _get_ray_remote_impl(self):
|
||||||
|
if self._raylet_remote is None:
|
||||||
|
self._raylet_remote = ray.remote(self._func)
|
||||||
|
return self._raylet_remote
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "ClientRemoteFunc(%s, %s)" % (self._name, self.id)
|
return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
|
||||||
|
|
||||||
|
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||||
|
if self._ref is None:
|
||||||
|
self._ref = ray.put(self._func)
|
||||||
|
task = ray_client_pb2.ClientTask()
|
||||||
|
task.type = ray_client_pb2.ClientTask.FUNCTION
|
||||||
|
task.name = self._name
|
||||||
|
task.payload_id = self._ref.id
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
class ClientActorClass:
|
class ClientActorClass(ClientStub):
|
||||||
|
""" A stub created on the Ray Client to represent an actor class.
|
||||||
|
|
||||||
|
It is wrapped by ray.remote and can be executed on the cluster.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor_cls: The actual class to execute remotely
|
||||||
|
_name: The original name of the class
|
||||||
|
_ref: The ClientObjectRef of the pickled `actor_cls`
|
||||||
|
_raylet_remote: The Raylet-side ray.ActorClass for this object
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, actor_cls):
|
def __init__(self, actor_cls):
|
||||||
self.actor_cls = actor_cls
|
self.actor_cls = actor_cls
|
||||||
self._name = actor_cls.__name__
|
self._name = actor_cls.__name__
|
||||||
|
self._ref = None
|
||||||
|
self._raylet_remote = None
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
raise TypeError(f"Remote actor cannot be instantiated directly. "
|
raise TypeError(f"Remote actor cannot be instantiated directly. "
|
||||||
"Use {self._name}.remote() instead")
|
"Use {self._name}.remote() instead")
|
||||||
|
|
||||||
|
def __getstate__(self) -> Dict:
|
||||||
|
state = {
|
||||||
|
"actor_cls": self.actor_cls,
|
||||||
|
"_name": self._name,
|
||||||
|
"_ref": self._ref,
|
||||||
|
}
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state: Dict) -> None:
|
||||||
|
self.actor_cls = state["actor_cls"]
|
||||||
|
self._name = state["_name"]
|
||||||
|
self._ref = state["_ref"]
|
||||||
|
|
||||||
def remote(self, *args, **kwargs):
|
def remote(self, *args, **kwargs):
|
||||||
# Actually instantiate the actor
|
# Actually instantiate the actor
|
||||||
ref = ray.call_remote(self, ray_client_pb2.ClientTask.ACTOR, *args,
|
ref = ray.call_remote(self, *args, **kwargs)
|
||||||
**kwargs)
|
return ClientActorHandle(ClientActorRef(ref.id), self)
|
||||||
return ClientActorHandle(ref, self)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "ClientRemoteActor(%s, %s)" % (self._name, self.id)
|
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
|
||||||
|
|
||||||
def __getattr__(self, key):
|
def __getattr__(self, key):
|
||||||
|
if key not in self.__dict__:
|
||||||
|
raise AttributeError("Not a class attribute")
|
||||||
raise NotImplementedError("static methods")
|
raise NotImplementedError("static methods")
|
||||||
|
|
||||||
|
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||||
|
if self._ref is None:
|
||||||
|
self._ref = ray.put(self.actor_cls)
|
||||||
|
task = ray_client_pb2.ClientTask()
|
||||||
|
task.type = ray_client_pb2.ClientTask.ACTOR
|
||||||
|
task.name = self._name
|
||||||
|
task.payload_id = self._ref.id
|
||||||
|
return task
|
||||||
|
|
||||||
class ClientActorHandle:
|
|
||||||
def __init__(self, actor_id: ClientActorRef,
|
class ClientActorHandle(ClientStub):
|
||||||
|
"""Client-side stub for instantiated actor.
|
||||||
|
|
||||||
|
A stub created on the Ray Client to represent a remote actor that
|
||||||
|
has been started on the cluster. This class is allowed to be passed
|
||||||
|
around between remote functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor_ref: A reference to the running actor given to the client. This
|
||||||
|
is a serialized version of the actual handle as an opaque token.
|
||||||
|
actor_class: A reference to the ClientActorClass that this actor was
|
||||||
|
instantiated from.
|
||||||
|
_real_actor_handle: Cached copy of the Raylet-side
|
||||||
|
ray.actor.ActorHandle contained in the actor_id ref.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, actor_ref: ClientActorRef,
|
||||||
actor_class: ClientActorClass):
|
actor_class: ClientActorClass):
|
||||||
self.actor_id = actor_id
|
self.actor_ref = actor_ref
|
||||||
self.actor_class = actor_class
|
self.actor_class = actor_class
|
||||||
|
self._real_actor_handle = None
|
||||||
|
|
||||||
|
def _get_ray_remote_impl(self):
|
||||||
|
if self._real_actor_handle is None:
|
||||||
|
self._real_actor_handle = cloudpickle.loads(self.actor_ref.id)
|
||||||
|
return self._real_actor_handle
|
||||||
|
|
||||||
|
def __getstate__(self) -> Dict:
|
||||||
|
state = {
|
||||||
|
"actor_ref": self.actor_ref,
|
||||||
|
"actor_class": self.actor_class,
|
||||||
|
"_real_actor_handle": self._real_actor_handle,
|
||||||
|
}
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state: Dict) -> None:
|
||||||
|
self.actor_ref = state["actor_ref"]
|
||||||
|
self.actor_class = state["actor_class"]
|
||||||
|
self._real_actor_handle = state["_real_actor_handle"]
|
||||||
|
|
||||||
def __getattr__(self, key):
|
def __getattr__(self, key):
|
||||||
return ClientRemoteMethod(self, key)
|
return ClientRemoteMethod(self, key)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "ClientActorHandle(%s)" % (self.actor_ref.id.hex())
|
||||||
|
|
||||||
|
|
||||||
|
class ClientRemoteMethod(ClientStub):
|
||||||
|
"""A stub for a method on a remote actor.
|
||||||
|
|
||||||
|
Can be annotated with exection options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor_handle: A reference to the ClientActorHandle that generated
|
||||||
|
this method and will have this method called upon it.
|
||||||
|
method_name: The name of this method
|
||||||
|
"""
|
||||||
|
|
||||||
class ClientRemoteMethod:
|
|
||||||
def __init__(self, actor_handle: ClientActorHandle, method_name: str):
|
def __init__(self, actor_handle: ClientActorHandle, method_name: str):
|
||||||
self.actor_handle = actor_handle
|
self.actor_handle = actor_handle
|
||||||
self.method_name = method_name
|
self.method_name = method_name
|
||||||
self._name = "%s.%s" % (self.actor_handle.actor_class._name,
|
|
||||||
self.method_name)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
raise TypeError(f"Remote method cannot be called directly. "
|
raise TypeError(f"Remote method cannot be called directly. "
|
||||||
"Use {self._name}.remote() instead")
|
"Use {self._name}.remote() instead")
|
||||||
|
|
||||||
|
def _get_ray_remote_impl(self):
|
||||||
|
return getattr(self.actor_handle._get_ray_remote_impl(),
|
||||||
|
self.method_name)
|
||||||
|
|
||||||
|
def __getstate__(self) -> Dict:
|
||||||
|
state = {
|
||||||
|
"actor_handle": self.actor_handle,
|
||||||
|
"method_name": self.method_name,
|
||||||
|
}
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state: Dict) -> None:
|
||||||
|
self.actor_handle = state["actor_handle"]
|
||||||
|
self.method_name = state["method_name"]
|
||||||
|
|
||||||
def remote(self, *args, **kwargs):
|
def remote(self, *args, **kwargs):
|
||||||
return ray.call_remote(self, ray_client_pb2.ClientTask.METHOD, *args,
|
return ray.call_remote(self, *args, **kwargs)
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "ClientRemoteMethod(%s, %s)" % (self._name, self.actor_id)
|
name = "%s.%s" % (self.actor_handle.actor_class._name,
|
||||||
|
self.method_name)
|
||||||
|
return "ClientRemoteMethod(%s, %s)" % (name,
|
||||||
|
self.actor_handle.actor_id)
|
||||||
|
|
||||||
|
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||||
|
task = ray_client_pb2.ClientTask()
|
||||||
|
task.type = ray_client_pb2.ClientTask.METHOD
|
||||||
|
task.name = self.method_name
|
||||||
|
task.payload_id = self.actor_handle.actor_ref.id
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
def convert_from_arg(pb) -> Any:
|
def convert_from_arg(pb) -> Any:
|
||||||
|
|
|
@ -7,18 +7,29 @@
|
||||||
# While the stub is trivial, it allows us to check that the calls we're
|
# While the stub is trivial, it allows us to check that the calls we're
|
||||||
# making into the core-ray module are contained and well-defined.
|
# making into the core-ray module are contained and well-defined.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
from ray.experimental.client.api import APIImpl
|
from ray.experimental.client.api import APIImpl
|
||||||
from ray.experimental.client.common import ClientRemoteFunc
|
from ray.experimental.client.common import ClientObjectRef
|
||||||
|
from ray.experimental.client.common import ClientStub
|
||||||
|
|
||||||
|
|
||||||
class CoreRayAPI(APIImpl):
|
class CoreRayAPI(APIImpl):
|
||||||
|
"""
|
||||||
|
Implements the equivalent client-side Ray API by simply passing along to
|
||||||
|
the Core Ray API. Primarily used inside of Ray Workers as a trampoline back
|
||||||
|
to core ray when passed client stubs.
|
||||||
|
"""
|
||||||
|
|
||||||
def get(self, *args, **kwargs):
|
def get(self, *args, **kwargs):
|
||||||
return ray.get(*args, **kwargs)
|
return ray.get(*args, **kwargs)
|
||||||
|
|
||||||
def put(self, *args, **kwargs):
|
def put(self, vals: Any, *args,
|
||||||
return ray.put(*args, **kwargs)
|
**kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]:
|
||||||
|
return ray.put(vals, *args, **kwargs)
|
||||||
|
|
||||||
def wait(self, *args, **kwargs):
|
def wait(self, *args, **kwargs):
|
||||||
return ray.wait(*args, **kwargs)
|
return ray.wait(*args, **kwargs)
|
||||||
|
@ -26,12 +37,10 @@ class CoreRayAPI(APIImpl):
|
||||||
def remote(self, *args, **kwargs):
|
def remote(self, *args, **kwargs):
|
||||||
return ray.remote(*args, **kwargs)
|
return ray.remote(*args, **kwargs)
|
||||||
|
|
||||||
def call_remote(self, f: ClientRemoteFunc, kind: int, *args, **kwargs):
|
def call_remote(self, instance: ClientStub, *args, **kwargs):
|
||||||
if f._raylet_remote_func is None:
|
return instance._get_ray_remote_impl().remote(*args, **kwargs)
|
||||||
f._raylet_remote_func = ray.remote(f._func)
|
|
||||||
return f._raylet_remote_func.remote(*args, **kwargs)
|
|
||||||
|
|
||||||
def close(self, *args, **kwargs):
|
def close(self) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Allow for generic fallback to ray.* in remote methods. This allows calls
|
# Allow for generic fallback to ray.* in remote methods. This allows calls
|
||||||
|
@ -39,3 +48,38 @@ class CoreRayAPI(APIImpl):
|
||||||
# doesn't currently support them.
|
# doesn't currently support them.
|
||||||
def __getattr__(self, key: str):
|
def __getattr__(self, key: str):
|
||||||
return getattr(ray, key)
|
return getattr(ray, key)
|
||||||
|
|
||||||
|
|
||||||
|
class RayServerAPI(CoreRayAPI):
|
||||||
|
"""
|
||||||
|
Ray Client server-side API shim. By default, simply calls the default Core
|
||||||
|
Ray API calls, but also accepts scheduling calls from functions running
|
||||||
|
inside of other remote functions that need to create more work.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, server_instance):
|
||||||
|
self.server = server_instance
|
||||||
|
|
||||||
|
# Wrap single item into list if needed before calling server put.
|
||||||
|
def put(self, vals: Any, *args, **kwargs) -> ClientObjectRef:
|
||||||
|
to_put = []
|
||||||
|
single = False
|
||||||
|
if isinstance(vals, list):
|
||||||
|
to_put = vals
|
||||||
|
else:
|
||||||
|
single = True
|
||||||
|
to_put.append(vals)
|
||||||
|
|
||||||
|
out = [self._put(x) for x in to_put]
|
||||||
|
if single:
|
||||||
|
out = out[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _put(self, val: Any):
|
||||||
|
resp = self.server._put_and_retain_obj(val)
|
||||||
|
return ClientObjectRef(resp.id)
|
||||||
|
|
||||||
|
def call_remote(self, instance: ClientStub, *args, **kwargs):
|
||||||
|
task = instance._prepare_client_task()
|
||||||
|
ticket = self.server.Schedule(task, prepared_args=args)
|
||||||
|
return ClientObjectRef(ticket.return_id)
|
||||||
|
|
|
@ -7,10 +7,10 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||||
import time
|
import time
|
||||||
import inspect
|
import inspect
|
||||||
from ray.experimental.client import stash_api_for_tests
|
from ray.experimental.client import stash_api_for_tests, _set_server_api
|
||||||
from ray.experimental.client.common import convert_from_arg
|
from ray.experimental.client.common import convert_from_arg
|
||||||
from ray.experimental.client.common import ClientObjectRef
|
from ray.experimental.client.common import ClientObjectRef
|
||||||
from ray.experimental.client.common import ClientRemoteFunc
|
from ray.experimental.client.server.core_ray_api import RayServerAPI
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -32,12 +32,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||||
item_ser = cloudpickle.dumps(item)
|
item_ser = cloudpickle.dumps(item)
|
||||||
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
|
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
|
||||||
|
|
||||||
def PutObject(self, request, context=None):
|
def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse:
|
||||||
obj = cloudpickle.loads(request.data)
|
obj = cloudpickle.loads(request.data)
|
||||||
|
objectref = self._put_and_retain_obj(obj)
|
||||||
|
return ray_client_pb2.PutResponse(id=objectref.binary())
|
||||||
|
|
||||||
|
def _put_and_retain_obj(self, obj) -> ray.ObjectRef:
|
||||||
objectref = ray.put(obj)
|
objectref = ray.put(obj)
|
||||||
self.object_refs[objectref.binary()] = objectref
|
self.object_refs[objectref.binary()] = objectref
|
||||||
logger.info("put: %s" % objectref)
|
logger.info("put: %s" % objectref)
|
||||||
return ray_client_pb2.PutResponse(id=objectref.binary())
|
return objectref
|
||||||
|
|
||||||
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
|
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
|
||||||
object_refs = [cloudpickle.loads(o) for o in request.object_refs]
|
object_refs = [cloudpickle.loads(o) for o in request.object_refs]
|
||||||
|
@ -70,70 +74,83 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||||
ready_object_ids=ready_object_ids,
|
ready_object_ids=ready_object_ids,
|
||||||
remaining_object_ids=remaining_object_ids)
|
remaining_object_ids=remaining_object_ids)
|
||||||
|
|
||||||
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
|
def Schedule(self, task, context=None,
|
||||||
|
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||||
logger.info("schedule: %s %s" %
|
logger.info("schedule: %s %s" %
|
||||||
(task.name,
|
(task.name,
|
||||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
|
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
|
||||||
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
||||||
return self._schedule_function(task, context)
|
return self._schedule_function(task, context, prepared_args)
|
||||||
elif task.type == ray_client_pb2.ClientTask.ACTOR:
|
elif task.type == ray_client_pb2.ClientTask.ACTOR:
|
||||||
return self._schedule_actor(task, context)
|
return self._schedule_actor(task, context, prepared_args)
|
||||||
elif task.type == ray_client_pb2.ClientTask.METHOD:
|
elif task.type == ray_client_pb2.ClientTask.METHOD:
|
||||||
return self._schedule_method(task, context)
|
return self._schedule_method(task, context, prepared_args)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Unimplemented Schedule task type: %s" %
|
"Unimplemented Schedule task type: %s" %
|
||||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
|
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
|
||||||
|
|
||||||
def _schedule_method(self, task: ray_client_pb2.ClientTask,
|
def _schedule_method(
|
||||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
self,
|
||||||
|
task: ray_client_pb2.ClientTask,
|
||||||
|
context=None,
|
||||||
|
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||||
actor_handle = self.actor_refs.get(task.payload_id)
|
actor_handle = self.actor_refs.get(task.payload_id)
|
||||||
if actor_handle is None:
|
if actor_handle is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Can't run an actor the server doesn't have a handle for")
|
"Can't run an actor the server doesn't have a handle for")
|
||||||
arglist = _convert_args(task.args)
|
arglist = _convert_args(task.args, prepared_args)
|
||||||
with stash_api_for_tests(self._test_mode):
|
with stash_api_for_tests(self._test_mode):
|
||||||
output = getattr(actor_handle, task.name).remote(*arglist)
|
output = getattr(actor_handle, task.name).remote(*arglist)
|
||||||
self.object_refs[output.binary()] = output
|
self.object_refs[output.binary()] = output
|
||||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||||
|
|
||||||
def _schedule_actor(self, task: ray_client_pb2.ClientTask,
|
def _schedule_actor(self,
|
||||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
task: ray_client_pb2.ClientTask,
|
||||||
|
context=None,
|
||||||
|
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||||
with stash_api_for_tests(self._test_mode):
|
with stash_api_for_tests(self._test_mode):
|
||||||
if task.payload_id not in self.registered_actor_classes:
|
if task.payload_id not in self.registered_actor_classes:
|
||||||
actor_class_ref = self.object_refs[task.payload_id]
|
actor_class_ref = self.object_refs[task.payload_id]
|
||||||
actor_class = ray.get(actor_class_ref)
|
actor_class = ray.get(actor_class_ref)
|
||||||
if not inspect.isclass(actor_class):
|
if not inspect.isclass(actor_class):
|
||||||
raise Exception("Attempting to schedule actor that "
|
raise Exception("Attempting to schedule actor that "
|
||||||
"isn't a ClientActorClass.")
|
"isn't a class.")
|
||||||
reg_class = ray.remote(actor_class)
|
reg_class = ray.remote(actor_class)
|
||||||
self.registered_actor_classes[task.payload_id] = reg_class
|
self.registered_actor_classes[task.payload_id] = reg_class
|
||||||
remote_class = self.registered_actor_classes[task.payload_id]
|
remote_class = self.registered_actor_classes[task.payload_id]
|
||||||
arglist = _convert_args(task.args)
|
arglist = _convert_args(task.args, prepared_args)
|
||||||
actor = remote_class.remote(*arglist)
|
actor = remote_class.remote(*arglist)
|
||||||
actor_ref = actor._actor_id
|
actorhandle = cloudpickle.dumps(actor)
|
||||||
self.actor_refs[actor_ref.binary()] = actor
|
self.actor_refs[actorhandle] = actor
|
||||||
return ray_client_pb2.ClientTaskTicket(return_id=actor_ref.binary())
|
return ray_client_pb2.ClientTaskTicket(return_id=actorhandle)
|
||||||
|
|
||||||
def _schedule_function(self, task: ray_client_pb2.ClientTask,
|
def _schedule_function(
|
||||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
self,
|
||||||
|
task: ray_client_pb2.ClientTask,
|
||||||
|
context=None,
|
||||||
|
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||||
if task.payload_id not in self.function_refs:
|
if task.payload_id not in self.function_refs:
|
||||||
funcref = self.object_refs[task.payload_id]
|
funcref = self.object_refs[task.payload_id]
|
||||||
func = ray.get(funcref)
|
func = ray.get(funcref)
|
||||||
if not isinstance(func, ClientRemoteFunc):
|
if not inspect.isfunction(func):
|
||||||
raise Exception("Attempting to schedule function that "
|
raise Exception("Attempting to schedule function that "
|
||||||
"isn't a ClientRemoteFunc.")
|
"isn't a function.")
|
||||||
self.function_refs[task.payload_id] = func
|
self.function_refs[task.payload_id] = ray.remote(func)
|
||||||
remote_func = self.function_refs[task.payload_id]
|
remote_func = self.function_refs[task.payload_id]
|
||||||
arglist = _convert_args(task.args)
|
arglist = _convert_args(task.args, prepared_args)
|
||||||
# Prepare call if we're in a test
|
# Prepare call if we're in a test
|
||||||
with stash_api_for_tests(self._test_mode):
|
with stash_api_for_tests(self._test_mode):
|
||||||
output = remote_func.remote(*arglist)
|
output = remote_func.remote(*arglist)
|
||||||
|
if output.binary() in self.object_refs:
|
||||||
|
raise Exception("already found it")
|
||||||
self.object_refs[output.binary()] = output
|
self.object_refs[output.binary()] = output
|
||||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||||
|
|
||||||
|
|
||||||
def _convert_args(arg_list):
|
def _convert_args(arg_list, prepared_args=None):
|
||||||
|
if prepared_args is not None:
|
||||||
|
return prepared_args
|
||||||
out = []
|
out = []
|
||||||
for arg in arg_list:
|
for arg in arg_list:
|
||||||
t = convert_from_arg(arg)
|
t = convert_from_arg(arg)
|
||||||
|
@ -147,6 +164,7 @@ def _convert_args(arg_list):
|
||||||
def serve(connection_str, test_mode=False):
|
def serve(connection_str, test_mode=False):
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||||
task_servicer = RayletServicer(test_mode=test_mode)
|
task_servicer = RayletServicer(test_mode=test_mode)
|
||||||
|
_set_server_api(RayServerAPI(task_servicer))
|
||||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||||
task_servicer, server)
|
task_servicer, server)
|
||||||
server.add_insecure_port(connection_str)
|
server.add_insecure_port(connection_str)
|
||||||
|
|
|
@ -3,6 +3,7 @@ It implements the Ray API functions that are forwarded through grpc calls
|
||||||
to the server.
|
to the server.
|
||||||
"""
|
"""
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
@ -14,11 +15,11 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||||
from ray.experimental.client.common import convert_to_arg
|
from ray.experimental.client.common import convert_to_arg
|
||||||
from ray.experimental.client.common import ClientObjectRef
|
from ray.experimental.client.common import ClientObjectRef
|
||||||
from ray.experimental.client.common import ClientActorRef
|
|
||||||
from ray.experimental.client.common import ClientActorClass
|
from ray.experimental.client.common import ClientActorClass
|
||||||
from ray.experimental.client.common import ClientRemoteMethod
|
|
||||||
from ray.experimental.client.common import ClientRemoteFunc
|
from ray.experimental.client.common import ClientRemoteFunc
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -130,50 +131,14 @@ class Worker:
|
||||||
raise TypeError("The @ray.remote decorator must be applied to "
|
raise TypeError("The @ray.remote decorator must be applied to "
|
||||||
"either a function or to a class.")
|
"either a function or to a class.")
|
||||||
|
|
||||||
def call_remote(self, instance, kind, *args, **kwargs):
|
def call_remote(self, instance, *args, **kwargs):
|
||||||
ticket = None
|
task = instance._prepare_client_task()
|
||||||
if kind == ray_client_pb2.ClientTask.FUNCTION:
|
for arg in args:
|
||||||
ticket = self._put_and_schedule(instance, kind, *args, **kwargs)
|
pb_arg = convert_to_arg(arg)
|
||||||
elif kind == ray_client_pb2.ClientTask.ACTOR:
|
task.args.append(pb_arg)
|
||||||
ticket = self._put_and_schedule(instance, kind, *args, **kwargs)
|
logging.debug("Scheduling %s" % task)
|
||||||
return ClientActorRef(ticket.return_id)
|
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||||
elif kind == ray_client_pb2.ClientTask.METHOD:
|
|
||||||
ticket = self._call_method(instance, *args, **kwargs)
|
|
||||||
|
|
||||||
if ticket is None:
|
|
||||||
raise Exception(
|
|
||||||
"Couldn't call_remote on %s for type %s" % (instance, kind))
|
|
||||||
return ClientObjectRef(ticket.return_id)
|
return ClientObjectRef(ticket.return_id)
|
||||||
|
|
||||||
def _call_method(self, instance: ClientRemoteMethod, *args, **kwargs):
|
|
||||||
if not isinstance(instance, ClientRemoteMethod):
|
|
||||||
raise TypeError("Client not passing a ClientRemoteMethod stub")
|
|
||||||
task = ray_client_pb2.ClientTask()
|
|
||||||
task.type = ray_client_pb2.ClientTask.METHOD
|
|
||||||
task.name = instance.method_name
|
|
||||||
task.payload_id = instance.actor_handle.actor_id.id
|
|
||||||
for arg in args:
|
|
||||||
pb_arg = convert_to_arg(arg)
|
|
||||||
task.args.append(pb_arg)
|
|
||||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
|
||||||
return ticket
|
|
||||||
|
|
||||||
def _put_and_schedule(self, item, task_type, *args, **kwargs):
|
|
||||||
if isinstance(item, ClientRemoteFunc):
|
|
||||||
ref = self._put(item)
|
|
||||||
elif isinstance(item, ClientActorClass):
|
|
||||||
ref = self._put(item.actor_cls)
|
|
||||||
else:
|
|
||||||
raise TypeError("Client not passing a ClientRemoteFunc stub")
|
|
||||||
task = ray_client_pb2.ClientTask()
|
|
||||||
task.type = task_type
|
|
||||||
task.name = item._name
|
|
||||||
task.payload_id = ref.id
|
|
||||||
for arg in args:
|
|
||||||
pb_arg = convert_to_arg(arg)
|
|
||||||
task.args.append(pb_arg)
|
|
||||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
|
||||||
return ticket
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.channel.close()
|
self.channel.close()
|
||||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import ray.experimental.client.server.server as ray_client_server
|
import ray.experimental.client.server.server as ray_client_server
|
||||||
from ray.experimental.client import ray
|
from ray.experimental.client import ray, reset_api
|
||||||
from ray.experimental.client.common import ClientObjectRef
|
from ray.experimental.client.common import ClientObjectRef
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ def ray_start_client_server():
|
||||||
yield ray
|
yield ray
|
||||||
ray.disconnect()
|
ray.disconnect()
|
||||||
server.stop(0)
|
server.stop(0)
|
||||||
|
reset_api()
|
||||||
|
|
||||||
|
|
||||||
def test_real_ray_fallback(ray_start_regular_shared):
|
def test_real_ray_fallback(ray_start_regular_shared):
|
||||||
|
@ -170,6 +171,70 @@ def test_basic_actor(ray_start_regular_shared):
|
||||||
assert count == 2
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_pass_handles(ray_start_regular_shared):
|
||||||
|
"""
|
||||||
|
Test that passing client handles to actors and functions to remote actors
|
||||||
|
in functions (on the server or raylet side) works transparently to the
|
||||||
|
caller.
|
||||||
|
"""
|
||||||
|
with ray_start_client_server() as ray:
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class ExecActor:
|
||||||
|
def exec(self, f, x):
|
||||||
|
return ray.get(f.remote(x))
|
||||||
|
|
||||||
|
def exec_exec(self, actor, f, x):
|
||||||
|
return ray.get(actor.exec.remote(f, x))
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def fact(x):
|
||||||
|
out = 1
|
||||||
|
while x > 0:
|
||||||
|
out = out * x
|
||||||
|
x -= 1
|
||||||
|
return out
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def func_exec(f, x):
|
||||||
|
return ray.get(f.remote(x))
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def func_actor_exec(actor, f, x):
|
||||||
|
return ray.get(actor.exec.remote(f, x))
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def sneaky_func_exec(obj, x):
|
||||||
|
return ray.get(obj["f"].remote(x))
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def sneaky_actor_exec(obj, x):
|
||||||
|
return ray.get(obj["actor"].exec.remote(obj["f"], x))
|
||||||
|
|
||||||
|
def local_fact(x):
|
||||||
|
if x <= 0:
|
||||||
|
return 1
|
||||||
|
return x * local_fact(x - 1)
|
||||||
|
|
||||||
|
assert ray.get(fact.remote(7)) == local_fact(7)
|
||||||
|
assert ray.get(func_exec.remote(fact, 8)) == local_fact(8)
|
||||||
|
test_obj = {}
|
||||||
|
test_obj["f"] = fact
|
||||||
|
assert ray.get(sneaky_func_exec.remote(test_obj, 5)) == local_fact(5)
|
||||||
|
actor_handle = ExecActor.remote()
|
||||||
|
assert ray.get(actor_handle.exec.remote(fact, 7)) == local_fact(7)
|
||||||
|
assert ray.get(func_actor_exec.remote(actor_handle, fact,
|
||||||
|
10)) == local_fact(10)
|
||||||
|
second_actor = ExecActor.remote()
|
||||||
|
assert ray.get(actor_handle.exec_exec.remote(second_actor, fact,
|
||||||
|
9)) == local_fact(9)
|
||||||
|
test_actor_obj = {}
|
||||||
|
test_actor_obj["actor"] = second_actor
|
||||||
|
test_actor_obj["f"] = fact
|
||||||
|
assert ray.get(sneaky_actor_exec.remote(test_actor_obj,
|
||||||
|
4)) == local_fact(4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
sys.exit(pytest.main(["-v", __file__]))
|
sys.exit(pytest.main(["-v", __file__]))
|
||||||
|
|
Loading…
Add table
Reference in a new issue