mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[client] let ray client reconnect on grpc failures (#18329)
* wip * client tests working again * extra prints * start reconnect logic for proxier * local proxy more wip * delay cleanup logic working on proxy * Fix up dataservicer logic * lint + fix proxy data servicer exit logic * hmmm * delay cleanup always in dataservicer * fix last_seen check * cancel channel on error * explicitly request cleanup * cleanup request fixes * fix dataclient proxy * start idempotence logic * change default channel state * add backoff logic * move connection logic back into worker.__init__ * add logic for replay cache case where request was received but response hasn't been fully resolved * new proto entries for data stream caching * start replay_cache logic, increase cleanup delay * hardcode retries * Let data channel attempt reconnects * manually reset queue, remove replay_cache logic * reduce cleanup delay to 5 minutes * fix local tests * Remove async cache logic * retry async requests * simplify backoff logic * Fix ray client proto * Configurable reconnect grace period * Basic logsclient fix? * Configure grace through environment variable * Use stopped event to force faster datapath cleanup * Better connect+reconnect logic * fix reconnect_grace_period default * init fixes for reconnect_grace_period * cleanup * fix _get_client_id_from_context call * add logic for pathological cache cases * less intrusive data channel error message * fix tests * Make stuff less painful to read * add ordered replay cache for dataservicer, replay cache tests * fix ordering import, start_reconnect test * add middleman testing logic * enforce ordering of dataclient requests * retry wheels * grace period through env only, restore test_dataclient_disconnect * minor fixes * force rerun * less intrusive error msgs * address review * replay->response cache * remove unneeded sleep * typing * extra response cache test * fix error msg * remove TODO * add _reconnect_channel * add grace period test * store thread_id and req_id in metadata * Revert "store thread_id and req_id in metadata" This reverts commit 12bc05cc0ceb0b764e2279353ba003fca16c3181. * Revert "Revert "store thread_id and req_id in metadata"" This reverts commit 67874cf3a207fed49e6070c7e955a640f0094d19. * fix metadata check * remove comment * removed unused cv * cast back to int * refactor Datapath for readability * Revert refactor This reverts commit f789bad473c953eebabefe7eb6aa891e5b8a8f13. * fix comment * merge fixes * refactor _shutdown * address reviews * log errors in both cases * add comments * address reviews * move reconnect test to medium * Always propogate error to callbacks * readability * formatting * Faster cleanup on uncaught dataservicer errors * delete tmp file * offset commit * rrefactor * propagate data servicer error message * Stricter handling/propagation of errors * remove tmp file * better docs * forward reconnecting metadata * add annotation * fix invalidate + add test * fix docstrings and types * disable retries and caching if reconnect grace period is set to 0 * update comments * address review, increase ack batch size and skip ack's if reconnect isn't enabled * Don't terminate data stream on missing reconnecting metadata
This commit is contained in:
parent
ffe7108eae
commit
8858489e2f
14 changed files with 1571 additions and 156 deletions
|
@ -67,6 +67,9 @@ DEFAULT_ACTOR_CREATION_CPU_SPECIFIED = 1
|
|||
# Default number of return values for each actor method.
|
||||
DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1
|
||||
|
||||
# Wait 30 seconds for client to reconnect after unexpected disconnection
|
||||
DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD = 30
|
||||
|
||||
# If a remote function or actor (or some other export) has serialized size
|
||||
# greater than this quantity, print an warning.
|
||||
FUNCTION_SIZE_WARN_THRESHOLD = 10**7
|
||||
|
|
|
@ -49,12 +49,13 @@ py_test_module_list(
|
|||
"test_client.py",
|
||||
"test_client_builder.py",
|
||||
"test_client_init.py",
|
||||
"test_client_multi.py",
|
||||
"test_client_multi.py",
|
||||
"test_client_proxy.py",
|
||||
"test_client_server.py",
|
||||
"test_client_references.py",
|
||||
"test_client_warnings.py",
|
||||
"test_client_library_integration.py",
|
||||
"test_client_reconnect.py",
|
||||
],
|
||||
size = "medium",
|
||||
extra_srcs = SRCS,
|
||||
|
@ -138,6 +139,7 @@ py_test_module_list(
|
|||
"test_dataclient_disconnect.py",
|
||||
"test_k8s_operator_unit_tests.py",
|
||||
"test_monitor.py",
|
||||
"test_response_cache.py",
|
||||
],
|
||||
size = "small",
|
||||
extra_srcs = SRCS,
|
||||
|
|
413
python/ray/tests/test_client_reconnect.py
Normal file
413
python/ray/tests/test_client_reconnect.py
Normal file
|
@ -0,0 +1,413 @@
|
|||
from concurrent import futures
|
||||
import contextlib
|
||||
import os
|
||||
import threading
|
||||
import sys
|
||||
from ray.util.client.common import CLIENT_SERVER_MAX_THREADS, GRPC_OPTIONS
|
||||
import grpc
|
||||
|
||||
import time
|
||||
import random
|
||||
import pytest
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
import ray.util.client.server.server as ray_client_server
|
||||
import ray
|
||||
|
||||
# At a high level, these tests rely on an extra RPC server sitting
|
||||
# between the client and the real Ray server to inject errors, drop responses
|
||||
# and drop requests, i.e. at a high level:
|
||||
# Ray Client <-> Middleman Server <-> Proxy Server
|
||||
|
||||
# Type for middleman hooks used to inject errors
|
||||
Hook = Callable[[Any], None]
|
||||
|
||||
|
||||
class MiddlemanDataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
"""
|
||||
Forwards all requests to the real data servicer. Useful for injecting
|
||||
errors between a client and server pair.
|
||||
"""
|
||||
|
||||
def __init__(self, on_response: Optional[Hook] = None):
|
||||
"""
|
||||
Args:
|
||||
on_response: Optional hook to inject errors before sending back a
|
||||
response
|
||||
"""
|
||||
self.stub = None
|
||||
self.on_response = on_response
|
||||
|
||||
def set_channel(self, channel: grpc.Channel) -> None:
|
||||
self.stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
|
||||
|
||||
def Datapath(self, request_iterator, context):
|
||||
try:
|
||||
for response in self.stub.Datapath(
|
||||
request_iterator, metadata=context.invocation_metadata()):
|
||||
if self.on_response:
|
||||
self.on_response(response)
|
||||
yield response
|
||||
except grpc.RpcError as e:
|
||||
context.set_code(e.code())
|
||||
context.set_details(e.details())
|
||||
|
||||
|
||||
class MiddlemanLogServicer(ray_client_pb2_grpc.RayletLogStreamerServicer):
|
||||
"""
|
||||
Forwards all requests to the real log servicer. Useful for injecting
|
||||
errors between a client and server pair.
|
||||
"""
|
||||
|
||||
def __init__(self, on_response: Optional[Hook] = None):
|
||||
"""
|
||||
Args:
|
||||
on_response: Optional hook to inject errors before sending back a
|
||||
response
|
||||
"""
|
||||
self.stub = None
|
||||
self.on_response = on_response
|
||||
|
||||
def set_channel(self, channel: grpc.Channel) -> None:
|
||||
self.stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel)
|
||||
|
||||
def Logstream(self, request_iterator, context):
|
||||
try:
|
||||
for response in self.stub.Logstream(
|
||||
request_iterator, metadata=context.invocation_metadata()):
|
||||
if self.on_response:
|
||||
self.on_response(response)
|
||||
yield response
|
||||
except grpc.RpcError as e:
|
||||
context.set_code(e.code())
|
||||
context.set_details(e.details())
|
||||
|
||||
|
||||
class MiddlemanRayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
"""
|
||||
Forwards all requests to the raylet driver servicer. Useful for injecting
|
||||
errors between a client and server pair.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
on_request: Optional[Hook] = None,
|
||||
on_response: Optional[Hook] = None):
|
||||
"""
|
||||
Args:
|
||||
on_request: Optional hook to inject errors before forwarding a
|
||||
request
|
||||
on_response: Optional hook to inject errors before sending back a
|
||||
response
|
||||
"""
|
||||
self.stub = None
|
||||
self.on_request = on_request
|
||||
self.on_response = on_response
|
||||
|
||||
def set_channel(self, channel: grpc.Channel) -> None:
|
||||
self.stub = ray_client_pb2_grpc.RayletDriverStub(channel)
|
||||
|
||||
def _call_inner_function(
|
||||
self, request: Any, context,
|
||||
method: str) -> Optional[ray_client_pb2_grpc.RayletDriverStub]:
|
||||
if self.on_request:
|
||||
self.on_request(request)
|
||||
try:
|
||||
response = getattr(self.stub, method)(
|
||||
request, metadata=context.invocation_metadata())
|
||||
except grpc.RpcError as e:
|
||||
context.set_code(e.code())
|
||||
context.set_details(e.details())
|
||||
raise
|
||||
if self.on_response:
|
||||
self.on_response(response)
|
||||
return response
|
||||
|
||||
def Init(self, request, context=None) -> ray_client_pb2.InitResponse:
|
||||
return self._call_inner_function(request, context, "Init")
|
||||
|
||||
def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse:
|
||||
return self._call_inner_function(request, context, "KVPut")
|
||||
|
||||
def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse:
|
||||
return self._call_inner_function(request, context, "KVGet")
|
||||
|
||||
def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse:
|
||||
return self._call_inner_function(request, context, "KVDel")
|
||||
|
||||
def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse:
|
||||
return self._call_inner_function(request, context, "KVList")
|
||||
|
||||
def KVExists(self, request,
|
||||
context=None) -> ray_client_pb2.KVExistsResponse:
|
||||
return self._call_inner_function(request, context, "KVExists")
|
||||
|
||||
def ListNamedActors(self, request, context=None
|
||||
) -> ray_client_pb2.ClientListNamedActorsResponse:
|
||||
return self._call_inner_function(request, context, "ListNamedActors")
|
||||
|
||||
def ClusterInfo(self, request,
|
||||
context=None) -> ray_client_pb2.ClusterInfoResponse:
|
||||
return self._call_inner_function(request, context, "ClusterInfo")
|
||||
|
||||
def Terminate(self, req, context=None):
|
||||
return self._call_inner_function(req, context, "Terminate")
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
return self._call_inner_function(request, context, "GetObject")
|
||||
|
||||
def PutObject(self, request: ray_client_pb2.PutRequest,
|
||||
context=None) -> ray_client_pb2.PutResponse:
|
||||
return self._call_inner_function(request, context, "PutObject")
|
||||
|
||||
def WaitObject(self, request: ray_client_pb2.WaitRequest,
|
||||
context=None) -> ray_client_pb2.WaitResponse:
|
||||
return self._call_inner_function(request, context, "WaitObject")
|
||||
|
||||
def Schedule(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
return self._call_inner_function(task, context, "Schedule")
|
||||
|
||||
|
||||
class MiddlemanServer:
|
||||
"""
|
||||
Helper class that wraps the RPC server that middlemans the connection
|
||||
between the client and the real ray server. Useful for injecting
|
||||
errors between a client and server pair.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
listen_addr: str,
|
||||
real_addr,
|
||||
on_log_response: Optional[Hook] = None,
|
||||
on_data_response: Optional[Hook] = None,
|
||||
on_task_request: Optional[Hook] = None,
|
||||
on_task_response: Optional[Hook] = None):
|
||||
"""
|
||||
Args:
|
||||
listen_addr: The address the middleman server will listen on
|
||||
real_addr: The address of the real ray server
|
||||
on_log_response: Optional hook to inject errors before sending back
|
||||
a log response
|
||||
on_data_response: Optional hook to inject errors before sending
|
||||
back a data response
|
||||
on_task_request: Optional hook to inject errors before forwarding
|
||||
a raylet driver request
|
||||
on_task_response: Optional hook to inject errors before sending
|
||||
back a raylet driver response
|
||||
"""
|
||||
self.listen_addr = listen_addr
|
||||
self.real_addr = real_addr
|
||||
self.server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=CLIENT_SERVER_MAX_THREADS),
|
||||
options=GRPC_OPTIONS)
|
||||
self.task_servicer = MiddlemanRayletServicer(
|
||||
on_response=on_task_response, on_request=on_task_request)
|
||||
self.data_servicer = MiddlemanDataServicer(
|
||||
on_response=on_data_response)
|
||||
self.logs_servicer = MiddlemanLogServicer(on_response=on_log_response)
|
||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||
self.task_servicer, self.server)
|
||||
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
|
||||
self.data_servicer, self.server)
|
||||
ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(
|
||||
self.logs_servicer, self.server)
|
||||
self.server.add_insecure_port(self.listen_addr)
|
||||
self.channel = None
|
||||
self.reset_channel()
|
||||
|
||||
def reset_channel(self) -> None:
|
||||
"""
|
||||
Manually close and reopen the channel to the real ray server. This
|
||||
simulates a disconnection between the client and the server.
|
||||
"""
|
||||
if self.channel:
|
||||
self.channel.close()
|
||||
self.channel = grpc.insecure_channel(
|
||||
self.real_addr, options=GRPC_OPTIONS)
|
||||
grpc.channel_ready_future(self.channel)
|
||||
self.task_servicer.set_channel(self.channel)
|
||||
self.data_servicer.set_channel(self.channel)
|
||||
self.logs_servicer.set_channel(self.channel)
|
||||
|
||||
def start(self) -> None:
|
||||
self.server.start()
|
||||
|
||||
def stop(self, grace: int) -> None:
|
||||
self.server.stop(grace)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def start_middleman_server(on_log_response=None,
|
||||
on_data_response=None,
|
||||
on_task_request=None,
|
||||
on_task_response=None):
|
||||
"""
|
||||
Helper context that starts a middleman server listening on port 10011,
|
||||
and a ray client server on port 50051.
|
||||
"""
|
||||
ray._inside_client_test = True
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
middleman = None
|
||||
try:
|
||||
middleman = MiddlemanServer(
|
||||
listen_addr="localhost:10011",
|
||||
real_addr="localhost:50051",
|
||||
on_log_response=on_log_response,
|
||||
on_data_response=on_data_response,
|
||||
on_task_request=on_task_response,
|
||||
on_task_response=on_task_request)
|
||||
middleman.start()
|
||||
ray.init("ray://localhost:10011")
|
||||
yield middleman, server
|
||||
finally:
|
||||
ray._inside_client_test = False
|
||||
ray.util.disconnect()
|
||||
server.stop(0)
|
||||
if middleman:
|
||||
middleman.stop(0)
|
||||
|
||||
|
||||
def test_disconnect_during_get():
|
||||
"""
|
||||
Disconnect the proxy and the client in the middle of a long running get
|
||||
"""
|
||||
|
||||
@ray.remote
|
||||
def slow_result():
|
||||
time.sleep(20)
|
||||
return 12345
|
||||
|
||||
def disconnect(middleman):
|
||||
time.sleep(3)
|
||||
middleman.reset_channel()
|
||||
|
||||
with start_middleman_server() as (middleman, _):
|
||||
disconnect_thread = threading.Thread(
|
||||
target=disconnect, args=(middleman, ))
|
||||
disconnect_thread.start()
|
||||
result = ray.get(slow_result.remote())
|
||||
assert result == 12345
|
||||
disconnect_thread.join()
|
||||
|
||||
|
||||
def test_valid_actor_state():
|
||||
"""
|
||||
Repeatedly inject errors in the middle of mutating actor calls. Check
|
||||
at the end that the final state of the actor is consistent with what
|
||||
we would expect had the disconnects not occurred.
|
||||
"""
|
||||
|
||||
@ray.remote
|
||||
class IncrActor:
|
||||
def __init__(self):
|
||||
self.val = 0
|
||||
|
||||
def incr(self):
|
||||
self.val += 1
|
||||
return self.val
|
||||
|
||||
i = 0
|
||||
|
||||
def fail_every_seven(_):
|
||||
# Inject an error every seventh time this method is called
|
||||
nonlocal i
|
||||
i += 1
|
||||
if i % 7 == 0:
|
||||
raise RuntimeError
|
||||
|
||||
with start_middleman_server(
|
||||
on_data_response=fail_every_seven,
|
||||
on_task_request=fail_every_seven,
|
||||
on_task_response=fail_every_seven):
|
||||
actor = IncrActor.remote()
|
||||
for _ in range(100):
|
||||
ref = actor.incr.remote()
|
||||
assert ray.get(ref) == 100
|
||||
|
||||
|
||||
def test_valid_actor_state_2():
|
||||
"""
|
||||
Do a full disconnect (cancel channel) every 11 requests. Failure
|
||||
happens:
|
||||
- before request sent: request never reaches server
|
||||
- before response received: response never reaches server
|
||||
- while get's are being processed
|
||||
"""
|
||||
|
||||
@ray.remote
|
||||
class IncrActor:
|
||||
def __init__(self):
|
||||
self.val = 0
|
||||
|
||||
def incr(self):
|
||||
self.val += 1
|
||||
return self.val
|
||||
|
||||
i = 0
|
||||
|
||||
with start_middleman_server() as (middleman, _):
|
||||
|
||||
def fail_every_eleven(_):
|
||||
nonlocal i
|
||||
i += 1
|
||||
if i % 11 == 0:
|
||||
middleman.reset_channel()
|
||||
|
||||
middleman.data_servicer.on_response = fail_every_eleven
|
||||
middleman.task_servicer.on_request = fail_every_eleven
|
||||
middleman.task_servicer.on_response = fail_every_eleven
|
||||
|
||||
actor = IncrActor.remote()
|
||||
for _ in range(100):
|
||||
ref = actor.incr.remote()
|
||||
assert ray.get(ref) == 100
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on windows")
|
||||
def test_noisy_puts():
|
||||
"""
|
||||
Randomly kills the data channel with 10% chance when receiving response
|
||||
(requests made it to server, responses dropped) and checks that final
|
||||
result is still consistent
|
||||
"""
|
||||
random.seed(12345)
|
||||
with start_middleman_server() as (middleman, _):
|
||||
|
||||
def fail_randomly(response: ray_client_pb2.DataResponse):
|
||||
if random.random() < 0.1:
|
||||
raise RuntimeError
|
||||
|
||||
middleman.data_servicer.on_response = fail_randomly
|
||||
|
||||
refs = [ray.put(i * 123) for i in range(500)]
|
||||
results = ray.get(refs)
|
||||
for i, result in enumerate(results):
|
||||
assert result == i * 123
|
||||
|
||||
|
||||
def test_client_reconnect_grace_period():
|
||||
"""
|
||||
Tests that the client gives up attempting to reconnect the channel
|
||||
after the grace period expires.
|
||||
"""
|
||||
# Lower grace period to 5 seconds to save time
|
||||
with patch.dict(os.environ, {"RAY_CLIENT_RECONNECT_GRACE_PERIOD": "5"}), \
|
||||
start_middleman_server() as (middleman, _):
|
||||
assert ray.get(ray.put(42)) == 42
|
||||
# Close channel
|
||||
middleman.channel.close()
|
||||
start_time = time.time()
|
||||
with pytest.raises(ConnectionError):
|
||||
ray.get(ray.put(42))
|
||||
# Connection error should have been raised within a reasonable
|
||||
# amount of time. Set to significantly higher than 5 seconds
|
||||
# to account for reconnect backoff timing
|
||||
assert time.time() - start_time < 20
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1,10 +1,16 @@
|
|||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
import os
|
||||
import time
|
||||
|
||||
|
||||
def test_dataclient_disconnect_on_request():
|
||||
with ray_start_client_server() as ray:
|
||||
# Client can't signal graceful shutdown to server after unrecoverable
|
||||
# error. Lower grace period so we don't have to sleep as long before
|
||||
# checking new connection data.
|
||||
with patch.dict(os.environ, {"RAY_CLIENT_RECONNECT_GRACE_PERIOD": "5"}), \
|
||||
ray_start_client_server() as ray:
|
||||
assert ray.is_connected()
|
||||
|
||||
@ray.remote
|
||||
|
@ -20,13 +26,18 @@ def test_dataclient_disconnect_on_request():
|
|||
assert not ray.is_connected()
|
||||
|
||||
# Test that a new connection can be made
|
||||
time.sleep(5) # Give server time to clean up old connection
|
||||
connection_data = ray.connect("localhost:50051")
|
||||
assert connection_data["num_clients"] == 1
|
||||
assert ray.get(f.remote()) == 42
|
||||
|
||||
|
||||
def test_dataclient_disconnect_before_request():
|
||||
with ray_start_client_server() as ray:
|
||||
# Client can't signal graceful shutdown to server after unrecoverable
|
||||
# error. Lower grace period so we don't have to sleep as long before
|
||||
# checking new connection data.
|
||||
with patch.dict(os.environ, {"RAY_CLIENT_RECONNECT_GRACE_PERIOD": "5"}), \
|
||||
ray_start_client_server() as ray:
|
||||
assert ray.is_connected()
|
||||
|
||||
@ray.remote
|
||||
|
@ -48,6 +59,7 @@ def test_dataclient_disconnect_before_request():
|
|||
assert not ray.is_connected()
|
||||
|
||||
# Test that a new connection can be made
|
||||
time.sleep(5) # Give server time to clean up old connection
|
||||
connection_data = ray.connect("localhost:50051")
|
||||
assert connection_data["num_clients"] == 1
|
||||
assert ray.get(f.remote()) == 42
|
||||
|
|
218
python/ray/tests/test_response_cache.py
Normal file
218
python/ray/tests/test_response_cache.py
Normal file
|
@ -0,0 +1,218 @@
|
|||
from ray.util.client.common import (_id_is_newer, ResponseCache,
|
||||
OrderedResponseCache, INT32_MAX)
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_id_is_newer():
|
||||
"""
|
||||
Sanity checks the logic for ID is newer. In general, we would expect
|
||||
that higher IDs are newer than lower IDs, for example 25 can be assumed
|
||||
to be newer than 24.
|
||||
Since IDs roll over at INT32_MAX (~2**31), we should check for weird
|
||||
behavior there. In particular, we would expect an ID like `11` to be
|
||||
newer than the ID `2**31` since it's likely that the counter rolled
|
||||
over.
|
||||
"""
|
||||
# Common cases -- higher IDs normally considered newer
|
||||
assert _id_is_newer(30, 29)
|
||||
assert _id_is_newer(12345, 12344)
|
||||
assert not _id_is_newer(12344, 12345)
|
||||
assert not _id_is_newer(5678, 5678)
|
||||
|
||||
# Check behavior near max int boundary
|
||||
assert _id_is_newer(INT32_MAX, INT32_MAX - 1)
|
||||
assert _id_is_newer(INT32_MAX - 1, INT32_MAX - 2)
|
||||
|
||||
# Low IDs are assumed newer than higher ones if it looks like rollover has
|
||||
# occurred
|
||||
assert _id_is_newer(0, INT32_MAX - 4)
|
||||
assert _id_is_newer(1001, INT32_MAX - 123)
|
||||
assert not _id_is_newer(INT32_MAX, 123)
|
||||
|
||||
|
||||
def test_response_cache_complete_response():
|
||||
"""
|
||||
Test basic check/update logic of cache, and that nothing blocks
|
||||
"""
|
||||
cache = ResponseCache()
|
||||
cache.check_cache(123, 15) # shouldn't block
|
||||
cache.update_cache(123, 15, "abcdef")
|
||||
assert cache.check_cache(123, 15) == "abcdef"
|
||||
|
||||
|
||||
def test_ordered_response_cache_complete_response():
|
||||
"""
|
||||
Test basic check/update logic of ordered cache, and that nothing blocks
|
||||
"""
|
||||
cache = OrderedResponseCache()
|
||||
cache.check_cache(15) # shouldn't block
|
||||
cache.update_cache(15, "vwxyz")
|
||||
assert cache.check_cache(15) == "vwxyz"
|
||||
|
||||
|
||||
def test_response_cache_incomplete_response():
|
||||
"""
|
||||
Tests case where a cache entry is populated after a long time. Any new
|
||||
threads attempting to access that entry should sleep until the response
|
||||
is ready.
|
||||
"""
|
||||
cache = ResponseCache()
|
||||
|
||||
def populate_cache():
|
||||
time.sleep(2)
|
||||
cache.update_cache(123, 15, "abcdef")
|
||||
|
||||
cache.check_cache(123, 15) # shouldn't block
|
||||
t = threading.Thread(target=populate_cache, args=())
|
||||
t.start()
|
||||
# Should block until other thread populates cache
|
||||
assert cache.check_cache(123, 15) == "abcdef"
|
||||
t.join()
|
||||
|
||||
|
||||
def test_ordered_response_cache_incomplete_response():
|
||||
"""
|
||||
Tests case where an ordered cache entry is populated after a long time. Any
|
||||
new threads attempting to access that entry should sleep until the response
|
||||
is ready.
|
||||
"""
|
||||
cache = OrderedResponseCache()
|
||||
|
||||
def populate_cache():
|
||||
time.sleep(2)
|
||||
cache.update_cache(15, "vwxyz")
|
||||
|
||||
cache.check_cache(15) # shouldn't block
|
||||
t = threading.Thread(target=populate_cache, args=())
|
||||
t.start()
|
||||
# Should block until other thread populates cache
|
||||
assert cache.check_cache(15) == "vwxyz"
|
||||
t.join()
|
||||
|
||||
|
||||
def test_ordered_response_cache_cleanup():
|
||||
"""
|
||||
Tests that the cleanup method of ordered cache works as expected, in
|
||||
particular that all entries <= the passed ID are cleared from the cache.
|
||||
"""
|
||||
cache = OrderedResponseCache()
|
||||
|
||||
for i in range(1, 21):
|
||||
assert cache.check_cache(i) is None
|
||||
cache.update_cache(i, str(i))
|
||||
|
||||
assert len(cache.cache) == 20
|
||||
for i in range(1, 21):
|
||||
assert cache.check_cache(i) == str(i)
|
||||
|
||||
# Expected: clean up all entries up to and including entry 10
|
||||
cache.cleanup(10)
|
||||
assert len(cache.cache) == 10
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
# Attempting to access value that has already been cleaned up
|
||||
cache.check_cache(10)
|
||||
|
||||
for i in range(21, 31):
|
||||
# Check that more entries can be inserted
|
||||
assert cache.check_cache(i) is None
|
||||
cache.update_cache(i, str(i))
|
||||
|
||||
# Cleanup everything
|
||||
cache.cleanup(30)
|
||||
assert len(cache.cache) == 0
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
cache.check_cache(30)
|
||||
|
||||
# Cleanup requests received out of order are tolerated
|
||||
cache.cleanup(27)
|
||||
cache.cleanup(23)
|
||||
|
||||
|
||||
def test_response_cache_update_while_waiting():
|
||||
"""
|
||||
Tests that an error is thrown when a cache entry is updated with the
|
||||
response for a different request than what was originally being
|
||||
checked for.
|
||||
"""
|
||||
# Error when awaiting cache to update, but entry is cleaned up
|
||||
cache = ResponseCache()
|
||||
assert cache.check_cache(16, 123) is None
|
||||
|
||||
def cleanup_cache():
|
||||
time.sleep(2)
|
||||
cache.check_cache(16, 124)
|
||||
cache.update_cache(16, 124, "asdf")
|
||||
|
||||
t = threading.Thread(target=cleanup_cache, args=())
|
||||
t.start()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
cache.check_cache(16, 123)
|
||||
t.join()
|
||||
|
||||
|
||||
def test_ordered_response_cache_cleanup_while_waiting():
|
||||
"""
|
||||
Tests that an error is thrown when an ordered cache entry is updated with
|
||||
the response for a different request than what was originally being
|
||||
checked for.
|
||||
"""
|
||||
# Error when awaiting cache to update, but entry is cleaned up
|
||||
cache = OrderedResponseCache()
|
||||
assert cache.check_cache(123) is None
|
||||
|
||||
def cleanup_cache():
|
||||
time.sleep(2)
|
||||
cache.cleanup(123)
|
||||
|
||||
t = threading.Thread(target=cleanup_cache, args=())
|
||||
t.start()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
cache.check_cache(123)
|
||||
t.join()
|
||||
|
||||
|
||||
def test_response_cache_cleanup():
|
||||
"""
|
||||
Checks that the response cache replaces old entries for a given thread
|
||||
with new entries as they come in, instead of creating new entries
|
||||
(possibly wasting memory on unneeded entries)
|
||||
"""
|
||||
# Check that the response cache cleans up previous entries for a given
|
||||
# thread properly.
|
||||
cache = ResponseCache()
|
||||
cache.check_cache(16, 123)
|
||||
cache.update_cache(16, 123, "Some response")
|
||||
assert len(cache.cache) == 1
|
||||
|
||||
cache.check_cache(16, 124)
|
||||
cache.update_cache(16, 124, "Second response")
|
||||
assert len(cache.cache) == 1 # Should reuse entry for thread 16
|
||||
assert cache.check_cache(16, 124) == "Second response"
|
||||
|
||||
|
||||
def test_response_cache_invalidate():
|
||||
"""
|
||||
Check that ordered response cache invalidate works as expected
|
||||
"""
|
||||
cache = OrderedResponseCache()
|
||||
e = RuntimeError("SomeError")
|
||||
# No pending entries, cache should be valid
|
||||
assert not cache.invalidate(e)
|
||||
# No entry for 123 yet
|
||||
assert cache.check_cache(123) is None
|
||||
# this should invalidate the entry for 123
|
||||
assert cache.invalidate(e)
|
||||
assert cache.check_cache(123) == e
|
||||
assert cache.invalidate(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# This version string is incremented to indicate breaking changes in the
|
||||
# protocol that require upgrading the client version.
|
||||
CURRENT_PROTOCOL_VERSION = "2021-08-26"
|
||||
CURRENT_PROTOCOL_VERSION = "2021-09-02"
|
||||
|
||||
|
||||
class _ClientContext:
|
||||
|
|
|
@ -13,10 +13,12 @@ from ray.util.inspect import is_cython, is_function_or_method
|
|||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -25,6 +27,21 @@ logger = logging.getLogger(__name__)
|
|||
# number of simultaneous in-flight requests.
|
||||
INT32_MAX = (2**31) - 1
|
||||
|
||||
# gRPC status codes that the client shouldn't attempt to recover from
|
||||
# Resource exhausted: Server is low on resources, or has hit the max number
|
||||
# of client connections
|
||||
# Invalid argument: Reserved for application errors
|
||||
# Not found: Set if the client is attempting to reconnect to a session that
|
||||
# does not exist
|
||||
# Failed precondition: Reserverd for application errors
|
||||
# Aborted: Set when an error is serialized into the details of the context,
|
||||
# signals that error should be deserialized on the client side
|
||||
GRPC_UNRECOVERABLE_ERRORS = (grpc.StatusCode.RESOURCE_EXHAUSTED,
|
||||
grpc.StatusCode.INVALID_ARGUMENT,
|
||||
grpc.StatusCode.NOT_FOUND,
|
||||
grpc.StatusCode.FAILED_PRECONDITION,
|
||||
grpc.StatusCode.ABORTED)
|
||||
|
||||
# TODO: Instead of just making the max message size large, the right thing to
|
||||
# do is to split up the bytes representation of serialized data into multiple
|
||||
# messages and reconstruct them on either end. That said, since clients are
|
||||
|
@ -385,6 +402,13 @@ class ClientServerHandle:
|
|||
logs_servicer: ray_client_pb2_grpc.RayletLogStreamerServicer
|
||||
grpc_server: grpc.Server
|
||||
|
||||
def stop(self, grace: int) -> None:
|
||||
# The data servicer might be sleeping while waiting for clients to
|
||||
# reconnect. Signal that they no longer have to sleep and can exit
|
||||
# immediately, since the RPC server is stopped.
|
||||
self.grpc_server.stop(grace)
|
||||
self.data_servicer.stopped.set()
|
||||
|
||||
# Add a hook for all the cases that previously
|
||||
# expected simply a gRPC server
|
||||
def __getattr__(self, attr):
|
||||
|
@ -402,3 +426,241 @@ def _get_client_id_from_context(context: Any) -> str:
|
|||
logger.error("Client connecting with no client_id")
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
return client_id
|
||||
|
||||
|
||||
def _propagate_error_in_context(e: Exception, context: Any) -> bool:
|
||||
"""
|
||||
Encode an error into the context of an RPC response. Returns True
|
||||
if the error can be recovered from, false otherwise
|
||||
"""
|
||||
try:
|
||||
if isinstance(e, grpc.RpcError):
|
||||
# RPC error, propagate directly by copying details into context
|
||||
context.set_code(e.code())
|
||||
context.set_details(e.details())
|
||||
return e.code() not in GRPC_UNRECOVERABLE_ERRORS
|
||||
except Exception:
|
||||
# Extra precaution -- if encoding the RPC directly fails fallback
|
||||
# to treating it as a regular error
|
||||
pass
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details(str(e))
|
||||
return False
|
||||
|
||||
|
||||
def _id_is_newer(id1: int, id2: int) -> bool:
|
||||
"""
|
||||
We should only replace cache entries with the responses for newer IDs.
|
||||
Most of the time newer IDs will be the ones with higher value, except when
|
||||
the req_id counter rolls over. We check for this case by checking the
|
||||
distance between the two IDs. If the distance is significant, then it's
|
||||
likely that the req_id counter rolled over, and the smaller id should
|
||||
still be used to replace the one in cache.
|
||||
"""
|
||||
diff = abs(id2 - id1)
|
||||
if diff > (INT32_MAX // 2):
|
||||
# Rollover likely occurred. In this case the smaller ID is newer
|
||||
return id1 < id2
|
||||
return id1 > id2
|
||||
|
||||
|
||||
class ResponseCache:
|
||||
"""
|
||||
Cache for blocking method calls. Needed to prevent retried requests from
|
||||
being applied multiple times on the server, for example when the client
|
||||
disconnects. This is used to cache requests/responses sent through
|
||||
unary-unary RPCs to the RayletServicer.
|
||||
|
||||
Note that no clean up logic is used, the last response for each thread
|
||||
will always be remembered, so at most the cache will hold N entries,
|
||||
where N is the number of threads on the client side. This relies on the
|
||||
assumption that a thread will not make a new blocking request until it has
|
||||
received a response for a previous one, at which point it's safe to
|
||||
overwrite the old response.
|
||||
|
||||
The high level logic is:
|
||||
|
||||
1. Before making a call, check the cache for the current thread.
|
||||
2. If present in the cache, check the request id of the cached
|
||||
response.
|
||||
a. If it matches the current request_id, then the request has been
|
||||
received before and we shouldn't re-attempt the logic. Wait for
|
||||
the response to become available in the cache, and then return it
|
||||
b. If it doesn't match, then this is a new request and we can
|
||||
proceed with calling the real stub. While the response is still
|
||||
being generated, temporarily keep (req_id, None) in the cache.
|
||||
Once the call is finished, update the cache entry with the
|
||||
new (req_id, response) pair. Notify other threads that may
|
||||
have been waiting for the response to be prepared.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.cv = threading.Condition()
|
||||
self.cache: Dict[int, Tuple[int, Any]] = {}
|
||||
|
||||
def check_cache(self, thread_id: int, request_id: int) -> Optional[Any]:
|
||||
"""
|
||||
Check the cache for a given thread, and see if the entry in the cache
|
||||
matches the current request_id. Returns None if the request_id has
|
||||
not been seen yet, otherwise returns the cached result.
|
||||
|
||||
Throws an error if the placeholder in the cache doesn't match the
|
||||
request_id -- this means that a new request evicted the old value in
|
||||
the cache, and that the RPC for `request_id` is redundant and the
|
||||
result can be discarded, i.e.:
|
||||
|
||||
1. Request A is sent (A1)
|
||||
2. Channel disconnects
|
||||
3. Request A is resent (A2)
|
||||
4. A1 is received
|
||||
5. A2 is received, waits for A1 to finish
|
||||
6. A1 finishes and is sent back to client
|
||||
7. Request B is sent
|
||||
8. Request B overwrites cache entry
|
||||
9. A2 wakes up extremely late, but cache is now invalid
|
||||
|
||||
In practice this is VERY unlikely to happen, but the error can at
|
||||
least serve as a sanity check or catch invalid request id's.
|
||||
"""
|
||||
with self.cv:
|
||||
if thread_id in self.cache:
|
||||
cached_request_id, cached_resp = self.cache[thread_id]
|
||||
if cached_request_id == request_id:
|
||||
while cached_resp is None:
|
||||
# The call was started, but the response hasn't yet
|
||||
# been added to the cache. Let go of the lock and
|
||||
# wait until the response is ready.
|
||||
self.cv.wait()
|
||||
cached_request_id, cached_resp = self.cache[thread_id]
|
||||
if cached_request_id != request_id:
|
||||
raise RuntimeError(
|
||||
"Cached response doesn't match the id of the "
|
||||
"original request. This might happen if this "
|
||||
"request was received out of order. The "
|
||||
"result of the caller is no longer needed. "
|
||||
f"({request_id} != {cached_request_id})")
|
||||
return cached_resp
|
||||
if not _id_is_newer(request_id, cached_request_id):
|
||||
raise RuntimeError(
|
||||
"Attempting to replace newer cache entry with older "
|
||||
"one. This might happen if this request was received "
|
||||
"out of order. The result of the caller is no "
|
||||
f"longer needed. ({request_id} != {cached_request_id}")
|
||||
self.cache[thread_id] = (request_id, None)
|
||||
return None
|
||||
|
||||
def update_cache(self, thread_id: int, request_id: int,
|
||||
response: Any) -> None:
|
||||
"""
|
||||
Inserts `response` into the cache for `request_id`.
|
||||
"""
|
||||
with self.cv:
|
||||
cached_request_id, cached_resp = self.cache[thread_id]
|
||||
if cached_request_id != request_id or cached_resp is not None:
|
||||
# The cache was overwritten by a newer requester between
|
||||
# our call to check_cache and our call to update it.
|
||||
# This can't happen if the assumption that the cached requests
|
||||
# are all blocking on the client side, so if you encounter
|
||||
# this, check if any async requests are being cached.
|
||||
raise RuntimeError(
|
||||
"Attempting to update the cache, but placeholder's "
|
||||
"do not match the current request_id. This might happen "
|
||||
"if this request was received out of order. The result "
|
||||
f"of the caller is no longer needed. ({request_id} != "
|
||||
f"{cached_request_id})")
|
||||
self.cache[thread_id] = (request_id, response)
|
||||
self.cv.notify_all()
|
||||
|
||||
|
||||
class OrderedResponseCache:
|
||||
"""
|
||||
Cache for streaming RPCs, i.e. the DataServicer. Relies on explicit
|
||||
ack's from the client to determine when it can clean up cache entries.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.last_received = 0
|
||||
self.cv = threading.Condition()
|
||||
self.cache: Dict[int, Any] = OrderedDict()
|
||||
|
||||
def check_cache(self, req_id: int) -> Optional[Any]:
|
||||
"""
|
||||
Check the cache for a given thread, and see if the entry in the cache
|
||||
matches the current request_id. Returns None if the request_id has
|
||||
not been seen yet, otherwise returns the cached result.
|
||||
"""
|
||||
with self.cv:
|
||||
if _id_is_newer(self.last_received,
|
||||
req_id) or self.last_received == req_id:
|
||||
# Request is for an id that has already been cleared from
|
||||
# cache/acknowledged.
|
||||
raise RuntimeError(
|
||||
"Attempting to accesss a cache entry that has already "
|
||||
"cleaned up. The client has already acknowledged "
|
||||
f"receiving this response. ({req_id}, "
|
||||
f"{self.last_received})")
|
||||
if req_id in self.cache:
|
||||
cached_resp = self.cache[req_id]
|
||||
while cached_resp is None:
|
||||
# The call was started, but the response hasn't yet been
|
||||
# added to the cache. Let go of the lock and wait until
|
||||
# the response is ready
|
||||
self.cv.wait()
|
||||
if req_id not in self.cache:
|
||||
raise RuntimeError(
|
||||
"Cache entry was removed. This likely means that "
|
||||
"the result of this call is no longer needed.")
|
||||
cached_resp = self.cache[req_id]
|
||||
return cached_resp
|
||||
self.cache[req_id] = None
|
||||
return None
|
||||
|
||||
def update_cache(self, req_id: int, resp: Any) -> None:
|
||||
"""
|
||||
Inserts `response` into the cache for `request_id`.
|
||||
"""
|
||||
with self.cv:
|
||||
self.cv.notify_all()
|
||||
if req_id not in self.cache:
|
||||
raise RuntimeError(
|
||||
"Attempting to update the cache, but placeholder is "
|
||||
"missing. This might happen on a redundant call to "
|
||||
f"update_cache. ({req_id})")
|
||||
self.cache[req_id] = resp
|
||||
|
||||
def invalidate(self, e: Exception) -> bool:
|
||||
"""
|
||||
Invalidate any partially populated cache entries, replacing their
|
||||
placeholders with the passed in exception. Useful to prevent a thread
|
||||
from waiting indefinitely on a failed call.
|
||||
|
||||
Returns True if the cache contains an error, False otherwise
|
||||
"""
|
||||
with self.cv:
|
||||
invalid = False
|
||||
for req_id in self.cache:
|
||||
if self.cache[req_id] is None:
|
||||
self.cache[req_id] = e
|
||||
if isinstance(self.cache[req_id], Exception):
|
||||
invalid = True
|
||||
self.cv.notify_all()
|
||||
return invalid
|
||||
|
||||
def cleanup(self, last_received: int) -> None:
|
||||
"""
|
||||
Cleanup all of the cached requests up to last_received. Assumes that
|
||||
the cache entries were inserted in ascending order.
|
||||
"""
|
||||
with self.cv:
|
||||
if _id_is_newer(last_received, self.last_received):
|
||||
self.last_received = last_received
|
||||
to_remove = []
|
||||
for req_id in self.cache:
|
||||
if _id_is_newer(last_received,
|
||||
req_id) or last_received == req_id:
|
||||
to_remove.append(req_id)
|
||||
else:
|
||||
break
|
||||
for req_id in to_remove:
|
||||
del self.cache[req_id]
|
||||
self.cv.notify_all()
|
||||
|
|
|
@ -6,6 +6,7 @@ import queue
|
|||
import threading
|
||||
import grpc
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Dict, TYPE_CHECKING, Optional, Union
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
|
@ -20,6 +21,9 @@ logger = logging.getLogger(__name__)
|
|||
ResponseCallable = Callable[[Union[ray_client_pb2.DataResponse, Exception]],
|
||||
None]
|
||||
|
||||
# Send an acknowledge on every 32nd response received
|
||||
ACKNOWLEDGE_BATCH_SIZE = 32
|
||||
|
||||
|
||||
class DataClient:
|
||||
def __init__(self, client_worker: "Worker", client_id: str,
|
||||
|
@ -36,9 +40,13 @@ class DataClient:
|
|||
self._metadata = metadata
|
||||
self.data_thread = self._start_datathread()
|
||||
|
||||
# Track outstanding requests to resend in case of disconnection
|
||||
self.outstanding_requests: Dict[int, Any] = OrderedDict()
|
||||
|
||||
# Serialize access to all mutable internal states: self.request_queue,
|
||||
# self.ready_data, self.asyncio_waiting_data,
|
||||
# self._in_shutdown, self._req_id and calling self._next_id().
|
||||
# self._in_shutdown, self._req_id, self.outstanding_requests and
|
||||
# calling self._next_id()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
# Waiting for response or shutdown.
|
||||
|
@ -51,6 +59,8 @@ class DataClient:
|
|||
self.asyncio_waiting_data: Dict[int, ResponseCallable] = {}
|
||||
self._in_shutdown = False
|
||||
self._req_id = 0
|
||||
self._last_exception = None
|
||||
self._acknowledge_counter = 0
|
||||
|
||||
self.data_thread.start()
|
||||
|
||||
|
@ -73,17 +83,32 @@ class DataClient:
|
|||
daemon=True)
|
||||
|
||||
def _data_main(self) -> None:
|
||||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(
|
||||
self.client_worker.channel)
|
||||
resp_stream = stub.Datapath(
|
||||
iter(self.request_queue.get, None),
|
||||
metadata=self._metadata,
|
||||
wait_for_ready=True)
|
||||
reconnecting = False
|
||||
try:
|
||||
for response in resp_stream:
|
||||
self._process_response(response)
|
||||
except grpc.RpcError as e:
|
||||
self._process_rpc_error(e)
|
||||
while not self.client_worker._in_shutdown:
|
||||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(
|
||||
self.client_worker.channel)
|
||||
metadata = self._metadata + \
|
||||
[("reconnecting", str(reconnecting))]
|
||||
resp_stream = stub.Datapath(
|
||||
iter(self.request_queue.get, None),
|
||||
metadata=metadata,
|
||||
wait_for_ready=True)
|
||||
try:
|
||||
for response in resp_stream:
|
||||
self._process_response(response)
|
||||
return
|
||||
except grpc.RpcError as e:
|
||||
reconnecting = self._can_reconnect(e)
|
||||
if not reconnecting:
|
||||
self._last_exception = e
|
||||
return
|
||||
self._reconnect_channel()
|
||||
except Exception as e:
|
||||
self._last_exception = e
|
||||
finally:
|
||||
logger.info("Shutting down data channel")
|
||||
self._shutdown()
|
||||
|
||||
def _process_response(self, response: Any) -> None:
|
||||
"""
|
||||
|
@ -101,54 +126,114 @@ class DataClient:
|
|||
# is accessed without holding self.lock. Holding the
|
||||
# lock shouldn't be necessary either.
|
||||
callback = self.asyncio_waiting_data.pop(response.req_id)
|
||||
callback(response)
|
||||
if callback:
|
||||
callback(response)
|
||||
except Exception:
|
||||
logger.exception("Callback error:")
|
||||
with self.lock:
|
||||
# Update outstanding requests
|
||||
if response.req_id in self.outstanding_requests:
|
||||
del self.outstanding_requests[response.req_id]
|
||||
# Acknowledge response
|
||||
self._acknowledge(response.req_id)
|
||||
else:
|
||||
with self.lock:
|
||||
self.ready_data[response.req_id] = response
|
||||
self.cv.notify_all()
|
||||
|
||||
def _process_rpc_error(self, e: grpc.RpcError):
|
||||
def _can_reconnect(self, e: grpc.RpcError) -> bool:
|
||||
"""
|
||||
Processes RPC errors that occur while reading from data stream.
|
||||
Returns True if the error can be recovered from, False otherwise.
|
||||
"""
|
||||
self._shutdown(e)
|
||||
if not self.client_worker._can_reconnect(e):
|
||||
logger.info("Unrecoverable error in data channel.")
|
||||
logger.debug(e)
|
||||
return False
|
||||
logger.debug("Recoverable error in data channel.")
|
||||
logger.debug(e)
|
||||
return True
|
||||
|
||||
if e.code() == grpc.StatusCode.CANCELLED:
|
||||
# Gracefully shutting down
|
||||
logger.info("Cancelling data channel")
|
||||
elif e.code() in (grpc.StatusCode.UNAVAILABLE,
|
||||
grpc.StatusCode.RESOURCE_EXHAUSTED):
|
||||
# TODO(barakmich): The server may have
|
||||
# dropped. In theory, we can retry, as per
|
||||
# https://grpc.github.io/grpc/core/md_doc_statuscodes.html but
|
||||
# in practice we may need to think about the correct semantics
|
||||
# here.
|
||||
logger.info("Server disconnected from data channel")
|
||||
else:
|
||||
logger.exception("Got Error from data channel -- shutting down:")
|
||||
|
||||
def _shutdown(self, e: grpc.RpcError) -> None:
|
||||
def _shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the data channel
|
||||
"""
|
||||
with self.lock:
|
||||
self._in_shutdown = True
|
||||
self._last_exception = e
|
||||
self.cv.notify_all()
|
||||
|
||||
callbacks = self.asyncio_waiting_data.values()
|
||||
self.asyncio_waiting_data = {}
|
||||
|
||||
# Abort async requests with the error.
|
||||
err = ConnectionError("Failed during this or a previous request. "
|
||||
f"Exception that broke the connection: {e}")
|
||||
if self._last_exception:
|
||||
# Abort async requests with the error.
|
||||
err = ConnectionError(
|
||||
"Failed during this or a previous request. Exception that "
|
||||
f"broke the connection: {self._last_exception}")
|
||||
else:
|
||||
err = ConnectionError(
|
||||
"Request cannot be fulfilled because the data client has "
|
||||
"disconnected.")
|
||||
for callback in callbacks:
|
||||
callback(err)
|
||||
if callback:
|
||||
callback(err)
|
||||
# Since self._in_shutdown is set to True, no new item
|
||||
# will be added to self.asyncio_waiting_data
|
||||
|
||||
def _acknowledge(self, req_id: int) -> None:
|
||||
"""
|
||||
Puts an acknowledge request on the request queue periodically.
|
||||
Lock should be held before calling this. Used when an async or
|
||||
blocking response is received.
|
||||
"""
|
||||
if not self.client_worker._reconnect_enabled:
|
||||
# Skip ACKs if reconnect isn't enabled
|
||||
return
|
||||
assert self.lock.locked()
|
||||
self._acknowledge_counter += 1
|
||||
if self._acknowledge_counter % ACKNOWLEDGE_BATCH_SIZE == 0:
|
||||
self.request_queue.put(
|
||||
ray_client_pb2.DataRequest(
|
||||
acknowledge=ray_client_pb2.AcknowledgeRequest(
|
||||
req_id=req_id)))
|
||||
|
||||
def _reconnect_channel(self) -> None:
|
||||
"""
|
||||
Attempts to reconnect the gRPC channel and resend outstanding
|
||||
requests. First, the server is pinged to see if the current channel
|
||||
still works. If the ping fails, then the current channel is closed
|
||||
and replaced with a new one.
|
||||
|
||||
Once a working channel is available, a new request queue is made
|
||||
and filled with any outstanding requests to be resent to the server.
|
||||
"""
|
||||
try:
|
||||
# Ping the server to see if the current channel is reuseable, for
|
||||
# example if gRPC reconnected the channel on its own or if the
|
||||
# RPC error was transient and the channel is still open
|
||||
ping_succeeded = self.client_worker.ping_server(timeout=5)
|
||||
except grpc.RpcError:
|
||||
ping_succeeded = False
|
||||
|
||||
if not ping_succeeded:
|
||||
# Ping failed, try refreshing the data channel
|
||||
logger.warning(
|
||||
"Encountered connection issues in the data channel. "
|
||||
"Attempting to reconnect.")
|
||||
try:
|
||||
self.client_worker._connect_channel(reconnecting=True)
|
||||
except ConnectionError:
|
||||
logger.warning("Failed to reconnect the data channel")
|
||||
raise
|
||||
logger.info("Reconnection succeeded!")
|
||||
|
||||
# Recreate the request queue, and resend outstanding requests
|
||||
with self.lock:
|
||||
self.request_queue = queue.Queue()
|
||||
for request in self.outstanding_requests.values():
|
||||
# Resend outstanding requests
|
||||
self.request_queue.put(request)
|
||||
|
||||
def close(self) -> None:
|
||||
thread = None
|
||||
with self.lock:
|
||||
|
@ -157,6 +242,12 @@ class DataClient:
|
|||
self.cv.notify_all()
|
||||
# Add sentinel to terminate streaming RPC.
|
||||
if self.request_queue is not None:
|
||||
# Intentional shutdown, tell server it can clean up the
|
||||
# connection immediately and ignore the reconnect grace period.
|
||||
cleanup_request = ray_client_pb2.DataRequest(
|
||||
connection_cleanup=ray_client_pb2.ConnectionCleanupRequest(
|
||||
))
|
||||
self.request_queue.put(cleanup_request)
|
||||
self.request_queue.put(None)
|
||||
if self.data_thread is not None:
|
||||
thread = self.data_thread
|
||||
|
@ -171,6 +262,7 @@ class DataClient:
|
|||
req_id = self._next_id()
|
||||
req.req_id = req_id
|
||||
self.request_queue.put(req)
|
||||
self.outstanding_requests[req_id] = req
|
||||
|
||||
self.cv.wait_for(
|
||||
lambda: req_id in self.ready_data or self._in_shutdown)
|
||||
|
@ -178,6 +270,8 @@ class DataClient:
|
|||
|
||||
data = self.ready_data[req_id]
|
||||
del self.ready_data[req_id]
|
||||
del self.outstanding_requests[req_id]
|
||||
self._acknowledge(req_id)
|
||||
|
||||
return data
|
||||
|
||||
|
@ -188,8 +282,8 @@ class DataClient:
|
|||
self._check_shutdown()
|
||||
req_id = self._next_id()
|
||||
req.req_id = req_id
|
||||
if callback:
|
||||
self.asyncio_waiting_data[req_id] = callback
|
||||
self.asyncio_waiting_data[req_id] = callback
|
||||
self.outstanding_requests[req_id] = req
|
||||
self.request_queue.put(req)
|
||||
|
||||
# Must hold self.lock when calling this function.
|
||||
|
@ -210,11 +304,10 @@ class DataClient:
|
|||
|
||||
self.lock.acquire()
|
||||
|
||||
last_exception = getattr(self, "_last_exception", None)
|
||||
if last_exception is not None:
|
||||
if self._last_exception is not None:
|
||||
msg = ("Request can't be sent because the Ray client has already "
|
||||
"been disconnected due to an error. Last exception: "
|
||||
f"{last_exception}")
|
||||
f"{self._last_exception}")
|
||||
else:
|
||||
msg = ("Request can't be sent because the Ray client has already "
|
||||
"been disconnected.")
|
||||
|
|
|
@ -5,6 +5,7 @@ import sys
|
|||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import grpc
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
@ -12,6 +13,8 @@ from typing import TYPE_CHECKING
|
|||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
from ray.util.debug import log_once
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.client.worker import Worker
|
||||
|
||||
|
@ -35,42 +38,62 @@ class LogstreamClient:
|
|||
self.request_queue = queue.Queue()
|
||||
self.log_thread = self._start_logthread()
|
||||
self.log_thread.start()
|
||||
self.last_req = None
|
||||
|
||||
def _start_logthread(self) -> threading.Thread:
|
||||
return threading.Thread(target=self._log_main, args=(), daemon=True)
|
||||
|
||||
def _log_main(self) -> None:
|
||||
stub = ray_client_pb2_grpc.RayletLogStreamerStub(
|
||||
self.client_worker.channel)
|
||||
log_stream = stub.Logstream(
|
||||
iter(self.request_queue.get, None), metadata=self._metadata)
|
||||
try:
|
||||
for record in log_stream:
|
||||
if record.level < 0:
|
||||
self.stdstream(level=record.level, msg=record.msg)
|
||||
self.log(level=record.level, msg=record.msg)
|
||||
except grpc.RpcError as e:
|
||||
self._process_rpc_error(e)
|
||||
reconnecting = False
|
||||
while not self.client_worker._in_shutdown:
|
||||
if reconnecting:
|
||||
# Refresh queue and retry last request
|
||||
self.request_queue = queue.Queue()
|
||||
if self.last_req:
|
||||
self.request_queue.put(self.last_req)
|
||||
stub = ray_client_pb2_grpc.RayletLogStreamerStub(
|
||||
self.client_worker.channel)
|
||||
try:
|
||||
log_stream = stub.Logstream(
|
||||
iter(self.request_queue.get, None),
|
||||
metadata=self._metadata)
|
||||
except ValueError:
|
||||
# Trying to use the stub on a cancelled channel will raise
|
||||
# ValueError. This should only happen when the data client
|
||||
# is attempting to reset the connection -- sleep and try
|
||||
# again.
|
||||
time.sleep(.5)
|
||||
continue
|
||||
try:
|
||||
for record in log_stream:
|
||||
if record.level < 0:
|
||||
self.stdstream(level=record.level, msg=record.msg)
|
||||
self.log(level=record.level, msg=record.msg)
|
||||
return
|
||||
except grpc.RpcError as e:
|
||||
reconnecting = self._process_rpc_error(e)
|
||||
if not reconnecting:
|
||||
return
|
||||
|
||||
def _process_rpc_error(self, e: grpc.RpcError):
|
||||
def _process_rpc_error(self, e: grpc.RpcError) -> bool:
|
||||
"""
|
||||
Processes RPC errors that occur while reading from data stream.
|
||||
Returns True if the error can be recovered from, False otherwise.
|
||||
"""
|
||||
if e.code() == grpc.StatusCode.CANCELLED:
|
||||
# Graceful shutdown. We've cancelled our own connection.
|
||||
logger.info("Cancelling logs channel")
|
||||
elif e.code() in (grpc.StatusCode.UNAVAILABLE,
|
||||
grpc.StatusCode.RESOURCE_EXHAUSTED):
|
||||
# TODO(barakmich): The server may have
|
||||
# dropped. In theory, we can retry, as per
|
||||
# https://grpc.github.io/grpc/core/md_doc_statuscodes.html but
|
||||
# in practice we may need to think about the correct semantics
|
||||
# here.
|
||||
logger.info("Server disconnected from logs channel")
|
||||
else:
|
||||
# Some other, unhandled, gRPC error
|
||||
logger.exception(
|
||||
f"Got Error from logger channel -- shutting down: {e}")
|
||||
if self.client_worker._can_reconnect(e):
|
||||
if log_once("lost_reconnect_logs"):
|
||||
logger.warning(
|
||||
"Log channel is reconnecting. Logs produced while "
|
||||
"the connection was down can be found on the head "
|
||||
"node of the cluster in "
|
||||
"`ray_client_server_[port].out`")
|
||||
logger.info("Log channel dropped, retrying.")
|
||||
time.sleep(.5)
|
||||
return True
|
||||
logger.info("Shutting down log channel.")
|
||||
if not self.client_worker._in_shutdown:
|
||||
logger.exception("Unexpected exception:")
|
||||
return False
|
||||
|
||||
def log(self, level: int, msg: str):
|
||||
"""Log the message from the log stream.
|
||||
|
@ -99,6 +122,7 @@ class LogstreamClient:
|
|||
req.enabled = True
|
||||
req.loglevel = level
|
||||
self.request_queue.put(req)
|
||||
self.last_req = req
|
||||
|
||||
def close(self) -> None:
|
||||
self.request_queue.put(None)
|
||||
|
@ -109,3 +133,4 @@ class LogstreamClient:
|
|||
req = ray_client_pb2.LogSettingsRequest()
|
||||
req.enabled = False
|
||||
self.request_queue.put(req)
|
||||
self.last_req = req
|
||||
|
|
|
@ -1,15 +1,20 @@
|
|||
from collections import defaultdict
|
||||
import threading
|
||||
import ray
|
||||
import logging
|
||||
import grpc
|
||||
from queue import Queue
|
||||
import sys
|
||||
|
||||
from typing import Any, Iterator, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, Iterator, TYPE_CHECKING, Union
|
||||
from threading import Lock, Thread
|
||||
import time
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.util.client.common import CLIENT_SERVER_MAX_THREADS
|
||||
from ray.util.client.common import (CLIENT_SERVER_MAX_THREADS,
|
||||
_propagate_error_in_context,
|
||||
OrderedResponseCache)
|
||||
from ray.util.client import CURRENT_PROTOCOL_VERSION
|
||||
from ray.util.debug import log_once
|
||||
from ray._private.client_mode_hook import disable_client_hook
|
||||
|
@ -22,6 +27,37 @@ logger = logging.getLogger(__name__)
|
|||
QUEUE_JOIN_SECONDS = 10
|
||||
|
||||
|
||||
def _get_reconnecting_from_context(context: Any) -> bool:
|
||||
"""
|
||||
Get `reconnecting` from gRPC metadata, or False if missing.
|
||||
"""
|
||||
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||
val = metadata.get("reconnecting")
|
||||
if val is None or val not in ("True", "False"):
|
||||
logger.error(
|
||||
f'Client connecting with invalid value for "reconnecting": {val}, '
|
||||
"This may be because you have a mismatched client and server "
|
||||
"version.")
|
||||
return False
|
||||
return val == "True"
|
||||
|
||||
|
||||
def _should_cache(req: ray_client_pb2.DataRequest) -> bool:
|
||||
"""
|
||||
Returns True if the response should to the given request should be cached,
|
||||
false otherwise. At the moment the only requests we do not cache are:
|
||||
- asynchronous gets: These arrive out of order. Skipping caching here
|
||||
is fine, since repeating an async get is idempotent
|
||||
- acks: Repeating acks is idempotent
|
||||
- clean up requests: Also idempotent, and client has likely already
|
||||
wrapped up the data connection by this point.
|
||||
"""
|
||||
req_type = req.WhichOneof("type")
|
||||
if req_type == "get" and req.get.asynchronous:
|
||||
return False
|
||||
return req_type not in ("acknowledge", "connection_cleanup")
|
||||
|
||||
|
||||
def fill_queue(
|
||||
grpc_input_generator: Iterator[ray_client_pb2.DataRequest],
|
||||
output_queue:
|
||||
|
@ -46,15 +82,30 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
self.basic_service = basic_service
|
||||
self.clients_lock = Lock()
|
||||
self.num_clients = 0 # guarded by self.clients_lock
|
||||
# dictionary mapping client_id's to the last time they connected
|
||||
self.client_last_seen: Dict[str, float] = {}
|
||||
# dictionary mapping client_id's to their reconnect grace periods
|
||||
self.reconnect_grace_periods: Dict[str, float] = {}
|
||||
# dictionary mapping client_id's to their response cache
|
||||
self.response_caches: Dict[str, OrderedResponseCache] = defaultdict(
|
||||
OrderedResponseCache)
|
||||
# stopped event, useful for signals that the server is shut down
|
||||
self.stopped = threading.Event()
|
||||
|
||||
def Datapath(self, request_iterator, context):
|
||||
start_time = time.time()
|
||||
# set to True if client shuts down gracefully
|
||||
cleanup_requested = False
|
||||
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||
client_id = metadata.get("client_id")
|
||||
if client_id is None:
|
||||
logger.error("Client connecting with no client_id")
|
||||
return
|
||||
logger.debug(f"New data connection from client {client_id}: ")
|
||||
accepted_connection = self._init(client_id, context)
|
||||
accepted_connection = self._init(client_id, context, start_time)
|
||||
response_cache = self.response_caches[client_id]
|
||||
# Set to False if client requests a reconnect grace period of 0
|
||||
reconnect_enabled = True
|
||||
if not accepted_connection:
|
||||
return
|
||||
try:
|
||||
|
@ -76,11 +127,26 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
continue
|
||||
|
||||
assert isinstance(req, ray_client_pb2.DataRequest)
|
||||
if _should_cache(req) and reconnect_enabled:
|
||||
cached_resp = response_cache.check_cache(req.req_id)
|
||||
if isinstance(cached_resp, Exception):
|
||||
# Cache state is invalid, raise exception
|
||||
raise cached_resp
|
||||
if cached_resp is not None:
|
||||
yield cached_resp
|
||||
continue
|
||||
|
||||
resp = None
|
||||
req_type = req.WhichOneof("type")
|
||||
if req_type == "init":
|
||||
resp_init = self.basic_service.Init(req.init)
|
||||
resp = ray_client_pb2.DataResponse(init=resp_init, )
|
||||
with self.clients_lock:
|
||||
self.reconnect_grace_periods[client_id] = \
|
||||
req.init.reconnect_grace_period
|
||||
if req.init.reconnect_grace_period == 0:
|
||||
reconnect_enabled = False
|
||||
|
||||
elif req_type == "get":
|
||||
if req.get.asynchronous:
|
||||
get_resp = self.basic_service._async_get_object(
|
||||
|
@ -115,23 +181,69 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
req.prep_runtime_env)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
prep_runtime_env=resp_prep)
|
||||
elif req_type == "connection_cleanup":
|
||||
cleanup_requested = True
|
||||
cleanup_resp = ray_client_pb2.ConnectionCleanupResponse()
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
connection_cleanup=cleanup_resp)
|
||||
elif req_type == "acknowledge":
|
||||
# Clean up acknowledged cache entries
|
||||
response_cache.cleanup(req.acknowledge.req_id)
|
||||
continue
|
||||
else:
|
||||
raise Exception(f"Unreachable code: Request type "
|
||||
f"{req_type} not handled in Datapath")
|
||||
resp.req_id = req.req_id
|
||||
if _should_cache(req) and reconnect_enabled:
|
||||
response_cache.update_cache(req.req_id, resp)
|
||||
yield resp
|
||||
except grpc.RpcError as e:
|
||||
logger.debug(f"Closing data channel: {e}")
|
||||
except Exception as e:
|
||||
logger.exception("Error in data channel:")
|
||||
recoverable = _propagate_error_in_context(e, context)
|
||||
invalid_cache = response_cache.invalidate(e)
|
||||
if not recoverable or invalid_cache:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
# Connection isn't recoverable, skip cleanup
|
||||
cleanup_requested = True
|
||||
finally:
|
||||
logger.debug(f"Lost data connection from client {client_id}")
|
||||
self.basic_service.release_all(client_id)
|
||||
queue_filler_thread.join(QUEUE_JOIN_SECONDS)
|
||||
if queue_filler_thread.is_alive():
|
||||
logger.error(
|
||||
"Queue filler thread failed to join before timeout: {}".
|
||||
format(QUEUE_JOIN_SECONDS))
|
||||
cleanup_delay = self.reconnect_grace_periods.get(client_id)
|
||||
if not cleanup_requested and cleanup_delay is not None:
|
||||
logger.debug("Cleanup wasn't requested, delaying cleanup by"
|
||||
f"{cleanup_delay} seconds.")
|
||||
# Delay cleanup, since client may attempt a reconnect
|
||||
# Wait on the "stopped" event in case the grpc server is
|
||||
# stopped and we can clean up earlier.
|
||||
self.stopped.wait(timeout=cleanup_delay)
|
||||
else:
|
||||
logger.debug("Cleanup was requested, cleaning up immediately.")
|
||||
with self.clients_lock:
|
||||
# Could fail before client accounting happens
|
||||
if client_id not in self.client_last_seen:
|
||||
logger.debug("Connection already cleaned up.")
|
||||
# Some other connection has already cleaned up this
|
||||
# this client's session. This can happen if the client
|
||||
# reconnects and then gracefully shut's down immediately.
|
||||
return
|
||||
last_seen = self.client_last_seen[client_id]
|
||||
if last_seen > start_time:
|
||||
# The client successfully reconnected and updated
|
||||
# last seen some time during the grace period
|
||||
logger.debug("Client reconnected, skipping cleanup")
|
||||
return
|
||||
# Either the client shut down gracefully, or the client
|
||||
# failed to reconnect within the grace period. Clean up
|
||||
# the connection.
|
||||
self.basic_service.release_all(client_id)
|
||||
del self.client_last_seen[client_id]
|
||||
if client_id in self.reconnect_grace_periods:
|
||||
del self.reconnect_grace_periods[client_id]
|
||||
if client_id in self.response_caches:
|
||||
del self.response_caches[client_id]
|
||||
self.num_clients -= 1
|
||||
logger.debug(f"Removed clients. {self.num_clients}")
|
||||
|
||||
|
@ -142,12 +254,13 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
logger.debug("Shutting down ray.")
|
||||
ray.shutdown()
|
||||
|
||||
def _init(self, client_id: str, context: Any):
|
||||
def _init(self, client_id: str, context: Any, start_time: float):
|
||||
"""
|
||||
Checks if resources allow for another client.
|
||||
Returns a boolean indicating if initialization was successful.
|
||||
"""
|
||||
with self.clients_lock:
|
||||
reconnecting = _get_reconnecting_from_context(context)
|
||||
threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
|
||||
if self.num_clients >= threshold:
|
||||
logger.warning(
|
||||
|
@ -162,10 +275,21 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
f"(currently set to {CLIENT_SERVER_MAX_THREADS}).")
|
||||
context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
|
||||
return False
|
||||
self.num_clients += 1
|
||||
logger.debug(f"Accepted data connection from {client_id}. "
|
||||
f"Total clients: {self.num_clients}")
|
||||
|
||||
if reconnecting and client_id not in self.client_last_seen:
|
||||
# Client took too long to reconnect, session has been
|
||||
# cleaned up.
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
context.set_details(
|
||||
"Attempted to reconnect to a session that has already "
|
||||
"been cleaned up.")
|
||||
return False
|
||||
if client_id in self.client_last_seen:
|
||||
logger.debug(f"Client {client_id} has reconnected.")
|
||||
else:
|
||||
self.num_clients += 1
|
||||
logger.debug(f"Accepted data connection from {client_id}. "
|
||||
f"Total clients: {self.num_clients}")
|
||||
self.client_last_seen[client_id] = start_time
|
||||
return True
|
||||
|
||||
def _build_connection_response(self):
|
||||
|
|
|
@ -7,7 +7,7 @@ from itertools import chain
|
|||
import json
|
||||
import socket
|
||||
import sys
|
||||
from threading import Lock, Thread, RLock
|
||||
from threading import Event, Lock, Thread, RLock
|
||||
import time
|
||||
import traceback
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
@ -21,9 +21,10 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
|||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
import ray.core.generated.runtime_env_agent_pb2 as runtime_env_agent_pb2
|
||||
import ray.core.generated.runtime_env_agent_pb2_grpc as runtime_env_agent_pb2_grpc # noqa: E501
|
||||
from ray.util.client.common import (_get_client_id_from_context,
|
||||
ClientServerHandle,
|
||||
CLIENT_SERVER_MAX_THREADS, GRPC_OPTIONS)
|
||||
from ray.util.client.common import (
|
||||
_get_client_id_from_context, ClientServerHandle, CLIENT_SERVER_MAX_THREADS,
|
||||
GRPC_OPTIONS, _propagate_error_in_context)
|
||||
from ray.util.client.server.dataservicer import _get_reconnecting_from_context
|
||||
from ray._private.client_mode_hook import disable_client_hook
|
||||
from ray._private.parameter import RayParams
|
||||
from ray._private.runtime_env import RuntimeEnvContext
|
||||
|
@ -384,10 +385,14 @@ class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
|
||||
stub = ray_client_pb2_grpc.RayletDriverStub(chan)
|
||||
try:
|
||||
return getattr(stub, method)(
|
||||
request, metadata=[("client_id", client_id)])
|
||||
except Exception:
|
||||
metadata = [("client_id", client_id)]
|
||||
if context:
|
||||
metadata = context.invocation_metadata()
|
||||
return getattr(stub, method)(request, metadata=metadata)
|
||||
except Exception as e:
|
||||
# Error while proxying -- propagate the error's context to user
|
||||
logger.exception(f"Proxying call to {method} failed!")
|
||||
_propagate_error_in_context(e, context)
|
||||
|
||||
def _has_channel_for_request(self, context):
|
||||
client_id = _get_client_id_from_context(context)
|
||||
|
@ -531,7 +536,8 @@ def prepare_runtime_init_req(init_request: ray_client_pb2.DataRequest
|
|||
new_job_config = ray_client_server_env_prep(job_config)
|
||||
modified_init_req = ray_client_pb2.InitRequest(
|
||||
job_config=pickle.dumps(new_job_config),
|
||||
ray_init_kwargs=init_request.init.ray_init_kwargs)
|
||||
ray_init_kwargs=init_request.init.ray_init_kwargs,
|
||||
reconnect_grace_period=init_request.init.reconnect_grace_period)
|
||||
|
||||
init_request.init.CopyFrom(modified_init_req)
|
||||
return (init_request, new_job_config)
|
||||
|
@ -540,8 +546,12 @@ def prepare_runtime_init_req(init_request: ray_client_pb2.DataRequest
|
|||
class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
def __init__(self, proxy_manager: ProxyManager):
|
||||
self.num_clients = 0
|
||||
# dictionary mapping client_id's to the last time they connected
|
||||
self.clients_last_seen: Dict[str, float] = {}
|
||||
self.reconnect_grace_periods: Dict[str, float] = {}
|
||||
self.clients_lock = Lock()
|
||||
self.proxy_manager = proxy_manager
|
||||
self.stopped = Event()
|
||||
|
||||
def modify_connection_info_resp(self,
|
||||
init_resp: ray_client_pb2.DataResponse
|
||||
|
@ -560,61 +570,117 @@ class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
return modified_resp
|
||||
|
||||
def Datapath(self, request_iterator, context):
|
||||
cleanup_requested = False
|
||||
start_time = time.time()
|
||||
client_id = _get_client_id_from_context(context)
|
||||
if client_id == "":
|
||||
return
|
||||
reconnecting = _get_reconnecting_from_context(context)
|
||||
|
||||
# Create Placeholder *before* reading the first request.
|
||||
server = self.proxy_manager.create_specific_server(client_id)
|
||||
try:
|
||||
if reconnecting:
|
||||
with self.clients_lock:
|
||||
if client_id not in self.clients_last_seen:
|
||||
# Client took too long to reconnect, session has already
|
||||
# been cleaned up
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
context.set_details(
|
||||
"Attempted to reconnect a session that has already "
|
||||
"been cleaned up")
|
||||
return
|
||||
self.clients_last_seen[client_id] = start_time
|
||||
server = self.proxy_manager._get_server_for_client(client_id)
|
||||
channel = self.proxy_manager.get_channel(client_id)
|
||||
# iterator doesn't need modification on reconnect
|
||||
new_iter = request_iterator
|
||||
else:
|
||||
# Create Placeholder *before* reading the first request.
|
||||
server = self.proxy_manager.create_specific_server(client_id)
|
||||
with self.clients_lock:
|
||||
self.clients_last_seen[client_id] = start_time
|
||||
self.num_clients += 1
|
||||
|
||||
logger.info(f"New data connection from client {client_id}: ")
|
||||
init_req = next(request_iterator)
|
||||
try:
|
||||
modified_init_req, job_config = prepare_runtime_init_req(
|
||||
init_req)
|
||||
if not self.proxy_manager.start_specific_server(
|
||||
client_id, job_config):
|
||||
logger.error(
|
||||
f"Server startup failed for client: {client_id}, "
|
||||
f"using JobConfig: {job_config}!")
|
||||
raise RuntimeError(
|
||||
"Starting Ray client server failed. See "
|
||||
f"ray_client_server_{server.port}.err for detailed "
|
||||
"logs.")
|
||||
channel = self.proxy_manager.get_channel(client_id)
|
||||
if channel is None:
|
||||
logger.error(f"Channel not found for {client_id}")
|
||||
raise RuntimeError(
|
||||
"Proxy failed to Connect to backend! Check "
|
||||
"`ray_client_server.err` and "
|
||||
f"`ray_client_server_{server.port}.err` on the head "
|
||||
"node of the cluster for the relevant logs. "
|
||||
"By default these are located at "
|
||||
"/tmp/ray/session_latest/logs.")
|
||||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
|
||||
except Exception:
|
||||
init_resp = ray_client_pb2.DataResponse(
|
||||
init=ray_client_pb2.InitResponse(
|
||||
ok=False, msg=traceback.format_exc()))
|
||||
init_resp.req_id = init_req.req_id
|
||||
yield init_resp
|
||||
return None
|
||||
try:
|
||||
if not reconnecting:
|
||||
logger.info(f"New data connection from client {client_id}: ")
|
||||
init_req = next(request_iterator)
|
||||
with self.clients_lock:
|
||||
self.reconnect_grace_periods[client_id] = \
|
||||
init_req.init.reconnect_grace_period
|
||||
try:
|
||||
modified_init_req, job_config = prepare_runtime_init_req(
|
||||
init_req)
|
||||
if not self.proxy_manager.start_specific_server(
|
||||
client_id, job_config):
|
||||
logger.error(
|
||||
f"Server startup failed for client: {client_id}, "
|
||||
f"using JobConfig: {job_config}!")
|
||||
raise RuntimeError(
|
||||
"Starting Ray client server failed. See "
|
||||
f"ray_client_server_{server.port}.err for "
|
||||
"detailed logs.")
|
||||
channel = self.proxy_manager.get_channel(client_id)
|
||||
if channel is None:
|
||||
logger.error(f"Channel not found for {client_id}")
|
||||
raise RuntimeError(
|
||||
"Proxy failed to Connect to backend! Check "
|
||||
"`ray_client_server.err` and "
|
||||
f"`ray_client_server_{server.port}.err` on the "
|
||||
"head node of the cluster for the relevant logs. "
|
||||
"By default these are located at "
|
||||
"/tmp/ray/session_latest/logs.")
|
||||
except Exception:
|
||||
init_resp = ray_client_pb2.DataResponse(
|
||||
init=ray_client_pb2.InitResponse(
|
||||
ok=False, msg=traceback.format_exc()))
|
||||
init_resp.req_id = init_req.req_id
|
||||
yield init_resp
|
||||
return None
|
||||
|
||||
new_iter = chain([modified_init_req], request_iterator)
|
||||
resp_stream = stub.Datapath(
|
||||
new_iter, metadata=[("client_id", client_id)])
|
||||
new_iter = chain([modified_init_req], request_iterator)
|
||||
|
||||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
|
||||
metadata = [("client_id", client_id), ("reconnecting",
|
||||
str(reconnecting))]
|
||||
resp_stream = stub.Datapath(new_iter, metadata=metadata)
|
||||
for resp in resp_stream:
|
||||
resp_type = resp.WhichOneof("type")
|
||||
if resp_type == "connection_cleanup":
|
||||
# Specific server is skipping cleanup, proxier should too
|
||||
cleanup_requested = True
|
||||
yield self.modify_connection_info_resp(resp)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception("Proxying Datapath failed!")
|
||||
# Propogate error through context
|
||||
recoverable = _propagate_error_in_context(e, context)
|
||||
if not recoverable:
|
||||
# Client shouldn't attempt to recover, clean up connection
|
||||
cleanup_requested = True
|
||||
finally:
|
||||
server.set_result(None)
|
||||
cleanup_delay = self.reconnect_grace_periods.get(client_id)
|
||||
if not cleanup_requested and cleanup_delay is not None:
|
||||
# Delay cleanup, since client may attempt a reconnect
|
||||
# Wait on stopped event in case the server closes and we
|
||||
# can clean up earlier
|
||||
self.stopped.wait(timeout=cleanup_delay)
|
||||
with self.clients_lock:
|
||||
if client_id not in self.clients_last_seen:
|
||||
logger.info(f"{client_id} not found. Skipping clean up.")
|
||||
# Connection has already been cleaned up
|
||||
return
|
||||
last_seen = self.clients_last_seen[client_id]
|
||||
logger.info(
|
||||
f"{client_id} last started stream at {last_seen}. Current "
|
||||
f"stream started at {start_time}.")
|
||||
if last_seen > start_time:
|
||||
logger.info("Client reconnected. Skipping cleanup.")
|
||||
# Client has reconnected, don't clean up
|
||||
return
|
||||
logger.debug(f"Client detached: {client_id}")
|
||||
self.num_clients -= 1
|
||||
del self.clients_last_seen[client_id]
|
||||
if client_id in self.reconnect_grace_periods:
|
||||
del self.reconnect_grace_periods[client_id]
|
||||
server.set_result(None)
|
||||
|
||||
|
||||
class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
|
||||
|
@ -641,7 +707,10 @@ class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
|
|||
time.sleep(LOGSTREAM_RETRY_INTERVAL_SEC)
|
||||
|
||||
if channel is None:
|
||||
context.set_code(grpc.StatusCode.UNAVAILABLE)
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
context.set_details(
|
||||
"Logstream proxy failed to connect. Channel for client "
|
||||
f"{client_id} not found.")
|
||||
return None
|
||||
|
||||
stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel)
|
||||
|
|
|
@ -3,6 +3,7 @@ from concurrent import futures
|
|||
import grpc
|
||||
import base64
|
||||
from collections import defaultdict
|
||||
import functools
|
||||
import queue
|
||||
import pickle
|
||||
|
||||
|
@ -22,7 +23,7 @@ import time
|
|||
import inspect
|
||||
import json
|
||||
from ray.util.client.common import (ClientServerHandle, GRPC_OPTIONS,
|
||||
CLIENT_SERVER_MAX_THREADS)
|
||||
CLIENT_SERVER_MAX_THREADS, ResponseCache)
|
||||
from ray.util.client.server.proxier import serve_proxier
|
||||
from ray.util.client.server.server_pickler import convert_from_arg
|
||||
from ray.util.client.server.server_pickler import dumps_from_server
|
||||
|
@ -40,6 +41,53 @@ TIMEOUT_FOR_SPECIFIC_SERVER_S = env_integer("TIMEOUT_FOR_SPECIFIC_SERVER_S",
|
|||
30)
|
||||
|
||||
|
||||
def _use_response_cache(func):
|
||||
"""
|
||||
Decorator for gRPC stubs. Before calling the real stubs, checks if there's
|
||||
an existing entry in the caches. If there is, then return the cached
|
||||
entry. Otherwise, call the real function and use the real cache
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, request, context):
|
||||
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||
expected_ids = ("client_id", "thread_id", "req_id")
|
||||
if any(i not in metadata for i in expected_ids):
|
||||
# Missing IDs, skip caching and call underlying stub directly
|
||||
return func(self, request, context)
|
||||
|
||||
# Get relevant IDs to check cache
|
||||
client_id = metadata["client_id"]
|
||||
thread_id = metadata["thread_id"]
|
||||
req_id = int(metadata["req_id"])
|
||||
|
||||
# Check if response already cached
|
||||
response_cache = self.response_caches[client_id]
|
||||
cached_entry = response_cache.check_cache(thread_id, req_id)
|
||||
if cached_entry is not None:
|
||||
if isinstance(cached_entry, Exception):
|
||||
# Original call errored, propogate error
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details(str(cached_entry))
|
||||
raise cached_entry
|
||||
return cached_entry
|
||||
|
||||
try:
|
||||
# Response wasn't cached, call underlying stub and cache result
|
||||
resp = func(self, request, context)
|
||||
except Exception as e:
|
||||
# Unexpected error in underlying stub -- update cache and
|
||||
# propagate to user through context
|
||||
response_cache.update_cache(thread_id, req_id, e)
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details(str(e))
|
||||
raise
|
||||
response_cache.update_cache(thread_id, req_id, resp)
|
||||
return resp
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
def __init__(self, ray_connect_handler: Callable):
|
||||
"""Construct a raylet service
|
||||
|
@ -60,6 +108,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
self.named_actors = set()
|
||||
self.state_lock = threading.Lock()
|
||||
self.ray_connect_handler = ray_connect_handler
|
||||
self.response_caches: Dict[str, ResponseCache] = defaultdict(
|
||||
ResponseCache)
|
||||
|
||||
def Init(self, request: ray_client_pb2.InitRequest,
|
||||
context=None) -> ray_client_pb2.InitResponse:
|
||||
|
@ -105,6 +155,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
f"current one {current_job_config.runtime_env.uris}")
|
||||
return ray_client_pb2.InitResponse(ok=True)
|
||||
|
||||
@_use_response_cache
|
||||
def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse:
|
||||
with disable_client_hook():
|
||||
already_exists = ray.experimental.internal_kv._internal_kv_put(
|
||||
|
@ -116,6 +167,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
value = ray.experimental.internal_kv._internal_kv_get(request.key)
|
||||
return ray_client_pb2.KVGetResponse(value=value)
|
||||
|
||||
@_use_response_cache
|
||||
def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse:
|
||||
with disable_client_hook():
|
||||
ray.experimental.internal_kv._internal_kv_del(request.key)
|
||||
|
@ -235,6 +287,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
del self.object_refs[client_id]
|
||||
if client_id in self.client_side_ref_map:
|
||||
del self.client_side_ref_map[client_id]
|
||||
if client_id in self.response_caches:
|
||||
del self.response_caches[client_id]
|
||||
logger.debug(f"Released all {count} objects for client {client_id}")
|
||||
|
||||
def _release_actors(self, client_id):
|
||||
|
@ -252,6 +306,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
|
||||
logger.debug(f"Released all {count} actors for client: {client_id}")
|
||||
|
||||
@_use_response_cache
|
||||
def Terminate(self, req, context=None):
|
||||
if req.WhichOneof("terminate_type") == "task_object":
|
||||
try:
|
||||
|
@ -428,6 +483,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
ready_object_ids=ready_object_ids,
|
||||
remaining_object_ids=remaining_object_ids)
|
||||
|
||||
@_use_response_cache
|
||||
def Schedule(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
logger.debug(
|
||||
|
|
|
@ -5,7 +5,9 @@ to the server.
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
|
@ -21,11 +23,13 @@ from ray.cloudpickle.compat import pickle
|
|||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.exceptions import GetTimeoutError
|
||||
from ray.ray_constants import DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD
|
||||
from ray.util.client.client_pickler import (convert_to_arg, dumps_from_client,
|
||||
loads_from_server)
|
||||
from ray.util.client.common import (ClientActorClass, ClientActorHandle,
|
||||
ClientActorRef, ClientObjectRef,
|
||||
ClientRemoteFunc, ClientStub, GRPC_OPTIONS)
|
||||
ClientRemoteFunc, ClientStub, GRPC_OPTIONS,
|
||||
GRPC_UNRECOVERABLE_ERRORS, INT32_MAX)
|
||||
from ray.util.client.dataclient import DataClient
|
||||
from ray.util.client.logsclient import LogstreamClient
|
||||
from ray.util.debug import log_once
|
||||
|
@ -108,7 +112,22 @@ class Worker:
|
|||
else:
|
||||
self._credentials = None
|
||||
|
||||
self._reconnect_grace_period = DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD
|
||||
if "RAY_CLIENT_RECONNECT_GRACE_PERIOD" in os.environ:
|
||||
# Use value in environment variable if available
|
||||
self._reconnect_grace_period = \
|
||||
int(os.environ["RAY_CLIENT_RECONNECT_GRACE_PERIOD"])
|
||||
# Disable retries if grace period is set to 0
|
||||
self._reconnect_enabled = self._reconnect_grace_period != 0
|
||||
|
||||
# Set to True when the connection cannot be recovered and reconnect
|
||||
# attempts should be stopped
|
||||
self._in_shutdown = False
|
||||
# Set to True after initial connection succeeds
|
||||
self._has_connected = False
|
||||
|
||||
self._connect_channel()
|
||||
self._has_connected = True
|
||||
|
||||
# Initialize the streams to finish protocol negotiation.
|
||||
self.data_client = DataClient(self, self._client_id, self.metadata)
|
||||
|
@ -124,10 +143,19 @@ class Worker:
|
|||
self.total_num_tasks_scheduled = 0
|
||||
self.total_outbound_message_size_bytes = 0
|
||||
|
||||
def _connect_channel(self) -> None:
|
||||
# Used to create unique IDs for RPCs to the RayletServicer
|
||||
self._req_id_lock = threading.Lock()
|
||||
self._req_id = 0
|
||||
|
||||
def _connect_channel(self, reconnecting=False) -> None:
|
||||
"""
|
||||
Attempts to connect to the server specified by conn_str.
|
||||
Attempts to connect to the server specified by conn_str. If
|
||||
reconnecting after an RPC error, cleans up the old channel and
|
||||
continues to attempt to connect until the grace period is over.
|
||||
"""
|
||||
if self.channel is not None:
|
||||
self.channel.unsubscribe(self._on_channel_state_change)
|
||||
self.channel.close()
|
||||
|
||||
if self._secure:
|
||||
if self._credentials is not None:
|
||||
|
@ -144,11 +172,21 @@ class Worker:
|
|||
|
||||
# Retry the connection until the channel responds to something
|
||||
# looking like a gRPC connection, though it may be a proxy.
|
||||
start_time = time.time()
|
||||
conn_attempts = 0
|
||||
timeout = INITIAL_TIMEOUT_SEC
|
||||
service_ready = False
|
||||
while conn_attempts < max(self._connection_retries, 1):
|
||||
while conn_attempts < max(self._connection_retries, 1) or reconnecting:
|
||||
conn_attempts += 1
|
||||
if self._in_shutdown:
|
||||
# User manually closed the worker before connection finished
|
||||
break
|
||||
elapsed_time = time.time() - start_time
|
||||
if reconnecting and elapsed_time > self._reconnect_grace_period:
|
||||
self._in_shutdown = True
|
||||
raise ConnectionError(
|
||||
"Failed to reconnect within the reconnection grace period "
|
||||
f"({self._reconnect_grace_period}s)")
|
||||
try:
|
||||
# Let gRPC wait for us to see if the channel becomes ready.
|
||||
# If it throws, we couldn't connect.
|
||||
|
@ -176,12 +214,17 @@ class Worker:
|
|||
# Fallthrough, backoff, and retry at the top of the loop
|
||||
logger.info("Waiting for Ray to become ready on the server, "
|
||||
f"retry in {timeout}s...")
|
||||
timeout = backoff(timeout)
|
||||
if not reconnecting:
|
||||
# Don't increase backoff when trying to reconnect --
|
||||
# we already know the server exists, attempt to reconnect
|
||||
# as soon as we can
|
||||
timeout = backoff(timeout)
|
||||
|
||||
# If we made it through the loop without service_ready
|
||||
# it means we've used up our retries and
|
||||
# should error back to the user.
|
||||
if not service_ready:
|
||||
self._in_shutdown = True
|
||||
if log_once("ray_client_security_groups"):
|
||||
warnings.warn(
|
||||
"Ray Client connection timed out. Ensure that "
|
||||
|
@ -191,6 +234,73 @@ class Worker:
|
|||
"more information.")
|
||||
raise ConnectionError("ray client connection timeout")
|
||||
|
||||
def _can_reconnect(self, e: grpc.RpcError) -> bool:
|
||||
"""
|
||||
Returns True if the RPC error can be recovered from and a retry is
|
||||
appropriate, false otherwise.
|
||||
"""
|
||||
if not self._reconnect_enabled:
|
||||
return False
|
||||
if self._in_shutdown:
|
||||
# Channel is being shutdown, don't try to reconnect
|
||||
return False
|
||||
if e.code() in GRPC_UNRECOVERABLE_ERRORS:
|
||||
# Unrecoverable error -- These errors are specifically raised
|
||||
# by the server's application logic
|
||||
return False
|
||||
if e.code() == grpc.StatusCode.INTERNAL:
|
||||
details = e.details()
|
||||
if details == "Exception serializing request!":
|
||||
# The client failed tried to send a bad request (for example,
|
||||
# passing "None" instead of a valid grpc message). Don't
|
||||
# try to reconnect/retry.
|
||||
return False
|
||||
# All other errors can be treated as recoverable
|
||||
return True
|
||||
|
||||
def _call_stub(self, stub_name: str, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Calls the stub specified by stub_name (Schedule, WaitObject, etc...).
|
||||
If a recoverable error occurrs while calling the stub, attempts to
|
||||
retry the RPC.
|
||||
"""
|
||||
while not self._in_shutdown:
|
||||
try:
|
||||
return getattr(self.server, stub_name)(*args, **kwargs)
|
||||
except grpc.RpcError as e:
|
||||
if self._can_reconnect(e):
|
||||
time.sleep(.5)
|
||||
continue
|
||||
raise
|
||||
except ValueError:
|
||||
# Trying to use the stub on a cancelled channel will raise
|
||||
# ValueError. This should only happen when the data client
|
||||
# is attempting to reset the connection -- sleep and try
|
||||
# again.
|
||||
time.sleep(.5)
|
||||
continue
|
||||
raise ConnectionError("Client is shutting down.")
|
||||
|
||||
def _add_ids_to_metadata(self, metadata: Any):
|
||||
"""
|
||||
Adds a unique req_id and the current thread's identifier to the
|
||||
metadata. These values are useful for preventing mutating operations
|
||||
from being replayed on the server side in the event that the client
|
||||
must retry a requsest.
|
||||
Args:
|
||||
metadata - the gRPC metadata to append the IDs to
|
||||
"""
|
||||
if not self._reconnect_enabled:
|
||||
# IDs not needed if the reconnects are disabled
|
||||
return metadata
|
||||
thread_id = str(threading.get_ident())
|
||||
with self._req_id_lock:
|
||||
self._req_id += 1
|
||||
if self._req_id > INT32_MAX:
|
||||
self._req_id = 1
|
||||
req_id = str(self._req_id)
|
||||
return metadata + [("thread_id", thread_id), ("req_id", req_id)]
|
||||
|
||||
def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity):
|
||||
logger.debug(f"client gRPC channel state change: {conn_state}")
|
||||
self._conn_state = conn_state
|
||||
|
@ -324,7 +434,7 @@ class Worker:
|
|||
"client_id": self._client_id,
|
||||
}
|
||||
req = ray_client_pb2.WaitRequest(**data)
|
||||
resp = self.server.WaitObject(req, metadata=self.metadata)
|
||||
resp = self._call_stub("WaitObject", req, metadata=self.metadata)
|
||||
if not resp.valid:
|
||||
# TODO(ameer): improve error/exceptions messages.
|
||||
raise Exception("Client Wait request failed. Reference invalid?")
|
||||
|
@ -350,11 +460,11 @@ class Worker:
|
|||
self, task: ray_client_pb2.ClientTask) -> List[bytes]:
|
||||
logger.debug("Scheduling %s" % task)
|
||||
task.client_id = self._client_id
|
||||
metadata = self._add_ids_to_metadata(self.metadata)
|
||||
try:
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
ticket = self._call_stub("Schedule", task, metadata=metadata)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e)
|
||||
|
||||
if not ticket.valid:
|
||||
try:
|
||||
raise cloudpickle.loads(ticket.error)
|
||||
|
@ -412,6 +522,7 @@ class Worker:
|
|||
self.reference_count[id] += 1
|
||||
|
||||
def close(self):
|
||||
self._in_shutdown = True
|
||||
self.data_client.close()
|
||||
self.log_client.close()
|
||||
if self.channel:
|
||||
|
@ -441,7 +552,8 @@ class Worker:
|
|||
try:
|
||||
term = ray_client_pb2.TerminateRequest(actor=term_actor)
|
||||
term.client_id = self._client_id
|
||||
self.server.Terminate(term, metadata=self.metadata)
|
||||
metadata = self._add_ids_to_metadata(self.metadata)
|
||||
self._call_stub("Terminate", term, metadata=metadata)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e)
|
||||
|
||||
|
@ -458,14 +570,18 @@ class Worker:
|
|||
try:
|
||||
term = ray_client_pb2.TerminateRequest(task_object=term_object)
|
||||
term.client_id = self._client_id
|
||||
self.server.Terminate(term, metadata=self.metadata)
|
||||
metadata = self._add_ids_to_metadata(self.metadata)
|
||||
self._call_stub("Terminate", term, metadata=metadata)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e)
|
||||
|
||||
def get_cluster_info(self, type: ray_client_pb2.ClusterInfoType.TypeEnum):
|
||||
def get_cluster_info(self,
|
||||
type: ray_client_pb2.ClusterInfoType.TypeEnum,
|
||||
timeout: Optional[float] = None):
|
||||
req = ray_client_pb2.ClusterInfoRequest()
|
||||
req.type = type
|
||||
resp = self.server.ClusterInfo(req, metadata=self.metadata)
|
||||
resp = self.server.ClusterInfo(
|
||||
req, timeout=timeout, metadata=self.metadata)
|
||||
if resp.WhichOneof("response_type") == "resource_table":
|
||||
# translate from a proto map to a python dict
|
||||
output_dict = {k: v for k, v in resp.resource_table.table.items()}
|
||||
|
@ -476,35 +592,37 @@ class Worker:
|
|||
|
||||
def internal_kv_get(self, key: bytes) -> bytes:
|
||||
req = ray_client_pb2.KVGetRequest(key=key)
|
||||
resp = self.server.KVGet(req, metadata=self.metadata)
|
||||
resp = self._call_stub("KVGet", req, metadata=self.metadata)
|
||||
return resp.value
|
||||
|
||||
def internal_kv_exists(self, key: bytes) -> bytes:
|
||||
req = ray_client_pb2.KVGetRequest(key=key)
|
||||
resp = self.server.KVGet(req, metadata=self.metadata)
|
||||
resp = self._call_stub("KVGet", req, metadata=self.metadata)
|
||||
return resp.value
|
||||
|
||||
def internal_kv_put(self, key: bytes, value: bytes,
|
||||
overwrite: bool) -> bool:
|
||||
req = ray_client_pb2.KVPutRequest(
|
||||
key=key, value=value, overwrite=overwrite)
|
||||
resp = self.server.KVPut(req, metadata=self.metadata)
|
||||
metadata = self._add_ids_to_metadata(self.metadata)
|
||||
resp = self._call_stub("KVPut", req, metadata=metadata)
|
||||
return resp.already_exists
|
||||
|
||||
def internal_kv_del(self, key: bytes) -> None:
|
||||
req = ray_client_pb2.KVDelRequest(key=key)
|
||||
self.server.KVDel(req, metadata=self.metadata)
|
||||
metadata = self._add_ids_to_metadata(self.metadata)
|
||||
self._call_stub("KVDel", req, metadata=metadata)
|
||||
|
||||
def internal_kv_list(self, prefix: bytes) -> bytes:
|
||||
req = ray_client_pb2.KVListRequest(prefix=prefix)
|
||||
return self.server.KVList(req, metadata=self.metadata).keys
|
||||
return self._call_stub("KVList", req, metadata=self.metadata).keys
|
||||
|
||||
def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]:
|
||||
req = ray_client_pb2.ClientListNamedActorsRequest(
|
||||
all_namespaces=all_namespaces)
|
||||
return json.loads(
|
||||
self.server.ListNamedActors(req,
|
||||
metadata=self.metadata).actors_json)
|
||||
self._call_stub("ListNamedActors", req,
|
||||
metadata=self.metadata).actors_json)
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
if self.server is not None:
|
||||
|
@ -512,7 +630,7 @@ class Worker:
|
|||
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
|
||||
return False
|
||||
|
||||
def ping_server(self) -> bool:
|
||||
def ping_server(self, timeout=None) -> bool:
|
||||
"""Simple health check.
|
||||
|
||||
Piggybacks the IS_INITIALIZED call to check if the server provides
|
||||
|
@ -520,12 +638,13 @@ class Worker:
|
|||
"""
|
||||
if self.server is not None:
|
||||
logger.debug("Pinging server.")
|
||||
result = self.get_cluster_info(ray_client_pb2.ClusterInfoType.PING)
|
||||
result = self.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.PING, timeout=timeout)
|
||||
return result is not None
|
||||
return False
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return self._conn_state == grpc.ChannelConnectivity.READY
|
||||
return not self._in_shutdown and self._has_connected
|
||||
|
||||
def _server_init(self,
|
||||
job_config: JobConfig,
|
||||
|
@ -549,7 +668,8 @@ class Worker:
|
|||
response = self.data_client.Init(
|
||||
ray_client_pb2.InitRequest(
|
||||
job_config=serialized_job_config,
|
||||
ray_init_kwargs=json.dumps(ray_init_kwargs)))
|
||||
ray_init_kwargs=json.dumps(ray_init_kwargs),
|
||||
reconnect_grace_period=self._reconnect_grace_period))
|
||||
if not response.ok:
|
||||
raise ConnectionAbortedError(
|
||||
f"Initialization failure from server:\n{response.msg}")
|
||||
|
|
|
@ -259,6 +259,7 @@ message InitRequest {
|
|||
// job_config of ray.init
|
||||
bytes job_config = 1;
|
||||
string ray_init_kwargs = 2;
|
||||
int32 reconnect_grace_period = 3;
|
||||
}
|
||||
|
||||
message InitResponse {
|
||||
|
@ -339,6 +340,20 @@ message ConnectionInfoResponse {
|
|||
string protocol_version = 5;
|
||||
}
|
||||
|
||||
message ConnectionCleanupRequest {
|
||||
// Explicitly request that connection is cleaned up for graceful shutdown
|
||||
}
|
||||
|
||||
message ConnectionCleanupResponse {
|
||||
// Acknowledge cleanup request
|
||||
}
|
||||
|
||||
message AcknowledgeRequest {
|
||||
// Used to acknowledge that all requests up to the given req_id have been
|
||||
// received
|
||||
int32 req_id = 1;
|
||||
}
|
||||
|
||||
message DataRequest {
|
||||
// An incrementing counter of request IDs on the Datapath,
|
||||
// to match requests with responses asynchronously.
|
||||
|
@ -350,6 +365,8 @@ message DataRequest {
|
|||
ConnectionInfoRequest connection_info = 5;
|
||||
InitRequest init = 6;
|
||||
PrepRuntimeEnvRequest prep_runtime_env = 7;
|
||||
ConnectionCleanupRequest connection_cleanup = 8;
|
||||
AcknowledgeRequest acknowledge = 9;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -363,6 +380,7 @@ message DataResponse {
|
|||
ConnectionInfoResponse connection_info = 5;
|
||||
InitResponse init = 6;
|
||||
PrepRuntimeEnvResponse prep_runtime_env = 7;
|
||||
ConnectionCleanupResponse connection_cleanup = 8;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue