[ray_client]: Implement object retain/release and Data Streaming API (#12818)

This commit is contained in:
Barak Michener 2020-12-18 11:47:38 -08:00 committed by GitHub
parent 55ae567f7a
commit 5cfa1934e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 1000 additions and 310 deletions

View file

@ -91,15 +91,22 @@ def _get_client_api() -> APIImpl:
return api
def _get_server_instance():
"""Used inside tests to inspect the running server.
"""
global _server_api
if _server_api is not None:
return _server_api.server
class RayAPIStub:
def connect(self,
conn_str: str,
secure: bool = False,
metadata: List[Tuple[str, str]] = None,
stub=None):
stub=None) -> None:
from ray.experimental.client.worker import Worker
_client_worker = Worker(
conn_str, secure=secure, metadata=metadata, stub=stub)
_client_worker = Worker(conn_str, secure=secure, metadata=metadata)
_set_client_api(ClientAPI(_client_worker))
def disconnect(self):
@ -113,6 +120,10 @@ class RayAPIStub:
api = _get_client_api()
return getattr(api, key)
def is_connected(self) -> bool:
global _client_api
return _client_api is not None
ray = RayAPIStub()

View file

@ -138,6 +138,31 @@ class APIImpl(ABC):
"""
pass
@abstractmethod
def call_release(self, id: bytes) -> None:
"""
Attempts to release an object reference.
When client references are destructed, they release their reference,
which can opportunistically send a notification through the datachannel
to release the reference being held for that object on the server.
Args:
id: The id of the reference to release on the server side.
"""
@abstractmethod
def call_retain(self, id: bytes) -> None:
"""
Attempts to retain a client object reference.
Increments the reference count on the client side, to prevent
the client worker from attempting to release the server reference.
Args:
id: The id of the reference to retain on the client side.
"""
class ClientAPI(APIImpl):
"""
@ -163,6 +188,12 @@ class ClientAPI(APIImpl):
def call_remote(self, instance: "ClientStub", *args, **kwargs):
return self.worker.call_remote(instance, *args, **kwargs)
def call_release(self, id: bytes) -> None:
return self.worker.call_release(id)
def call_retain(self, id: bytes) -> None:
return self.worker.call_retain(id)
def close(self) -> None:
return self.worker.close()

View file

