[Client] Use async rpc for remote call and actor creation (#18298)

* Use async rpc for remote calls, task and actor creations.

* fix

* check placement

* check placement group. wait for id in destructor

* fix

* fix exception in destructor

* Add test

* revert change

* Fix comment

* fix
This commit is contained in:
mwtian 2021-09-22 18:30:50 -07:00 committed by GitHub
parent 8dd3057644
commit e41109a5e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 238 additions and 135 deletions

View file

@ -91,7 +91,7 @@ cdef class ClientObjectRef(ObjectRef):
cdef object _id_future cdef object _id_future
cdef _set_id(self, id) cdef _set_id(self, id)
cdef inline _wait_for_id(self) cdef inline _wait_for_id(self, timeout=None)
cdef class ActorID(BaseID): cdef class ActorID(BaseID):
cdef CActorID data cdef CActorID data
@ -105,7 +105,7 @@ cdef class ClientActorRef(ActorID):
cdef object _id_future cdef object _id_future
cdef _set_id(self, id) cdef _set_id(self, id)
cdef inline _wait_for_id(self) cdef inline _wait_for_id(self, timeout=None)
cdef class CoreWorker: cdef class CoreWorker:
cdef: cdef:

View file

@ -166,8 +166,20 @@ cdef class ClientObjectRef(ObjectRef):
# call_release in this case, since the client should have already # call_release in this case, since the client should have already
# disconnected at this point. # disconnected at this point.
return return
if client.ray.is_connected() and not self.data.IsNil(): if client.ray.is_connected():
client.ray.call_release(self.id) try:
self._wait_for_id()
# cython would suppress this exception as well, but it tries to
# print out the exception which may crash. Log a simpler message
# instead.
except Exception:
logger.info(
"Exception in ObjectRef is ignored in destructor. "
"To receive this exception in application code, call "
"a method on the actor reference before its destructor "
"is run.")
if not self.data.IsNil():
client.ray.call_release(self.id)
cdef CObjectID native(self): cdef CObjectID native(self):
self._wait_for_id() self._wait_for_id()
@ -244,9 +256,9 @@ cdef class ClientObjectRef(ObjectRef):
self.data = CObjectID.FromBinary(<c_string>id) self.data = CObjectID.FromBinary(<c_string>id)
client.ray.call_retain(id) client.ray.call_retain(id)
cdef inline _wait_for_id(self): cdef inline _wait_for_id(self, timeout=None):
if self._id_future: if self._id_future:
with self._mutex: with self._mutex:
if self._id_future: if self._id_future:
self._set_id(self._id_future.result()) self._set_id(self._id_future.result(timeout=timeout))
self._id_future = None self._id_future = None

View file

@ -8,6 +8,7 @@ See https://github.com/ray-project/ray/issues/3721.
# _ID_TYPES list at the bottom of this file. # _ID_TYPES list at the bottom of this file.
from concurrent.futures import Future from concurrent.futures import Future
import logging
import os import os
from ray.includes.unique_ids cimport ( from ray.includes.unique_ids cimport (
@ -27,6 +28,8 @@ from ray.includes.unique_ids cimport (
import ray import ray
from ray._private.utils import decode from ray._private.utils import decode
logger = logging.getLogger(__name__)
def check_id(b, size=kUniqueIDSize): def check_id(b, size=kUniqueIDSize):
if not isinstance(b, bytes): if not isinstance(b, bytes):
@ -326,8 +329,20 @@ cdef class ClientActorRef(ActorID):
# call_release in this case, since the client should have already # call_release in this case, since the client should have already
# disconnected at this point. # disconnected at this point.
return return
if client.ray.is_connected() and not self.data.IsNil(): if client.ray.is_connected():
client.ray.call_release(self.id) try:
self._wait_for_id()
# cython would suppress this exception as well, but it tries to
# print out the exception which may crash. Log a simpler message
# instead.
except Exception:
logger.info(
"Exception from actor creation is ignored in destructor. "
"To receive this exception in application code, call "
"a method on the actor reference before its destructor "
"is run.")
if not self.data.IsNil():
client.ray.call_release(self.id)
def binary(self): def binary(self):
self._wait_for_id() self._wait_for_id()
@ -358,11 +373,11 @@ cdef class ClientActorRef(ActorID):
self.data = CActorID.FromBinary(<c_string>id) self.data = CActorID.FromBinary(<c_string>id)
client.ray.call_retain(id) client.ray.call_retain(id)
cdef _wait_for_id(self): cdef _wait_for_id(self, timeout=None):
if self._id_future: if self._id_future:
with self._mutex: with self._mutex:
if self._id_future: if self._id_future:
self._set_id(self._id_future.result()) self._set_id(self._id_future.result(timeout=timeout))
self._id_future = None self._id_future = None

View file

@ -303,6 +303,8 @@ async def test_async_obj_unhandled_errors(ray_start_regular_shared):
# Test we report unhandled exceptions. # Test we report unhandled exceptions.
ray.worker._unhandled_error_handler = interceptor ray.worker._unhandled_error_handler = interceptor
x1 = f.remote() x1 = f.remote()
# NOTE: Unhandled exception is from waiting for the value of x1's ObjectID
# in x1's destructor, and receiving an exception from f() instead.
del x1 del x1
wait_for_condition(lambda: num_exceptions == 1) wait_for_condition(lambda: num_exceptions == 1)

View file

@ -135,14 +135,23 @@ def test_get_list(ray_start_regular_shared):
assert ray.get([]) == [] assert ray.get([]) == []
assert ray.get([f.remote()]) == ["OK"] assert ray.get([f.remote()]) == ["OK"]
get_count = 0
get_stub = ray.worker.server.GetObject
# ray.get() uses unary-unary RPC. Mock the server handler to count
# the number of requests received.
def get(req, metadata=None):
nonlocal get_count
get_count += 1
return get_stub(req, metadata=metadata)
ray.worker.server.GetObject = get
refs = [f.remote() for _ in range(100)] refs = [f.remote() for _ in range(100)]
with ray.worker.data_client.lock:
req_id_before = ray.worker.data_client._req_id
assert ray.get(refs) == ["OK" for _ in range(100)] assert ray.get(refs) == ["OK" for _ in range(100)]
# Only 1 RPC should be sent. # Only 1 RPC should be sent.
with ray.worker.data_client.lock: assert get_count == 1
assert ray.worker.data_client._req_id == req_id_before + 1, \
ray.worker.data_client._req_id
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
@ -470,6 +479,22 @@ def test_serializing_exceptions(ray_start_regular_shared):
ray.get_actor("abc") ray.get_actor("abc")
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_invalid_task(ray_start_regular_shared):
with ray_start_client_server() as ray:
@ray.remote(runtime_env="invalid value")
def f():
return 1
# No exception on making the remote call.
ref = f.remote()
# Exception during scheduling will be raised on ray.get()
with pytest.raises(Exception):
ray.get(ref)
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_create_remote_before_start(ray_start_regular_shared): def test_create_remote_before_start(ray_start_regular_shared):
"""Creates remote objects (as though in a library) before """Creates remote objects (as though in a library) before

View file

@ -113,11 +113,12 @@ def test_kill_cancel_metadata(ray_start_regular):
class MetadataIsCorrectlyPassedException(Exception): class MetadataIsCorrectlyPassedException(Exception):
pass pass
def mock_terminate(term, metadata): def mock_terminate(self, term):
raise MetadataIsCorrectlyPassedException(metadata[1][0]) raise MetadataIsCorrectlyPassedException(self._metadata[1][0])
# Mock stub's Terminate method to raise an exception. # Mock stub's Terminate method to raise an exception.
ray.get_context().api.worker.server.Terminate = mock_terminate stub = ray.get_context().api.worker.data_client
stub.Terminate = mock_terminate.__get__(stub)
# Verify the expected exception is raised with ray.kill. # Verify the expected exception is raised with ray.kill.
# Check that argument of the exception matches "key" from the # Check that argument of the exception matches "key" from the

View file

@ -1,12 +1,10 @@
from ray.util.client.ray_client_helpers import ray_start_client_server from ray.util.client.ray_client_helpers import ray_start_client_server
from ray.util.client.worker import TASK_WARNING_THRESHOLD
from ray.util.debug import _logged from ray.util.debug import _logged
import numpy as np import numpy as np
import pytest import pytest
import unittest import unittest
import warnings
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -19,33 +17,6 @@ def reset_debug_logs():
class LoggerSuite(unittest.TestCase): class LoggerSuite(unittest.TestCase):
"""Test client warnings are raised when many tasks are scheduled""" """Test client warnings are raised when many tasks are scheduled"""
def testManyTasksWarning(self):
with ray_start_client_server() as ray:
@ray.remote
def f():
return 42
with self.assertWarns(UserWarning) as cm:
for _ in range(TASK_WARNING_THRESHOLD + 1):
f.remote()
assert f"More than {TASK_WARNING_THRESHOLD} remote tasks have " \
"been scheduled." in cm.warning.args[0]
def testNoWarning(self):
with ray_start_client_server() as ray:
@ray.remote
def f():
return 42
with warnings.catch_warnings(record=True) as warn_list:
for _ in range(TASK_WARNING_THRESHOLD):
f.remote()
assert not any(f"More than {TASK_WARNING_THRESHOLD} remote tasks "
"have been scheduled." in str(w.args[0])
for w in warn_list)
def testOutboundMessageSizeWarning(self): def testOutboundMessageSizeWarning(self):
with ray_start_client_server() as ray: with ray_start_client_server() as ray:
large_argument = np.random.rand(100, 100, 100) large_argument = np.random.rand(100, 100, 100)

View file

@ -50,11 +50,24 @@ def test_dataclient_disconnect_before_request():
# different remote calls. # different remote calls.
ray.worker.data_client.request_queue.put(Mock()) ray.worker.data_client.request_queue.put(Mock())
# The next remote call should error since the data channel has shut # The following two assertions are relatively brittle. Consider a more
# down, which should also disconnect the client. # robust mechanism if they fail with code changes or become flaky.
# The next remote call should error since the data channel will shut
# down because of the invalid input above. Two cases can happen:
# (1) Data channel shuts down after `f.remote()` finishes.
# error is raised to `ray.get()`. The next background operation
# will disconnect Ray client.
# (2) Data channel shuts down before `f.remote()` is called.
# `f.remote()` will raise the error and disconnect the client.
with pytest.raises(ConnectionError): with pytest.raises(ConnectionError):
ray.get(f.remote()) ray.get(f.remote())
with pytest.raises(
ConnectionError,
match="Ray client has already been disconnected"):
ray.get(f.remote())
# Client should be disconnected # Client should be disconnected
assert not ray.is_connected() assert not ray.is_connected()

View file

@ -46,28 +46,18 @@ def test_check_bundle_index(ray_start_cluster, connect_to_client):
"CPU": 2 "CPU": 2
}]) }])
error_count = 0 with pytest.raises(ValueError, match="bundle index 3 is invalid"):
try:
Actor.options( Actor.options(
placement_group=placement_group, placement_group=placement_group,
placement_group_bundle_index=3).remote() placement_group_bundle_index=3).remote()
except ValueError:
error_count = error_count + 1
assert error_count == 1
try: with pytest.raises(ValueError, match="bundle index -2 is invalid"):
Actor.options( Actor.options(
placement_group=placement_group, placement_group=placement_group,
placement_group_bundle_index=-2).remote() placement_group_bundle_index=-2).remote()
except ValueError:
error_count = error_count + 1
assert error_count == 2
try: with pytest.raises(ValueError, match="bundle index must be -1"):
Actor.options(placement_group_bundle_index=0).remote() Actor.options(placement_group_bundle_index=0).remote()
except ValueError:
error_count = error_count + 1
assert error_count == 3
@pytest.mark.parametrize("connect_to_client", [False, True]) @pytest.mark.parametrize("connect_to_client", [False, True])

