mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
[Client] Use async rpc for remote call and actor creation (#18298)
* Use async rpc for remote calls, task and actor creations. * fix * check placement * check placement group. wait for id in destructor * fix * fix exception in destructor * Add test * revert change * Fix comment * fix
This commit is contained in:
parent
8dd3057644
commit
e41109a5e7
17 changed files with 238 additions and 135 deletions
|
@ -91,7 +91,7 @@ cdef class ClientObjectRef(ObjectRef):
|
||||||
cdef object _id_future
|
cdef object _id_future
|
||||||
|
|
||||||
cdef _set_id(self, id)
|
cdef _set_id(self, id)
|
||||||
cdef inline _wait_for_id(self)
|
cdef inline _wait_for_id(self, timeout=None)
|
||||||
|
|
||||||
cdef class ActorID(BaseID):
|
cdef class ActorID(BaseID):
|
||||||
cdef CActorID data
|
cdef CActorID data
|
||||||
|
@ -105,7 +105,7 @@ cdef class ClientActorRef(ActorID):
|
||||||
cdef object _id_future
|
cdef object _id_future
|
||||||
|
|
||||||
cdef _set_id(self, id)
|
cdef _set_id(self, id)
|
||||||
cdef inline _wait_for_id(self)
|
cdef inline _wait_for_id(self, timeout=None)
|
||||||
|
|
||||||
cdef class CoreWorker:
|
cdef class CoreWorker:
|
||||||
cdef:
|
cdef:
|
||||||
|
|
|
@ -166,8 +166,20 @@ cdef class ClientObjectRef(ObjectRef):
|
||||||
# call_release in this case, since the client should have already
|
# call_release in this case, since the client should have already
|
||||||
# disconnected at this point.
|
# disconnected at this point.
|
||||||
return
|
return
|
||||||
if client.ray.is_connected() and not self.data.IsNil():
|
if client.ray.is_connected():
|
||||||
client.ray.call_release(self.id)
|
try:
|
||||||
|
self._wait_for_id()
|
||||||
|
# cython would suppress this exception as well, but it tries to
|
||||||
|
# print out the exception which may crash. Log a simpler message
|
||||||
|
# instead.
|
||||||
|
except Exception:
|
||||||
|
logger.info(
|
||||||
|
"Exception in ObjectRef is ignored in destructor. "
|
||||||
|
"To receive this exception in application code, call "
|
||||||
|
"a method on the actor reference before its destructor "
|
||||||
|
"is run.")
|
||||||
|
if not self.data.IsNil():
|
||||||
|
client.ray.call_release(self.id)
|
||||||
|
|
||||||
cdef CObjectID native(self):
|
cdef CObjectID native(self):
|
||||||
self._wait_for_id()
|
self._wait_for_id()
|
||||||
|
@ -244,9 +256,9 @@ cdef class ClientObjectRef(ObjectRef):
|
||||||
self.data = CObjectID.FromBinary(<c_string>id)
|
self.data = CObjectID.FromBinary(<c_string>id)
|
||||||
client.ray.call_retain(id)
|
client.ray.call_retain(id)
|
||||||
|
|
||||||
cdef inline _wait_for_id(self):
|
cdef inline _wait_for_id(self, timeout=None):
|
||||||
if self._id_future:
|
if self._id_future:
|
||||||
with self._mutex:
|
with self._mutex:
|
||||||
if self._id_future:
|
if self._id_future:
|
||||||
self._set_id(self._id_future.result())
|
self._set_id(self._id_future.result(timeout=timeout))
|
||||||
self._id_future = None
|
self._id_future = None
|
||||||
|
|
|
@ -8,6 +8,7 @@ See https://github.com/ray-project/ray/issues/3721.
|
||||||
# _ID_TYPES list at the bottom of this file.
|
# _ID_TYPES list at the bottom of this file.
|
||||||
|
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from ray.includes.unique_ids cimport (
|
from ray.includes.unique_ids cimport (
|
||||||
|
@ -27,6 +28,8 @@ from ray.includes.unique_ids cimport (
|
||||||
import ray
|
import ray
|
||||||
from ray._private.utils import decode
|
from ray._private.utils import decode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_id(b, size=kUniqueIDSize):
|
def check_id(b, size=kUniqueIDSize):
|
||||||
if not isinstance(b, bytes):
|
if not isinstance(b, bytes):
|
||||||
|
@ -326,8 +329,20 @@ cdef class ClientActorRef(ActorID):
|
||||||
# call_release in this case, since the client should have already
|
# call_release in this case, since the client should have already
|
||||||
# disconnected at this point.
|
# disconnected at this point.
|
||||||
return
|
return
|
||||||
if client.ray.is_connected() and not self.data.IsNil():
|
if client.ray.is_connected():
|
||||||
client.ray.call_release(self.id)
|
try:
|
||||||
|
self._wait_for_id()
|
||||||
|
# cython would suppress this exception as well, but it tries to
|
||||||
|
# print out the exception which may crash. Log a simpler message
|
||||||
|
# instead.
|
||||||
|
except Exception:
|
||||||
|
logger.info(
|
||||||
|
"Exception from actor creation is ignored in destructor. "
|
||||||
|
"To receive this exception in application code, call "
|
||||||
|
"a method on the actor reference before its destructor "
|
||||||
|
"is run.")
|
||||||
|
if not self.data.IsNil():
|
||||||
|
client.ray.call_release(self.id)
|
||||||
|
|
||||||
def binary(self):
|
def binary(self):
|
||||||
self._wait_for_id()
|
self._wait_for_id()
|
||||||
|
@ -358,11 +373,11 @@ cdef class ClientActorRef(ActorID):
|
||||||
self.data = CActorID.FromBinary(<c_string>id)
|
self.data = CActorID.FromBinary(<c_string>id)
|
||||||
client.ray.call_retain(id)
|
client.ray.call_retain(id)
|
||||||
|
|
||||||
cdef _wait_for_id(self):
|
cdef _wait_for_id(self, timeout=None):
|
||||||
if self._id_future:
|
if self._id_future:
|
||||||
with self._mutex:
|
with self._mutex:
|
||||||
if self._id_future:
|
if self._id_future:
|
||||||
self._set_id(self._id_future.result())
|
self._set_id(self._id_future.result(timeout=timeout))
|
||||||
self._id_future = None
|
self._id_future = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -303,6 +303,8 @@ async def test_async_obj_unhandled_errors(ray_start_regular_shared):
|
||||||
# Test we report unhandled exceptions.
|
# Test we report unhandled exceptions.
|
||||||
ray.worker._unhandled_error_handler = interceptor
|
ray.worker._unhandled_error_handler = interceptor
|
||||||
x1 = f.remote()
|
x1 = f.remote()
|
||||||
|
# NOTE: Unhandled exception is from waiting for the value of x1's ObjectID
|
||||||
|
# in x1's destructor, and receiving an exception from f() instead.
|
||||||
del x1
|
del x1
|
||||||
wait_for_condition(lambda: num_exceptions == 1)
|
wait_for_condition(lambda: num_exceptions == 1)
|
||||||
|
|
||||||
|
|
|
@ -135,14 +135,23 @@ def test_get_list(ray_start_regular_shared):
|
||||||
assert ray.get([]) == []
|
assert ray.get([]) == []
|
||||||
assert ray.get([f.remote()]) == ["OK"]
|
assert ray.get([f.remote()]) == ["OK"]
|
||||||
|
|
||||||
|
get_count = 0
|
||||||
|
get_stub = ray.worker.server.GetObject
|
||||||
|
|
||||||
|
# ray.get() uses unary-unary RPC. Mock the server handler to count
|
||||||
|
# the number of requests received.
|
||||||
|
def get(req, metadata=None):
|
||||||
|
nonlocal get_count
|
||||||
|
get_count += 1
|
||||||
|
return get_stub(req, metadata=metadata)
|
||||||
|
|
||||||
|
ray.worker.server.GetObject = get
|
||||||
|
|
||||||
refs = [f.remote() for _ in range(100)]
|
refs = [f.remote() for _ in range(100)]
|
||||||
with ray.worker.data_client.lock:
|
|
||||||
req_id_before = ray.worker.data_client._req_id
|
|
||||||
assert ray.get(refs) == ["OK" for _ in range(100)]
|
assert ray.get(refs) == ["OK" for _ in range(100)]
|
||||||
|
|
||||||
# Only 1 RPC should be sent.
|
# Only 1 RPC should be sent.
|
||||||
with ray.worker.data_client.lock:
|
assert get_count == 1
|
||||||
assert ray.worker.data_client._req_id == req_id_before + 1, \
|
|
||||||
ray.worker.data_client._req_id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||||
|
@ -470,6 +479,22 @@ def test_serializing_exceptions(ray_start_regular_shared):
|
||||||
ray.get_actor("abc")
|
ray.get_actor("abc")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||||
|
def test_invalid_task(ray_start_regular_shared):
|
||||||
|
with ray_start_client_server() as ray:
|
||||||
|
|
||||||
|
@ray.remote(runtime_env="invalid value")
|
||||||
|
def f():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# No exception on making the remote call.
|
||||||
|
ref = f.remote()
|
||||||
|
|
||||||
|
# Exception during scheduling will be raised on ray.get()
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
ray.get(ref)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||||
def test_create_remote_before_start(ray_start_regular_shared):
|
def test_create_remote_before_start(ray_start_regular_shared):
|
||||||
"""Creates remote objects (as though in a library) before
|
"""Creates remote objects (as though in a library) before
|
||||||
|
|
|
@ -113,11 +113,12 @@ def test_kill_cancel_metadata(ray_start_regular):
|
||||||
class MetadataIsCorrectlyPassedException(Exception):
|
class MetadataIsCorrectlyPassedException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_terminate(term, metadata):
|
def mock_terminate(self, term):
|
||||||
raise MetadataIsCorrectlyPassedException(metadata[1][0])
|
raise MetadataIsCorrectlyPassedException(self._metadata[1][0])
|
||||||
|
|
||||||
# Mock stub's Terminate method to raise an exception.
|
# Mock stub's Terminate method to raise an exception.
|
||||||
ray.get_context().api.worker.server.Terminate = mock_terminate
|
stub = ray.get_context().api.worker.data_client
|
||||||
|
stub.Terminate = mock_terminate.__get__(stub)
|
||||||
|
|
||||||
# Verify the expected exception is raised with ray.kill.
|
# Verify the expected exception is raised with ray.kill.
|
||||||
# Check that argument of the exception matches "key" from the
|
# Check that argument of the exception matches "key" from the
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||||
from ray.util.client.worker import TASK_WARNING_THRESHOLD
|
|
||||||
from ray.util.debug import _logged
|
from ray.util.debug import _logged
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
@ -19,33 +17,6 @@ def reset_debug_logs():
|
||||||
class LoggerSuite(unittest.TestCase):
|
class LoggerSuite(unittest.TestCase):
|
||||||
"""Test client warnings are raised when many tasks are scheduled"""
|
"""Test client warnings are raised when many tasks are scheduled"""
|
||||||
|
|
||||||
def testManyTasksWarning(self):
|
|
||||||
with ray_start_client_server() as ray:
|
|
||||||
|
|
||||||
@ray.remote
|
|
||||||
def f():
|
|
||||||
return 42
|
|
||||||
|
|
||||||
with self.assertWarns(UserWarning) as cm:
|
|
||||||
for _ in range(TASK_WARNING_THRESHOLD + 1):
|
|
||||||
f.remote()
|
|
||||||
assert f"More than {TASK_WARNING_THRESHOLD} remote tasks have " \
|
|
||||||
"been scheduled." in cm.warning.args[0]
|
|
||||||
|
|
||||||
def testNoWarning(self):
|
|
||||||
with ray_start_client_server() as ray:
|
|
||||||
|
|
||||||
@ray.remote
|
|
||||||
def f():
|
|
||||||
return 42
|
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as warn_list:
|
|
||||||
for _ in range(TASK_WARNING_THRESHOLD):
|
|
||||||
f.remote()
|
|
||||||
assert not any(f"More than {TASK_WARNING_THRESHOLD} remote tasks "
|
|
||||||
"have been scheduled." in str(w.args[0])
|
|
||||||
for w in warn_list)
|
|
||||||
|
|
||||||
def testOutboundMessageSizeWarning(self):
|
def testOutboundMessageSizeWarning(self):
|
||||||
with ray_start_client_server() as ray:
|
with ray_start_client_server() as ray:
|
||||||
large_argument = np.random.rand(100, 100, 100)
|
large_argument = np.random.rand(100, 100, 100)
|
||||||
|
|
|
@ -50,11 +50,24 @@ def test_dataclient_disconnect_before_request():
|
||||||
# different remote calls.
|
# different remote calls.
|
||||||
ray.worker.data_client.request_queue.put(Mock())
|
ray.worker.data_client.request_queue.put(Mock())
|
||||||
|
|
||||||
# The next remote call should error since the data channel has shut
|
# The following two assertions are relatively brittle. Consider a more
|
||||||
# down, which should also disconnect the client.
|
# robust mechanism if they fail with code changes or become flaky.
|
||||||
|
|
||||||
|
# The next remote call should error since the data channel will shut
|
||||||
|
# down because of the invalid input above. Two cases can happen:
|
||||||
|
# (1) Data channel shuts down after `f.remote()` finishes.
|
||||||
|
# error is raised to `ray.get()`. The next background operation
|
||||||
|
# will disconnect Ray client.
|
||||||
|
# (2) Data channel shuts down before `f.remote()` is called.
|
||||||
|
# `f.remote()` will raise the error and disconnect the client.
|
||||||
with pytest.raises(ConnectionError):
|
with pytest.raises(ConnectionError):
|
||||||
ray.get(f.remote())
|
ray.get(f.remote())
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ConnectionError,
|
||||||
|
match="Ray client has already been disconnected"):
|
||||||
|
ray.get(f.remote())
|
||||||
|
|
||||||
# Client should be disconnected
|
# Client should be disconnected
|
||||||
assert not ray.is_connected()
|
assert not ray.is_connected()
|
||||||
|
|
||||||
|
|
|
@ -46,28 +46,18 @@ def test_check_bundle_index(ray_start_cluster, connect_to_client):
|
||||||
"CPU": 2
|
"CPU": 2
|
||||||
}])
|
}])
|
||||||
|
|
||||||
error_count = 0
|
with pytest.raises(ValueError, match="bundle index 3 is invalid"):
|
||||||
try:
|
|
||||||
Actor.options(
|
Actor.options(
|
||||||
placement_group=placement_group,
|
placement_group=placement_group,
|
||||||
placement_group_bundle_index=3).remote()
|
placement_group_bundle_index=3).remote()
|
||||||
except ValueError:
|
|
||||||
error_count = error_count + 1
|
|
||||||
assert error_count == 1
|
|
||||||
|
|
||||||
try:
|
with pytest.raises(ValueError, match="bundle index -2 is invalid"):
|
||||||
Actor.options(
|
Actor.options(
|
||||||
placement_group=placement_group,
|
placement_group=placement_group,
|
||||||
placement_group_bundle_index=-2).remote()
|
placement_group_bundle_index=-2).remote()
|
||||||
except ValueError:
|
|
||||||
error_count = error_count + 1
|
|
||||||
assert error_count == 2
|
|
||||||
|
|
||||||
try:
|
with pytest.raises(ValueError, match="bundle index must be -1"):
|
||||||
Actor.options(placement_group_bundle_index=0).remote()
|
Actor.options(placement_group_bundle_index=0).remote()
|
||||||
except ValueError:
|
|
||||||
error_count = error_count + 1
|
|
||||||
assert error_count == 3
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("connect_to_client", [False, True])
|
@pytest.mark.parametrize("connect_to_client", [False, True])
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""This file defines the interface between the ray client worker
|
"""This file defines the interface between the ray client worker
|
||||||
and the overall ray module API.
|
and the overall ray module API.
|
||||||
"""
|
"""
|
||||||
|
from concurrent.futures import Future
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -85,7 +86,10 @@ class ClientAPI:
|
||||||
assert len(args) == 0 and len(kwargs) > 0, error_string
|
assert len(args) == 0 and len(kwargs) > 0, error_string
|
||||||
return remote_decorator(options=kwargs)
|
return remote_decorator(options=kwargs)
|
||||||
|
|
||||||
def call_remote(self, instance: "ClientStub", *args, **kwargs):
|
# TODO(mwtian): consider adding _internal_ prefix to call_remote /
|
||||||
|
# call_release / call_retain.
|
||||||
|
def call_remote(self, instance: "ClientStub", *args,
|
||||||
|
**kwargs) -> List[Future]:
|
||||||
"""call_remote is called by stub objects to execute them remotely.
|
"""call_remote is called by stub objects to execute them remotely.
|
||||||
|
|
||||||
This is used by stub objects in situations where they're called
|
This is used by stub objects in situations where they're called
|
||||||
|
|
|
@ -141,9 +141,9 @@ class ServerUnpickler(pickle.Unpickler):
|
||||||
def persistent_load(self, pid):
|
def persistent_load(self, pid):
|
||||||
assert isinstance(pid, PickleStub)
|
assert isinstance(pid, PickleStub)
|
||||||
if pid.type == "Object":
|
if pid.type == "Object":
|
||||||
return ClientObjectRef(id=pid.ref_id)
|
return ClientObjectRef(pid.ref_id)
|
||||||
elif pid.type == "Actor":
|
elif pid.type == "Actor":
|
||||||
return ClientActorHandle(ClientActorRef(id=pid.ref_id))
|
return ClientActorHandle(ClientActorRef(pid.ref_id))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Being passed back an unknown stub")
|
raise NotImplementedError("Being passed back an unknown stub")
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ from ray.util.client.options import validate_options
|
||||||
from ray._private.signature import get_signature, extract_signature
|
from ray._private.signature import get_signature, extract_signature
|
||||||
from ray._private.utils import check_oversized_function
|
from ray._private.utils import check_oversized_function
|
||||||
|
|
||||||
|
import concurrent
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import grpc
|
import grpc
|
||||||
import os
|
import os
|
||||||
|
@ -209,9 +210,9 @@ class ClientActorClass(ClientStub):
|
||||||
def remote(self, *args, **kwargs) -> "ClientActorHandle":
|
def remote(self, *args, **kwargs) -> "ClientActorHandle":
|
||||||
self._init_signature.bind(*args, **kwargs)
|
self._init_signature.bind(*args, **kwargs)
|
||||||
# Actually instantiate the actor
|
# Actually instantiate the actor
|
||||||
ref_ids = ray.call_remote(self, *args, **kwargs)
|
futures = ray.call_remote(self, *args, **kwargs)
|
||||||
assert len(ref_ids) == 1
|
assert len(futures) == 1
|
||||||
return ClientActorHandle(ClientActorRef(ref_ids[0]), actor_class=self)
|
return ClientActorHandle(ClientActorRef(futures[0]), actor_class=self)
|
||||||
|
|
||||||
def options(self, **kwargs):
|
def options(self, **kwargs):
|
||||||
return ActorOptionWrapper(self, kwargs)
|
return ActorOptionWrapper(self, kwargs)
|
||||||
|
@ -397,13 +398,13 @@ class OptionWrapper:
|
||||||
class ActorOptionWrapper(OptionWrapper):
|
class ActorOptionWrapper(OptionWrapper):
|
||||||
def remote(self, *args, **kwargs):
|
def remote(self, *args, **kwargs):
|
||||||
self._remote_stub._init_signature.bind(*args, **kwargs)
|
self._remote_stub._init_signature.bind(*args, **kwargs)
|
||||||
ref_ids = ray.call_remote(self, *args, **kwargs)
|
futures = ray.call_remote(self, *args, **kwargs)
|
||||||
assert len(ref_ids) == 1
|
assert len(futures) == 1
|
||||||
actor_class = None
|
actor_class = None
|
||||||
if isinstance(self._remote_stub, ClientActorClass):
|
if isinstance(self._remote_stub, ClientActorClass):
|
||||||
actor_class = self._remote_stub
|
actor_class = self._remote_stub
|
||||||
return ClientActorHandle(
|
return ClientActorHandle(
|
||||||
ClientActorRef(ref_ids[0]), actor_class=actor_class)
|
ClientActorRef(futures[0]), actor_class=actor_class)
|
||||||
|
|
||||||
|
|
||||||
def set_task_options(task: ray_client_pb2.ClientTask,
|
def set_task_options(task: ray_client_pb2.ClientTask,
|
||||||
|
@ -423,13 +424,13 @@ def set_task_options(task: ray_client_pb2.ClientTask,
|
||||||
getattr(task, field).json_options = options_str
|
getattr(task, field).json_options = options_str
|
||||||
|
|
||||||
|
|
||||||
def return_refs(ids: List[bytes]
|
def return_refs(futures: List[concurrent.futures.Future]
|
||||||
) -> Union[None, ClientObjectRef, List[ClientObjectRef]]:
|
) -> Union[None, ClientObjectRef, List[ClientObjectRef]]:
|
||||||
if len(ids) == 1:
|
if not futures:
|
||||||
return ClientObjectRef(ids[0])
|
|
||||||
if len(ids) == 0:
|
|
||||||
return None
|
return None
|
||||||
return [ClientObjectRef(id) for id in ids]
|
if len(futures) == 1:
|
||||||
|
return ClientObjectRef(futures[0])
|
||||||
|
return [ClientObjectRef(fut) for fut in futures]
|
||||||
|
|
||||||
|
|
||||||
class InProgressSentinel:
|
class InProgressSentinel:
|
||||||
|
|
|
@ -361,3 +361,21 @@ class DataClient:
|
||||||
context=None) -> None:
|
context=None) -> None:
|
||||||
datareq = ray_client_pb2.DataRequest(release=request, )
|
datareq = ray_client_pb2.DataRequest(release=request, )
|
||||||
self._async_send(datareq)
|
self._async_send(datareq)
|
||||||
|
|
||||||
|
def Schedule(self, request: ray_client_pb2.ClientTask,
|
||||||
|
callback: ResponseCallable):
|
||||||
|
datareq = ray_client_pb2.DataRequest(task=request)
|
||||||
|
self._async_send(datareq, callback)
|
||||||
|
|
||||||
|
def Terminate(self, request: ray_client_pb2.TerminateRequest
|
||||||
|
) -> ray_client_pb2.TerminateResponse:
|
||||||
|
req = ray_client_pb2.DataRequest(terminate=request, )
|
||||||
|
resp = self._blocking_send(req)
|
||||||
|
return resp.terminate
|
||||||
|
|
||||||
|
def ListNamedActors(self,
|
||||||
|
request: ray_client_pb2.ClientListNamedActorsRequest
|
||||||
|
) -> ray_client_pb2.ClientListNamedActorsResponse:
|
||||||
|
req = ray_client_pb2.DataRequest(list_named_actors=request, )
|
||||||
|
resp = self._blocking_send(req)
|
||||||
|
return resp.list_named_actors
|
||||||
|
|
|
@ -2,6 +2,9 @@ from typing import Any
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from ray.util.placement_group import (PlacementGroup,
|
||||||
|
check_placement_group_index)
|
||||||
|
|
||||||
options = {
|
options = {
|
||||||
"num_returns": (int, lambda x: x >= 0,
|
"num_returns": (int, lambda x: x >= 0,
|
||||||
"The keyword 'num_returns' only accepts 0 "
|
"The keyword 'num_returns' only accepts 0 "
|
||||||
|
@ -43,6 +46,7 @@ def validate_options(
|
||||||
return None
|
return None
|
||||||
if len(kwargs_dict) == 0:
|
if len(kwargs_dict) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
out = {}
|
out = {}
|
||||||
for k, v in kwargs_dict.items():
|
for k, v in kwargs_dict.items():
|
||||||
if k not in options.keys():
|
if k not in options.keys():
|
||||||
|
@ -55,4 +59,21 @@ def validate_options(
|
||||||
if not validator[1](v):
|
if not validator[1](v):
|
||||||
raise ValueError(validator[2])
|
raise ValueError(validator[2])
|
||||||
out[k] = v
|
out[k] = v
|
||||||
|
|
||||||
|
# Validate placement setting similar to the logic in ray/actor.py and
|
||||||
|
# ray/remote_function.py. The difference is that when
|
||||||
|
# placement_group = default and placement_group_capture_child_tasks
|
||||||
|
# specified, placement group cannot be resolved at client. So this check
|
||||||
|
# skips this case and relies on server to enforce any condition.
|
||||||
|
bundle_index = out.get("placement_group_bundle_index", None)
|
||||||
|
if bundle_index is not None:
|
||||||
|
pg = out.get("placement_group", None)
|
||||||
|
if pg is None:
|
||||||
|
pg = PlacementGroup.empty()
|
||||||
|
if pg == "default" and (out.get("placement_group_capture_child_tasks",
|
||||||
|
None) is None):
|
||||||
|
pg = PlacementGroup.empty()
|
||||||
|
if isinstance(pg, PlacementGroup):
|
||||||
|
check_placement_group_index(pg, bundle_index)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -189,6 +189,23 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||||
# Clean up acknowledged cache entries
|
# Clean up acknowledged cache entries
|
||||||
response_cache.cleanup(req.acknowledge.req_id)
|
response_cache.cleanup(req.acknowledge.req_id)
|
||||||
continue
|
continue
|
||||||
|
elif req_type == "task":
|
||||||
|
with self.clients_lock:
|
||||||
|
resp_ticket = self.basic_service.Schedule(
|
||||||
|
req.task, context)
|
||||||
|
resp = ray_client_pb2.DataResponse(
|
||||||
|
task_ticket=resp_ticket)
|
||||||
|
elif req_type == "terminate":
|
||||||
|
with self.clients_lock:
|
||||||
|
response = self.basic_service.Terminate(
|
||||||
|
req.terminate, context)
|
||||||
|
resp = ray_client_pb2.DataResponse(terminate=response)
|
||||||
|
elif req_type == "list_named_actors":
|
||||||
|
with self.clients_lock:
|
||||||
|
response = self.basic_service.ListNamedActors(
|
||||||
|
req.list_named_actors)
|
||||||
|
resp = ray_client_pb2.DataResponse(
|
||||||
|
list_named_actors=response)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unreachable code: Request type "
|
raise Exception(f"Unreachable code: Request type "
|
||||||
f"{req_type} not handled in Datapath")
|
f"{req_type} not handled in Datapath")
|
||||||
|
|
|
@ -11,8 +11,10 @@ import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from concurrent.futures import Future
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
from typing import (Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING,
|
||||||
|
Union)
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
|
@ -53,10 +55,6 @@ MAX_BLOCKING_OPERATION_TIME_S: float = 2.0
|
||||||
# the connection began exceeds this value, a warning should be raised
|
# the connection began exceeds this value, a warning should be raised
|
||||||
MESSAGE_SIZE_THRESHOLD = 10 * 2**20 # 10 MB
|
MESSAGE_SIZE_THRESHOLD = 10 * 2**20 # 10 MB
|
||||||
|
|
||||||
# If the number of tasks scheduled on the client side since the connection
|
|
||||||
# began exceeds this value, a warning should be raised
|
|
||||||
TASK_WARNING_THRESHOLD = 1000
|
|
||||||
|
|
||||||
# Links to the Ray Design Pattern doc to use in the task overhead warning
|
# Links to the Ray Design Pattern doc to use in the task overhead warning
|
||||||
# message
|
# message
|
||||||
DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = \
|
DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK = \
|
||||||
|
@ -138,9 +136,7 @@ class Worker:
|
||||||
|
|
||||||
self.closed = False
|
self.closed = False
|
||||||
|
|
||||||
# Track these values to raise a warning if many tasks are being
|
# Track this value to raise a warning if a lot of data are transferred.
|
||||||
# scheduled
|
|
||||||
self.total_num_tasks_scheduled = 0
|
|
||||||
self.total_outbound_message_size_bytes = 0
|
self.total_outbound_message_size_bytes = 0
|
||||||
|
|
||||||
# Used to create unique IDs for RPCs to the RayletServicer
|
# Used to create unique IDs for RPCs to the RayletServicer
|
||||||
|
@ -365,17 +361,17 @@ class Worker:
|
||||||
req = ray_client_pb2.GetRequest(
|
req = ray_client_pb2.GetRequest(
|
||||||
ids=[r.id for r in ref], timeout=timeout)
|
ids=[r.id for r in ref], timeout=timeout)
|
||||||
try:
|
try:
|
||||||
data = self.data_client.GetObject(req)
|
resp = self._call_stub("GetObject", req, metadata=self.metadata)
|
||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
raise decode_exception(e)
|
raise decode_exception(e)
|
||||||
if not data.valid:
|
if not resp.valid:
|
||||||
try:
|
try:
|
||||||
err = cloudpickle.loads(data.error)
|
err = cloudpickle.loads(resp.error)
|
||||||
except (pickle.UnpicklingError, TypeError):
|
except (pickle.UnpicklingError, TypeError):
|
||||||
logger.exception("Failed to deserialize {}".format(data.error))
|
logger.exception("Failed to deserialize {}".format(resp.error))
|
||||||
raise
|
raise
|
||||||
raise err
|
raise err
|
||||||
return loads_from_server(data.data)
|
return loads_from_server(resp.data)
|
||||||
|
|
||||||
def put(self, vals, *, client_ref_id: bytes = None):
|
def put(self, vals, *, client_ref_id: bytes = None):
|
||||||
to_put = []
|
to_put = []
|
||||||
|
@ -450,7 +446,7 @@ class Worker:
|
||||||
|
|
||||||
return (client_ready_object_ids, client_remaining_object_ids)
|
return (client_ready_object_ids, client_remaining_object_ids)
|
||||||
|
|
||||||
def call_remote(self, instance, *args, **kwargs) -> List[bytes]:
|
def call_remote(self, instance, *args, **kwargs) -> List[Future]:
|
||||||
task = instance._prepare_client_task()
|
task = instance._prepare_client_task()
|
||||||
for arg in args:
|
for arg in args:
|
||||||
pb_arg = convert_to_arg(arg, self._client_id)
|
pb_arg = convert_to_arg(arg, self._client_id)
|
||||||
|
@ -460,38 +456,47 @@ class Worker:
|
||||||
return self._call_schedule_for_task(task, instance._num_returns())
|
return self._call_schedule_for_task(task, instance._num_returns())
|
||||||
|
|
||||||
def _call_schedule_for_task(self, task: ray_client_pb2.ClientTask,
|
def _call_schedule_for_task(self, task: ray_client_pb2.ClientTask,
|
||||||
num_returns: int) -> List[bytes]:
|
num_returns: int) -> List[Future]:
|
||||||
logger.debug("Scheduling %s" % task)
|
logger.debug("Scheduling %s" % task)
|
||||||
task.client_id = self._client_id
|
task.client_id = self._client_id
|
||||||
metadata = self._add_ids_to_metadata(self.metadata)
|
|
||||||
if num_returns is None:
|
if num_returns is None:
|
||||||
num_returns = 1
|
num_returns = 1
|
||||||
|
|
||||||
try:
|
id_futures = [Future() for _ in range(num_returns)]
|
||||||
ticket = self._call_stub("Schedule", task, metadata=metadata)
|
|
||||||
except grpc.RpcError as e:
|
def populate_ids(
|
||||||
raise decode_exception(e)
|
resp: Union[ray_client_pb2.DataResponse, Exception]) -> None:
|
||||||
|
if isinstance(resp, Exception):
|
||||||
|
if isinstance(resp, grpc.RpcError):
|
||||||
|
resp = decode_exception(resp)
|
||||||
|
for future in id_futures:
|
||||||
|
future.set_exception(resp)
|
||||||
|
return
|
||||||
|
|
||||||
|
ticket = resp.task_ticket
|
||||||
|
if not ticket.valid:
|
||||||
|
try:
|
||||||
|
ex = cloudpickle.loads(ticket.error)
|
||||||
|
except (pickle.UnpicklingError, TypeError) as e_new:
|
||||||
|
ex = e_new
|
||||||
|
for future in id_futures:
|
||||||
|
future.set_exception(ex)
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(ticket.return_ids) != num_returns:
|
||||||
|
exc = ValueError(
|
||||||
|
f"Expected {num_returns} returns but received "
|
||||||
|
f"{len(ticket.return_ids)}")
|
||||||
|
for future, raw_id in zip(id_futures, ticket.return_ids):
|
||||||
|
future.set_exception(exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
for future, raw_id in zip(id_futures, ticket.return_ids):
|
||||||
|
future.set_result(raw_id)
|
||||||
|
|
||||||
|
self.data_client.Schedule(task, populate_ids)
|
||||||
|
|
||||||
if not ticket.valid:
|
|
||||||
try:
|
|
||||||
raise cloudpickle.loads(ticket.error)
|
|
||||||
except (pickle.UnpicklingError, TypeError):
|
|
||||||
logger.exception("Failed to deserialize {}".format(
|
|
||||||
ticket.error))
|
|
||||||
raise
|
|
||||||
self.total_num_tasks_scheduled += 1
|
|
||||||
self.total_outbound_message_size_bytes += task.ByteSize()
|
self.total_outbound_message_size_bytes += task.ByteSize()
|
||||||
if self.total_num_tasks_scheduled > TASK_WARNING_THRESHOLD and \
|
|
||||||
log_once("client_communication_overhead_warning"):
|
|
||||||
warnings.warn(
|
|
||||||
f"More than {TASK_WARNING_THRESHOLD} remote tasks have been "
|
|
||||||
"scheduled. This can be slow on Ray Client due to "
|
|
||||||
"communication overhead over the network. If you're running "
|
|
||||||
"many fine-grained tasks, consider running them in a single "
|
|
||||||
"remote function. See the section on \"Too fine-grained "
|
|
||||||
"tasks\" in the Ray Design Patterns document for more "
|
|
||||||
f"details: {DESIGN_PATTERN_FINE_GRAIN_TASKS_LINK}",
|
|
||||||
UserWarning)
|
|
||||||
if self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD \
|
if self.total_outbound_message_size_bytes > MESSAGE_SIZE_THRESHOLD \
|
||||||
and log_once("client_communication_overhead_warning"):
|
and log_once("client_communication_overhead_warning"):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -508,10 +513,7 @@ class Worker:
|
||||||
"unserializable object\" section of the Ray Design Patterns "
|
"unserializable object\" section of the Ray Design Patterns "
|
||||||
"document, available here: "
|
"document, available here: "
|
||||||
f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}", UserWarning)
|
f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}", UserWarning)
|
||||||
if num_returns != len(ticket.return_ids):
|
return id_futures
|
||||||
raise TypeError("Unexpected number of returned values. Expected "
|
|
||||||
f"{num_returns} actual {ticket.return_ids}")
|
|
||||||
return ticket.return_ids
|
|
||||||
|
|
||||||
def call_release(self, id: bytes) -> None:
|
def call_release(self, id: bytes) -> None:
|
||||||
if self.closed:
|
if self.closed:
|
||||||
|
@ -547,9 +549,15 @@ class Worker:
|
||||||
task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
|
task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
|
||||||
task.name = name
|
task.name = name
|
||||||
task.namespace = namespace or ""
|
task.namespace = namespace or ""
|
||||||
ids = self._call_schedule_for_task(task, num_returns=1)
|
futures = self._call_schedule_for_task(task, 1)
|
||||||
assert len(ids) == 1
|
assert len(futures) == 1
|
||||||
return ClientActorHandle(ClientActorRef(ids[0]))
|
handle = ClientActorHandle(ClientActorRef(futures[0]))
|
||||||
|
# `actor_ref.is_nil()` waits until the underlying ID is resolved.
|
||||||
|
# This is needed because `get_actor` is often used to check the
|
||||||
|
# existence of an actor.
|
||||||
|
if handle.actor_ref.is_nil():
|
||||||
|
raise ValueError(f"ActorID for {name} is empty")
|
||||||
|
return handle
|
||||||
|
|
||||||
def terminate_actor(self, actor: ClientActorHandle,
|
def terminate_actor(self, actor: ClientActorHandle,
|
||||||
no_restart: bool) -> None:
|
no_restart: bool) -> None:
|
||||||
|
@ -559,11 +567,10 @@ class Worker:
|
||||||
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
|
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
|
||||||
term_actor.id = actor.actor_ref.id
|
term_actor.id = actor.actor_ref.id
|
||||||
term_actor.no_restart = no_restart
|
term_actor.no_restart = no_restart
|
||||||
|
term = ray_client_pb2.TerminateRequest(actor=term_actor)
|
||||||
|
term.client_id = self._client_id
|
||||||
try:
|
try:
|
||||||
term = ray_client_pb2.TerminateRequest(actor=term_actor)
|
self.data_client.Terminate(term)
|
||||||
term.client_id = self._client_id
|
|
||||||
metadata = self._add_ids_to_metadata(self.metadata)
|
|
||||||
self._call_stub("Terminate", term, metadata=metadata)
|
|
||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
raise decode_exception(e)
|
raise decode_exception(e)
|
||||||
|
|
||||||
|
@ -577,19 +584,18 @@ class Worker:
|
||||||
term_object.id = obj.id
|
term_object.id = obj.id
|
||||||
term_object.force = force
|
term_object.force = force
|
||||||
term_object.recursive = recursive
|
term_object.recursive = recursive
|
||||||
|
term = ray_client_pb2.TerminateRequest(task_object=term_object)
|
||||||
|
term.client_id = self._client_id
|
||||||
try:
|
try:
|
||||||
term = ray_client_pb2.TerminateRequest(task_object=term_object)
|
self.data_client.Terminate(term)
|
||||||
term.client_id = self._client_id
|
|
||||||
metadata = self._add_ids_to_metadata(self.metadata)
|
|
||||||
self._call_stub("Terminate", term, metadata=metadata)
|
|
||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
raise decode_exception(e)
|
raise decode_exception(e)
|
||||||
|
|
||||||
def get_cluster_info(self,
|
def get_cluster_info(self,
|
||||||
type: ray_client_pb2.ClusterInfoType.TypeEnum,
|
req_type: ray_client_pb2.ClusterInfoType.TypeEnum,
|
||||||
timeout: Optional[float] = None):
|
timeout: Optional[float] = None):
|
||||||
req = ray_client_pb2.ClusterInfoRequest()
|
req = ray_client_pb2.ClusterInfoRequest()
|
||||||
req.type = type
|
req.type = req_type
|
||||||
resp = self.server.ClusterInfo(
|
resp = self.server.ClusterInfo(
|
||||||
req, timeout=timeout, metadata=self.metadata)
|
req, timeout=timeout, metadata=self.metadata)
|
||||||
if resp.WhichOneof("response_type") == "resource_table":
|
if resp.WhichOneof("response_type") == "resource_table":
|
||||||
|
@ -630,9 +636,7 @@ class Worker:
|
||||||
def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]:
|
def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]:
|
||||||
req = ray_client_pb2.ClientListNamedActorsRequest(
|
req = ray_client_pb2.ClientListNamedActorsRequest(
|
||||||
all_namespaces=all_namespaces)
|
all_namespaces=all_namespaces)
|
||||||
return json.loads(
|
return json.loads(self.data_client.ListNamedActors(req).actors_json)
|
||||||
self._call_stub("ListNamedActors", req,
|
|
||||||
metadata=self.metadata).actors_json)
|
|
||||||
|
|
||||||
def is_initialized(self) -> bool:
|
def is_initialized(self) -> bool:
|
||||||
if self.server is not None:
|
if self.server is not None:
|
||||||
|
|
|
@ -367,6 +367,9 @@ message DataRequest {
|
||||||
PrepRuntimeEnvRequest prep_runtime_env = 7;
|
PrepRuntimeEnvRequest prep_runtime_env = 7;
|
||||||
ConnectionCleanupRequest connection_cleanup = 8;
|
ConnectionCleanupRequest connection_cleanup = 8;
|
||||||
AcknowledgeRequest acknowledge = 9;
|
AcknowledgeRequest acknowledge = 9;
|
||||||
|
ClientTask task = 10;
|
||||||
|
TerminateRequest terminate = 11;
|
||||||
|
ClientListNamedActorsRequest list_named_actors = 12;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -381,7 +384,13 @@ message DataResponse {
|
||||||
InitResponse init = 6;
|
InitResponse init = 6;
|
||||||
PrepRuntimeEnvResponse prep_runtime_env = 7;
|
PrepRuntimeEnvResponse prep_runtime_env = 7;
|
||||||
ConnectionCleanupResponse connection_cleanup = 8;
|
ConnectionCleanupResponse connection_cleanup = 8;
|
||||||
|
ClientTaskTicket task_ticket = 10;
|
||||||
|
TerminateResponse terminate = 11;
|
||||||
|
ClientListNamedActorsResponse list_named_actors = 12;
|
||||||
}
|
}
|
||||||
|
// tag 9 is skipped, since there is no AcknowledgeResponse
|
||||||
|
reserved 9;
|
||||||
|
reserved "acknowledge";
|
||||||
}
|
}
|
||||||
|
|
||||||
service RayletDataStreamer {
|
service RayletDataStreamer {
|
||||||
|
|
Loading…
Add table
Reference in a new issue