mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Client] Use async rpc for remote call and actor creation (#18298)
* Use async rpc for remote calls, task and actor creations. * fix * check placement * check placement group. wait for id in destructor * fix * fix exception in destructor * Add test * revert change * Fix comment * fix
This commit is contained in:
parent
8dd3057644
commit
e41109a5e7
17 changed files with 238 additions and 135 deletions
|
@ -91,7 +91,7 @@ cdef class ClientObjectRef(ObjectRef):
|
|||
cdef object _id_future
|
||||
|
||||
cdef _set_id(self, id)
|
||||
cdef inline _wait_for_id(self)
|
||||
cdef inline _wait_for_id(self, timeout=None)
|
||||
|
||||
cdef class ActorID(BaseID):
|
||||
cdef CActorID data
|
||||
|
@ -105,7 +105,7 @@ cdef class ClientActorRef(ActorID):
|
|||
cdef object _id_future
|
||||
|
||||
cdef _set_id(self, id)
|
||||
cdef inline _wait_for_id(self)
|
||||
cdef inline _wait_for_id(self, timeout=None)
|
||||
|
||||
cdef class CoreWorker:
|
||||
cdef:
|
||||
|
|
|
@ -166,8 +166,20 @@ cdef class ClientObjectRef(ObjectRef):
|
|||
# call_release in this case, since the client should have already
|
||||
# disconnected at this point.
|
||||
return
|
||||
if client.ray.is_connected() and not self.data.IsNil():
|
||||
client.ray.call_release(self.id)
|
||||
if client.ray.is_connected():
|
||||
try:
|
||||
self._wait_for_id()
|
||||
# cython would suppress this exception as well, but it tries to
|
||||
# print out the exception which may crash. Log a simpler message
|
||||
# instead.
|
||||
except Exception:
|
||||
logger.info(
|
||||
"Exception in ObjectRef is ignored in destructor. "
|
||||
"To receive this exception in application code, call "
|
||||
"a method on the actor reference before its destructor "
|
||||
"is run.")
|
||||
if not self.data.IsNil():
|
||||
client.ray.call_release(self.id)
|
||||
|
||||
cdef CObjectID native(self):
|
||||
self._wait_for_id()
|
||||
|
@ -244,9 +256,9 @@ cdef class ClientObjectRef(ObjectRef):
|
|||
self.data = CObjectID.FromBinary(<c_string>id)
|
||||
client.ray.call_retain(id)
|
||||
|
||||
cdef inline _wait_for_id(self):
|
||||
cdef inline _wait_for_id(self, timeout=None):
|
||||
if self._id_future:
|
||||
with self._mutex:
|
||||
if self._id_future:
|
||||
self._set_id(self._id_future.result())
|
||||
self._set_id(self._id_future.result(timeout=timeout))
|
||||
self._id_future = None
|
||||
|
|
|
@ -8,6 +8,7 @@ See https://github.com/ray-project/ray/issues/3721.
|
|||
# _ID_TYPES list at the bottom of this file.
|
||||
|
||||
from concurrent.futures import Future
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ray.includes.unique_ids cimport (
|
||||
|
@ -27,6 +28,8 @@ from ray.includes.unique_ids cimport (
|
|||
import ray
|
||||
from ray._private.utils import decode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_id(b, size=kUniqueIDSize):
|
||||
if not isinstance(b, bytes):
|
||||
|
@ -326,8 +329,20 @@ cdef class ClientActorRef(ActorID):
|
|||
# call_release in this case, since the client should have already
|
||||
# disconnected at this point.
|
||||
return
|
||||
if client.ray.is_connected() and not self.data.IsNil():
|
||||
client.ray.call_release(self.id)
|
||||
if client.ray.is_connected():
|
||||
try:
|
||||
self._wait_for_id()
|
||||
# cython would suppress this exception as well, but it tries to
|
||||
# print out the exception which may crash. Log a simpler message
|
||||
# instead.
|
||||
except Exception:
|
||||
logger.info(
|
||||
"Exception from actor creation is ignored in destructor. "
|
||||
"To receive this exception in application code, call "
|
||||
"a method on the actor reference before its destructor "
|
||||
"is run.")
|
||||
if not self.data.IsNil():
|
||||
client.ray.call_release(self.id)
|
||||
|
||||
def binary(self):
|
||||
self._wait_for_id()
|
||||
|
@ -358,11 +373,11 @@ cdef class ClientActorRef(ActorID):
|
|||
self.data = CActorID.FromBinary(<c_string>id)
|
||||
client.ray.call_retain(id)
|
||||
|
||||
cdef _wait_for_id(self):
|
||||
cdef _wait_for_id(self, timeout=None):
|
||||
if self._id_future:
|
||||
with self._mutex:
|
||||
if self._id_future:
|
||||
self._set_id(self._id_future.result())
|
||||
self._set_id(self._id_future.result(timeout=timeout))
|
||||
self._id_future = None
|
||||
|
||||
|
||||
|
|
|
@ -303,6 +303,8 @@ async def test_async_obj_unhandled_errors(ray_start_regular_shared):
|
|||
# Test we report unhandled exceptions.
|
||||
ray.worker._unhandled_error_handler = interceptor
|
||||
x1 = f.remote()
|
||||
# NOTE: Unhandled exception is from waiting for the value of x1's ObjectID
|
||||
# in x1's destructor, and receiving an exception from f() instead.
|
||||
del x1
|
||||
wait_for_condition(lambda: num_exceptions == 1)
|
||||
|
||||
|
|
|
@ -135,14 +135,23 @@ def test_get_list(ray_start_regular_shared):
|
|||
assert ray.get([]) == []
|
||||
assert ray.get([f.remote()]) == ["OK"]
|
||||
|
||||
get_count = 0
|
||||
get_stub = ray.worker.server.GetObject
|
||||
|
||||
# ray.get() uses unary-unary RPC. Mock the server handler to count
|
||||
# the number of requests received.
|
||||
def get(req, metadata=None):
|
||||
nonlocal get_count
|
||||
get_count += 1
|
||||
return get_stub(req, metadata=metadata)
|
||||
|
||||
ray.worker.server.GetObject = get
|
||||
|
||||
refs = [f.remote() for _ in range(100)]
|
||||
with ray.worker.data_client.lock:
|
||||
req_id_before = ray.worker.data_client._req_id
|
||||
assert ray.get(refs) == ["OK" for _ in range(100)]
|
||||
|
||||
# Only 1 RPC should be sent.
|
||||
with ray.worker.data_client.lock:
|
||||
assert ray.worker.data_client._req_id == req_id_before + 1, \
|
||||
ray.worker.data_client._req_id
|
||||
assert get_count == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||
|
@ -470,6 +479,22 @@ def test_serializing_exceptions(ray_start_regular_shared):
|
|||
ray.get_actor("abc")
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||
def test_invalid_task(ray_start_regular_shared):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote(runtime_env="invalid value")
|
||||
def f():
|
||||
return 1
|
||||
|
||||
# No exception on making the remote call.
|
||||
ref = f.remote()
|
||||
|
||||
# Exception during scheduling will be raised on ray.get()
|
||||
with pytest.raises(Exception):
|
||||
ray.get(ref)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||
def test_create_remote_before_start(ray_start_regular_shared):
|
||||
"""Creates remote objects (as though in a library) before
|
||||
|
|
|
@ -113,11 +113,12 @@ def test_kill_cancel_metadata(ray_start_regular):
|
|||
class MetadataIsCorrectlyPassedException(Exception):
|
||||
pass
|
||||
|
||||
def mock_terminate(term, metadata):
|
||||
raise MetadataIsCorrectlyPassedException(metadata[1][0])
|
||||
def mock_terminate(self, term):
|
||||
raise MetadataIsCorrectlyPassedException(self._metadata[1][0])
|
||||
|
||||
# Mock stub's Terminate method to raise an exception.
|
||||
ray.get_context().api.worker.server.Terminate = mock_terminate
|
||||
stub = ray.get_context().api.worker.data_client
|
||||
stub.Terminate = mock_terminate.__get__(stub)
|
||||
|
||||
# Verify the expected exception is raised with ray.kill.
|
||||
# Check that argument of the exception matches "key" from the
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
from ray.util.client.worker import TASK_WARNING_THRESHOLD
|
||||
from ray.util.debug import _logged
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
@ -19,33 +17,6 @@ def reset_debug_logs():
|
|||
class LoggerSuite(unittest.TestCase):
|
||||
"""Test client warnings are raised when many tasks are scheduled"""
|
||||
|
||||
def testManyTasksWarning(self):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return 42
|
||||
|
||||
with self.assertWarns(UserWarning) as cm:
|
||||
for _ in range(TASK_WARNING_THRESHOLD + 1):
|
||||
f.remote()
|
||||
assert f"More than {TASK_WARNING_THRESHOLD} remote tasks have " \
|
||||
"been scheduled." in cm.warning.args[0]
|
||||
|
||||
def testNoWarning(self):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return 42
|
||||
|
||||
with warnings.catch_warnings(record=True) as warn_list:
|
||||
for _ in range(TASK_WARNING_THRESHOLD):
|
||||
f.remote()
|
||||
assert not any(f"More than {TASK_WARNING_THRESHOLD} remote tasks "
|
||||
"have been scheduled." in str(w.args[0])
|
||||
for w in warn_list)
|
||||
|
||||
def testOutboundMessageSizeWarning(self):
|
||||
with ray_start_client_server() as ray:
|
||||
large_argument = np.random.rand(100, 100, 100)
|
||||
|
|
|
@ -50,11 +50,24 @@ def test_dataclient_disconnect_before_request():
|
|||
# different remote calls.
|
||||
ray.worker.data_client.request_queue.put(Mock())
|
||||
|
||||
# The next remote call should error since the data channel has shut
|
||||
# down, which should also disconnect the client.
|
||||
# The following two assertions are relatively brittle. Consider a more
|
||||
# robust mechanism if they fail with code changes or become flaky.
|
||||
|
||||
# The next remote call should error since the data channel will shut
|
||||
# down because of the invalid input above. Two cases can happen:
|
||||
# (1) Data channel shuts down after `f.remote()` finishes.
|
||||
# error is raised to `ray.get()`. The next background operation
|
||||
# will disconnect Ray client.
|
||||
# (2) Data channel shuts down before `f.remote()` is called.
|
||||
# `f.remote()` will raise the error and disconnect the client.
|
||||
with pytest.raises(ConnectionError):
|
||||
ray.get(f.remote())
|
||||
|
||||
with pytest.raises(
|
||||
ConnectionError,
|
||||
match="Ray client has already been disconnected"):
|
||||
ray.get(f.remote())
|
||||
|
||||
# Client should be disconnected
|
||||
assert not ray.is_connected()
|
||||
|
||||
|
|
|
@ -46,28 +46,18 @@ def test_check_bundle_index(ray_start_cluster, connect_to_client):
|
|||
"CPU": 2
|
||||
}])
|
||||
|
||||
error_count = 0
|
||||
try:
|
||||
with pytest.raises(ValueError, match="bundle index 3 is invalid"):
|
||||
Actor.options(
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=3).remote()
|
||||
except ValueError:
|
||||
error_count = error_count + 1
|
||||
assert error_count == 1
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError, match="bundle index -2 is invalid"):
|
||||
Actor.options(
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=-2).remote()
|
||||
except ValueError:
|
||||
error_count = error_count + 1
|
||||
assert error_count == 2
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError, match="bundle index must be -1"):
|
||||
Actor.options(placement_group_bundle_index=0).remote()
|
||||
except ValueError:
|
||||
error_count = error_count + 1
|
||||
assert error_count == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("connect_to_client", [False, True])
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""This file defines the interface between the ray client worker
|
||||
and the overall ray module API.
|
||||
"""
|
||||
from concurrent.futures import Future
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
@ -85,7 +86,10 @@ class ClientAPI:
|
|||
assert len(args) == 0 and len(kwargs) > 0, error_string
|
||||
return remote_decorator(options=kwargs)
|
||||
|
||||
def call_remote(self, instance: "ClientStub", *args, **kwargs):
|
||||
# TODO(mwtian): consider adding _internal_ prefix to call_remote /
|
||||
# call_release / call_retain.
|
||||
def call_remote(self, instance: "ClientStub", *args,
|
||||
**kwargs) -> List[Future]:
|
||||
"""call_remote is called by stub objects to execute them remotely.
|
||||
|
||||
This is used by stub objects in situations where they're called
|
||||
|
|
|
@ -141,9 +141,9 @@ class ServerUnpickler(pickle.Unpickler):
|
|||
def persistent_load(self, pid):
|
||||
assert isinstance(pid, PickleStub)
|
||||
if pid.type == "Object":
|
||||
return ClientObjectRef(id=pid.ref_id)
|
||||
return ClientObjectRef(pid.ref_id)
|
||||
elif pid.type == "Actor":
|
||||
return ClientActorHandle(ClientActorRef(id=pid.ref_id))
|
||||
return ClientActorHandle(ClientActorRef(pid.ref_id))
|
||||
else:
|
||||
raise NotImplementedError("Being passed back an unknown stub")
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from ray.util.client.options import validate_options
|
|||
from ray._private.signature import get_signature, extract_signature
|
||||
from ray._private.utils import check_oversized_function
|
||||
|
||||
import concurrent
|
||||
from dataclasses import dataclass
|
||||
import grpc
|
||||
import os
|
||||
|
@ -209,9 +210,9 @@ class ClientActorClass(ClientStub):
|
|||
def remote(self, *args, **kwargs) -> "ClientActorHandle":
|
||||
self._init_signature.bind(*args, **kwargs)
|
||||
# Actually instantiate the actor
|
||||
ref_ids = ray.call_remote(self, *args, **kwargs)
|
||||
assert len(ref_ids) == 1
|
||||
return ClientActorHandle(ClientActorRef(ref_ids[0]), actor_class=self)
|
||||
futures = ray.call_remote(self, *args, **kwargs)
|
||||
assert len(futures) == 1
|
||||
return ClientActorHandle(ClientActorRef(futures[0]), actor_class=self)
|
||||
|
||||
def options(self, **kwargs):
|
||||
return ActorOptionWrapper(self, kwargs)
|
||||
|
@ -397,13 +398,13 @@ class OptionWrapper:
|
|||
class ActorOptionWrapper(OptionWrapper):
|
||||
def remote(self, *args, **kwargs):
|
||||
self._remote_stub._init_signature.bind(*args, **kwargs)
|
||||
ref_ids = ray.call_remote(self, *args, **kwargs)
|
||||
assert len(ref_ids) == 1
|
||||
futures = ray.call_remote(self, *args, **kwargs)
|
||||
assert len(futures) == 1
|
||||
actor_class = None
|
||||
if isinstance(self._remote_stub, ClientActorClass):
|
||||
actor_class = self._remote_stub
|
||||
return ClientActorHandle(
|
||||
ClientActorRef(ref_ids[0]), actor_class=actor_class)
|
||||
ClientActorRef(futures[0]), actor_class=actor_class)
|
||||
|
||||
|
||||
def set_task_options(task: ray_client_pb2.ClientTask,
|
||||
|
@ -423,13 +424,13 @@ def set_task_options(task: ray_client_pb2.ClientTask,
|
|||
getattr(task, field).json_options = options_str
|
||||
|
||||
|
||||
def return_refs(ids: List[bytes]
|
||||
def return_refs(futures: List[concurrent.futures.Future]
|
||||
) -> Union[None, ClientObjectRef, List[ClientObjectRef]]:
|
||||
if len(ids) == 1:
|
||||
return ClientObjectRef(ids[0])
|
||||
if len(ids) == 0:
|
||||
if not futures:
|
||||
return None
|
||||
return [ClientObjectRef(id) for id in ids]
|
||||
if len(futures) == 1:
|
||||
return ClientObjectRef(futures[0])
|
||||
return [ClientObjectRef(fut) for fut in futures]
|
||||
|
||||
|
||||
class InProgressSentinel:
|
||||
|
|
|
@ -361,3 +361,21 @@ class DataClient:
|
|||
context=None) -> None:
|
||||
datareq = ray_client_pb2.DataRequest(release=request, )
|
||||
self._async_send(datareq)
|
||||
|
||||
def Schedule(self, request: ray_client_pb2.ClientTask,
|
||||
callback: ResponseCallable):
|
||||
datareq = ray_client_pb2.DataRequest(task=request)
|
||||
self._async_send(datareq, callback)
|
||||
|
||||
def Terminate(self, request: ray_client_pb2.TerminateRequest
|
||||
) -> ray_client_pb2.TerminateResponse:
|
||||
req = ray_client_pb2.DataRequest(terminate=request, )
|
||||
resp = self._blocking_send(req)
|
||||
return resp.terminate
|
||||
|
||||
def ListNamedActors(self,
|
||||
request: ray_client_pb2.ClientListNamedActorsRequest
|
||||
) -> ray_client_pb2.ClientListNamedActorsResponse:
|
||||
req = ray_client_pb2.DataRequest(list_named_actors=request, )
|
||||
resp = self._blocking_send(req)
|
||||
return resp.list_named_actors
|
||||
|
|
|
@ -2,6 +2,9 @@ from typing import Any
|
|||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from ray.util.placement_group import (PlacementGroup,
|
||||
check_placement_group_index)
|
||||
|
||||
options = {
|
||||
"num_returns": (int, lambda x: x >= 0,
|
||||
"The keyword 'num_returns' only accepts 0 "
|
||||
|
@ -43,6 +46,7 @@ def validate_options(
|
|||
return None
|
||||
if len(kwargs_dict) == 0:
|
||||
return None
|
||||
|
||||
out = {}
|
||||
for k, v in kwargs_dict.items():
|
||||
if k not in options.keys():
|
||||
|
@ -55,4 +59,21 @@ def validate_options(
|
|||
if not validator[1](v):
|
||||
raise ValueError(validator[2])
|
||||
out[k] = v
|
||||
|
||||
# Validate placement setting similar to the logic in ray/actor.py and
|
||||
# ray/remote_function.py. The difference is that when
|
||||
# placement_group = default and placement_group_capture_child_tasks
|
||||
# specified, placement group cannot be resolved at client. So this check
|
||||
# skips this case and relies on server to enforce any condition.
|
||||
bundle_index = out.get("placement_group_bundle_index", None)
|
||||
if bundle_index is not None:
|
||||
pg = out.get("placement_group", None)
|
||||
if pg is None:
|
||||
pg = PlacementGroup.empty()
|
||||
if pg == "default" and (out.get("placement_group_capture_child_tasks",
|
||||
None) is None):
|
||||
pg = PlacementGroup.empty()
|
||||
if isinstance(pg, PlacementGroup):
|
||||
check_placement_group_index(pg, bundle_index)
|
||||
|
||||
return out
|
||||
|
|
|
@ -189,6 +189,23 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
# Clean up acknowledged cache entries
|
||||
response_cache.cleanup(req.acknowledge.req_id)
|
||||
continue
|
||||
elif req_type == "task":
|
||||
with self.clients_lock:
|
||||
resp_ticket = self.basic_service.Schedule(
|
||||
req.task, context)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
task_ticket=resp_ticket)
|
||||
elif req_type == "terminate":
|
||||
with self.clients_lock:
|
||||
response = self.basic_service.Terminate(
|
||||
req.terminate, context)
|
||||
resp = ray_client_pb2.DataResponse(terminate=response)
|
||||
elif req_type == "list_named_actors":
|
||||
with self.clients_lock:
|
||||
response = self.basic_service.ListNamedActors(
|
||||
req.list_named_actors)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
list_named_actors=response)
|
||||
else:
|
||||
raise Exception(f"Unreachable code: Request type "
|
||||
f"{req_type} not handled in Datapath")
|
||||
|
|
|
@ -11,8 +11,10 @@ import time
|
|||
import uuid
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
import tempfile
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import (Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING,
|
||||
Union)
|
||||
|
||||
import grpc
|
||||
|
||||
|
@ -53,10 +55,6 @@ MAX_BLOCKING_OPERATION_TIME_S: float = 2.0
|
|||
# the connection began exceeds this value, a warning should be raised
|
||||
MESSAGE_SIZE_THRESHOLD = 10 * 2**20 # 10 MB
|
||||
|
||||
# If the number of tasks scheduled on the client side since the connection
|
||||
# began exceeds this value, a warning should be raised
|
||||
TASK_WARNING_THRESHOLD = 1000
|
||||
|
||||
# Links to the Ray Design Pattern doc to use in the task overhead warning
|
||||
# message
|
||||
DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = \
|
||||
|
@ -138,9 +136,7 @@ class Worker:
|
|||
|
||||
self.closed = False
|
||||
|
||||
# Track these values to raise a warning if many tasks are being
|
||||
# scheduled
|
||||
self.total_num_tasks_scheduled = 0
|
||||
# Track this value to raise a warning if a lot of data are transferred.
|
||||
self.total_outbound_message_size_bytes = 0
|
||||
|
||||
# Used to create unique IDs for RPCs to the RayletServicer
|
||||
|
@ -365,17 +361,17 @@ class Worker:
|
|||
req = ray_client_pb2.GetRequest(
|
||||
ids=[r.id for r in ref], timeout=timeout)
|
||||
try:
|
||||
data = self.data_client.GetObject(req)
|
||||
resp = self._call_stub("GetObject", req, metadata=self.metadata)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e)
|
||||
if not data.valid:
|
||||
if not resp.valid:
|
||||
try:
|
||||
err = cloudpickle.loads(data.error)
|
||||
err = cloudpickle.loads(resp.error)
|
||||
except (pickle.UnpicklingError, TypeError):
|
||||
logger.exception("Failed to deserialize {}".format(data.error))
|
||||
logger.exception("Failed to deserialize {}".format(resp.error))
|
||||
raise
|
||||
raise err
|
||||
return loads_from_server(data.data)
|
||||
return loads_from_server(resp.data)
|
||||
|
||||
def put(self, vals, *, client_ref_id: bytes = None):
|
||||
to_put = []
|
||||
|
@ -450,7 +446,7 @@ class Worker:
|
|||
|
||||
return (client_ready_object_ids, client_remaining_object_ids)
|
||||
|
||||
def call_remote(self, instance, *args, **kwargs) -> List[bytes]:
|
||||
def call_remote(self, instance, *args, **kwargs) -> List[Future]:
|
||||
task = instance._prepare_client_task()
|
||||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg, self._client_id)
|
||||
|
@ -460,38 +456,47 @@ class Worker:
|
|||
return self._call_schedule_for_task(task, instance._num_returns())
|
||||
|
||||
def _call_schedule_for_task(self, task: ray_client_pb2.ClientTask,
|
||||
num_returns: int) -> List[bytes]:
|
||||
num_returns: int) -> List[Future]:
|
||||
logger.debug("Scheduling %s" % task)
|
||||
task.client_id = self._client_id
|
||||
metadata = self._add_ids_to_metadata(self.metadata)
|
||||
if num_returns is None:
|
||||
num_returns = 1
|
||||
|
||||
try:
|
||||
ticket = self._call_stub("Schedule", task, metadata=metadata)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e)
|
||||
id_futures = [Future() for _ in range(num_returns)]
|
||||
|
||||
def populate_ids(
|
||||
resp: Union[ray_client_pb2.DataResponse, Exception]) -> None:
|
||||
if isinstance(resp, Exception):
|
||||
if isinstance(resp, grpc.RpcError):
|
||||
resp = decode_exception(resp)
|
||||
for future in id_futures:
|
||||
future.set_exception(resp)
|
||||
return
|
||||
|
||||
ticket = resp.task_ticket
|
||||
if not ticket.valid:
|
||||
try:
|
||||
ex = cloudpickle.loads(ticket.error)
|
||||
except (pickle.UnpicklingError, TypeError) as e_new:
|
||||
ex = e_new
|
||||
for future in id_futures:
|
||||
future.set_exception(ex)
|
||||
return
|
||||
|
||||
if len(ticket.return_ids) != num_returns:
|
||||
exc = ValueError(
|
||||
f"Expected {num_returns} returns but received "
|
||||
f"{len(ticket.return_ids)}")
|
||||
for future, raw_id in zip(id_futures, ticket.return_ids):
|
||||
future.set_exception(exc)
|
||||
return
|
||||
|
||||
for future, raw_id in zip(id_futures, ticket.return_ids):
|
||||
future.set_result(raw_id)
|
||||
|
||||
self.data_client.Schedule(task, populate_ids)
|
||||
|
||||
if not ticket.valid:
|
||||
try:
|
||||
raise cloudpickle.loads(ticket.error)
|
||||
except (pickle.UnpicklingError, TypeError):
|
||||
logger.exception("Failed to deserialize {}".format(
|
||||
ticket.error))
|
||||
raise
|
||||
self.total_num_tasks_scheduled += 1
|
||||
self.total_outbound_message_size_bytes += task.ByteSize()
|
||||
if self.total_num_tasks_scheduled > TASK_WARNING_THRESHOLD and \
|
||||
log_once("client_communication_overhead_warning"):
|
||||
warnings.warn(
|
||||
f"More than {TASK_WARNING_THRESHOLD} remote tasks have been "
|
||||
"scheduled. This can be slow on Ray Client due to "
|
||||
"communication overhead over the network. If you're running "
|
||||
"many fine-grained tasks, consider running them in a single "
|
||||
"remote function. See the section on \"Too fine-grained "
|
||||
"tasks\" in the Ray Design Patterns document for more "
|
||||
f"details: {DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK}",
|
||||
UserWarning)
|
||||
if self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD \
|
||||
and log_once("client_communication_overhead_warning"):
|
||||
warnings.warn(
|
||||
|
@ -508,10 +513,7 @@ class Worker:
|
|||
"unserializable object\" section of the Ray Design Patterns "
|
||||
"document, available here: "
|
||||
f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}", UserWarning)
|
||||
if num_returns != len(ticket.return_ids):
|
||||
raise TypeError("Unexpected number of returned values. Expected "
|
||||
f"{num_returns} actual {ticket.return_ids}")
|
||||
return ticket.return_ids
|
||||
return id_futures
|
||||
|
||||
def call_release(self, id: bytes) -> None:
|
||||
if self.closed:
|
||||
|
@ -547,9 +549,15 @@ class Worker:
|
|||
task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
|
||||
task.name = name
|
||||
task.namespace = namespace or ""
|
||||
ids = self._call_schedule_for_task(task, num_returns=1)
|
||||
assert len(ids) == 1
|
||||
return ClientActorHandle(ClientActorRef(ids[0]))
|
||||
futures = self._call_schedule_for_task(task, 1)
|
||||
assert len(futures) == 1
|
||||
handle = ClientActorHandle(ClientActorRef(futures[0]))
|
||||
# `actor_ref.is_nil()` waits until the underlying ID is resolved.
|
||||
# This is needed because `get_actor` is often used to check the
|
||||
# existence of an actor.
|
||||
if handle.actor_ref.is_nil():
|
||||
raise ValueError(f"ActorID for {name} is empty")
|
||||
return handle
|
||||
|
||||
def terminate_actor(self, actor: ClientActorHandle,
|
||||
no_restart: bool) -> None:
|
||||
|
@ -559,11 +567,10 @@ class Worker:
|
|||
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
|
||||
term_actor.id = actor.actor_ref.id
|
||||
term_actor.no_restart = no_restart
|
||||
term = ray_client_pb2.TerminateRequest(actor=term_actor)
|
||||
term.client_id = self._client_id
|
||||
try:
|
||||
term = ray_client_pb2.TerminateRequest(actor=term_actor)
|
||||
term.client_id = self._client_id
|
||||
metadata = self._add_ids_to_metadata(self.metadata)
|
||||
self._call_stub("Terminate", term, metadata=metadata)
|
||||
self.data_client.Terminate(term)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e)
|
||||
|
||||
|
@ -577,19 +584,18 @@ class Worker:
|
|||
term_object.id = obj.id
|
||||
term_object.force = force
|
||||
term_object.recursive = recursive
|
||||
term = ray_client_pb2.TerminateRequest(task_object=term_object)
|
||||
term.client_id = self._client_id
|
||||
try:
|
||||
term = ray_client_pb2.TerminateRequest(task_object=term_object)
|
||||
term.client_id = self._client_id
|
||||
metadata = self._add_ids_to_metadata(self.metadata)
|
||||
self._call_stub("Terminate", term, metadata=metadata)
|
||||
self.data_client.Terminate(term)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e)
|
||||
|
||||
def get_cluster_info(self,
|
||||
type: ray_client_pb2.ClusterInfoType.TypeEnum,
|
||||
req_type: ray_client_pb2.ClusterInfoType.TypeEnum,
|
||||
timeout: Optional[float] = None):
|
||||
req = ray_client_pb2.ClusterInfoRequest()
|
||||
req.type = type
|
||||
req.type = req_type
|
||||
resp = self.server.ClusterInfo(
|
||||
req, timeout=timeout, metadata=self.metadata)
|
||||
if resp.WhichOneof("response_type") == "resource_table":
|
||||
|
@ -630,9 +636,7 @@ class Worker:
|
|||
def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]:
|
||||
req = ray_client_pb2.ClientListNamedActorsRequest(
|
||||
all_namespaces=all_namespaces)
|
||||
return json.loads(
|
||||
self._call_stub("ListNamedActors", req,
|
||||
metadata=self.metadata).actors_json)
|
||||
return json.loads(self.data_client.ListNamedActors(req).actors_json)
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
if self.server is not None:
|
||||
|
|
|
@ -367,6 +367,9 @@ message DataRequest {
|
|||
PrepRuntimeEnvRequest prep_runtime_env = 7;
|
||||
ConnectionCleanupRequest connection_cleanup = 8;
|
||||
AcknowledgeRequest acknowledge = 9;
|
||||
ClientTask task = 10;
|
||||
TerminateRequest terminate = 11;
|
||||
ClientListNamedActorsRequest list_named_actors = 12;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -381,7 +384,13 @@ message DataResponse {
|
|||
InitResponse init = 6;
|
||||
PrepRuntimeEnvResponse prep_runtime_env = 7;
|
||||
ConnectionCleanupResponse connection_cleanup = 8;
|
||||
ClientTaskTicket task_ticket = 10;
|
||||
TerminateResponse terminate = 11;
|
||||
ClientListNamedActorsResponse list_named_actors = 12;
|
||||
}
|
||||
// tag 9 is skipped, since there is no AcknowledgeResponse
|
||||
reserved 9;
|
||||
reserved "acknowledge";
|
||||
}
|
||||
|
||||
service RayletDataStreamer {
|
||||
|
|
Loading…
Add table
Reference in a new issue