View file

@ -1,6 +1,7 @@
"""This file defines the interface between the ray client worker """This file defines the interface between the ray client worker
and the overall ray module API. and the overall ray module API.
""" """
from concurrent.futures import Future
import json import json
import logging import logging
@ -85,7 +86,10 @@ class ClientAPI:
assert len(args) == 0 and len(kwargs) > 0, error_string assert len(args) == 0 and len(kwargs) > 0, error_string
return remote_decorator(options=kwargs) return remote_decorator(options=kwargs)
def call_remote(self, instance: "ClientStub", *args, **kwargs): # TODO(mwtian): consider adding _internal_ prefix to call_remote /
# call_release / call_retain.
def call_remote(self, instance: "ClientStub", *args,
**kwargs) -> List[Future]:
"""call_remote is called by stub objects to execute them remotely. """call_remote is called by stub objects to execute them remotely.
This is used by stub objects in situations where they're called This is used by stub objects in situations where they're called

View file

@ -141,9 +141,9 @@ class ServerUnpickler(pickle.Unpickler):
def persistent_load(self, pid): def persistent_load(self, pid):
assert isinstance(pid, PickleStub) assert isinstance(pid, PickleStub)
if pid.type == "Object": if pid.type == "Object":
return ClientObjectRef(id=pid.ref_id) return ClientObjectRef(pid.ref_id)
elif pid.type == "Actor": elif pid.type == "Actor":
return ClientActorHandle(ClientActorRef(id=pid.ref_id)) return ClientActorHandle(ClientActorRef(pid.ref_id))
else: else:
raise NotImplementedError("Being passed back an unknown stub") raise NotImplementedError("Being passed back an unknown stub")

View file

@ -6,6 +6,7 @@ from ray.util.client.options import validate_options
from ray._private.signature import get_signature, extract_signature from ray._private.signature import get_signature, extract_signature
from ray._private.utils import check_oversized_function from ray._private.utils import check_oversized_function
import concurrent
from dataclasses import dataclass from dataclasses import dataclass
import grpc import grpc
import os import os
@ -209,9 +210,9 @@ class ClientActorClass(ClientStub):
def remote(self, *args, **kwargs) -> "ClientActorHandle": def remote(self, *args, **kwargs) -> "ClientActorHandle":
self._init_signature.bind(*args, **kwargs) self._init_signature.bind(*args, **kwargs)
# Actually instantiate the actor # Actually instantiate the actor
ref_ids = ray.call_remote(self, *args, **kwargs) futures = ray.call_remote(self, *args, **kwargs)
assert len(ref_ids) == 1 assert len(futures) == 1
return ClientActorHandle(ClientActorRef(ref_ids[0]), actor_class=self) return ClientActorHandle(ClientActorRef(futures[0]), actor_class=self)
def options(self, **kwargs): def options(self, **kwargs):
return ActorOptionWrapper(self, kwargs) return ActorOptionWrapper(self, kwargs)
@ -397,13 +398,13 @@ class OptionWrapper:
class ActorOptionWrapper(OptionWrapper): class ActorOptionWrapper(OptionWrapper):
def remote(self, *args, **kwargs): def remote(self, *args, **kwargs):
self._remote_stub._init_signature.bind(*args, **kwargs) self._remote_stub._init_signature.bind(*args, **kwargs)
ref_ids = ray.call_remote(self, *args, **kwargs) futures = ray.call_remote(self, *args, **kwargs)
assert len(ref_ids) == 1 assert len(futures) == 1
actor_class = None actor_class = None
if isinstance(self._remote_stub, ClientActorClass): if isinstance(self._remote_stub, ClientActorClass):
actor_class = self._remote_stub actor_class = self._remote_stub
return ClientActorHandle( return ClientActorHandle(
ClientActorRef(ref_ids[0]), actor_class=actor_class) ClientActorRef(futures[0]), actor_class=actor_class)
def set_task_options(task: ray_client_pb2.ClientTask, def set_task_options(task: ray_client_pb2.ClientTask,
@ -423,13 +424,13 @@ def set_task_options(task: ray_client_pb2.ClientTask,
getattr(task, field).json_options = options_str getattr(task, field).json_options = options_str
def return_refs(ids: List[bytes] def return_refs(futures: List[concurrent.futures.Future]
) -> Union[None, ClientObjectRef, List[ClientObjectRef]]: ) -> Union[None, ClientObjectRef, List[ClientObjectRef]]:
if len(ids) == 1: if not futures:
return ClientObjectRef(ids[0])
if len(ids) == 0:
return None return None
return [ClientObjectRef(id) for id in ids] if len(futures) == 1:
return ClientObjectRef(futures[0])
return [ClientObjectRef(fut) for fut in futures]
class InProgressSentinel: class InProgressSentinel:

View file

@ -361,3 +361,21 @@ class DataClient:
context=None) -> None: context=None) -> None:
datareq = ray_client_pb2.DataRequest(release=request, ) datareq = ray_client_pb2.DataRequest(release=request, )
self._async_send(datareq) self._async_send(datareq)
def Schedule(self, request: ray_client_pb2.ClientTask,
callback: ResponseCallable):
datareq = ray_client_pb2.DataRequest(task=request)
self._async_send(datareq, callback)
def Terminate(self, request: ray_client_pb2.TerminateRequest
) -> ray_client_pb2.TerminateResponse:
req = ray_client_pb2.DataRequest(terminate=request, )
resp = self._blocking_send(req)
return resp.terminate
def ListNamedActors(self,
request: ray_client_pb2.ClientListNamedActorsRequest
) -> ray_client_pb2.ClientListNamedActorsResponse:
req = ray_client_pb2.DataRequest(list_named_actors=request, )
resp = self._blocking_send(req)
return resp.list_named_actors

View file

@ -2,6 +2,9 @@ from typing import Any
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
from ray.util.placement_group import (PlacementGroup,
check_placement_group_index)
options = { options = {
"num_returns": (int, lambda x: x >= 0, "num_returns": (int, lambda x: x >= 0,
"The keyword 'num_returns' only accepts 0 " "The keyword 'num_returns' only accepts 0 "
@ -43,6 +46,7 @@ def validate_options(
return None return None
if len(kwargs_dict) == 0: if len(kwargs_dict) == 0:
return None return None
out = {} out = {}
for k, v in kwargs_dict.items(): for k, v in kwargs_dict.items():
if k not in options.keys(): if k not in options.keys():
@ -55,4 +59,21 @@ def validate_options(
if not validator[1](v): if not validator[1](v):
raise ValueError(validator[2]) raise ValueError(validator[2])
out[k] = v out[k] = v
# Validate placement setting similar to the logic in ray/actor.py and
# ray/remote_function.py. The difference is that when
# placement_group = default and placement_group_capture_child_tasks
# specified, placement group cannot be resolved at client. So this check
# skips this case and relies on server to enforce any condition.
bundle_index = out.get("placement_group_bundle_index", None)
if bundle_index is not None:
pg = out.get("placement_group", None)
if pg is None:
pg = PlacementGroup.empty()
if pg == "default" and (out.get("placement_group_capture_child_tasks",
None) is None):
pg = PlacementGroup.empty()
if isinstance(pg, PlacementGroup):
check_placement_group_index(pg, bundle_index)
return out return out

View file

@ -189,6 +189,23 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
# Clean up acknowledged cache entries # Clean up acknowledged cache entries
response_cache.cleanup(req.acknowledge.req_id) response_cache.cleanup(req.acknowledge.req_id)
continue continue
elif req_type == "task":
with self.clients_lock:
resp_ticket = self.basic_service.Schedule(
req.task, context)
resp = ray_client_pb2.DataResponse(
task_ticket=resp_ticket)
elif req_type == "terminate":
with self.clients_lock:
response = self.basic_service.Terminate(
req.terminate, context)
resp = ray_client_pb2.DataResponse(terminate=response)
elif req_type == "list_named_actors":
with self.clients_lock:
response = self.basic_service.ListNamedActors(
req.list_named_actors)
resp = ray_client_pb2.DataResponse(
list_named_actors=response)
else: else:
raise Exception(f"Unreachable code: Request type " raise Exception(f"Unreachable code: Request type "
f"{req_type} not handled in Datapath") f"{req_type} not handled in Datapath")

View file

@ -11,8 +11,10 @@ import time
import uuid import uuid
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from concurrent.futures import Future
import tempfile import tempfile
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING from typing import (Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING,
Union)
import grpc import grpc
@ -53,10 +55,6 @@ MAX_BLOCKING_OPERATION_TIME_S: float = 2.0
# the connection began exceeds this value, a warning should be raised # the connection began exceeds this value, a warning should be raised
MESSAGE_SIZE_THRESHOLD = 10 * 2**20 # 10 MB MESSAGE_SIZE_THRESHOLD = 10 * 2**20 # 10 MB
# If the number of tasks scheduled on the client side since the connection
# began exceeds this value, a warning should be raised
TASK_WARNING_THRESHOLD = 1000
# Links to the Ray Design Pattern doc to use in the task overhead warning # Links to the Ray Design Pattern doc to use in the task overhead warning
# message # message
DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = \ DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = \
@ -138,9 +136,7 @@ class Worker:
self.closed = False self.closed = False
# Track these values to raise a warning if many tasks are being # Track this value to raise a warning if a lot of data are transferred.
# scheduled
self.total_num_tasks_scheduled = 0
self.total_outbound_message_size_bytes = 0 self.total_outbound_message_size_bytes = 0
# Used to create unique IDs for RPCs to the RayletServicer # Used to create unique IDs for RPCs to the RayletServicer
@ -365,17 +361,17 @@ class Worker:
req = ray_client_pb2.GetRequest( req = ray_client_pb2.GetRequest(
ids=[r.id for r in ref], timeout=timeout) ids=[r.id for r in ref], timeout=timeout)
try: try:
data = self.data_client.GetObject(req) resp = self._call_stub("GetObject", req, metadata=self.metadata)
except grpc.RpcError as e: except grpc.RpcError as e:
raise decode_exception(e) raise decode_exception(e)
if not data.valid: if not resp.valid:
try: try:
err = cloudpickle.loads(data.error) err = cloudpickle.loads(resp.error)
except (pickle.UnpicklingError, TypeError): except (pickle.UnpicklingError, TypeError):
logger.exception("Failed to deserialize {}".format(data.error)) logger.exception("Failed to deserialize {}".format(resp.error))
raise raise
raise err raise err
return loads_from_server(data.data) return loads_from_server(resp.data)
def put(self, vals, *, client_ref_id: bytes = None): def put(self, vals, *, client_ref_id: bytes = None):
to_put = [] to_put = []
@ -450,7 +446,7 @@ class Worker:
return (client_ready_object_ids, client_remaining_object_ids) return (client_ready_object_ids, client_remaining_object_ids)
def call_remote(self, instance, *args, **kwargs) -> List[bytes]: def call_remote(self, instance, *args, **kwargs) -> List[Future]:
task = instance._prepare_client_task() task = instance._prepare_client_task()
for arg in args: for arg in args:
pb_arg = convert_to_arg(arg, self._client_id) pb_arg = convert_to_arg(arg, self._client_id)
@ -460,38 +456,47 @@ class Worker:
return self._call_schedule_for_task(task, instance._num_returns()) return self._call_schedule_for_task(task, instance._num_returns())
def _call_schedule_for_task(self, task: ray_client_pb2.ClientTask, def _call_schedule_for_task(self, task: ray_client_pb2.ClientTask,
num_returns: int) -> List[bytes]: num_returns: int) -> List[Future]:
logger.debug("Scheduling %s" % task) logger.debug("Scheduling %s" % task)
task.client_id = self._client_id task.client_id = self._client_id
metadata = self._add_ids_to_metadata(self.metadata)
if num_returns is None: if num_returns is None:
num_returns = 1 num_returns = 1
try: id_futures = [Future() for _ in range(num_returns)]
ticket = self._call_stub("Schedule", task, metadata=metadata)
except grpc.RpcError as e: def populate_ids(
raise decode_exception(e) resp: Union[ray_client_pb2.DataResponse, Exception]) -> None:
if isinstance(resp, Exception):
if isinstance(resp, grpc.RpcError):
resp = decode_exception(resp)
for future in id_futures:
future.set_exception(resp)
return
ticket = resp.task_ticket
if not ticket.valid:
try:
ex = cloudpickle.loads(ticket.error)
except (pickle.UnpicklingError, TypeError) as e_new:
ex = e_new
for future in id_futures:
future.set_exception(ex)
return
if len(ticket.return_ids) != num_returns:
exc = ValueError(
f"Expected {num_returns} returns but received "
f"{len(ticket.return_ids)}")
for future, raw_id in zip(id_futures, ticket.return_ids):
future.set_exception(exc)
return
for future, raw_id in zip(id_futures, ticket.return_ids):
future.set_result(raw_id)
self.data_client.Schedule(task, populate_ids)
if not ticket.valid:
try:
raise cloudpickle.loads(ticket.error)
except (pickle.UnpicklingError, TypeError):
logger.exception("Failed to deserialize {}".format(
ticket.error))
raise
self.total_num_tasks_scheduled += 1
self.total_outbound_message_size_bytes += task.ByteSize() self.total_outbound_message_size_bytes += task.ByteSize()
if self.total_num_tasks_scheduled > TASK_WARNING_THRESHOLD and \
log_once("client_communication_overhead_warning"):
warnings.warn(
f"More than {TASK_WARNING_THRESHOLD} remote tasks have been "
"scheduled. This can be slow on Ray Client due to "
"communication overhead over the network. If you're running "
"many fine-grained tasks, consider running them in a single "
"remote function. See the section on \"Too fine-grained "
"tasks\" in the Ray Design Patterns document for more "
f"details: {DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK}",
UserWarning)
if self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD \ if self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD \
and log_once("client_communication_overhead_warning"): and log_once("client_communication_overhead_warning"):
warnings.warn( warnings.warn(
@ -508,10 +513,7 @@ class Worker:
"unserializable object\" section of the Ray Design Patterns " "unserializable object\" section of the Ray Design Patterns "
"document, available here: " "document, available here: "
f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}", UserWarning) f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}", UserWarning)
if num_returns != len(ticket.return_ids): return id_futures
raise TypeError("Unexpected number of returned values. Expected "
f"{num_returns} actual {ticket.return_ids}")
return ticket.return_ids
def call_release(self, id: bytes) -> None: def call_release(self, id: bytes) -> None:
if self.closed: if self.closed:
@ -547,9 +549,15 @@ class Worker:
task.type = ray_client_pb2.ClientTask.NAMED_ACTOR task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
task.name = name task.name = name
task.namespace = namespace or "" task.namespace = namespace or ""
ids = self._call_schedule_for_task(task, num_returns=1) futures = self._call_schedule_for_task(task, 1)
assert len(ids) == 1 assert len(futures) == 1
return ClientActorHandle(ClientActorRef(ids[0])) handle = ClientActorHandle(ClientActorRef(futures[0]))
# `actor_ref.is_nil()` waits until the underlying ID is resolved.
# This is needed because `get_actor` is often used to check the
# existence of an actor.
if handle.actor_ref.is_nil():
raise ValueError(f"ActorID for {name} is empty")
return handle
def terminate_actor(self, actor: ClientActorHandle, def terminate_actor(self, actor: ClientActorHandle,
no_restart: bool) -> None: no_restart: bool) -> None:
@ -559,11 +567,10 @@ class Worker:
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate() term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
term_actor.id = actor.actor_ref.id term_actor.id = actor.actor_ref.id
term_actor.no_restart = no_restart term_actor.no_restart = no_restart
term = ray_client_pb2.TerminateRequest(actor=term_actor)
term.client_id = self._client_id
try: try:
term = ray_client_pb2.TerminateRequest(actor=term_actor) self.data_client.Terminate(term)
term.client_id = self._client_id
metadata = self._add_ids_to_metadata(self.metadata)
self._call_stub("Terminate", term, metadata=metadata)
except grpc.RpcError as e: except grpc.RpcError as e:
raise decode_exception(e) raise decode_exception(e)
@ -577,19 +584,18 @@ class Worker:
term_object.id = obj.id term_object.id = obj.id
term_object.force = force term_object.force = force
term_object.recursive = recursive term_object.recursive = recursive
term = ray_client_pb2.TerminateRequest(task_object=term_object)
term.client_id = self._client_id
try: try:
term = ray_client_pb2.TerminateRequest(task_object=term_object) self.data_client.Terminate(term)
term.client_id = self._client_id
metadata = self._add_ids_to_metadata(self.metadata)
self._call_stub("Terminate", term, metadata=metadata)
except grpc.RpcError as e: except grpc.RpcError as e:
raise decode_exception(e) raise decode_exception(e)
def get_cluster_info(self, def get_cluster_info(self,
type: ray_client_pb2.ClusterInfoType.TypeEnum, req_type: ray_client_pb2.ClusterInfoType.TypeEnum,
timeout: Optional[float] = None): timeout: Optional[float] = None):
req = ray_client_pb2.ClusterInfoRequest() req = ray_client_pb2.ClusterInfoRequest()
req.type = type req.type = req_type
resp = self.server.ClusterInfo( resp = self.server.ClusterInfo(
req, timeout=timeout, metadata=self.metadata) req, timeout=timeout, metadata=self.metadata)
if resp.WhichOneof("response_type") == "resource_table": if resp.WhichOneof("response_type") == "resource_table":
@ -630,9 +636,7 @@ class Worker:
def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]: def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]:
req = ray_client_pb2.ClientListNamedActorsRequest( req = ray_client_pb2.ClientListNamedActorsRequest(
all_namespaces=all_namespaces) all_namespaces=all_namespaces)
return json.loads( return json.loads(self.data_client.ListNamedActors(req).actors_json)
self._call_stub("ListNamedActors", req,
metadata=self.metadata).actors_json)
def is_initialized(self) -> bool: def is_initialized(self) -> bool:
if self.server is not None: if self.server is not None:

View file

@ -367,6 +367,9 @@ message DataRequest {
PrepRuntimeEnvRequest prep_runtime_env = 7; PrepRuntimeEnvRequest prep_runtime_env = 7;
ConnectionCleanupRequest connection_cleanup = 8; ConnectionCleanupRequest connection_cleanup = 8;
AcknowledgeRequest acknowledge = 9; AcknowledgeRequest acknowledge = 9;
ClientTask task = 10;
TerminateRequest terminate = 11;
ClientListNamedActorsRequest list_named_actors = 12;
} }
} }
@ -381,7 +384,13 @@ message DataResponse {
InitResponse init = 6; InitResponse init = 6;
PrepRuntimeEnvResponse prep_runtime_env = 7; PrepRuntimeEnvResponse prep_runtime_env = 7;
ConnectionCleanupResponse connection_cleanup = 8; ConnectionCleanupResponse connection_cleanup = 8;
ClientTaskTicket task_ticket = 10;
TerminateResponse terminate = 11;
ClientListNamedActorsResponse list_named_actors = 12;
} }
// tag 9 is skipped, since there is no AcknowledgeResponse
reserved 9;
reserved "acknowledge";
} }
service RayletDataStreamer { service RayletDataStreamer {