diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 4baa8a55b..5c79c3b79 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -91,7 +91,7 @@ cdef class ClientObjectRef(ObjectRef): cdef object _id_future cdef _set_id(self, id) - cdef inline _wait_for_id(self) + cdef inline _wait_for_id(self, timeout=None) cdef class ActorID(BaseID): cdef CActorID data @@ -105,7 +105,7 @@ cdef class ClientActorRef(ActorID): cdef object _id_future cdef _set_id(self, id) - cdef inline _wait_for_id(self) + cdef inline _wait_for_id(self, timeout=None) cdef class CoreWorker: cdef: diff --git a/python/ray/includes/object_ref.pxi b/python/ray/includes/object_ref.pxi index 7b0ccb22c..7a8ab79a8 100644 --- a/python/ray/includes/object_ref.pxi +++ b/python/ray/includes/object_ref.pxi @@ -166,8 +166,20 @@ cdef class ClientObjectRef(ObjectRef): # call_release in this case, since the client should have already # disconnected at this point. return - if client.ray.is_connected() and not self.data.IsNil(): - client.ray.call_release(self.id) + if client.ray.is_connected(): + try: + self._wait_for_id() + # cython would suppress this exception as well, but it tries to + # print out the exception which may crash. Log a simpler message + # instead. + except Exception: + logger.info( + "Exception in ObjectRef is ignored in destructor. " + "To receive this exception in application code, call " + "a method on the actor reference before its destructor " + "is run.") + if not self.data.IsNil(): + client.ray.call_release(self.id) cdef CObjectID native(self): self._wait_for_id() @@ -244,9 +256,9 @@ cdef class ClientObjectRef(ObjectRef): self.data = CObjectID.FromBinary(id) client.ray.call_retain(id) - cdef inline _wait_for_id(self): + cdef inline _wait_for_id(self, timeout=None): if self._id_future: with self._mutex: if self._id_future: - self._set_id(self._id_future.result()) + self._set_id(self._id_future.result(timeout=timeout)) self._id_future = None diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 6b3ef93c1..ed205b94d 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -8,6 +8,7 @@ See https://github.com/ray-project/ray/issues/3721. # _ID_TYPES list at the bottom of this file. from concurrent.futures import Future +import logging import os from ray.includes.unique_ids cimport ( @@ -27,6 +28,8 @@ from ray.includes.unique_ids cimport ( import ray from ray._private.utils import decode +logger = logging.getLogger(__name__) + def check_id(b, size=kUniqueIDSize): if not isinstance(b, bytes): @@ -326,8 +329,20 @@ cdef class ClientActorRef(ActorID): # call_release in this case, since the client should have already # disconnected at this point. return - if client.ray.is_connected() and not self.data.IsNil(): - client.ray.call_release(self.id) + if client.ray.is_connected(): + try: + self._wait_for_id() + # cython would suppress this exception as well, but it tries to + # print out the exception which may crash. Log a simpler message + # instead. + except Exception: + logger.info( + "Exception from actor creation is ignored in destructor. " + "To receive this exception in application code, call " + "a method on the actor reference before its destructor " + "is run.") + if not self.data.IsNil(): + client.ray.call_release(self.id) def binary(self): self._wait_for_id() @@ -358,11 +373,11 @@ cdef class ClientActorRef(ActorID): self.data = CActorID.FromBinary(id) client.ray.call_retain(id) - cdef _wait_for_id(self): + cdef _wait_for_id(self, timeout=None): if self._id_future: with self._mutex: if self._id_future: - self._set_id(self._id_future.result()) + self._set_id(self._id_future.result(timeout=timeout)) self._id_future = None diff --git a/python/ray/tests/test_asyncio.py b/python/ray/tests/test_asyncio.py index 8724bf407..93af73f18 100644 --- a/python/ray/tests/test_asyncio.py +++ b/python/ray/tests/test_asyncio.py @@ -303,6 +303,8 @@ async def test_async_obj_unhandled_errors(ray_start_regular_shared): # Test we report unhandled exceptions. ray.worker._unhandled_error_handler = interceptor x1 = f.remote() + # NOTE: Unhandled exception is from waiting for the value of x1's ObjectID + # in x1's destructor, and receiving an exception from f() instead. del x1 wait_for_condition(lambda: num_exceptions == 1) diff --git a/python/ray/tests/test_client.py b/python/ray/tests/test_client.py index ea907a9ff..de552b1fe 100644 --- a/python/ray/tests/test_client.py +++ b/python/ray/tests/test_client.py @@ -135,14 +135,23 @@ def test_get_list(ray_start_regular_shared): assert ray.get([]) == [] assert ray.get([f.remote()]) == ["OK"] + get_count = 0 + get_stub = ray.worker.server.GetObject + + # ray.get() uses unary-unary RPC. Mock the server handler to count + # the number of requests received. + def get(req, metadata=None): + nonlocal get_count + get_count += 1 + return get_stub(req, metadata=metadata) + + ray.worker.server.GetObject = get + refs = [f.remote() for _ in range(100)] - with ray.worker.data_client.lock: - req_id_before = ray.worker.data_client._req_id assert ray.get(refs) == ["OK" for _ in range(100)] + # Only 1 RPC should be sent. - with ray.worker.data_client.lock: - assert ray.worker.data_client._req_id == req_id_before + 1, \ - ray.worker.data_client._req_id + assert get_count == 1 @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") @@ -470,6 +479,22 @@ def test_serializing_exceptions(ray_start_regular_shared): ray.get_actor("abc") +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +def test_invalid_task(ray_start_regular_shared): + with ray_start_client_server() as ray: + + @ray.remote(runtime_env="invalid value") + def f(): + return 1 + + # No exception on making the remote call. + ref = f.remote() + + # Exception during scheduling will be raised on ray.get() + with pytest.raises(Exception): + ray.get(ref) + + @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") def test_create_remote_before_start(ray_start_regular_shared): """Creates remote objects (as though in a library) before diff --git a/python/ray/tests/test_client_terminate.py b/python/ray/tests/test_client_terminate.py index c18f3fcdb..211bb40cb 100644 --- a/python/ray/tests/test_client_terminate.py +++ b/python/ray/tests/test_client_terminate.py @@ -113,11 +113,12 @@ def test_kill_cancel_metadata(ray_start_regular): class MetadataIsCorrectlyPassedException(Exception): pass - def mock_terminate(term, metadata): - raise MetadataIsCorrectlyPassedException(metadata[1][0]) + def mock_terminate(self, term): + raise MetadataIsCorrectlyPassedException(self._metadata[1][0]) # Mock stub's Terminate method to raise an exception. - ray.get_context().api.worker.server.Terminate = mock_terminate + stub = ray.get_context().api.worker.data_client + stub.Terminate = mock_terminate.__get__(stub) # Verify the expected exception is raised with ray.kill. # Check that argument of the exception matches "key" from the diff --git a/python/ray/tests/test_client_warnings.py b/python/ray/tests/test_client_warnings.py index 8754de076..0f67c2ea7 100644 --- a/python/ray/tests/test_client_warnings.py +++ b/python/ray/tests/test_client_warnings.py @@ -1,12 +1,10 @@ from ray.util.client.ray_client_helpers import ray_start_client_server -from ray.util.client.worker import TASK_WARNING_THRESHOLD from ray.util.debug import _logged import numpy as np import pytest import unittest -import warnings @pytest.fixture(autouse=True) @@ -19,33 +17,6 @@ def reset_debug_logs(): class LoggerSuite(unittest.TestCase): """Test client warnings are raised when many tasks are scheduled""" - def testManyTasksWarning(self): - with ray_start_client_server() as ray: - - @ray.remote - def f(): - return 42 - - with self.assertWarns(UserWarning) as cm: - for _ in range(TASK_WARNING_THRESHOLD + 1): - f.remote() - assert f"More than {TASK_WARNING_THRESHOLD} remote tasks have " \ - "been scheduled." in cm.warning.args[0] - - def testNoWarning(self): - with ray_start_client_server() as ray: - - @ray.remote - def f(): - return 42 - - with warnings.catch_warnings(record=True) as warn_list: - for _ in range(TASK_WARNING_THRESHOLD): - f.remote() - assert not any(f"More than {TASK_WARNING_THRESHOLD} remote tasks " - "have been scheduled." in str(w.args[0]) - for w in warn_list) - def testOutboundMessageSizeWarning(self): with ray_start_client_server() as ray: large_argument = np.random.rand(100, 100, 100) diff --git a/python/ray/tests/test_dataclient_disconnect.py b/python/ray/tests/test_dataclient_disconnect.py index b0f42a4ea..ab8c1bfcd 100644 --- a/python/ray/tests/test_dataclient_disconnect.py +++ b/python/ray/tests/test_dataclient_disconnect.py @@ -50,11 +50,24 @@ def test_dataclient_disconnect_before_request(): # different remote calls. ray.worker.data_client.request_queue.put(Mock()) - # The next remote call should error since the data channel has shut - # down, which should also disconnect the client. + # The following two assertions are relatively brittle. Consider a more + # robust mechanism if they fail with code changes or become flaky. + + # The next remote call should error since the data channel will shut + # down because of the invalid input above. Two cases can happen: + # (1) Data channel shuts down after `f.remote()` finishes. + # error is raised to `ray.get()`. The next background operation + # will disconnect Ray client. + # (2) Data channel shuts down before `f.remote()` is called. + # `f.remote()` will raise the error and disconnect the client. with pytest.raises(ConnectionError): ray.get(f.remote()) + with pytest.raises( + ConnectionError, + match="Ray client has already been disconnected"): + ray.get(f.remote()) + # Client should be disconnected assert not ray.is_connected() diff --git a/python/ray/tests/test_placement_group_2.py b/python/ray/tests/test_placement_group_2.py index 683c6c09a..f2688325b 100644 --- a/python/ray/tests/test_placement_group_2.py +++ b/python/ray/tests/test_placement_group_2.py @@ -46,28 +46,18 @@ def test_check_bundle_index(ray_start_cluster, connect_to_client): "CPU": 2 }]) - error_count = 0 - try: + with pytest.raises(ValueError, match="bundle index 3 is invalid"): Actor.options( placement_group=placement_group, placement_group_bundle_index=3).remote() - except ValueError: - error_count = error_count + 1 - assert error_count == 1 - try: + with pytest.raises(ValueError, match="bundle index -2 is invalid"): Actor.options( placement_group=placement_group, placement_group_bundle_index=-2).remote() - except ValueError: - error_count = error_count + 1 - assert error_count == 2 - try: + with pytest.raises(ValueError, match="bundle index must be -1"): Actor.options(placement_group_bundle_index=0).remote() - except ValueError: - error_count = error_count + 1 - assert error_count == 3 @pytest.mark.parametrize("connect_to_client", [False, True]) diff --git a/python/ray/util/client/api.py b/python/ray/util/client/api.py index a683e9823..c8a77166b 100644 --- a/python/ray/util/client/api.py +++ b/python/ray/util/client/api.py @@ -1,6 +1,7 @@ """This file defines the interface between the ray client worker and the overall ray module API. """ +from concurrent.futures import Future import json import logging @@ -85,7 +86,10 @@ class ClientAPI: assert len(args) == 0 and len(kwargs) > 0, error_string return remote_decorator(options=kwargs) - def call_remote(self, instance: "ClientStub", *args, **kwargs): + # TODO(mwtian): consider adding _internal_ prefix to call_remote / + # call_release / call_retain. + def call_remote(self, instance: "ClientStub", *args, + **kwargs) -> List[Future]: """call_remote is called by stub objects to execute them remotely. This is used by stub objects in situations where they're called diff --git a/python/ray/util/client/client_pickler.py b/python/ray/util/client/client_pickler.py index 8e4dece07..9c1ebef68 100644 --- a/python/ray/util/client/client_pickler.py +++ b/python/ray/util/client/client_pickler.py @@ -141,9 +141,9 @@ class ServerUnpickler(pickle.Unpickler): def persistent_load(self, pid): assert isinstance(pid, PickleStub) if pid.type == "Object": - return ClientObjectRef(id=pid.ref_id) + return ClientObjectRef(pid.ref_id) elif pid.type == "Actor": - return ClientActorHandle(ClientActorRef(id=pid.ref_id)) + return ClientActorHandle(ClientActorRef(pid.ref_id)) else: raise NotImplementedError("Being passed back an unknown stub") diff --git a/python/ray/util/client/common.py b/python/ray/util/client/common.py index 52757a25a..7403b28cd 100644 --- a/python/ray/util/client/common.py +++ b/python/ray/util/client/common.py @@ -6,6 +6,7 @@ from ray.util.client.options import validate_options from ray._private.signature import get_signature, extract_signature from ray._private.utils import check_oversized_function +import concurrent from dataclasses import dataclass import grpc import os @@ -209,9 +210,9 @@ class ClientActorClass(ClientStub): def remote(self, *args, **kwargs) -> "ClientActorHandle": self._init_signature.bind(*args, **kwargs) # Actually instantiate the actor - ref_ids = ray.call_remote(self, *args, **kwargs) - assert len(ref_ids) == 1 - return ClientActorHandle(ClientActorRef(ref_ids[0]), actor_class=self) + futures = ray.call_remote(self, *args, **kwargs) + assert len(futures) == 1 + return ClientActorHandle(ClientActorRef(futures[0]), actor_class=self) def options(self, **kwargs): return ActorOptionWrapper(self, kwargs) @@ -397,13 +398,13 @@ class OptionWrapper: class ActorOptionWrapper(OptionWrapper): def remote(self, *args, **kwargs): self._remote_stub._init_signature.bind(*args, **kwargs) - ref_ids = ray.call_remote(self, *args, **kwargs) - assert len(ref_ids) == 1 + futures = ray.call_remote(self, *args, **kwargs) + assert len(futures) == 1 actor_class = None if isinstance(self._remote_stub, ClientActorClass): actor_class = self._remote_stub return ClientActorHandle( - ClientActorRef(ref_ids[0]), actor_class=actor_class) + ClientActorRef(futures[0]), actor_class=actor_class) def set_task_options(task: ray_client_pb2.ClientTask, @@ -423,13 +424,13 @@ def set_task_options(task: ray_client_pb2.ClientTask, getattr(task, field).json_options = options_str -def return_refs(ids: List[bytes] +def return_refs(futures: List[concurrent.futures.Future] ) -> Union[None, ClientObjectRef, List[ClientObjectRef]]: - if len(ids) == 1: - return ClientObjectRef(ids[0]) - if len(ids) == 0: + if not futures: return None - return [ClientObjectRef(id) for id in ids] + if len(futures) == 1: + return ClientObjectRef(futures[0]) + return [ClientObjectRef(fut) for fut in futures] class InProgressSentinel: diff --git a/python/ray/util/client/dataclient.py b/python/ray/util/client/dataclient.py index c48ff7369..9d351feec 100644 --- a/python/ray/util/client/dataclient.py +++ b/python/ray/util/client/dataclient.py @@ -361,3 +361,21 @@ class DataClient: context=None) -> None: datareq = ray_client_pb2.DataRequest(release=request, ) self._async_send(datareq) + + def Schedule(self, request: ray_client_pb2.ClientTask, + callback: ResponseCallable): + datareq = ray_client_pb2.DataRequest(task=request) + self._async_send(datareq, callback) + + def Terminate(self, request: ray_client_pb2.TerminateRequest + ) -> ray_client_pb2.TerminateResponse: + req = ray_client_pb2.DataRequest(terminate=request, ) + resp = self._blocking_send(req) + return resp.terminate + + def ListNamedActors(self, + request: ray_client_pb2.ClientListNamedActorsRequest + ) -> ray_client_pb2.ClientListNamedActorsResponse: + req = ray_client_pb2.DataRequest(list_named_actors=request, ) + resp = self._blocking_send(req) + return resp.list_named_actors diff --git a/python/ray/util/client/options.py b/python/ray/util/client/options.py index a0d932352..9c9df946d 100644 --- a/python/ray/util/client/options.py +++ b/python/ray/util/client/options.py @@ -2,6 +2,9 @@ from typing import Any from typing import Dict from typing import Optional +from ray.util.placement_group import (PlacementGroup, + check_placement_group_index) + options = { "num_returns": (int, lambda x: x >= 0, "The keyword 'num_returns' only accepts 0 " @@ -43,6 +46,7 @@ def validate_options( return None if len(kwargs_dict) == 0: return None + out = {} for k, v in kwargs_dict.items(): if k not in options.keys(): @@ -55,4 +59,21 @@ def validate_options( if not validator[1](v): raise ValueError(validator[2]) out[k] = v + + # Validate placement setting similar to the logic in ray/actor.py and + # ray/remote_function.py. The difference is that when + # placement_group = default and placement_group_capture_child_tasks + # specified, placement group cannot be resolved at client. So this check + # skips this case and relies on server to enforce any condition. + bundle_index = out.get("placement_group_bundle_index", None) + if bundle_index is not None: + pg = out.get("placement_group", None) + if pg is None: + pg = PlacementGroup.empty() + if pg == "default" and (out.get("placement_group_capture_child_tasks", + None) is None): + pg = PlacementGroup.empty() + if isinstance(pg, PlacementGroup): + check_placement_group_index(pg, bundle_index) + return out diff --git a/python/ray/util/client/server/dataservicer.py b/python/ray/util/client/server/dataservicer.py index b6579e6da..d0b11218e 100644 --- a/python/ray/util/client/server/dataservicer.py +++ b/python/ray/util/client/server/dataservicer.py @@ -189,6 +189,23 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): # Clean up acknowledged cache entries response_cache.cleanup(req.acknowledge.req_id) continue + elif req_type == "task": + with self.clients_lock: + resp_ticket = self.basic_service.Schedule( + req.task, context) + resp = ray_client_pb2.DataResponse( + task_ticket=resp_ticket) + elif req_type == "terminate": + with self.clients_lock: + response = self.basic_service.Terminate( + req.terminate, context) + resp = ray_client_pb2.DataResponse(terminate=response) + elif req_type == "list_named_actors": + with self.clients_lock: + response = self.basic_service.ListNamedActors( + req.list_named_actors) + resp = ray_client_pb2.DataResponse( + list_named_actors=response) else: raise Exception(f"Unreachable code: Request type " f"{req_type} not handled in Datapath") diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index c55af6387..664bf9f41 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -11,8 +11,10 @@ import time import uuid import warnings from collections import defaultdict +from concurrent.futures import Future import tempfile -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import (Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, + Union) import grpc @@ -53,10 +55,6 @@ MAX_BLOCKING_OPERATION_TIME_S: float = 2.0 # the connection began exceeds this value, a warning should be raised MESSAGE_SIZE_THRESHOLD = 10 * 2**20 # 10 MB -# If the number of tasks scheduled on the client side since the connection -# began exceeds this value, a warning should be raised -TASK_WARNING_THRESHOLD = 1000 - # Links to the Ray Design Pattern doc to use in the task overhead warning # message DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = \ @@ -138,9 +136,7 @@ class Worker: self.closed = False - # Track these values to raise a warning if many tasks are being - # scheduled - self.total_num_tasks_scheduled = 0 + # Track this value to raise a warning if a lot of data are transferred. self.total_outbound_message_size_bytes = 0 # Used to create unique IDs for RPCs to the RayletServicer @@ -365,17 +361,17 @@ class Worker: req = ray_client_pb2.GetRequest( ids=[r.id for r in ref], timeout=timeout) try: - data = self.data_client.GetObject(req) + resp = self._call_stub("GetObject", req, metadata=self.metadata) except grpc.RpcError as e: raise decode_exception(e) - if not data.valid: + if not resp.valid: try: - err = cloudpickle.loads(data.error) + err = cloudpickle.loads(resp.error) except (pickle.UnpicklingError, TypeError): - logger.exception("Failed to deserialize {}".format(data.error)) + logger.exception("Failed to deserialize {}".format(resp.error)) raise raise err - return loads_from_server(data.data) + return loads_from_server(resp.data) def put(self, vals, *, client_ref_id: bytes = None): to_put = [] @@ -450,7 +446,7 @@ class Worker: return (client_ready_object_ids, client_remaining_object_ids) - def call_remote(self, instance, *args, **kwargs) -> List[bytes]: + def call_remote(self, instance, *args, **kwargs) -> List[Future]: task = instance._prepare_client_task() for arg in args: pb_arg = convert_to_arg(arg, self._client_id) @@ -460,38 +456,47 @@ class Worker: return self._call_schedule_for_task(task, instance._num_returns()) def _call_schedule_for_task(self, task: ray_client_pb2.ClientTask, - num_returns: int) -> List[bytes]: + num_returns: int) -> List[Future]: logger.debug("Scheduling %s" % task) task.client_id = self._client_id - metadata = self._add_ids_to_metadata(self.metadata) if num_returns is None: num_returns = 1 - try: - ticket = self._call_stub("Schedule", task, metadata=metadata) - except grpc.RpcError as e: - raise decode_exception(e) + id_futures = [Future() for _ in range(num_returns)] + + def populate_ids( + resp: Union[ray_client_pb2.DataResponse, Exception]) -> None: + if isinstance(resp, Exception): + if isinstance(resp, grpc.RpcError): + resp = decode_exception(resp) + for future in id_futures: + future.set_exception(resp) + return + + ticket = resp.task_ticket + if not ticket.valid: + try: + ex = cloudpickle.loads(ticket.error) + except (pickle.UnpicklingError, TypeError) as e_new: + ex = e_new + for future in id_futures: + future.set_exception(ex) + return + + if len(ticket.return_ids) != num_returns: + exc = ValueError( + f"Expected {num_returns} returns but received " + f"{len(ticket.return_ids)}") + for future, raw_id in zip(id_futures, ticket.return_ids): + future.set_exception(exc) + return + + for future, raw_id in zip(id_futures, ticket.return_ids): + future.set_result(raw_id) + + self.data_client.Schedule(task, populate_ids) - if not ticket.valid: - try: - raise cloudpickle.loads(ticket.error) - except (pickle.UnpicklingError, TypeError): - logger.exception("Failed to deserialize {}".format( - ticket.error)) - raise - self.total_num_tasks_scheduled += 1 self.total_outbound_message_size_bytes += task.ByteSize() - if self.total_num_tasks_scheduled > TASK_WARNING_THRESHOLD and \ - log_once("client_communication_overhead_warning"): - warnings.warn( - f"More than {TASK_WARNING_THRESHOLD} remote tasks have been " - "scheduled. This can be slow on Ray Client due to " - "communication overhead over the network. If you're running " - "many fine-grained tasks, consider running them in a single " - "remote function. See the section on \"Too fine-grained " - "tasks\" in the Ray Design Patterns document for more " - f"details: {DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK}", - UserWarning) if self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD \ and log_once("client_communication_overhead_warning"): warnings.warn( @@ -508,10 +513,7 @@ class Worker: "unserializable object\" section of the Ray Design Patterns " "document, available here: " f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}", UserWarning) - if num_returns != len(ticket.return_ids): - raise TypeError("Unexpected number of returned values. Expected " - f"{num_returns} actual {ticket.return_ids}") - return ticket.return_ids + return id_futures def call_release(self, id: bytes) -> None: if self.closed: @@ -547,9 +549,15 @@ class Worker: task.type = ray_client_pb2.ClientTask.NAMED_ACTOR task.name = name task.namespace = namespace or "" - ids = self._call_schedule_for_task(task, num_returns=1) - assert len(ids) == 1 - return ClientActorHandle(ClientActorRef(ids[0])) + futures = self._call_schedule_for_task(task, 1) + assert len(futures) == 1 + handle = ClientActorHandle(ClientActorRef(futures[0])) + # `actor_ref.is_nil()` waits until the underlying ID is resolved. + # This is needed because `get_actor` is often used to check the + # existence of an actor. + if handle.actor_ref.is_nil(): + raise ValueError(f"ActorID for {name} is empty") + return handle def terminate_actor(self, actor: ClientActorHandle, no_restart: bool) -> None: @@ -559,11 +567,10 @@ class Worker: term_actor = ray_client_pb2.TerminateRequest.ActorTerminate() term_actor.id = actor.actor_ref.id term_actor.no_restart = no_restart + term = ray_client_pb2.TerminateRequest(actor=term_actor) + term.client_id = self._client_id try: - term = ray_client_pb2.TerminateRequest(actor=term_actor) - term.client_id = self._client_id - metadata = self._add_ids_to_metadata(self.metadata) - self._call_stub("Terminate", term, metadata=metadata) + self.data_client.Terminate(term) except grpc.RpcError as e: raise decode_exception(e) @@ -577,19 +584,18 @@ class Worker: term_object.id = obj.id term_object.force = force term_object.recursive = recursive + term = ray_client_pb2.TerminateRequest(task_object=term_object) + term.client_id = self._client_id try: - term = ray_client_pb2.TerminateRequest(task_object=term_object) - term.client_id = self._client_id - metadata = self._add_ids_to_metadata(self.metadata) - self._call_stub("Terminate", term, metadata=metadata) + self.data_client.Terminate(term) except grpc.RpcError as e: raise decode_exception(e) def get_cluster_info(self, - type: ray_client_pb2.ClusterInfoType.TypeEnum, + req_type: ray_client_pb2.ClusterInfoType.TypeEnum, timeout: Optional[float] = None): req = ray_client_pb2.ClusterInfoRequest() - req.type = type + req.type = req_type resp = self.server.ClusterInfo( req, timeout=timeout, metadata=self.metadata) if resp.WhichOneof("response_type") == "resource_table": @@ -630,9 +636,7 @@ class Worker: def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]: req = ray_client_pb2.ClientListNamedActorsRequest( all_namespaces=all_namespaces) - return json.loads( - self._call_stub("ListNamedActors", req, - metadata=self.metadata).actors_json) + return json.loads(self.data_client.ListNamedActors(req).actors_json) def is_initialized(self) -> bool: if self.server is not None: diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index 9ccbd2ac7..e207263e5 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -367,6 +367,9 @@ message DataRequest { PrepRuntimeEnvRequest prep_runtime_env = 7; ConnectionCleanupRequest connection_cleanup = 8; AcknowledgeRequest acknowledge = 9; + ClientTask task = 10; + TerminateRequest terminate = 11; + ClientListNamedActorsRequest list_named_actors = 12; } } @@ -381,7 +384,13 @@ message DataResponse { InitResponse init = 6; PrepRuntimeEnvResponse prep_runtime_env = 7; ConnectionCleanupResponse connection_cleanup = 8; + ClientTaskTicket task_ticket = 10; + TerminateResponse terminate = 11; + ClientListNamedActorsResponse list_named_actors = 12; } + // tag 9 is skipped, since there is no AcknowledgeResponse + reserved 9; + reserved "acknowledge"; } service RayletDataStreamer {