@ -0,0 +1,123 @@
"""
Implements the client side of the client/server pickling protocol.
All ray client client/server data transfer happens through this pickling
protocol. The model is as follows:
* All Client objects (eg ClientObjectRef) always live on the client and
are never represented in the server
* All Ray objects (eg, ray.ObjectRef) always live on the server and are
never returned to the client
* In order to translate between these two references, PickleStub tuples
are generated as persistent ids in the data blobs during the pickling
and unpickling of these objects.
The PickleStubs have just enough information to find or generate their
associated partner object on either side.
This also has the advantage of avoiding predefined pickle behavior for ray
objects, which may include ray internal reference counting.
ClientPickler dumps things from the client into the appropriate stubs
ServerUnpickler loads stubs from the server into their client counterparts.
"""
import cloudpickle
import io
import sys
from typing import NamedTuple
from typing import Any
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientActorRef
from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.common import SelfReferenceSentinel
import ray.core.generated.ray_client_pb2 as ray_client_pb2
if sys.version_info < (3, 8):
try:
import pickle5 as pickle # noqa: F401
except ImportError:
import pickle # noqa: F401
else:
import pickle # noqa: F401
PickleStub = NamedTuple("PickleStub", [("type", str), ("client_id", str),
("ref_id", bytes)])
class ClientPickler(cloudpickle.CloudPickler):
def __init__(self, client_id, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client_id = client_id
def persistent_id(self, obj):
if isinstance(obj, ClientObjectRef):
return PickleStub(
type="Object",
client_id=self.client_id,
ref_id=obj.id,
)
elif isinstance(obj, ClientActorHandle):
return PickleStub(
type="Actor",
client_id=self.client_id,
ref_id=obj._actor_id,
)
elif isinstance(obj, ClientRemoteFunc):
# TODO(barakmich): This is going to have trouble with mutually
# recursive functions that haven't, as yet, been executed. It's
# relatively doable (keep track of intermediate refs in progress
# with ensure_ref and return appropriately) But punting for now.
if obj._ref is None:
obj._ensure_ref()
if type(obj._ref) == SelfReferenceSentinel:
return PickleStub(
type="RemoteFuncSelfReference",
client_id=self.client_id,
ref_id=b"")
return PickleStub(
type="RemoteFunc",
client_id=self.client_id,
ref_id=obj._ref.id)
return None
class ServerUnpickler(pickle.Unpickler):
def persistent_load(self, pid):
assert isinstance(pid, PickleStub)
if pid.type == "Object":
return ClientObjectRef(id=pid.ref_id)
elif pid.type == "Actor":
return ClientActorHandle(ClientActorRef(id=pid.ref_id))
else:
raise NotImplementedError("Being passed back an unknown stub")
def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes:
with io.BytesIO() as file:
cp = ClientPickler(client_id, file, protocol=protocol)
cp.dump(obj)
return file.getvalue()
def loads_from_server(data: bytes,
*,
fix_imports=True,
encoding="ASCII",
errors="strict") -> Any:
if isinstance(data, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(data)
return ServerUnpickler(
file, fix_imports=fix_imports, encoding=encoding,
errors=errors).load()
def convert_to_arg(val: Any, client_id: str) -> ray_client_pb2.Arg:
out = ray_client_pb2.Arg()
out.local = ray_client_pb2.Arg.Locality.INTERNED
out.data = dumps_from_client(val, client_id)
return out

View file

@ -1,16 +1,12 @@
import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray.experimental.client import ray
from typing import Any
from typing import Dict
from ray import cloudpickle
import base64
class ClientBaseRef:
def __init__(self, id, handle=None):
self.id = id
self.handle = handle
def __init__(self, id: bytes):
self.id: bytes = id
ray.call_retain(id)
def __repr__(self):
return "%s(%s)" % (
@ -24,14 +20,13 @@ class ClientBaseRef:
def binary(self):
return self.id
@classmethod
def from_remote_ref(cls, ref: ray_client_pb2.RemoteRef):
return cls(id=ref.id, handle=ref.handle)
def __del__(self):
if ray.is_connected():
ray.call_release(self.id)
class ClientObjectRef(ClientBaseRef):
def _unpack_ref(self):
return cloudpickle.loads(self.handle)
pass
class ClientActorRef(ClientBaseRef):
@ -53,50 +48,42 @@ class ClientRemoteFunc(ClientStub):
_func: The actual function to execute remotely
_name: The original name of the function
_ref: The ClientObjectRef of the pickled code of the function, _func
_raylet_remote: The Raylet-side ray.remote_function.RemoteFunction
for this object
"""
def __init__(self, f):
self._func = f
self._name = f.__name__
self.id = None
# self._ref can be lazily instantiated. Rather than eagerly creating
# function data objects in the server we can put them just before we
# execute the function, especially in cases where many @ray.remote
# functions exist in a library and only a handful are ever executed by
# a user of the library.
#
# TODO(barakmich): This ref might actually be better as a serialized
# ObjectRef. This requires being able to serialize the ref without
# pinning it (as the lifetime of the ref is tied with the server, not
# the client)
self._ref = None
self._raylet_remote = None
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote function cannot be called directly. "
"Use {self._name}.remote method instead")
def remote(self, *args, **kwargs):
return ray.call_remote(self, *args, **kwargs)
def _get_ray_remote_impl(self):
if self._raylet_remote is None:
self._raylet_remote = ray.remote(self._func)
return self._raylet_remote
return ClientObjectRef(ray.call_remote(self, *args, **kwargs))
def __repr__(self):
return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
def _ensure_ref(self):
if self._ref is None:
# While calling ray.put() on our function, if
# our function is recursive, it will attempt to
# encode the ClientRemoteFunc -- itself -- and
# infinitely recurse on _ensure_ref.
#
# So we set the state of the reference to be an
# in-progress self reference value, which
# the encoding can detect and handle correctly.
self._ref = SelfReferenceSentinel()
self._ref = ray.put(self._func)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
self._ensure_ref()
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.FUNCTION
task.name = self._name
task.payload_id = self._ref.handle
task.payload_id = self._ref.id
return task
@ -109,14 +96,12 @@ class ClientActorClass(ClientStub):
actor_cls: The actual class to execute remotely
_name: The original name of the class
_ref: The ClientObjectRef of the pickled `actor_cls`
_raylet_remote: The Raylet-side ray.ActorClass for this object
"""
def __init__(self, actor_cls):
self.actor_cls = actor_cls
self._name = actor_cls.__name__
self._ref = None
self._raylet_remote = None
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote actor cannot be instantiated directly. "
@ -135,10 +120,10 @@ class ClientActorClass(ClientStub):
self._name = state["_name"]
self._ref = state["_ref"]
def remote(self, *args, **kwargs):
def remote(self, *args, **kwargs) -> "ClientActorHandle":
# Actually instantiate the actor
ref = ray.call_remote(self, *args, **kwargs)
return ClientActorHandle(ClientActorRef(ref.id, ref.handle), self)
ref_id = ray.call_remote(self, *args, **kwargs)
return ClientActorHandle(ClientActorRef(ref_id), self)
def __repr__(self):
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
@ -154,7 +139,7 @@ class ClientActorClass(ClientStub):
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.ACTOR
task.name = self._name
task.payload_id = self._ref.handle
task.payload_id = self._ref.id
return task
@ -177,26 +162,9 @@ class ClientActorHandle(ClientStub):
def __init__(self, actor_ref: ClientActorRef,
actor_class: ClientActorClass):
self.actor_ref = actor_ref
self.actor_class = actor_class
self._real_actor_handle = None
def _get_ray_remote_impl(self):
if self._real_actor_handle is None:
self._real_actor_handle = cloudpickle.loads(self.actor_ref.handle)
return self._real_actor_handle
def __getstate__(self) -> Dict:
state = {
"actor_ref": self.actor_ref,
"actor_class": self.actor_class,
"_real_actor_handle": self._real_actor_handle,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_ref = state["actor_ref"]
self.actor_class = state["actor_class"]
self._real_actor_handle = state["_real_actor_handle"]
def __del__(self) -> None:
ray.call_release(self.actor_ref.id)
@property
def _actor_id(self):
@ -226,65 +194,27 @@ class ClientRemoteMethod(ClientStub):
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote method cannot be called directly. "
"Use {self._name}.remote() instead")
def _get_ray_remote_impl(self):
return getattr(self.actor_handle._get_ray_remote_impl(),
self.method_name)
def __getstate__(self) -> Dict:
state = {
"actor_handle": self.actor_handle,
"method_name": self.method_name,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_handle = state["actor_handle"]
self.method_name = state["method_name"]
f"Use {self._name}.remote() instead")
def remote(self, *args, **kwargs):
return ray.call_remote(self, *args, **kwargs)
return ClientObjectRef(ray.call_remote(self, *args, **kwargs))
def __repr__(self):
name = "%s.%s" % (self.actor_handle.actor_class._name,
self.method_name)
return "ClientRemoteMethod(%s, %s)" % (name,
self.actor_handle.actor_id)
return "ClientRemoteMethod(%s, %s)" % (self.method_name,
self.actor_handle)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.METHOD
task.name = self.method_name
task.payload_id = self.actor_handle.actor_ref.handle
task.payload_id = self.actor_handle.actor_ref.id
return task
def convert_from_arg(pb) -> Any:
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
return ClientObjectRef(pb.reference_id)
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
return cloudpickle.loads(pb.data)
raise Exception("convert_from_arg: Uncovered locality enum")
class DataEncodingSentinel:
def __repr__(self) -> str:
return self.__class__.__name__
def convert_to_arg(val):
out = ray_client_pb2.Arg()
if isinstance(val, ClientObjectRef):
out.local = ray_client_pb2.Arg.Locality.REFERENCE
out.reference_id = val.id
else:
out.local = ray_client_pb2.Arg.Locality.INTERNED
out.data = cloudpickle.dumps(val)
return out
def encode_exception(exception) -> str:
data = cloudpickle.dumps(exception)
return base64.standard_b64encode(data).decode()
def decode_exception(data) -> Exception:
data = base64.standard_b64decode(data)
return cloudpickle.loads(data)
class SelfReferenceSentinel(DataEncodingSentinel):
pass

View file

@ -0,0 +1,103 @@
"""
This file implements a threaded stream controller to abstract a data stream
back to the ray clientserver.
"""
import logging
import queue
import threading
import grpc
from typing import Any
from typing import Dict
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
logger = logging.getLogger(__name__)
# The maximum field value for request_id -- which is also the maximum
# number of simultaneous in-flight requests.
INT32_MAX = (2**31) - 1
class DataClient:
def __init__(self, channel: "grpc._channel.Channel", client_id: str):
"""Initializes a thread-safe datapath over a Ray Client gRPC channel.
Args:
channel: connected gRPC channel
"""
self.channel = channel
self.request_queue = queue.Queue()
self.data_thread = self._start_datathread()
self.ready_data: Dict[int, Any] = {}
self.cv = threading.Condition()
self._req_id = 0
self._client_id = client_id
self.data_thread.start()
def _next_id(self) -> int:
self._req_id += 1
if self._req_id > INT32_MAX:
self._req_id = 1
# Responses that aren't tracked (like opportunistic releases)
# have req_id=0, so make sure we never mint such an id.
assert self._req_id != 0
return self._req_id
def _start_datathread(self) -> threading.Thread:
return threading.Thread(target=self._data_main, args=(), daemon=True)
def _data_main(self) -> None:
stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
resp_stream = stub.Datapath(
iter(self.request_queue.get, None),
metadata=(("client_id", self._client_id), ))
for response in resp_stream:
if response.req_id == 0:
# This is not being waited for.
logger.debug(f"Got unawaited response {response}")
continue
with self.cv:
self.ready_data[response.req_id] = response
self.cv.notify_all()
def close(self, close_channel: bool = False) -> None:
if self.request_queue is not None:
self.request_queue.put(None)
self.request_queue = None
if self.data_thread is not None:
self.data_thread.join()
self.data_thread = None
if close_channel:
self.channel.close()
def _blocking_send(self, req: ray_client_pb2.DataRequest
) -> ray_client_pb2.DataResponse:
req_id = self._next_id()
req.req_id = req_id
self.request_queue.put(req)
data = None
with self.cv:
self.cv.wait_for(lambda: req_id in self.ready_data)
data = self.ready_data[req_id]
del self.ready_data[req_id]
return data
def GetObject(self, request: ray_client_pb2.GetRequest,
context=None) -> ray_client_pb2.GetResponse:
datareq = ray_client_pb2.DataRequest(get=request, )
resp = self._blocking_send(datareq)
return resp.get
def PutObject(self, request: ray_client_pb2.PutRequest,
context=None) -> ray_client_pb2.PutResponse:
datareq = ray_client_pb2.DataRequest(put=request, )
resp = self._blocking_send(datareq)
return resp.put
def ReleaseObject(self,
request: ray_client_pb2.ReleaseRequest,
context=None) -> None:
datareq = ray_client_pb2.DataRequest(release=request, )
self.request_queue.put(datareq)

View file

@ -11,12 +11,15 @@ from typing import Any
from typing import Optional
from typing import Union
import logging
import ray
from ray.experimental.client.api import APIImpl
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientStub
logger = logging.getLogger(__name__)
class CoreRayAPI(APIImpl):
"""
@ -26,12 +29,6 @@ class CoreRayAPI(APIImpl):
"""
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
if isinstance(vals, list):
if isinstance(vals[0], ClientObjectRef):
return ray.get(
[val._unpack_ref() for val in vals], timeout=timeout)
elif isinstance(vals, ClientObjectRef):
return ray.get(vals._unpack_ref(), timeout=timeout)
return ray.get(vals, timeout=timeout)
def put(self, vals: Any, *args,
@ -45,7 +42,8 @@ class CoreRayAPI(APIImpl):
return ray.remote(*args, **kwargs)
def call_remote(self, instance: ClientStub, *args, **kwargs):
return instance._get_ray_remote_impl().remote(*args, **kwargs)
raise NotImplementedError(
"Should not attempt execution of a client stub inside the raylet")
def close(self) -> None:
return None
@ -59,6 +57,12 @@ class CoreRayAPI(APIImpl):
def is_initialized(self) -> bool:
return ray.is_initialized()
def call_release(self, id: bytes) -> None:
return None
def call_retain(self, id: bytes) -> None:
return None
# Allow for generic fallback to ray.* in remote methods. This allows calls
# like ray.nodes() to be run in remote functions even though the client
# doesn't currently support them.
@ -76,26 +80,7 @@ class RayServerAPI(CoreRayAPI):
def __init__(self, server_instance):
self.server = server_instance
# Wrap single item into list if needed before calling server put.
def put(self, vals: Any, *args, **kwargs) -> ClientObjectRef:
to_put = []
single = False
if isinstance(vals, list):
to_put = vals
else:
single = True
to_put.append(vals)
out = [self._put(x) for x in to_put]
if single:
out = out[0]
return out
def _put(self, val: Any):
resp = self.server._put_and_retain_obj(val)
return ClientObjectRef(resp.id)
def call_remote(self, instance: ClientStub, *args, **kwargs):
def call_remote(self, instance: ClientStub, *args, **kwargs) -> bytes:
task = instance._prepare_client_task()
ticket = self.server.Schedule(task, prepared_args=args)
return ClientObjectRef(ticket.return_id)
return ticket.return_id

View file

@ -0,0 +1,54 @@
import logging
import grpc
from typing import TYPE_CHECKING
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
if TYPE_CHECKING:
from ray.experimental.client.server.server import RayletServicer
logger = logging.getLogger(__name__)
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, basic_service: "RayletServicer"):
self.basic_service = basic_service
def Datapath(self, request_iterator, context):
metadata = {k: v for k, v in context.invocation_metadata()}
client_id = metadata["client_id"]
if client_id == "":
logger.error("Client connecting with no client_id")
return
logger.info(f"New data connection from client {client_id}")
try:
for req in request_iterator:
resp = None
req_type = req.WhichOneof("type")
if req_type == "get":
get_resp = self.basic_service._get_object(
req.get, client_id)
resp = ray_client_pb2.DataResponse(get=get_resp)
elif req_type == "put":
put_resp = self.basic_service._put_object(
req.put, client_id)
resp = ray_client_pb2.DataResponse(put=put_resp)
elif req_type == "release":
released = []
for rel_id in req.release.ids:
rel = self.basic_service.release(client_id, rel_id)
released.append(rel)
resp = ray_client_pb2.DataResponse(
release=ray_client_pb2.ReleaseResponse(ok=released))
else:
raise Exception(f"Unreachable code: Request type "
f"{req_type} not handled in Datapath")
resp.req_id = req.req_id
yield resp
except grpc.RpcError as e:
logger.debug(f"Closing channel: {e}")
finally:
logger.info(f"Lost data connection from client {client_id}")
self.basic_service.release_all(client_id)

View file

@ -1,6 +1,12 @@
import logging
from concurrent import futures
import grpc
import base64
from collections import defaultdict
from typing import Dict
from typing import Set
from ray import cloudpickle
import ray
import ray.state
@ -10,21 +16,26 @@ import time
import inspect
import json
from ray.experimental.client import stash_api_for_tests, _set_server_api
from ray.experimental.client.common import convert_from_arg
from ray.experimental.client.common import encode_exception
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.server.server_pickler import convert_from_arg
from ray.experimental.client.server.server_pickler import dumps_from_server
from ray.experimental.client.server.server_pickler import loads_from_client
from ray.experimental.client.server.core_ray_api import RayServerAPI
from ray.experimental.client.server.dataservicer import DataServicer
from ray.experimental.client.server.server_stubs import current_func
logger = logging.getLogger(__name__)
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
def __init__(self, test_mode=False):
self.object_refs = {}
self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict(
dict)
self.function_refs = {}
self.actor_refs = {}
self.actor_refs: Dict[bytes, ray.ActorHandle] = {}
self.actor_owners: Dict[str, Set[bytes]] = defaultdict(set)
self.registered_actor_classes = {}
self._test_mode = test_mode
self._current_function_stub = None
def ClusterInfo(self, request,
context=None) -> ray_client_pb2.ClusterInfoResponse:
@ -61,20 +72,59 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
raise TypeError("Unsupported cluster info type")
return json.dumps(data)
def Terminate(self, request, context=None):
if request.WhichOneof("terminate_type") == "task_object":
def release(self, client_id: str, id: bytes) -> bool:
if client_id in self.object_refs:
if id in self.object_refs[client_id]:
logger.debug(f"Releasing object {id.hex()} for {client_id}")
del self.object_refs[client_id][id]
return True
if client_id in self.actor_owners:
if id in self.actor_owners[client_id]:
logger.debug(f"Releasing actor {id.hex()} for {client_id}")
del self.actor_refs[id]
self.actor_owners[client_id].remove(id)
return True
return False
def release_all(self, client_id):
self._release_objects(client_id)
self._release_actors(client_id)
def _release_objects(self, client_id):
if client_id not in self.object_refs:
logger.debug(f"Releasing client with no references: {client_id}")
return
count = len(self.object_refs[client_id])
del self.object_refs[client_id]
logger.debug(f"Released all {count} objects for client {client_id}")
def _release_actors(self, client_id):
if client_id not in self.actor_owners:
logger.debug(f"Releasing client with no actors: {client_id}")
count = 0
for id_bytes in self.actor_owners[client_id]:
count += 1
del self.actor_refs[id_bytes]
del self.actor_owners[client_id]
logger.debug(f"Released all {count} actors for client: {client_id}")
def Terminate(self, req, context=None):
if req.WhichOneof("terminate_type") == "task_object":
try:
object_ref = cloudpickle.loads(request.task_object.handle)
object_ref = \
self.object_refs[req.client_id][req.task_object.id]
ray.cancel(
object_ref,
force=request.task_object.force,
recursive=request.task_object.recursive)
force=req.task_object.force,
recursive=req.task_object.recursive)
except Exception as e:
return_exception_in_context(e, context)
elif request.WhichOneof("terminate_type") == "actor":
elif req.WhichOneof("terminate_type") == "actor":
try:
actor_ref = cloudpickle.loads(request.actor.handle)
ray.kill(actor_ref, no_restart=request.actor.no_restart)
actor_ref = self.actor_refs[req.actor.id]
ray.kill(actor_ref, no_restart=req.actor.no_restart)
except Exception as e:
return_exception_in_context(e, context)
else:
@ -84,61 +134,71 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
return ray_client_pb2.TerminateResponse(ok=True)
def GetObject(self, request, context=None):
request_ref = cloudpickle.loads(request.handle)
if request_ref.binary() not in self.object_refs:
return self._get_object(request, "", context)
def _get_object(self, request, client_id: str, context=None):
if request.id not in self.object_refs[client_id]:
return ray_client_pb2.GetResponse(valid=False)
objectref = self.object_refs[request_ref.binary()]
logger.info("get: %s" % objectref)
objectref = self.object_refs[client_id][request.id]
logger.debug("get: %s" % objectref)
try:
item = ray.get(objectref, timeout=request.timeout)
except Exception as e:
return_exception_in_context(e, context)
item_ser = cloudpickle.dumps(item)
return ray_client_pb2.GetResponse(
valid=False, error=cloudpickle.dumps(e))
item_ser = dumps_from_server(item, client_id, self)
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse:
obj = cloudpickle.loads(request.data)
objectref = self._put_and_retain_obj(obj)
pickled_ref = cloudpickle.dumps(objectref)
return ray_client_pb2.PutResponse(
ref=make_remote_ref(objectref.binary(), pickled_ref))
def PutObject(self, request: ray_client_pb2.PutRequest,
context=None) -> ray_client_pb2.PutResponse:
"""gRPC entrypoint for unary PutObject
"""
return self._put_object(request, "", context)
def _put_and_retain_obj(self, obj) -> ray.ObjectRef:
def _put_object(self,
request: ray_client_pb2.PutRequest,
client_id: str,
context=None):
"""Put an object in the cluster with ray.put() via gRPC.
Args:
request: PutRequest with pickled data.
client_id: The client who owns this data, for tracking when to
delete this reference.
context: gRPC context.
"""
obj = loads_from_client(request.data, self)
objectref = ray.put(obj)
self.object_refs[objectref.binary()] = objectref
logger.info("put: %s" % objectref)
return objectref
self.object_refs[client_id][objectref.binary()] = objectref
logger.debug("put: %s" % objectref)
return ray_client_pb2.PutResponse(id=objectref.binary())
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
object_refs = [cloudpickle.loads(o) for o in request.object_handles]
object_refs = []
for id in request.object_ids:
if id not in self.object_refs[request.client_id]:
raise Exception(
"Asking for a ref not associated with this client: %s" %
str(id))
object_refs.append(self.object_refs[request.client_id][id])
num_returns = request.num_returns
timeout = request.timeout
object_refs_ids = []
for object_ref in object_refs:
if object_ref.binary() not in self.object_refs:
return ray_client_pb2.WaitResponse(valid=False)
object_refs_ids.append(self.object_refs[object_ref.binary()])
try:
ready_object_refs, remaining_object_refs = ray.wait(
object_refs_ids,
object_refs,
num_returns=num_returns,
timeout=timeout if timeout != -1 else None)
except Exception:
# TODO(ameer): improve exception messages.
return ray_client_pb2.WaitResponse(valid=False)
logger.info("wait: %s %s" % (str(ready_object_refs),
str(remaining_object_refs)))
logger.debug("wait: %s %s" % (str(ready_object_refs),
str(remaining_object_refs)))
ready_object_ids = [
make_remote_ref(
id=ready_object_ref.binary(),
handle=cloudpickle.dumps(ready_object_ref),
) for ready_object_ref in ready_object_refs
ready_object_ref.binary() for ready_object_ref in ready_object_refs
]
remaining_object_ids = [
make_remote_ref(
id=remaining_object_ref.binary(),
handle=cloudpickle.dumps(remaining_object_ref),
) for remaining_object_ref in remaining_object_refs
remaining_object_ref.binary()
for remaining_object_ref in remaining_object_refs
]
return ray_client_pb2.WaitResponse(
valid=True,
@ -150,16 +210,17 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
logger.info("schedule: %s %s" %
(task.name,
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
if task.type == ray_client_pb2.ClientTask.FUNCTION:
return self._schedule_function(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.ACTOR:
return self._schedule_actor(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.METHOD:
return self._schedule_method(task, context, prepared_args)
else:
raise NotImplementedError(
"Unimplemented Schedule task type: %s" %
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
with stash_api_for_tests(self._test_mode):
if task.type == ray_client_pb2.ClientTask.FUNCTION:
return self._schedule_function(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.ACTOR:
return self._schedule_actor(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.METHOD:
return self._schedule_method(task, context, prepared_args)
else:
raise NotImplementedError(
"Unimplemented Schedule task type: %s" %
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
def _schedule_method(
self,
@ -170,80 +231,67 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
if actor_handle is None:
raise Exception(
"Can't run an actor the server doesn't have a handle for")
arglist = _convert_args(task.args, prepared_args)
with stash_api_for_tests(self._test_mode):
output = getattr(actor_handle, task.name).remote(*arglist)
self.object_refs[output.binary()] = output
pickled_ref = cloudpickle.dumps(output)
return ray_client_pb2.ClientTaskTicket(
return_ref=make_remote_ref(output.binary(), pickled_ref))
arglist = self._convert_args(task.args, prepared_args)
output = getattr(actor_handle, task.name).remote(*arglist)
self.object_refs[task.client_id][output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
def _schedule_actor(self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
with stash_api_for_tests(self._test_mode):
payload_ref = cloudpickle.loads(task.payload_id)
if payload_ref.binary() not in self.registered_actor_classes:
actor_class_ref = self.object_refs[payload_ref.binary()]
actor_class = ray.get(actor_class_ref)
if not inspect.isclass(actor_class):
raise Exception("Attempting to schedule actor that "
"isn't a class.")
reg_class = ray.remote(actor_class)
self.registered_actor_classes[payload_ref.binary()] = reg_class
remote_class = self.registered_actor_classes[payload_ref.binary()]
arglist = _convert_args(task.args, prepared_args)
actor = remote_class.remote(*arglist)
actorhandle = cloudpickle.dumps(actor)
self.actor_refs[actorhandle] = actor
if task.payload_id not in self.registered_actor_classes:
actor_class_ref = \
self.object_refs[task.client_id][task.payload_id]
actor_class = ray.get(actor_class_ref)
if not inspect.isclass(actor_class):
raise Exception("Attempting to schedule actor that "
"isn't a class.")
reg_class = ray.remote(actor_class)
self.registered_actor_classes[task.payload_id] = reg_class
remote_class = self.registered_actor_classes[task.payload_id]
arglist = self._convert_args(task.args, prepared_args)
actor = remote_class.remote(*arglist)
self.actor_refs[actor._actor_id.binary()] = actor
self.actor_owners[task.client_id].add(actor._actor_id.binary())
return ray_client_pb2.ClientTaskTicket(
return_ref=make_remote_ref(actor._actor_id.binary(), actorhandle))
return_id=actor._actor_id.binary())
def _schedule_function(
self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
payload_ref = cloudpickle.loads(task.payload_id)
if payload_ref.binary() not in self.function_refs:
funcref = self.object_refs[payload_ref.binary()]
remote_func = self.lookup_or_register_func(task.payload_id,
task.client_id)
arglist = self._convert_args(task.args, prepared_args)
# Prepare call if we're in a test
with current_func(remote_func):
output = remote_func.remote(*arglist)
if output.binary() in self.object_refs[task.client_id]:
raise Exception("already found it")
self.object_refs[task.client_id][output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
def _convert_args(self, arg_list, prepared_args=None):
if prepared_args is not None:
return prepared_args
out = []
for arg in arg_list:
t = convert_from_arg(arg, self)
out.append(t)
return out
def lookup_or_register_func(self, id: bytes, client_id: str
) -> ray.remote_function.RemoteFunction:
if id not in self.function_refs:
funcref = self.object_refs[client_id][id]
func = ray.get(funcref)
if not inspect.isfunction(func):
raise Exception("Attempting to schedule function that "
raise Exception("Attempting to register function that "
"isn't a function.")
self.function_refs[payload_ref.binary()] = ray.remote(func)
remote_func = self.function_refs[payload_ref.binary()]
arglist = _convert_args(task.args, prepared_args)
# Prepare call if we're in a test
with stash_api_for_tests(self._test_mode):
output = remote_func.remote(*arglist)
if output.binary() in self.object_refs:
raise Exception("already found it")
self.object_refs[output.binary()] = output
pickled_output = cloudpickle.dumps(output)
return ray_client_pb2.ClientTaskTicket(
return_ref=make_remote_ref(output.binary(), pickled_output))
def _convert_args(arg_list, prepared_args=None):
if prepared_args is not None:
return prepared_args
out = []
for arg in arg_list:
t = convert_from_arg(arg)
if isinstance(t, ClientObjectRef):
out.append(t._unpack_ref())
else:
out.append(t)
return out
def make_remote_ref(id: bytes, handle: bytes) -> ray_client_pb2.RemoteRef:
return ray_client_pb2.RemoteRef(
id=id,
handle=handle,
)
self.function_refs[id] = ray.remote(func)
return self.function_refs[id]
def return_exception_in_context(err, context):
@ -252,12 +300,20 @@ def return_exception_in_context(err, context):
context.set_code(grpc.StatusCode.INTERNAL)
def encode_exception(exception) -> str:
data = cloudpickle.dumps(exception)
return base64.standard_b64encode(data).decode()
def serve(connection_str, test_mode=False):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
task_servicer = RayletServicer(test_mode=test_mode)
data_servicer = DataServicer(task_servicer)
_set_server_api(RayServerAPI(task_servicer))
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
task_servicer, server)
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
data_servicer, server)
server.add_insecure_port(connection_str)
server.start()
return server

View file

@ -0,0 +1,119 @@
"""
Implements the client side of the client/server pickling protocol.
These picklers are aware of the server internals and can find the
references held for the client within the server.
More discussion about the client/server pickling protocol can be found in:
ray/experimental/client/client_pickler.py
ServerPickler dumps ray objects from the server into the appropriate stubs.
ClientUnpickler loads stubs from the client and finds their associated handle
in the server instance.
"""
import cloudpickle
import io
import sys
import ray
from typing import Any
from typing import TYPE_CHECKING
from ray.experimental.client.client_pickler import PickleStub
from ray.experimental.client.server.server_stubs import ServerFunctionSentinel
if TYPE_CHECKING:
from ray.experimental.client.server.server import RayletServicer
import ray.core.generated.ray_client_pb2 as ray_client_pb2
if sys.version_info < (3, 8):
try:
import pickle5 as pickle # noqa: F401
except ImportError:
import pickle # noqa: F401
else:
import pickle # noqa: F401
class ServerPickler(cloudpickle.CloudPickler):
def __init__(self, client_id: str, server: "RayletServicer", *args,
**kwargs):
super().__init__(*args, **kwargs)
self.client_id = client_id
self.server = server
def persistent_id(self, obj):
if isinstance(obj, ray.ObjectRef):
obj_id = obj.binary()
if obj_id not in self.server.object_refs[self.client_id]:
# We're passing back a reference, probably inside a reference.
# Let's hold onto it.
self.server.object_refs[self.client_id][obj_id] = obj
return PickleStub(
type="Object",
client_id=self.client_id,
ref_id=obj_id,
)
elif isinstance(obj, ray.actor.ActorHandle):
actor_id = obj._actor_id.binary()
if actor_id not in self.server.actor_refs:
# We're passing back a handle, probably inside a reference.
self.actor_refs[actor_id] = obj
if actor_id not in self.actor_owners[self.client_id]:
self.actor_owners[self.client_id].add(actor_id)
return PickleStub(
type="Actor",
client_id=self.client_id,
ref_id=obj._actor_id.binary(),
)
return None
class ClientUnpickler(pickle.Unpickler):
def __init__(self, server, *args, **kwargs):
super().__init__(*args, **kwargs)
self.server = server
def persistent_load(self, pid):
assert isinstance(pid, PickleStub)
if pid.type == "Object":
return self.server.object_refs[pid.client_id][pid.ref_id]
elif pid.type == "Actor":
return self.server.actor_refs[pid.ref_id]
elif pid.type == "RemoteFuncSelfReference":
return ServerFunctionSentinel()
elif pid.type == "RemoteFunc":
return self.server.lookup_or_register_func(pid.ref_id,
pid.client_id)
else:
raise NotImplementedError("Uncovered client data type")
def dumps_from_server(obj: Any,
client_id: str,
server_instance: "RayletServicer",
protocol=None) -> bytes:
with io.BytesIO() as file:
sp = ServerPickler(client_id, server_instance, file, protocol=protocol)
sp.dump(obj)
return file.getvalue()
def loads_from_client(data: bytes,
server_instance: "RayletServicer",
*,
fix_imports=True,
encoding="ASCII",
errors="strict") -> Any:
if isinstance(data, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(data)
return ClientUnpickler(
server_instance, file, fix_imports=fix_imports,
encoding=encoding).load()
def convert_from_arg(pb: "ray_client_pb2.Arg",
server: "RayletServicer") -> Any:
return loads_from_client(pb.data, server)

View file

@ -0,0 +1,29 @@
from contextlib import contextmanager
_current_remote_func = None
@contextmanager
def current_func(f):
global _current_remote_func
remote_func = _current_remote_func
_current_remote_func = f
try:
yield
finally:
_current_remote_func = remote_func
class ServerFunctionSentinel:
def __init__(self):
pass
def __reduce__(self):
global _current_remote_func
if _current_remote_func is None:
return (ServerFunctionSentinel, tuple())
return (identity, (_current_remote_func, ))
def identity(x):
return x

View file

@ -2,27 +2,32 @@
It implements the Ray API functions that are forwarded through grpc calls
to the server.
"""
import base64
import inspect
import json
import logging
import uuid
from collections import defaultdict
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Optional
import ray.cloudpickle as cloudpickle
from ray.util.inspect import is_cython
import grpc
from ray.exceptions import TaskCancelledError
import ray.cloudpickle as cloudpickle
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.experimental.client.common import convert_to_arg
from ray.experimental.client.common import decode_exception
from ray.experimental.client.client_pickler import convert_to_arg
from ray.experimental.client.client_pickler import loads_from_server
from ray.experimental.client.client_pickler import dumps_from_client
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.dataclient import DataClient
logger = logging.getLogger(__name__)
@ -31,34 +36,32 @@ class Worker:
def __init__(self,
conn_str: str = "",
secure: bool = False,
metadata: List[Tuple[str, str]] = None,
stub=None):
metadata: List[Tuple[str, str]] = None):
"""Initializes the worker side grpc client.
Args:
stub: custom grpc stub.
secure: whether to use SSL secure channel or not.
metadata: additional metadata passed in the grpc request headers.
"""
self.metadata = metadata
self.channel = None
if stub is None:
if secure:
credentials = grpc.ssl_channel_credentials()
self.channel = grpc.secure_channel(conn_str, credentials)
else:
self.channel = grpc.insecure_channel(conn_str)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
self._client_id = make_client_id()
if secure:
credentials = grpc.ssl_channel_credentials()
self.channel = grpc.secure_channel(conn_str, credentials)
else:
self.server = stub
self.channel = grpc.insecure_channel(conn_str)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
self.data_client = DataClient(self.channel, self._client_id)
self.reference_count: Dict[bytes, int] = defaultdict(int)
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
to_get = []
single = False
if isinstance(vals, list):
to_get = [x.handle for x in vals]
to_get = vals
elif isinstance(vals, ClientObjectRef):
to_get = [vals.handle]
to_get = [vals]
single = True
else:
raise Exception("Can't get something that's not a "
@ -70,15 +73,15 @@ class Worker:
out = out[0]
return out
def _get(self, handle: bytes, timeout: float):
req = ray_client_pb2.GetRequest(handle=handle, timeout=timeout)
def _get(self, ref: ClientObjectRef, timeout: float):
req = ray_client_pb2.GetRequest(id=ref.id, timeout=timeout)
try:
data = self.server.GetObject(req, metadata=self.metadata)
data = self.data_client.GetObject(req)
except grpc.RpcError as e:
raise decode_exception(e.details())
raise e.details()
if not data.valid:
raise TaskCancelledError(handle)
return cloudpickle.loads(data.data)
raise cloudpickle.loads(data.error)
return loads_from_server(data.data)
def put(self, vals):
to_put = []
@ -95,10 +98,10 @@ class Worker:
return out
def _put(self, val):
data = cloudpickle.dumps(val)
data = dumps_from_client(val, self._client_id)
req = ray_client_pb2.PutRequest(data=data)
resp = self.server.PutObject(req, metadata=self.metadata)
return ClientObjectRef.from_remote_ref(resp.ref)
resp = self.data_client.PutObject(req)
return ClientObjectRef(resp.id)
def wait(self,
object_refs: List[ClientObjectRef],
@ -110,11 +113,10 @@ class Worker:
for ref in object_refs:
assert isinstance(ref, ClientObjectRef)
data = {
"object_handles": [
object_ref.handle for object_ref in object_refs
],
"object_ids": [object_ref.id for object_ref in object_refs],
"num_returns": num_returns,
"timeout": timeout if timeout else -1
"timeout": timeout if timeout else -1,
"client_id": self._client_id,
}
req = ray_client_pb2.WaitRequest(**data)
resp = self.server.WaitObject(req, metadata=self.metadata)
@ -122,12 +124,10 @@ class Worker:
# TODO(ameer): improve error/exceptions messages.
raise Exception("Client Wait request failed. Reference invalid?")
client_ready_object_ids = [
ClientObjectRef.from_remote_ref(ref)
for ref in resp.ready_object_ids
ClientObjectRef(ref) for ref in resp.ready_object_ids
]
client_remaining_object_ids = [
ClientObjectRef.from_remote_ref(ref)
for ref in resp.remaining_object_ids
ClientObjectRef(ref) for ref in resp.remaining_object_ids
]
return (client_ready_object_ids, client_remaining_object_ids)
@ -144,19 +144,38 @@ class Worker:
raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.")
def call_remote(self, instance, *args, **kwargs):
def call_remote(self, instance, *args, **kwargs) -> bytes:
task = instance._prepare_client_task()
for arg in args:
pb_arg = convert_to_arg(arg)
pb_arg = convert_to_arg(arg, self._client_id)
task.args.append(pb_arg)
logging.debug("Scheduling %s" % task)
task.client_id = self._client_id
logger.debug("Scheduling %s" % task)
ticket = self.server.Schedule(task, metadata=self.metadata)
return ClientObjectRef.from_remote_ref(ticket.return_ref)
return ticket.return_id
def call_release(self, id: bytes) -> None:
self.reference_count[id] -= 1
if self.reference_count[id] == 0:
self._release_server(id)
del self.reference_count[id]
def _release_server(self, id: bytes) -> None:
if self.data_client is not None:
logger.debug(f"Releasing {id}")
self.data_client.ReleaseObject(
ray_client_pb2.ReleaseRequest(ids=[id]))
def call_retain(self, id: bytes) -> None:
logger.debug(f"Retaining {id}")
self.reference_count[id] += 1
def close(self):
self.data_client.close()
self.server = None
if self.channel:
self.channel.close()
self.channel = None
def terminate_actor(self, actor: ClientActorHandle,
no_restart: bool) -> None:
@ -164,10 +183,11 @@ class Worker:
raise ValueError("ray.kill() only supported for actors. "
"Got: {}.".format(type(actor)))
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
term_actor.handle = actor.actor_ref.handle
term_actor.id = actor.actor_ref.id
term_actor.no_restart = no_restart
try:
term = ray_client_pb2.TerminateRequest(actor=term_actor)
term.client_id = self._client_id
self.server.Terminate(term)
except grpc.RpcError as e:
raise decode_exception(e.details())
@ -179,11 +199,12 @@ class Worker:
"ray.cancel() only supported for non-actor object refs. "
f"Got: {type(obj)}.")
term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate()
term_object.handle = obj.handle
term_object.id = obj.id
term_object.force = force
term_object.recursive = recursive
try:
term = ray_client_pb2.TerminateRequest(task_object=term_object)
term.client_id = self._client_id
self.server.Terminate(term)
except grpc.RpcError as e:
raise decode_exception(e.details())
@ -201,3 +222,13 @@ class Worker:
return self.get_cluster_info(
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
return False
def make_client_id() -> str:
id = uuid.uuid4()
return id.hex
def decode_exception(data) -> Exception:
data = base64.standard_b64decode(data)
return loads_from_server(data)

View file

@ -96,6 +96,7 @@ py_test_module_list(
"test_debug_tools.py",
"test_experimental_client.py",
"test_experimental_client_metadata.py",
"test_experimental_client_references.py",
"test_experimental_client_terminate.py",
"test_job.py",
"test_memstat.py",

View file

@ -142,7 +142,7 @@ def test_function_calling_function(ray_start_regular_shared):
@ray.remote
def f():
print(f, f._name, g._name, g)
print(f, g)
return ray.get(g.remote())
print(f, type(f))

View file

@ -0,0 +1,152 @@
from ray.tests.test_experimental_client import ray_start_client_server
from ray.test_utils import wait_for_condition
import ray as real_ray
from ray.core.generated.gcs_pb2 import ActorTableData
from ray.experimental.client import _get_server_instance
def server_object_ref_count(n):
server = _get_server_instance()
assert server is not None
def test_cond():
if len(server.object_refs) == 0:
# No open clients
return n == 0
client_id = list(server.object_refs.keys())[0]
return len(server.object_refs[client_id]) == n
return test_cond
def server_actor_ref_count(n):
server = _get_server_instance()
assert server is not None
def test_cond():
if len(server.actor_refs) == 0:
# No running actors
return n == 0
return len(server.actor_refs) == n
return test_cond
def test_delete_refs_on_disconnect(ray_start_regular):
with ray_start_client_server() as ray:
@ray.remote
def f(x):
return x + 2
thing1 = f.remote(6) # noqa
thing2 = ray.put("Hello World") # noqa
# One put, one function -- the function result thing1 is
# in a different category, according to the raylet.
assert len(real_ray.objects()) == 2
# But we're maintaining the reference
assert server_object_ref_count(3)()
# And can get the data
assert ray.get(thing1) == 8
# Close the client
ray.close()
wait_for_condition(server_object_ref_count(0), timeout=5)
def test_cond():
return len(real_ray.objects()) == 0
wait_for_condition(test_cond, timeout=5)
def test_delete_ref_on_object_deletion(ray_start_regular):
with ray_start_client_server() as ray:
vals = {
"ref": ray.put("Hello World"),
"ref2": ray.put("This value stays"),
}
del vals["ref"]
wait_for_condition(server_object_ref_count(1), timeout=5)
def test_delete_actor_on_disconnect(ray_start_regular):
with ray_start_client_server() as ray:
@ray.remote
class Accumulator:
def __init__(self):
self.acc = 0
def inc(self):
self.acc += 1
def get(self):
return self.acc
actor = Accumulator.remote()
actor.inc.remote()
assert server_actor_ref_count(1)()
assert ray.get(actor.get.remote()) == 1
ray.close()
wait_for_condition(server_actor_ref_count(0), timeout=5)
def test_cond():
alive_actors = [
v for v in real_ray.actors().values()
if v["State"] != ActorTableData.DEAD
]
return len(alive_actors) == 0
wait_for_condition(test_cond, timeout=10)
def test_delete_actor(ray_start_regular):
with ray_start_client_server() as ray:
@ray.remote
class Accumulator:
def __init__(self):
self.acc = 0
def inc(self):
self.acc += 1
actor = Accumulator.remote()
actor.inc.remote()
actor2 = Accumulator.remote()
actor2.inc.remote()
assert server_actor_ref_count(2)()
del actor
wait_for_condition(server_actor_ref_count(1), timeout=5)
def test_simple_multiple_references(ray_start_regular):
with ray_start_client_server() as ray:
@ray.remote
class A:
def __init__(self):
self.x = ray.put("hi")
def get(self):
return [self.x]
a = A.remote()
ref1 = ray.get(a.get.remote())[0]
ref2 = ray.get(a.get.remote())[0]
del a
assert ray.get(ref1) == "hi"
del ref1
assert ray.get(ref2) == "hi"
del ref2

View file

@ -18,17 +18,24 @@ package ray.rpc;
enum Type { DEFAULT = 0; }
// An argument to a ClientTask.
message Arg {
enum Locality {
INTERNED = 0;
REFERENCE = 1;
}
// The type of argument this is -- whether a data blob or a reference.
Locality local = 1;
// The reference id, if a reference.
bytes reference_id = 2;
// A data blob, if passed in-band.
bytes data = 3;
// How to decode this data blob.
Type type = 4;
}
// Represents one unit of work to be executed by the server.
message ClientTask {
enum RemoteExecType {
FUNCTION = 0;
@ -36,49 +43,69 @@ message ClientTask {
METHOD = 2;
STATIC_METHOD = 3;
}
// Which type of work this request represents.
RemoteExecType type = 1;
// A name parameter, if the payload can be called in more than one way (like a method on
// a payload object).
string name = 2;
// A reference to the payload.
bytes payload_id = 3;
// The parameters to pass to this call.
repeated Arg args = 4;
}
message RemoteRef {
bytes id = 1;
bytes handle = 2;
// The ID of the client namespace associated with the Datapath stream making this
// request.
string client_id = 5;
}
message ClientTaskTicket {
RemoteRef return_ref = 1;
// A reference to the returned value from the execution.
bytes return_id = 1;
}
// Delivers data to the server
message PutRequest {
// The data blob for the server to store.
bytes data = 1;
}
message PutResponse {
RemoteRef ref = 1;
// The reference ID for the data that the server has stored.
bytes id = 1;
}
// Requests data from the server.
message GetRequest {
bytes handle = 1;
// The reference ID for the requested object data
bytes id = 1;
// Length of time to wait for data to be available, in seconds. Zero is no timeout.
float timeout = 2;
}
message GetResponse {
// Whether or not the data was successfully retrieved
bool valid = 1;
// The data blob, on success
bytes data = 2;
// An error blob (for example, an exception) on failure.
bytes error = 3;
}
// Waits for data to be ready on the server, with a timeout.
message WaitRequest {
repeated bytes object_handles = 1;
// The IDs of the data to wait for ready status.
repeated bytes object_ids = 1;
// How many of the above ids to wait for before returning.
int64 num_returns = 2;
// How long to wait for these IDs to become ready.
double timeout = 3;
// The Client namespace associated with the Datapath stream that holds these IDs.
string client_id = 4;
}
message WaitResponse {
bool valid = 1;
repeated RemoteRef ready_object_ids = 2;
repeated RemoteRef remaining_object_ids = 3;
repeated bytes ready_object_ids = 2;
repeated bytes remaining_object_ids = 3;
}
message ClusterInfoType {
@ -108,18 +135,19 @@ message ClusterInfoResponse {
message TerminateRequest {
message ActorTerminate {
bytes handle = 1;
bytes id = 1;
bool no_restart = 2;
}
message TaskObjectTerminate {
bytes handle = 1;
bytes id = 1;
bool force = 2;
bool recursive = 3;
}
string client_id = 1;
oneof terminate_type {
ActorTerminate actor = 1;
TaskObjectTerminate task_object = 2;
ActorTerminate actor = 2;
TaskObjectTerminate task_object = 3;
}
}
@ -141,3 +169,40 @@ service RayletDriver {
rpc ClusterInfo(ClusterInfoRequest) returns (ClusterInfoResponse) {
}
}
message ReleaseRequest {
// The IDs to release from the server; the client connected on this stream no
// longer holds a reference to them.
repeated bytes ids = 1;
}
message ReleaseResponse {
// For each requested ID, whether or not it was released.
repeated bool ok = 2;
}
message DataRequest {
// An incrementing counter of request IDs on the Datapath,
// to match requests with responses asynchronously.
int32 req_id = 1;
oneof type {
GetRequest get = 2;
PutRequest put = 3;
ReleaseRequest release = 4;
}
}
message DataResponse {
// The request id that this response matches with.
int32 req_id = 1;
oneof type {
GetResponse get = 2;
PutResponse put = 3;
ReleaseResponse release = 4;
}
}
service RayletDataStreamer {
rpc Datapath(stream DataRequest) returns (stream DataResponse) {
}
}