mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[ray_client]: Implement object retain/release and Data Streaming API (#12818)
This commit is contained in:
parent
55ae567f7a
commit
5cfa1934e4
15 changed files with 1000 additions and 310 deletions
|
@ -91,15 +91,22 @@ def _get_client_api() -> APIImpl:
|
|||
return api
|
||||
|
||||
|
||||
def _get_server_instance():
|
||||
"""Used inside tests to inspect the running server.
|
||||
"""
|
||||
global _server_api
|
||||
if _server_api is not None:
|
||||
return _server_api.server
|
||||
|
||||
|
||||
class RayAPIStub:
|
||||
def connect(self,
|
||||
conn_str: str,
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
stub=None):
|
||||
stub=None) -> None:
|
||||
from ray.experimental.client.worker import Worker
|
||||
_client_worker = Worker(
|
||||
conn_str, secure=secure, metadata=metadata, stub=stub)
|
||||
_client_worker = Worker(conn_str, secure=secure, metadata=metadata)
|
||||
_set_client_api(ClientAPI(_client_worker))
|
||||
|
||||
def disconnect(self):
|
||||
|
@ -113,6 +120,10 @@ class RayAPIStub:
|
|||
api = _get_client_api()
|
||||
return getattr(api, key)
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
global _client_api
|
||||
return _client_api is not None
|
||||
|
||||
|
||||
ray = RayAPIStub()
|
||||
|
||||
|
|
|
@ -138,6 +138,31 @@ class APIImpl(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def call_release(self, id: bytes) -> None:
|
||||
"""
|
||||
Attempts to release an object reference.
|
||||
|
||||
When client references are destructed, they release their reference,
|
||||
which can opportunistically send a notification through the datachannel
|
||||
to release the reference being held for that object on the server.
|
||||
|
||||
Args:
|
||||
id: The id of the reference to release on the server side.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def call_retain(self, id: bytes) -> None:
|
||||
"""
|
||||
Attempts to retain a client object reference.
|
||||
|
||||
Increments the reference count on the client side, to prevent
|
||||
the client worker from attempting to release the server reference.
|
||||
|
||||
Args:
|
||||
id: The id of the reference to retain on the client side.
|
||||
"""
|
||||
|
||||
|
||||
class ClientAPI(APIImpl):
|
||||
"""
|
||||
|
@ -163,6 +188,12 @@ class ClientAPI(APIImpl):
|
|||
def call_remote(self, instance: "ClientStub", *args, **kwargs):
|
||||
return self.worker.call_remote(instance, *args, **kwargs)
|
||||
|
||||
def call_release(self, id: bytes) -> None:
|
||||
return self.worker.call_release(id)
|
||||
|
||||
def call_retain(self, id: bytes) -> None:
|
||||
return self.worker.call_retain(id)
|
||||
|
||||
def close(self) -> None:
|
||||
return self.worker.close()
|
||||
|
||||
|
|
123
python/ray/experimental/client/client_pickler.py
Normal file
123
python/ray/experimental/client/client_pickler.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
"""
|
||||
Implements the client side of the client/server pickling protocol.
|
||||
|
||||
All ray client client/server data transfer happens through this pickling
|
||||
protocol. The model is as follows:
|
||||
|
||||
* All Client objects (eg ClientObjectRef) always live on the client and
|
||||
are never represented in the server
|
||||
* All Ray objects (eg, ray.ObjectRef) always live on the server and are
|
||||
never returned to the client
|
||||
* In order to translate between these two references, PickleStub tuples
|
||||
are generated as persistent ids in the data blobs during the pickling
|
||||
and unpickling of these objects.
|
||||
|
||||
The PickleStubs have just enough information to find or generate their
|
||||
associated partner object on either side.
|
||||
|
||||
This also has the advantage of avoiding predefined pickle behavior for ray
|
||||
objects, which may include ray internal reference counting.
|
||||
|
||||
ClientPickler dumps things from the client into the appropriate stubs
|
||||
ServerUnpickler loads stubs from the server into their client counterparts.
|
||||
"""
|
||||
|
||||
import cloudpickle
|
||||
import io
|
||||
import sys
|
||||
|
||||
from typing import NamedTuple
|
||||
from typing import Any
|
||||
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientActorRef
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.common import SelfReferenceSentinel
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
try:
|
||||
import pickle5 as pickle # noqa: F401
|
||||
except ImportError:
|
||||
import pickle # noqa: F401
|
||||
else:
|
||||
import pickle # noqa: F401
|
||||
|
||||
PickleStub = NamedTuple("PickleStub", [("type", str), ("client_id", str),
|
||||
("ref_id", bytes)])
|
||||
|
||||
|
||||
class ClientPickler(cloudpickle.CloudPickler):
|
||||
def __init__(self, client_id, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client_id = client_id
|
||||
|
||||
def persistent_id(self, obj):
|
||||
if isinstance(obj, ClientObjectRef):
|
||||
return PickleStub(
|
||||
type="Object",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj.id,
|
||||
)
|
||||
elif isinstance(obj, ClientActorHandle):
|
||||
return PickleStub(
|
||||
type="Actor",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._actor_id,
|
||||
)
|
||||
elif isinstance(obj, ClientRemoteFunc):
|
||||
# TODO(barakmich): This is going to have trouble with mutually
|
||||
# recursive functions that haven't, as yet, been executed. It's
|
||||
# relatively doable (keep track of intermediate refs in progress
|
||||
# with ensure_ref and return appropriately) But punting for now.
|
||||
if obj._ref is None:
|
||||
obj._ensure_ref()
|
||||
if type(obj._ref) == SelfReferenceSentinel:
|
||||
return PickleStub(
|
||||
type="RemoteFuncSelfReference",
|
||||
client_id=self.client_id,
|
||||
ref_id=b"")
|
||||
return PickleStub(
|
||||
type="RemoteFunc",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._ref.id)
|
||||
return None
|
||||
|
||||
|
||||
class ServerUnpickler(pickle.Unpickler):
|
||||
def persistent_load(self, pid):
|
||||
assert isinstance(pid, PickleStub)
|
||||
if pid.type == "Object":
|
||||
return ClientObjectRef(id=pid.ref_id)
|
||||
elif pid.type == "Actor":
|
||||
return ClientActorHandle(ClientActorRef(id=pid.ref_id))
|
||||
else:
|
||||
raise NotImplementedError("Being passed back an unknown stub")
|
||||
|
||||
|
||||
def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes:
|
||||
with io.BytesIO() as file:
|
||||
cp = ClientPickler(client_id, file, protocol=protocol)
|
||||
cp.dump(obj)
|
||||
return file.getvalue()
|
||||
|
||||
|
||||
def loads_from_server(data: bytes,
|
||||
*,
|
||||
fix_imports=True,
|
||||
encoding="ASCII",
|
||||
errors="strict") -> Any:
|
||||
if isinstance(data, str):
|
||||
raise TypeError("Can't load pickle from unicode string")
|
||||
file = io.BytesIO(data)
|
||||
return ServerUnpickler(
|
||||
file, fix_imports=fix_imports, encoding=encoding,
|
||||
errors=errors).load()
|
||||
|
||||
|
||||
def convert_to_arg(val: Any, client_id: str) -> ray_client_pb2.Arg:
|
||||
out = ray_client_pb2.Arg()
|
||||
out.local = ray_client_pb2.Arg.Locality.INTERNED
|
||||
out.data = dumps_from_client(val, client_id)
|
||||
return out
|
|
@ -1,16 +1,12 @@
|
|||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
from ray.experimental.client import ray
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from ray import cloudpickle
|
||||
|
||||
import base64
|
||||
|
||||
|
||||
class ClientBaseRef:
|
||||
def __init__(self, id, handle=None):
|
||||
self.id = id
|
||||
self.handle = handle
|
||||
def __init__(self, id: bytes):
|
||||
self.id: bytes = id
|
||||
ray.call_retain(id)
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%s)" % (
|
||||
|
@ -24,14 +20,13 @@ class ClientBaseRef:
|
|||
def binary(self):
|
||||
return self.id
|
||||
|
||||
@classmethod
|
||||
def from_remote_ref(cls, ref: ray_client_pb2.RemoteRef):
|
||||
return cls(id=ref.id, handle=ref.handle)
|
||||
def __del__(self):
|
||||
if ray.is_connected():
|
||||
ray.call_release(self.id)
|
||||
|
||||
|
||||
class ClientObjectRef(ClientBaseRef):
|
||||
def _unpack_ref(self):
|
||||
return cloudpickle.loads(self.handle)
|
||||
pass
|
||||
|
||||
|
||||
class ClientActorRef(ClientBaseRef):
|
||||
|
@ -53,50 +48,42 @@ class ClientRemoteFunc(ClientStub):
|
|||
_func: The actual function to execute remotely
|
||||
_name: The original name of the function
|
||||
_ref: The ClientObjectRef of the pickled code of the function, _func
|
||||
_raylet_remote: The Raylet-side ray.remote_function.RemoteFunction
|
||||
for this object
|
||||
"""
|
||||
|
||||
def __init__(self, f):
|
||||
self._func = f
|
||||
self._name = f.__name__
|
||||
self.id = None
|
||||
|
||||
# self._ref can be lazily instantiated. Rather than eagerly creating
|
||||
# function data objects in the server we can put them just before we
|
||||
# execute the function, especially in cases where many @ray.remote
|
||||
# functions exist in a library and only a handful are ever executed by
|
||||
# a user of the library.
|
||||
#
|
||||
# TODO(barakmich): This ref might actually be better as a serialized
|
||||
# ObjectRef. This requires being able to serialize the ref without
|
||||
# pinning it (as the lifetime of the ref is tied with the server, not
|
||||
# the client)
|
||||
self._ref = None
|
||||
self._raylet_remote = None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote function cannot be called directly. "
|
||||
"Use {self._name}.remote method instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.call_remote(self, *args, **kwargs)
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
if self._raylet_remote is None:
|
||||
self._raylet_remote = ray.remote(self._func)
|
||||
return self._raylet_remote
|
||||
return ClientObjectRef(ray.call_remote(self, *args, **kwargs))
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
def _ensure_ref(self):
|
||||
if self._ref is None:
|
||||
# While calling ray.put() on our function, if
|
||||
# our function is recursive, it will attempt to
|
||||
# encode the ClientRemoteFunc -- itself -- and
|
||||
# infinitely recurse on _ensure_ref.
|
||||
#
|
||||
# So we set the state of the reference to be an
|
||||
# in-progress self reference value, which
|
||||
# the encoding can detect and handle correctly.
|
||||
self._ref = SelfReferenceSentinel()
|
||||
self._ref = ray.put(self._func)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
self._ensure_ref()
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.FUNCTION
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.handle
|
||||
task.payload_id = self._ref.id
|
||||
return task
|
||||
|
||||
|
||||
|
@ -109,14 +96,12 @@ class ClientActorClass(ClientStub):
|
|||
actor_cls: The actual class to execute remotely
|
||||
_name: The original name of the class
|
||||
_ref: The ClientObjectRef of the pickled `actor_cls`
|
||||
_raylet_remote: The Raylet-side ray.ActorClass for this object
|
||||
"""
|
||||
|
||||
def __init__(self, actor_cls):
|
||||
self.actor_cls = actor_cls
|
||||
self._name = actor_cls.__name__
|
||||
self._ref = None
|
||||
self._raylet_remote = None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote actor cannot be instantiated directly. "
|
||||
|
@ -135,10 +120,10 @@ class ClientActorClass(ClientStub):
|
|||
self._name = state["_name"]
|
||||
self._ref = state["_ref"]
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
def remote(self, *args, **kwargs) -> "ClientActorHandle":
|
||||
# Actually instantiate the actor
|
||||
ref = ray.call_remote(self, *args, **kwargs)
|
||||
return ClientActorHandle(ClientActorRef(ref.id, ref.handle), self)
|
||||
ref_id = ray.call_remote(self, *args, **kwargs)
|
||||
return ClientActorHandle(ClientActorRef(ref_id), self)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
|
||||
|
@ -154,7 +139,7 @@ class ClientActorClass(ClientStub):
|
|||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.ACTOR
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.handle
|
||||
task.payload_id = self._ref.id
|
||||
return task
|
||||
|
||||
|
||||
|
@ -177,26 +162,9 @@ class ClientActorHandle(ClientStub):
|
|||
def __init__(self, actor_ref: ClientActorRef,
|
||||
actor_class: ClientActorClass):
|
||||
self.actor_ref = actor_ref
|
||||
self.actor_class = actor_class
|
||||
self._real_actor_handle = None
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
if self._real_actor_handle is None:
|
||||
self._real_actor_handle = cloudpickle.loads(self.actor_ref.handle)
|
||||
return self._real_actor_handle
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_ref": self.actor_ref,
|
||||
"actor_class": self.actor_class,
|
||||
"_real_actor_handle": self._real_actor_handle,
|
||||
}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_ref = state["actor_ref"]
|
||||
self.actor_class = state["actor_class"]
|
||||
self._real_actor_handle = state["_real_actor_handle"]
|
||||
def __del__(self) -> None:
|
||||
ray.call_release(self.actor_ref.id)
|
||||
|
||||
@property
|
||||
def _actor_id(self):
|
||||
|
@ -226,65 +194,27 @@ class ClientRemoteMethod(ClientStub):
|
|||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote method cannot be called directly. "
|
||||
"Use {self._name}.remote() instead")
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
return getattr(self.actor_handle._get_ray_remote_impl(),
|
||||
self.method_name)
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_handle": self.actor_handle,
|
||||
"method_name": self.method_name,
|
||||
}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_handle = state["actor_handle"]
|
||||
self.method_name = state["method_name"]
|
||||
f"Use {self._name}.remote() instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.call_remote(self, *args, **kwargs)
|
||||
return ClientObjectRef(ray.call_remote(self, *args, **kwargs))
|
||||
|
||||
def __repr__(self):
|
||||
name = "%s.%s" % (self.actor_handle.actor_class._name,
|
||||
self.method_name)
|
||||
return "ClientRemoteMethod(%s, %s)" % (name,
|
||||
self.actor_handle.actor_id)
|
||||
return "ClientRemoteMethod(%s, %s)" % (self.method_name,
|
||||
self.actor_handle)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.METHOD
|
||||
task.name = self.method_name
|
||||
task.payload_id = self.actor_handle.actor_ref.handle
|
||||
task.payload_id = self.actor_handle.actor_ref.id
|
||||
return task
|
||||
|
||||
|
||||
def convert_from_arg(pb) -> Any:
|
||||
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
||||
return ClientObjectRef(pb.reference_id)
|
||||
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
|
||||
return cloudpickle.loads(pb.data)
|
||||
|
||||
raise Exception("convert_from_arg: Uncovered locality enum")
|
||||
class DataEncodingSentinel:
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
def convert_to_arg(val):
|
||||
out = ray_client_pb2.Arg()
|
||||
if isinstance(val, ClientObjectRef):
|
||||
out.local = ray_client_pb2.Arg.Locality.REFERENCE
|
||||
out.reference_id = val.id
|
||||
else:
|
||||
out.local = ray_client_pb2.Arg.Locality.INTERNED
|
||||
out.data = cloudpickle.dumps(val)
|
||||
return out
|
||||
|
||||
|
||||
def encode_exception(exception) -> str:
|
||||
data = cloudpickle.dumps(exception)
|
||||
return base64.standard_b64encode(data).decode()
|
||||
|
||||
|
||||
def decode_exception(data) -> Exception:
|
||||
data = base64.standard_b64decode(data)
|
||||
return cloudpickle.loads(data)
|
||||
class SelfReferenceSentinel(DataEncodingSentinel):
|
||||
pass
|
||||
|
|
103
python/ray/experimental/client/dataclient.py
Normal file
103
python/ray/experimental/client/dataclient.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
"""
|
||||
This file implements a threaded stream controller to abstract a data stream
|
||||
back to the ray clientserver.
|
||||
"""
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import grpc
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The maximum field value for request_id -- which is also the maximum
|
||||
# number of simultaneous in-flight requests.
|
||||
INT32_MAX = (2**31) - 1
|
||||
|
||||
|
||||
class DataClient:
|
||||
def __init__(self, channel: "grpc._channel.Channel", client_id: str):
|
||||
"""Initializes a thread-safe datapath over a Ray Client gRPC channel.
|
||||
|
||||
Args:
|
||||
channel: connected gRPC channel
|
||||
"""
|
||||
self.channel = channel
|
||||
self.request_queue = queue.Queue()
|
||||
self.data_thread = self._start_datathread()
|
||||
self.ready_data: Dict[int, Any] = {}
|
||||
self.cv = threading.Condition()
|
||||
self._req_id = 0
|
||||
self._client_id = client_id
|
||||
self.data_thread.start()
|
||||
|
||||
def _next_id(self) -> int:
|
||||
self._req_id += 1
|
||||
if self._req_id > INT32_MAX:
|
||||
self._req_id = 1
|
||||
# Responses that aren't tracked (like opportunistic releases)
|
||||
# have req_id=0, so make sure we never mint such an id.
|
||||
assert self._req_id != 0
|
||||
return self._req_id
|
||||
|
||||
def _start_datathread(self) -> threading.Thread:
|
||||
return threading.Thread(target=self._data_main, args=(), daemon=True)
|
||||
|
||||
def _data_main(self) -> None:
|
||||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
|
||||
resp_stream = stub.Datapath(
|
||||
iter(self.request_queue.get, None),
|
||||
metadata=(("client_id", self._client_id), ))
|
||||
for response in resp_stream:
|
||||
if response.req_id == 0:
|
||||
# This is not being waited for.
|
||||
logger.debug(f"Got unawaited response {response}")
|
||||
continue
|
||||
with self.cv:
|
||||
self.ready_data[response.req_id] = response
|
||||
self.cv.notify_all()
|
||||
|
||||
def close(self, close_channel: bool = False) -> None:
|
||||
if self.request_queue is not None:
|
||||
self.request_queue.put(None)
|
||||
self.request_queue = None
|
||||
if self.data_thread is not None:
|
||||
self.data_thread.join()
|
||||
self.data_thread = None
|
||||
if close_channel:
|
||||
self.channel.close()
|
||||
|
||||
def _blocking_send(self, req: ray_client_pb2.DataRequest
|
||||
) -> ray_client_pb2.DataResponse:
|
||||
req_id = self._next_id()
|
||||
req.req_id = req_id
|
||||
self.request_queue.put(req)
|
||||
data = None
|
||||
with self.cv:
|
||||
self.cv.wait_for(lambda: req_id in self.ready_data)
|
||||
data = self.ready_data[req_id]
|
||||
del self.ready_data[req_id]
|
||||
return data
|
||||
|
||||
def GetObject(self, request: ray_client_pb2.GetRequest,
|
||||
context=None) -> ray_client_pb2.GetResponse:
|
||||
datareq = ray_client_pb2.DataRequest(get=request, )
|
||||
resp = self._blocking_send(datareq)
|
||||
return resp.get
|
||||
|
||||
def PutObject(self, request: ray_client_pb2.PutRequest,
|
||||
context=None) -> ray_client_pb2.PutResponse:
|
||||
datareq = ray_client_pb2.DataRequest(put=request, )
|
||||
resp = self._blocking_send(datareq)
|
||||
return resp.put
|
||||
|
||||
def ReleaseObject(self,
|
||||
request: ray_client_pb2.ReleaseRequest,
|
||||
context=None) -> None:
|
||||
datareq = ray_client_pb2.DataRequest(release=request, )
|
||||
self.request_queue.put(datareq)
|
|
@ -11,12 +11,15 @@ from typing import Any
|
|||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import logging
|
||||
import ray
|
||||
|
||||
from ray.experimental.client.api import APIImpl
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientStub
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CoreRayAPI(APIImpl):
|
||||
"""
|
||||
|
@ -26,12 +29,6 @@ class CoreRayAPI(APIImpl):
|
|||
"""
|
||||
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
if isinstance(vals, list):
|
||||
if isinstance(vals[0], ClientObjectRef):
|
||||
return ray.get(
|
||||
[val._unpack_ref() for val in vals], timeout=timeout)
|
||||
elif isinstance(vals, ClientObjectRef):
|
||||
return ray.get(vals._unpack_ref(), timeout=timeout)
|
||||
return ray.get(vals, timeout=timeout)
|
||||
|
||||
def put(self, vals: Any, *args,
|
||||
|
@ -45,7 +42,8 @@ class CoreRayAPI(APIImpl):
|
|||
return ray.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, instance: ClientStub, *args, **kwargs):
|
||||
return instance._get_ray_remote_impl().remote(*args, **kwargs)
|
||||
raise NotImplementedError(
|
||||
"Should not attempt execution of a client stub inside the raylet")
|
||||
|
||||
def close(self) -> None:
|
||||
return None
|
||||
|
@ -59,6 +57,12 @@ class CoreRayAPI(APIImpl):
|
|||
def is_initialized(self) -> bool:
|
||||
return ray.is_initialized()
|
||||
|
||||
def call_release(self, id: bytes) -> None:
|
||||
return None
|
||||
|
||||
def call_retain(self, id: bytes) -> None:
|
||||
return None
|
||||
|
||||
# Allow for generic fallback to ray.* in remote methods. This allows calls
|
||||
# like ray.nodes() to be run in remote functions even though the client
|
||||
# doesn't currently support them.
|
||||
|
@ -76,26 +80,7 @@ class RayServerAPI(CoreRayAPI):
|
|||
def __init__(self, server_instance):
|
||||
self.server = server_instance
|
||||
|
||||
# Wrap single item into list if needed before calling server put.
|
||||
def put(self, vals: Any, *args, **kwargs) -> ClientObjectRef:
|
||||
to_put = []
|
||||
single = False
|
||||
if isinstance(vals, list):
|
||||
to_put = vals
|
||||
else:
|
||||
single = True
|
||||
to_put.append(vals)
|
||||
|
||||
out = [self._put(x) for x in to_put]
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _put(self, val: Any):
|
||||
resp = self.server._put_and_retain_obj(val)
|
||||
return ClientObjectRef(resp.id)
|
||||
|
||||
def call_remote(self, instance: ClientStub, *args, **kwargs):
|
||||
def call_remote(self, instance: ClientStub, *args, **kwargs) -> bytes:
|
||||
task = instance._prepare_client_task()
|
||||
ticket = self.server.Schedule(task, prepared_args=args)
|
||||
return ClientObjectRef(ticket.return_id)
|
||||
return ticket.return_id
|
||||
|
|
54
python/ray/experimental/client/server/dataservicer.py
Normal file
54
python/ray/experimental/client/server/dataservicer.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
import logging
|
||||
import grpc
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.experimental.client.server.server import RayletServicer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
def __init__(self, basic_service: "RayletServicer"):
|
||||
self.basic_service = basic_service
|
||||
|
||||
def Datapath(self, request_iterator, context):
|
||||
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||
client_id = metadata["client_id"]
|
||||
if client_id == "":
|
||||
logger.error("Client connecting with no client_id")
|
||||
return
|
||||
logger.info(f"New data connection from client {client_id}")
|
||||
try:
|
||||
for req in request_iterator:
|
||||
resp = None
|
||||
req_type = req.WhichOneof("type")
|
||||
if req_type == "get":
|
||||
get_resp = self.basic_service._get_object(
|
||||
req.get, client_id)
|
||||
resp = ray_client_pb2.DataResponse(get=get_resp)
|
||||
elif req_type == "put":
|
||||
put_resp = self.basic_service._put_object(
|
||||
req.put, client_id)
|
||||
resp = ray_client_pb2.DataResponse(put=put_resp)
|
||||
elif req_type == "release":
|
||||
released = []
|
||||
for rel_id in req.release.ids:
|
||||
rel = self.basic_service.release(client_id, rel_id)
|
||||
released.append(rel)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
release=ray_client_pb2.ReleaseResponse(ok=released))
|
||||
else:
|
||||
raise Exception(f"Unreachable code: Request type "
|
||||
f"{req_type} not handled in Datapath")
|
||||
resp.req_id = req.req_id
|
||||
yield resp
|
||||
except grpc.RpcError as e:
|
||||
logger.debug(f"Closing channel: {e}")
|
||||
finally:
|
||||
logger.info(f"Lost data connection from client {client_id}")
|
||||
self.basic_service.release_all(client_id)
|
|
@ -1,6 +1,12 @@
|
|||
import logging
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import base64
|
||||
from collections import defaultdict
|
||||
|
||||
from typing import Dict
|
||||
from typing import Set
|
||||
|
||||
from ray import cloudpickle
|
||||
import ray
|
||||
import ray.state
|
||||
|
@ -10,21 +16,26 @@ import time
|
|||
import inspect
|
||||
import json
|
||||
from ray.experimental.client import stash_api_for_tests, _set_server_api
|
||||
from ray.experimental.client.common import convert_from_arg
|
||||
from ray.experimental.client.common import encode_exception
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.server.server_pickler import convert_from_arg
|
||||
from ray.experimental.client.server.server_pickler import dumps_from_server
|
||||
from ray.experimental.client.server.server_pickler import loads_from_client
|
||||
from ray.experimental.client.server.core_ray_api import RayServerAPI
|
||||
from ray.experimental.client.server.dataservicer import DataServicer
|
||||
from ray.experimental.client.server.server_stubs import current_func
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
def __init__(self, test_mode=False):
|
||||
self.object_refs = {}
|
||||
self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict(
|
||||
dict)
|
||||
self.function_refs = {}
|
||||
self.actor_refs = {}
|
||||
self.actor_refs: Dict[bytes, ray.ActorHandle] = {}
|
||||
self.actor_owners: Dict[str, Set[bytes]] = defaultdict(set)
|
||||
self.registered_actor_classes = {}
|
||||
self._test_mode = test_mode
|
||||
self._current_function_stub = None
|
||||
|
||||
def ClusterInfo(self, request,
|
||||
context=None) -> ray_client_pb2.ClusterInfoResponse:
|
||||
|
@ -61,20 +72,59 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
raise TypeError("Unsupported cluster info type")
|
||||
return json.dumps(data)
|
||||
|
||||
def Terminate(self, request, context=None):
|
||||
if request.WhichOneof("terminate_type") == "task_object":
|
||||
def release(self, client_id: str, id: bytes) -> bool:
|
||||
if client_id in self.object_refs:
|
||||
if id in self.object_refs[client_id]:
|
||||
logger.debug(f"Releasing object {id.hex()} for {client_id}")
|
||||
del self.object_refs[client_id][id]
|
||||
return True
|
||||
|
||||
if client_id in self.actor_owners:
|
||||
if id in self.actor_owners[client_id]:
|
||||
logger.debug(f"Releasing actor {id.hex()} for {client_id}")
|
||||
del self.actor_refs[id]
|
||||
self.actor_owners[client_id].remove(id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def release_all(self, client_id):
|
||||
self._release_objects(client_id)
|
||||
self._release_actors(client_id)
|
||||
|
||||
def _release_objects(self, client_id):
|
||||
if client_id not in self.object_refs:
|
||||
logger.debug(f"Releasing client with no references: {client_id}")
|
||||
return
|
||||
count = len(self.object_refs[client_id])
|
||||
del self.object_refs[client_id]
|
||||
logger.debug(f"Released all {count} objects for client {client_id}")
|
||||
|
||||
def _release_actors(self, client_id):
|
||||
if client_id not in self.actor_owners:
|
||||
logger.debug(f"Releasing client with no actors: {client_id}")
|
||||
count = 0
|
||||
for id_bytes in self.actor_owners[client_id]:
|
||||
count += 1
|
||||
del self.actor_refs[id_bytes]
|
||||
del self.actor_owners[client_id]
|
||||
logger.debug(f"Released all {count} actors for client: {client_id}")
|
||||
|
||||
def Terminate(self, req, context=None):
|
||||
if req.WhichOneof("terminate_type") == "task_object":
|
||||
try:
|
||||
object_ref = cloudpickle.loads(request.task_object.handle)
|
||||
object_ref = \
|
||||
self.object_refs[req.client_id][req.task_object.id]
|
||||
ray.cancel(
|
||||
object_ref,
|
||||
force=request.task_object.force,
|
||||
recursive=request.task_object.recursive)
|
||||
force=req.task_object.force,
|
||||
recursive=req.task_object.recursive)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
elif request.WhichOneof("terminate_type") == "actor":
|
||||
elif req.WhichOneof("terminate_type") == "actor":
|
||||
try:
|
||||
actor_ref = cloudpickle.loads(request.actor.handle)
|
||||
ray.kill(actor_ref, no_restart=request.actor.no_restart)
|
||||
actor_ref = self.actor_refs[req.actor.id]
|
||||
ray.kill(actor_ref, no_restart=req.actor.no_restart)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
else:
|
||||
|
@ -84,61 +134,71 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
return ray_client_pb2.TerminateResponse(ok=True)
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
request_ref = cloudpickle.loads(request.handle)
|
||||
if request_ref.binary() not in self.object_refs:
|
||||
return self._get_object(request, "", context)
|
||||
|
||||
def _get_object(self, request, client_id: str, context=None):
|
||||
if request.id not in self.object_refs[client_id]:
|
||||
return ray_client_pb2.GetResponse(valid=False)
|
||||
objectref = self.object_refs[request_ref.binary()]
|
||||
logger.info("get: %s" % objectref)
|
||||
objectref = self.object_refs[client_id][request.id]
|
||||
logger.debug("get: %s" % objectref)
|
||||
try:
|
||||
item = ray.get(objectref, timeout=request.timeout)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
item_ser = cloudpickle.dumps(item)
|
||||
return ray_client_pb2.GetResponse(
|
||||
valid=False, error=cloudpickle.dumps(e))
|
||||
item_ser = dumps_from_server(item, client_id, self)
|
||||
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
|
||||
|
||||
def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse:
|
||||
obj = cloudpickle.loads(request.data)
|
||||
objectref = self._put_and_retain_obj(obj)
|
||||
pickled_ref = cloudpickle.dumps(objectref)
|
||||
return ray_client_pb2.PutResponse(
|
||||
ref=make_remote_ref(objectref.binary(), pickled_ref))
|
||||
def PutObject(self, request: ray_client_pb2.PutRequest,
|
||||
context=None) -> ray_client_pb2.PutResponse:
|
||||
"""gRPC entrypoint for unary PutObject
|
||||
"""
|
||||
return self._put_object(request, "", context)
|
||||
|
||||
def _put_and_retain_obj(self, obj) -> ray.ObjectRef:
|
||||
def _put_object(self,
|
||||
request: ray_client_pb2.PutRequest,
|
||||
client_id: str,
|
||||
context=None):
|
||||
"""Put an object in the cluster with ray.put() via gRPC.
|
||||
|
||||
Args:
|
||||
request: PutRequest with pickled data.
|
||||
client_id: The client who owns this data, for tracking when to
|
||||
delete this reference.
|
||||
context: gRPC context.
|
||||
"""
|
||||
obj = loads_from_client(request.data, self)
|
||||
objectref = ray.put(obj)
|
||||
self.object_refs[objectref.binary()] = objectref
|
||||
logger.info("put: %s" % objectref)
|
||||
return objectref
|
||||
self.object_refs[client_id][objectref.binary()] = objectref
|
||||
logger.debug("put: %s" % objectref)
|
||||
return ray_client_pb2.PutResponse(id=objectref.binary())
|
||||
|
||||
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
|
||||
object_refs = [cloudpickle.loads(o) for o in request.object_handles]
|
||||
object_refs = []
|
||||
for id in request.object_ids:
|
||||
if id not in self.object_refs[request.client_id]:
|
||||
raise Exception(
|
||||
"Asking for a ref not associated with this client: %s" %
|
||||
str(id))
|
||||
object_refs.append(self.object_refs[request.client_id][id])
|
||||
num_returns = request.num_returns
|
||||
timeout = request.timeout
|
||||
object_refs_ids = []
|
||||
for object_ref in object_refs:
|
||||
if object_ref.binary() not in self.object_refs:
|
||||
return ray_client_pb2.WaitResponse(valid=False)
|
||||
object_refs_ids.append(self.object_refs[object_ref.binary()])
|
||||
try:
|
||||
ready_object_refs, remaining_object_refs = ray.wait(
|
||||
object_refs_ids,
|
||||
object_refs,
|
||||
num_returns=num_returns,
|
||||
timeout=timeout if timeout != -1 else None)
|
||||
except Exception:
|
||||
# TODO(ameer): improve exception messages.
|
||||
return ray_client_pb2.WaitResponse(valid=False)
|
||||
logger.info("wait: %s %s" % (str(ready_object_refs),
|
||||
str(remaining_object_refs)))
|
||||
logger.debug("wait: %s %s" % (str(ready_object_refs),
|
||||
str(remaining_object_refs)))
|
||||
ready_object_ids = [
|
||||
make_remote_ref(
|
||||
id=ready_object_ref.binary(),
|
||||
handle=cloudpickle.dumps(ready_object_ref),
|
||||
) for ready_object_ref in ready_object_refs
|
||||
ready_object_ref.binary() for ready_object_ref in ready_object_refs
|
||||
]
|
||||
remaining_object_ids = [
|
||||
make_remote_ref(
|
||||
id=remaining_object_ref.binary(),
|
||||
handle=cloudpickle.dumps(remaining_object_ref),
|
||||
) for remaining_object_ref in remaining_object_refs
|
||||
remaining_object_ref.binary()
|
||||
for remaining_object_ref in remaining_object_refs
|
||||
]
|
||||
return ray_client_pb2.WaitResponse(
|
||||
valid=True,
|
||||
|
@ -150,16 +210,17 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
logger.info("schedule: %s %s" %
|
||||
(task.name,
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
|
||||
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
||||
return self._schedule_function(task, context, prepared_args)
|
||||
elif task.type == ray_client_pb2.ClientTask.ACTOR:
|
||||
return self._schedule_actor(task, context, prepared_args)
|
||||
elif task.type == ray_client_pb2.ClientTask.METHOD:
|
||||
return self._schedule_method(task, context, prepared_args)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented Schedule task type: %s" %
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
||||
return self._schedule_function(task, context, prepared_args)
|
||||
elif task.type == ray_client_pb2.ClientTask.ACTOR:
|
||||
return self._schedule_actor(task, context, prepared_args)
|
||||
elif task.type == ray_client_pb2.ClientTask.METHOD:
|
||||
return self._schedule_method(task, context, prepared_args)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented Schedule task type: %s" %
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
|
||||
|
||||
def _schedule_method(
|
||||
self,
|
||||
|
@ -170,80 +231,67 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
if actor_handle is None:
|
||||
raise Exception(
|
||||
"Can't run an actor the server doesn't have a handle for")
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
output = getattr(actor_handle, task.name).remote(*arglist)
|
||||
self.object_refs[output.binary()] = output
|
||||
pickled_ref = cloudpickle.dumps(output)
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ref=make_remote_ref(output.binary(), pickled_ref))
|
||||
arglist = self._convert_args(task.args, prepared_args)
|
||||
output = getattr(actor_handle, task.name).remote(*arglist)
|
||||
self.object_refs[task.client_id][output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
|
||||
def _schedule_actor(self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
payload_ref = cloudpickle.loads(task.payload_id)
|
||||
if payload_ref.binary() not in self.registered_actor_classes:
|
||||
actor_class_ref = self.object_refs[payload_ref.binary()]
|
||||
actor_class = ray.get(actor_class_ref)
|
||||
if not inspect.isclass(actor_class):
|
||||
raise Exception("Attempting to schedule actor that "
|
||||
"isn't a class.")
|
||||
reg_class = ray.remote(actor_class)
|
||||
self.registered_actor_classes[payload_ref.binary()] = reg_class
|
||||
remote_class = self.registered_actor_classes[payload_ref.binary()]
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
actor = remote_class.remote(*arglist)
|
||||
actorhandle = cloudpickle.dumps(actor)
|
||||
self.actor_refs[actorhandle] = actor
|
||||
if task.payload_id not in self.registered_actor_classes:
|
||||
actor_class_ref = \
|
||||
self.object_refs[task.client_id][task.payload_id]
|
||||
actor_class = ray.get(actor_class_ref)
|
||||
if not inspect.isclass(actor_class):
|
||||
raise Exception("Attempting to schedule actor that "
|
||||
"isn't a class.")
|
||||
reg_class = ray.remote(actor_class)
|
||||
self.registered_actor_classes[task.payload_id] = reg_class
|
||||
remote_class = self.registered_actor_classes[task.payload_id]
|
||||
arglist = self._convert_args(task.args, prepared_args)
|
||||
actor = remote_class.remote(*arglist)
|
||||
self.actor_refs[actor._actor_id.binary()] = actor
|
||||
self.actor_owners[task.client_id].add(actor._actor_id.binary())
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ref=make_remote_ref(actor._actor_id.binary(), actorhandle))
|
||||
return_id=actor._actor_id.binary())
|
||||
|
||||
def _schedule_function(
|
||||
self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
payload_ref = cloudpickle.loads(task.payload_id)
|
||||
if payload_ref.binary() not in self.function_refs:
|
||||
funcref = self.object_refs[payload_ref.binary()]
|
||||
remote_func = self.lookup_or_register_func(task.payload_id,
|
||||
task.client_id)
|
||||
arglist = self._convert_args(task.args, prepared_args)
|
||||
# Prepare call if we're in a test
|
||||
with current_func(remote_func):
|
||||
output = remote_func.remote(*arglist)
|
||||
if output.binary() in self.object_refs[task.client_id]:
|
||||
raise Exception("already found it")
|
||||
self.object_refs[task.client_id][output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
|
||||
def _convert_args(self, arg_list, prepared_args=None):
|
||||
if prepared_args is not None:
|
||||
return prepared_args
|
||||
out = []
|
||||
for arg in arg_list:
|
||||
t = convert_from_arg(arg, self)
|
||||
out.append(t)
|
||||
return out
|
||||
|
||||
def lookup_or_register_func(self, id: bytes, client_id: str
|
||||
) -> ray.remote_function.RemoteFunction:
|
||||
if id not in self.function_refs:
|
||||
funcref = self.object_refs[client_id][id]
|
||||
func = ray.get(funcref)
|
||||
if not inspect.isfunction(func):
|
||||
raise Exception("Attempting to schedule function that "
|
||||
raise Exception("Attempting to register function that "
|
||||
"isn't a function.")
|
||||
self.function_refs[payload_ref.binary()] = ray.remote(func)
|
||||
remote_func = self.function_refs[payload_ref.binary()]
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
# Prepare call if we're in a test
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
output = remote_func.remote(*arglist)
|
||||
if output.binary() in self.object_refs:
|
||||
raise Exception("already found it")
|
||||
self.object_refs[output.binary()] = output
|
||||
pickled_output = cloudpickle.dumps(output)
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ref=make_remote_ref(output.binary(), pickled_output))
|
||||
|
||||
|
||||
def _convert_args(arg_list, prepared_args=None):
|
||||
if prepared_args is not None:
|
||||
return prepared_args
|
||||
out = []
|
||||
for arg in arg_list:
|
||||
t = convert_from_arg(arg)
|
||||
if isinstance(t, ClientObjectRef):
|
||||
out.append(t._unpack_ref())
|
||||
else:
|
||||
out.append(t)
|
||||
return out
|
||||
|
||||
|
||||
def make_remote_ref(id: bytes, handle: bytes) -> ray_client_pb2.RemoteRef:
|
||||
return ray_client_pb2.RemoteRef(
|
||||
id=id,
|
||||
handle=handle,
|
||||
)
|
||||
self.function_refs[id] = ray.remote(func)
|
||||
return self.function_refs[id]
|
||||
|
||||
|
||||
def return_exception_in_context(err, context):
|
||||
|
@ -252,12 +300,20 @@ def return_exception_in_context(err, context):
|
|||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
|
||||
|
||||
def encode_exception(exception) -> str:
|
||||
data = cloudpickle.dumps(exception)
|
||||
return base64.standard_b64encode(data).decode()
|
||||
|
||||
|
||||
def serve(connection_str, test_mode=False):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
task_servicer = RayletServicer(test_mode=test_mode)
|
||||
data_servicer = DataServicer(task_servicer)
|
||||
_set_server_api(RayServerAPI(task_servicer))
|
||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||
task_servicer, server)
|
||||
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
|
||||
data_servicer, server)
|
||||
server.add_insecure_port(connection_str)
|
||||
server.start()
|
||||
return server
|
||||
|
|
119
python/ray/experimental/client/server/server_pickler.py
Normal file
119
python/ray/experimental/client/server/server_pickler.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
"""
|
||||
Implements the client side of the client/server pickling protocol.
|
||||
|
||||
These picklers are aware of the server internals and can find the
|
||||
references held for the client within the server.
|
||||
|
||||
More discussion about the client/server pickling protocol can be found in:
|
||||
|
||||
ray/experimental/client/client_pickler.py
|
||||
|
||||
ServerPickler dumps ray objects from the server into the appropriate stubs.
|
||||
ClientUnpickler loads stubs from the client and finds their associated handle
|
||||
in the server instance.
|
||||
"""
|
||||
import cloudpickle
|
||||
import io
|
||||
import sys
|
||||
import ray
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ray.experimental.client.client_pickler import PickleStub
|
||||
from ray.experimental.client.server.server_stubs import ServerFunctionSentinel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.experimental.client.server.server import RayletServicer
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
try:
|
||||
import pickle5 as pickle # noqa: F401
|
||||
except ImportError:
|
||||
import pickle # noqa: F401
|
||||
else:
|
||||
import pickle # noqa: F401
|
||||
|
||||
|
||||
class ServerPickler(cloudpickle.CloudPickler):
|
||||
def __init__(self, client_id: str, server: "RayletServicer", *args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client_id = client_id
|
||||
self.server = server
|
||||
|
||||
def persistent_id(self, obj):
|
||||
if isinstance(obj, ray.ObjectRef):
|
||||
obj_id = obj.binary()
|
||||
if obj_id not in self.server.object_refs[self.client_id]:
|
||||
# We're passing back a reference, probably inside a reference.
|
||||
# Let's hold onto it.
|
||||
self.server.object_refs[self.client_id][obj_id] = obj
|
||||
return PickleStub(
|
||||
type="Object",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj_id,
|
||||
)
|
||||
elif isinstance(obj, ray.actor.ActorHandle):
|
||||
actor_id = obj._actor_id.binary()
|
||||
if actor_id not in self.server.actor_refs:
|
||||
# We're passing back a handle, probably inside a reference.
|
||||
self.actor_refs[actor_id] = obj
|
||||
if actor_id not in self.actor_owners[self.client_id]:
|
||||
self.actor_owners[self.client_id].add(actor_id)
|
||||
return PickleStub(
|
||||
type="Actor",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._actor_id.binary(),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class ClientUnpickler(pickle.Unpickler):
|
||||
def __init__(self, server, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.server = server
|
||||
|
||||
def persistent_load(self, pid):
|
||||
assert isinstance(pid, PickleStub)
|
||||
if pid.type == "Object":
|
||||
return self.server.object_refs[pid.client_id][pid.ref_id]
|
||||
elif pid.type == "Actor":
|
||||
return self.server.actor_refs[pid.ref_id]
|
||||
elif pid.type == "RemoteFuncSelfReference":
|
||||
return ServerFunctionSentinel()
|
||||
elif pid.type == "RemoteFunc":
|
||||
return self.server.lookup_or_register_func(pid.ref_id,
|
||||
pid.client_id)
|
||||
else:
|
||||
raise NotImplementedError("Uncovered client data type")
|
||||
|
||||
|
||||
def dumps_from_server(obj: Any,
|
||||
client_id: str,
|
||||
server_instance: "RayletServicer",
|
||||
protocol=None) -> bytes:
|
||||
with io.BytesIO() as file:
|
||||
sp = ServerPickler(client_id, server_instance, file, protocol=protocol)
|
||||
sp.dump(obj)
|
||||
return file.getvalue()
|
||||
|
||||
|
||||
def loads_from_client(data: bytes,
|
||||
server_instance: "RayletServicer",
|
||||
*,
|
||||
fix_imports=True,
|
||||
encoding="ASCII",
|
||||
errors="strict") -> Any:
|
||||
if isinstance(data, str):
|
||||
raise TypeError("Can't load pickle from unicode string")
|
||||
file = io.BytesIO(data)
|
||||
return ClientUnpickler(
|
||||
server_instance, file, fix_imports=fix_imports,
|
||||
encoding=encoding).load()
|
||||
|
||||
|
||||
def convert_from_arg(pb: "ray_client_pb2.Arg",
|
||||
server: "RayletServicer") -> Any:
|
||||
return loads_from_client(pb.data, server)
|
29
python/ray/experimental/client/server/server_stubs.py
Normal file
29
python/ray/experimental/client/server/server_stubs.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
from contextlib import contextmanager
|
||||
|
||||
_current_remote_func = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def current_func(f):
|
||||
global _current_remote_func
|
||||
remote_func = _current_remote_func
|
||||
_current_remote_func = f
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_current_remote_func = remote_func
|
||||
|
||||
|
||||
class ServerFunctionSentinel:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __reduce__(self):
|
||||
global _current_remote_func
|
||||
if _current_remote_func is None:
|
||||
return (ServerFunctionSentinel, tuple())
|
||||
return (identity, (_current_remote_func, ))
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
|
@ -2,27 +2,32 @@
|
|||
It implements the Ray API functions that are forwarded through grpc calls
|
||||
to the server.
|
||||
"""
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Optional
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.util.inspect import is_cython
|
||||
import grpc
|
||||
|
||||
from ray.exceptions import TaskCancelledError
|
||||
import ray.cloudpickle as cloudpickle
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.experimental.client.common import convert_to_arg
|
||||
from ray.experimental.client.common import decode_exception
|
||||
from ray.experimental.client.client_pickler import convert_to_arg
|
||||
from ray.experimental.client.client_pickler import loads_from_server
|
||||
from ray.experimental.client.client_pickler import dumps_from_client
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientActorClass
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.dataclient import DataClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -31,34 +36,32 @@ class Worker:
|
|||
def __init__(self,
|
||||
conn_str: str = "",
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
stub=None):
|
||||
metadata: List[Tuple[str, str]] = None):
|
||||
"""Initializes the worker side grpc client.
|
||||
|
||||
Args:
|
||||
stub: custom grpc stub.
|
||||
secure: whether to use SSL secure channel or not.
|
||||
metadata: additional metadata passed in the grpc request headers.
|
||||
"""
|
||||
self.metadata = metadata
|
||||
self.channel = None
|
||||
if stub is None:
|
||||
if secure:
|
||||
credentials = grpc.ssl_channel_credentials()
|
||||
self.channel = grpc.secure_channel(conn_str, credentials)
|
||||
else:
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
self._client_id = make_client_id()
|
||||
if secure:
|
||||
credentials = grpc.ssl_channel_credentials()
|
||||
self.channel = grpc.secure_channel(conn_str, credentials)
|
||||
else:
|
||||
self.server = stub
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
self.data_client = DataClient(self.channel, self._client_id)
|
||||
self.reference_count: Dict[bytes, int] = defaultdict(int)
|
||||
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
to_get = []
|
||||
single = False
|
||||
if isinstance(vals, list):
|
||||
to_get = [x.handle for x in vals]
|
||||
to_get = vals
|
||||
elif isinstance(vals, ClientObjectRef):
|
||||
to_get = [vals.handle]
|
||||
to_get = [vals]
|
||||
single = True
|
||||
else:
|
||||
raise Exception("Can't get something that's not a "
|
||||
|
@ -70,15 +73,15 @@ class Worker:
|
|||
out = out[0]
|
||||
return out
|
||||
|
||||
def _get(self, handle: bytes, timeout: float):
|
||||
req = ray_client_pb2.GetRequest(handle=handle, timeout=timeout)
|
||||
def _get(self, ref: ClientObjectRef, timeout: float):
|
||||
req = ray_client_pb2.GetRequest(id=ref.id, timeout=timeout)
|
||||
try:
|
||||
data = self.server.GetObject(req, metadata=self.metadata)
|
||||
data = self.data_client.GetObject(req)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details())
|
||||
raise e.details()
|
||||
if not data.valid:
|
||||
raise TaskCancelledError(handle)
|
||||
return cloudpickle.loads(data.data)
|
||||
raise cloudpickle.loads(data.error)
|
||||
return loads_from_server(data.data)
|
||||
|
||||
def put(self, vals):
|
||||
to_put = []
|
||||
|
@ -95,10 +98,10 @@ class Worker:
|
|||
return out
|
||||
|
||||
def _put(self, val):
|
||||
data = cloudpickle.dumps(val)
|
||||
data = dumps_from_client(val, self._client_id)
|
||||
req = ray_client_pb2.PutRequest(data=data)
|
||||
resp = self.server.PutObject(req, metadata=self.metadata)
|
||||
return ClientObjectRef.from_remote_ref(resp.ref)
|
||||
resp = self.data_client.PutObject(req)
|
||||
return ClientObjectRef(resp.id)
|
||||
|
||||
def wait(self,
|
||||
object_refs: List[ClientObjectRef],
|
||||
|
@ -110,11 +113,10 @@ class Worker:
|
|||
for ref in object_refs:
|
||||
assert isinstance(ref, ClientObjectRef)
|
||||
data = {
|
||||
"object_handles": [
|
||||
object_ref.handle for object_ref in object_refs
|
||||
],
|
||||
"object_ids": [object_ref.id for object_ref in object_refs],
|
||||
"num_returns": num_returns,
|
||||
"timeout": timeout if timeout else -1
|
||||
"timeout": timeout if timeout else -1,
|
||||
"client_id": self._client_id,
|
||||
}
|
||||
req = ray_client_pb2.WaitRequest(**data)
|
||||
resp = self.server.WaitObject(req, metadata=self.metadata)
|
||||
|
@ -122,12 +124,10 @@ class Worker:
|
|||
# TODO(ameer): improve error/exceptions messages.
|
||||
raise Exception("Client Wait request failed. Reference invalid?")
|
||||
client_ready_object_ids = [
|
||||
ClientObjectRef.from_remote_ref(ref)
|
||||
for ref in resp.ready_object_ids
|
||||
ClientObjectRef(ref) for ref in resp.ready_object_ids
|
||||
]
|
||||
client_remaining_object_ids = [
|
||||
ClientObjectRef.from_remote_ref(ref)
|
||||
for ref in resp.remaining_object_ids
|
||||
ClientObjectRef(ref) for ref in resp.remaining_object_ids
|
||||
]
|
||||
|
||||
return (client_ready_object_ids, client_remaining_object_ids)
|
||||
|
@ -144,19 +144,38 @@ class Worker:
|
|||
raise TypeError("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
def call_remote(self, instance, *args, **kwargs):
|
||||
def call_remote(self, instance, *args, **kwargs) -> bytes:
|
||||
task = instance._prepare_client_task()
|
||||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg)
|
||||
pb_arg = convert_to_arg(arg, self._client_id)
|
||||
task.args.append(pb_arg)
|
||||
logging.debug("Scheduling %s" % task)
|
||||
task.client_id = self._client_id
|
||||
logger.debug("Scheduling %s" % task)
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
return ClientObjectRef.from_remote_ref(ticket.return_ref)
|
||||
return ticket.return_id
|
||||
|
||||
def call_release(self, id: bytes) -> None:
|
||||
self.reference_count[id] -= 1
|
||||
if self.reference_count[id] == 0:
|
||||
self._release_server(id)
|
||||
del self.reference_count[id]
|
||||
|
||||
def _release_server(self, id: bytes) -> None:
|
||||
if self.data_client is not None:
|
||||
logger.debug(f"Releasing {id}")
|
||||
self.data_client.ReleaseObject(
|
||||
ray_client_pb2.ReleaseRequest(ids=[id]))
|
||||
|
||||
def call_retain(self, id: bytes) -> None:
|
||||
logger.debug(f"Retaining {id}")
|
||||
self.reference_count[id] += 1
|
||||
|
||||
def close(self):
|
||||
self.data_client.close()
|
||||
self.server = None
|
||||
if self.channel:
|
||||
self.channel.close()
|
||||
self.channel = None
|
||||
|
||||
def terminate_actor(self, actor: ClientActorHandle,
|
||||
no_restart: bool) -> None:
|
||||
|
@ -164,10 +183,11 @@ class Worker:
|
|||
raise ValueError("ray.kill() only supported for actors. "
|
||||
"Got: {}.".format(type(actor)))
|
||||
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
|
||||
term_actor.handle = actor.actor_ref.handle
|
||||
term_actor.id = actor.actor_ref.id
|
||||
term_actor.no_restart = no_restart
|
||||
try:
|
||||
term = ray_client_pb2.TerminateRequest(actor=term_actor)
|
||||
term.client_id = self._client_id
|
||||
self.server.Terminate(term)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details())
|
||||
|
@ -179,11 +199,12 @@ class Worker:
|
|||
"ray.cancel() only supported for non-actor object refs. "
|
||||
f"Got: {type(obj)}.")
|
||||
term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate()
|
||||
term_object.handle = obj.handle
|
||||
term_object.id = obj.id
|
||||
term_object.force = force
|
||||
term_object.recursive = recursive
|
||||
try:
|
||||
term = ray_client_pb2.TerminateRequest(task_object=term_object)
|
||||
term.client_id = self._client_id
|
||||
self.server.Terminate(term)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details())
|
||||
|
@ -201,3 +222,13 @@ class Worker:
|
|||
return self.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
|
||||
return False
|
||||
|
||||
|
||||
def make_client_id() -> str:
|
||||
id = uuid.uuid4()
|
||||
return id.hex
|
||||
|
||||
|
||||
def decode_exception(data) -> Exception:
|
||||
data = base64.standard_b64decode(data)
|
||||
return loads_from_server(data)
|
||||
|
|
|
@ -96,6 +96,7 @@ py_test_module_list(
|
|||
"test_debug_tools.py",
|
||||
"test_experimental_client.py",
|
||||
"test_experimental_client_metadata.py",
|
||||
"test_experimental_client_references.py",
|
||||
"test_experimental_client_terminate.py",
|
||||
"test_job.py",
|
||||
"test_memstat.py",
|
||||
|
|
|
@ -142,7 +142,7 @@ def test_function_calling_function(ray_start_regular_shared):
|
|||
|
||||
@ray.remote
|
||||
def f():
|
||||
print(f, f._name, g._name, g)
|
||||
print(f, g)
|
||||
return ray.get(g.remote())
|
||||
|
||||
print(f, type(f))
|
||||
|
|
152
python/ray/tests/test_experimental_client_references.py
Normal file
152
python/ray/tests/test_experimental_client_references.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
from ray.tests.test_experimental_client import ray_start_client_server
|
||||
from ray.test_utils import wait_for_condition
|
||||
import ray as real_ray
|
||||
from ray.core.generated.gcs_pb2 import ActorTableData
|
||||
from ray.experimental.client import _get_server_instance
|
||||
|
||||
|
||||
def server_object_ref_count(n):
|
||||
server = _get_server_instance()
|
||||
assert server is not None
|
||||
|
||||
def test_cond():
|
||||
if len(server.object_refs) == 0:
|
||||
# No open clients
|
||||
return n == 0
|
||||
client_id = list(server.object_refs.keys())[0]
|
||||
return len(server.object_refs[client_id]) == n
|
||||
|
||||
return test_cond
|
||||
|
||||
|
||||
def server_actor_ref_count(n):
|
||||
server = _get_server_instance()
|
||||
assert server is not None
|
||||
|
||||
def test_cond():
|
||||
if len(server.actor_refs) == 0:
|
||||
# No running actors
|
||||
return n == 0
|
||||
return len(server.actor_refs) == n
|
||||
|
||||
return test_cond
|
||||
|
||||
|
||||
def test_delete_refs_on_disconnect(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return x + 2
|
||||
|
||||
thing1 = f.remote(6) # noqa
|
||||
thing2 = ray.put("Hello World") # noqa
|
||||
|
||||
# One put, one function -- the function result thing1 is
|
||||
# in a different category, according to the raylet.
|
||||
assert len(real_ray.objects()) == 2
|
||||
# But we're maintaining the reference
|
||||
assert server_object_ref_count(3)()
|
||||
# And can get the data
|
||||
assert ray.get(thing1) == 8
|
||||
|
||||
# Close the client
|
||||
ray.close()
|
||||
|
||||
wait_for_condition(server_object_ref_count(0), timeout=5)
|
||||
|
||||
def test_cond():
|
||||
return len(real_ray.objects()) == 0
|
||||
|
||||
wait_for_condition(test_cond, timeout=5)
|
||||
|
||||
|
||||
def test_delete_ref_on_object_deletion(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
vals = {
|
||||
"ref": ray.put("Hello World"),
|
||||
"ref2": ray.put("This value stays"),
|
||||
}
|
||||
|
||||
del vals["ref"]
|
||||
|
||||
wait_for_condition(server_object_ref_count(1), timeout=5)
|
||||
|
||||
|
||||
def test_delete_actor_on_disconnect(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class Accumulator:
|
||||
def __init__(self):
|
||||
self.acc = 0
|
||||
|
||||
def inc(self):
|
||||
self.acc += 1
|
||||
|
||||
def get(self):
|
||||
return self.acc
|
||||
|
||||
actor = Accumulator.remote()
|
||||
actor.inc.remote()
|
||||
|
||||
assert server_actor_ref_count(1)()
|
||||
|
||||
assert ray.get(actor.get.remote()) == 1
|
||||
|
||||
ray.close()
|
||||
|
||||
wait_for_condition(server_actor_ref_count(0), timeout=5)
|
||||
|
||||
def test_cond():
|
||||
alive_actors = [
|
||||
v for v in real_ray.actors().values()
|
||||
if v["State"] != ActorTableData.DEAD
|
||||
]
|
||||
return len(alive_actors) == 0
|
||||
|
||||
wait_for_condition(test_cond, timeout=10)
|
||||
|
||||
|
||||
def test_delete_actor(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class Accumulator:
|
||||
def __init__(self):
|
||||
self.acc = 0
|
||||
|
||||
def inc(self):
|
||||
self.acc += 1
|
||||
|
||||
actor = Accumulator.remote()
|
||||
actor.inc.remote()
|
||||
actor2 = Accumulator.remote()
|
||||
actor2.inc.remote()
|
||||
|
||||
assert server_actor_ref_count(2)()
|
||||
|
||||
del actor
|
||||
|
||||
wait_for_condition(server_actor_ref_count(1), timeout=5)
|
||||
|
||||
|
||||
def test_simple_multiple_references(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class A:
|
||||
def __init__(self):
|
||||
self.x = ray.put("hi")
|
||||
|
||||
def get(self):
|
||||
return [self.x]
|
||||
|
||||
a = A.remote()
|
||||
ref1 = ray.get(a.get.remote())[0]
|
||||
ref2 = ray.get(a.get.remote())[0]
|
||||
del a
|
||||
assert ray.get(ref1) == "hi"
|
||||
del ref1
|
||||
assert ray.get(ref2) == "hi"
|
||||
del ref2
|
|
@ -18,17 +18,24 @@ package ray.rpc;
|
|||
|
||||
enum Type { DEFAULT = 0; }
|
||||
|
||||
// An argument to a ClientTask.
|
||||
message Arg {
|
||||
enum Locality {
|
||||
INTERNED = 0;
|
||||
REFERENCE = 1;
|
||||
}
|
||||
|
||||
// The type of argument this is -- whether a data blob or a reference.
|
||||
Locality local = 1;
|
||||
// The reference id, if a reference.
|
||||
bytes reference_id = 2;
|
||||
// A data blob, if passed in-band.
|
||||
bytes data = 3;
|
||||
// How to decode this data blob.
|
||||
Type type = 4;
|
||||
}
|
||||
|
||||
// Represents one unit of work to be executed by the server.
|
||||
message ClientTask {
|
||||
enum RemoteExecType {
|
||||
FUNCTION = 0;
|
||||
|
@ -36,49 +43,69 @@ message ClientTask {
|
|||
METHOD = 2;
|
||||
STATIC_METHOD = 3;
|
||||
}
|
||||
// Which type of work this request represents.
|
||||
RemoteExecType type = 1;
|
||||
// A name parameter, if the payload can be called in more than one way (like a method on
|
||||
// a payload object).
|
||||
string name = 2;
|
||||
// A reference to the payload.
|
||||
bytes payload_id = 3;
|
||||
// The parameters to pass to this call.
|
||||
repeated Arg args = 4;
|
||||
}
|
||||
|
||||
message RemoteRef {
|
||||
bytes id = 1;
|
||||
bytes handle = 2;
|
||||
// The ID of the client namespace associated with the Datapath stream making this
|
||||
// request.
|
||||
string client_id = 5;
|
||||
}
|
||||
|
||||
message ClientTaskTicket {
|
||||
RemoteRef return_ref = 1;
|
||||
// A reference to the returned value from the execution.
|
||||
bytes return_id = 1;
|
||||
}
|
||||
|
||||
// Delivers data to the server
|
||||
message PutRequest {
|
||||
// The data blob for the server to store.
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message PutResponse {
|
||||
RemoteRef ref = 1;
|
||||
// The reference ID for the data that the server has stored.
|
||||
bytes id = 1;
|
||||
}
|
||||
|
||||
// Requests data from the server.
|
||||
message GetRequest {
|
||||
bytes handle = 1;
|
||||
// The reference ID for the requested object data
|
||||
bytes id = 1;
|
||||
// Length of time to wait for data to be available, in seconds. Zero is no timeout.
|
||||
float timeout = 2;
|
||||
}
|
||||
|
||||
message GetResponse {
|
||||
// Whether or not the data was successfully retrieved
|
||||
bool valid = 1;
|
||||
// The data blob, on success
|
||||
bytes data = 2;
|
||||
// An error blob (for example, an exception) on failure.
|
||||
bytes error = 3;
|
||||
}
|
||||
|
||||
// Waits for data to be ready on the server, with a timeout.
|
||||
message WaitRequest {
|
||||
repeated bytes object_handles = 1;
|
||||
// The IDs of the data to wait for ready status.
|
||||
repeated bytes object_ids = 1;
|
||||
// How many of the above ids to wait for before returning.
|
||||
int64 num_returns = 2;
|
||||
// How long to wait for these IDs to become ready.
|
||||
double timeout = 3;
|
||||
// The Client namespace associated with the Datapath stream that holds these IDs.
|
||||
string client_id = 4;
|
||||
}
|
||||
|
||||
message WaitResponse {
|
||||
bool valid = 1;
|
||||
repeated RemoteRef ready_object_ids = 2;
|
||||
repeated RemoteRef remaining_object_ids = 3;
|
||||
repeated bytes ready_object_ids = 2;
|
||||
repeated bytes remaining_object_ids = 3;
|
||||
}
|
||||
|
||||
message ClusterInfoType {
|
||||
|
@ -108,18 +135,19 @@ message ClusterInfoResponse {
|
|||
|
||||
message TerminateRequest {
|
||||
message ActorTerminate {
|
||||
bytes handle = 1;
|
||||
bytes id = 1;
|
||||
bool no_restart = 2;
|
||||
}
|
||||
message TaskObjectTerminate {
|
||||
bytes handle = 1;
|
||||
bytes id = 1;
|
||||
bool force = 2;
|
||||
bool recursive = 3;
|
||||
}
|
||||
|
||||
string client_id = 1;
|
||||
oneof terminate_type {
|
||||
ActorTerminate actor = 1;
|
||||
TaskObjectTerminate task_object = 2;
|
||||
ActorTerminate actor = 2;
|
||||
TaskObjectTerminate task_object = 3;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -141,3 +169,40 @@ service RayletDriver {
|
|||
rpc ClusterInfo(ClusterInfoRequest) returns (ClusterInfoResponse) {
|
||||
}
|
||||
}
|
||||
|
||||
message ReleaseRequest {
|
||||
// The IDs to release from the server; the client connected on this stream no
|
||||
// longer holds a reference to them.
|
||||
repeated bytes ids = 1;
|
||||
}
|
||||
|
||||
message ReleaseResponse {
|
||||
// For each requested ID, whether or not it was released.
|
||||
repeated bool ok = 2;
|
||||
}
|
||||
|
||||
message DataRequest {
|
||||
// An incrementing counter of request IDs on the Datapath,
|
||||
// to match requests with responses asynchronously.
|
||||
int32 req_id = 1;
|
||||
oneof type {
|
||||
GetRequest get = 2;
|
||||
PutRequest put = 3;
|
||||
ReleaseRequest release = 4;
|
||||
}
|
||||
}
|
||||
|
||||
message DataResponse {
|
||||
// The request id that this response matches with.
|
||||
int32 req_id = 1;
|
||||
oneof type {
|
||||
GetResponse get = 2;
|
||||
PutResponse put = 3;
|
||||
ReleaseResponse release = 4;
|
||||
}
|
||||
}
|
||||
|
||||
service RayletDataStreamer {
|
||||
rpc Datapath(stream DataRequest) returns (stream DataResponse) {
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue