[client] let ray client reconnect on grpc failures (#18329)

* wip

* client tests working again

* extra prints

* start reconnect logic for proxier

* local proxy more wip

* delay cleanup logic working on proxy

* Fix up dataservicer logic

* lint + fix proxy data servicer exit logic

* hmmm

* delay cleanup always in dataservicer

* fix last_seen check

* cancel channel on error

* explicitly request cleanup

* cleanup request fixes

* fix dataclient proxy

* start idempotence logic

* change default channel state

* add backoff logic

* move connection logic back into worker.__init__

* add logic for replay cache case where request was received but response hasn't been fully resolved

* new proto entries for data stream caching

* start replay_cache logic, increase cleanup delay

* hardcode retries

* Let data channel attempt reconnects

* manually reset queue, remove replay_cache logic

* reduce cleanup delay to 5 minutes

* fix local tests

* Remove async cache logic

* retry async requests

* simplify backoff logic

* Fix ray client proto

* Configurable reconnect grace period

* Basic logsclient fix?

* Configure grace through environment variable

* Use stopped event to force faster datapath cleanup

* Better connect+reconnect logic

* fix reconnect_grace_period default

* init fixes for reconnect_grace_period

* cleanup

* fix _get_client_id_from_context call

* add logic for pathological cache cases

* less intrusive data channel error message

* fix tests

* Make stuff less painful to read

* add ordered replay cache for dataservicer, replay cache tests

* fix ordering import, start_reconnect test

* add middleman testing logic

* enforce ordering of dataclient requests

* retry wheels

* grace period through env only, restore test_dataclient_disconnect

* minor fixes

* force rerun

* less intrusive error msgs

* address review

* replay->response cache

* remove unneeded sleep

* typing

* extra response cache test

* fix error msg

* remove TODO

* add _reconnect_channel

* add grace period test

* store thread_id and req_id in metadata

* Revert "store thread_id and req_id in metadata"

This reverts commit 12bc05cc0ceb0b764e2279353ba003fca16c3181.

* Revert "Revert "store thread_id and req_id in metadata""

This reverts commit 67874cf3a207fed49e6070c7e955a640f0094d19.

* fix metadata check

* remove comment

* removed unused cv

* cast back to int

* refactor Datapath for readability

* Revert refactor

This reverts commit f789bad473c953eebabefe7eb6aa891e5b8a8f13.

* fix comment

* merge fixes

* refactor _shutdown

* address reviews

* log errors in both cases

* add comments

* address reviews

* move reconnect test to medium

* Always propogate error to callbacks

* readability

* formatting

* Faster cleanup on uncaught dataservicer errors

* delete tmp file

* offset commit

* rrefactor

* propagate data servicer error message

* Stricter handling/propagation of errors

* remove tmp file

* better docs

* forward reconnecting metadata

* add annotation

* fix invalidate + add test

* fix docstrings and types

* disable retries and caching if reconnect grace period is set to 0

* update comments

* address review, increase ack batch size and skip ack's if reconnect isn't enabled

* Don't terminate data stream on missing reconnecting metadata
This commit is contained in:
Chris K. W 2021-09-17 15:11:00 -07:00 committed by GitHub
parent ffe7108eae
commit 8858489e2f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 1571 additions and 156 deletions

View file

@ -67,6 +67,9 @@ DEFAULT_ACTOR_CREATION_CPU_SPECIFIED = 1
# Default number of return values for each actor method.
DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1
# Wait 30 seconds for client to reconnect after unexpected disconnection
DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD = 30
# If a remote function or actor (or some other export) has serialized size
# greater than this quantity, print an warning.
FUNCTION_SIZE_WARN_THRESHOLD = 10**7

View file

@ -49,12 +49,13 @@ py_test_module_list(
"test_client.py",
"test_client_builder.py",
"test_client_init.py",
"test_client_multi.py",
"test_client_multi.py",
"test_client_proxy.py",
"test_client_server.py",
"test_client_references.py",
"test_client_warnings.py",
"test_client_library_integration.py",
"test_client_reconnect.py",
],
size = "medium",
extra_srcs = SRCS,
@ -138,6 +139,7 @@ py_test_module_list(
"test_dataclient_disconnect.py",
"test_k8s_operator_unit_tests.py",
"test_monitor.py",
"test_response_cache.py",
],
size = "small",
extra_srcs = SRCS,

View file

@ -0,0 +1,413 @@
from concurrent import futures
import contextlib
import os
import threading
import sys
from ray.util.client.common import CLIENT_SERVER_MAX_THREADS, GRPC_OPTIONS
import grpc
import time
import random
import pytest
from typing import Any, Callable, Optional
from unittest.mock import patch
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import ray.util.client.server.server as ray_client_server
import ray
# At a high level, these tests rely on an extra RPC server sitting
# between the client and the real Ray server to inject errors, drop responses
# and drop requests, i.e. at a high level:
# Ray Client <-> Middleman Server <-> Proxy Server
# Type for middleman hooks used to inject errors
Hook = Callable[[Any], None]
class MiddlemanDataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
"""
Forwards all requests to the real data servicer. Useful for injecting
errors between a client and server pair.
"""
def __init__(self, on_response: Optional[Hook] = None):
"""
Args:
on_response: Optional hook to inject errors before sending back a
response
"""
self.stub = None
self.on_response = on_response
def set_channel(self, channel: grpc.Channel) -> None:
self.stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
def Datapath(self, request_iterator, context):
try:
for response in self.stub.Datapath(
request_iterator, metadata=context.invocation_metadata()):
if self.on_response:
self.on_response(response)
yield response
except grpc.RpcError as e:
context.set_code(e.code())
context.set_details(e.details())
class MiddlemanLogServicer(ray_client_pb2_grpc.RayletLogStreamerServicer):
"""
Forwards all requests to the real log servicer. Useful for injecting
errors between a client and server pair.
"""
def __init__(self, on_response: Optional[Hook] = None):
"""
Args:
on_response: Optional hook to inject errors before sending back a
response
"""
self.stub = None
self.on_response = on_response
def set_channel(self, channel: grpc.Channel) -> None:
self.stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel)
def Logstream(self, request_iterator, context):
try:
for response in self.stub.Logstream(
request_iterator, metadata=context.invocation_metadata()):
if self.on_response:
self.on_response(response)
yield response
except grpc.RpcError as e:
context.set_code(e.code())
context.set_details(e.details())
class MiddlemanRayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
"""
Forwards all requests to the raylet driver servicer. Useful for injecting
errors between a client and server pair.
"""
def __init__(self,
on_request: Optional[Hook] = None,
on_response: Optional[Hook] = None):
"""
Args:
on_request: Optional hook to inject errors before forwarding a
request
on_response: Optional hook to inject errors before sending back a
response
"""
self.stub = None
self.on_request = on_request
self.on_response = on_response
def set_channel(self, channel: grpc.Channel) -> None:
self.stub = ray_client_pb2_grpc.RayletDriverStub(channel)
def _call_inner_function(
self, request: Any, context,
method: str) -> Optional[ray_client_pb2_grpc.RayletDriverStub]:
if self.on_request:
self.on_request(request)
try:
response = getattr(self.stub, method)(
request, metadata=context.invocation_metadata())
except grpc.RpcError as e:
context.set_code(e.code())
context.set_details(e.details())
raise
if self.on_response:
self.on_response(response)
return response
def Init(self, request, context=None) -> ray_client_pb2.InitResponse:
return self._call_inner_function(request, context, "Init")
def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse:
return self._call_inner_function(request, context, "KVPut")
def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse:
return self._call_inner_function(request, context, "KVGet")
def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse:
return self._call_inner_function(request, context, "KVDel")
def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse:
return self._call_inner_function(request, context, "KVList")
def KVExists(self, request,
context=None) -> ray_client_pb2.KVExistsResponse:
return self._call_inner_function(request, context, "KVExists")
def ListNamedActors(self, request, context=None
) -> ray_client_pb2.ClientListNamedActorsResponse:
return self._call_inner_function(request, context, "ListNamedActors")
def ClusterInfo(self, request,
context=None) -> ray_client_pb2.ClusterInfoResponse:
return self._call_inner_function(request, context, "ClusterInfo")
def Terminate(self, req, context=None):
return self._call_inner_function(req, context, "Terminate")
def GetObject(self, request, context=None):
return self._call_inner_function(request, context, "GetObject")
def PutObject(self, request: ray_client_pb2.PutRequest,
context=None) -> ray_client_pb2.PutResponse:
return self._call_inner_function(request, context, "PutObject")
def WaitObject(self, request: ray_client_pb2.WaitRequest,
context=None) -> ray_client_pb2.WaitResponse:
return self._call_inner_function(request, context, "WaitObject")
def Schedule(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
return self._call_inner_function(task, context, "Schedule")
class MiddlemanServer:
"""
Helper class that wraps the RPC server that middlemans the connection
between the client and the real ray server. Useful for injecting
errors between a client and server pair.
"""
def __init__(self,
listen_addr: str,
real_addr,
on_log_response: Optional[Hook] = None,
on_data_response: Optional[Hook] = None,
on_task_request: Optional[Hook] = None,
on_task_response: Optional[Hook] = None):
"""
Args:
listen_addr: The address the middleman server will listen on
real_addr: The address of the real ray server
on_log_response: Optional hook to inject errors before sending back
a log response
on_data_response: Optional hook to inject errors before sending
back a data response
on_task_request: Optional hook to inject errors before forwarding
a raylet driver request
on_task_response: Optional hook to inject errors before sending
back a raylet driver response
"""
self.listen_addr = listen_addr
self.real_addr = real_addr
self.server = grpc.server(
futures.ThreadPoolExecutor(max_workers=CLIENT_SERVER_MAX_THREADS),
options=GRPC_OPTIONS)
self.task_servicer = MiddlemanRayletServicer(
on_response=on_task_response, on_request=on_task_request)
self.data_servicer = MiddlemanDataServicer(
on_response=on_data_response)
self.logs_servicer = MiddlemanLogServicer(on_response=on_log_response)
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
self.task_servicer, self.server)
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
self.data_servicer, self.server)
ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(
self.logs_servicer, self.server)
self.server.add_insecure_port(self.listen_addr)
self.channel = None
self.reset_channel()
def reset_channel(self) -> None:
"""
Manually close and reopen the channel to the real ray server. This
simulates a disconnection between the client and the server.
"""
if self.channel:
self.channel.close()
self.channel = grpc.insecure_channel(
self.real_addr, options=GRPC_OPTIONS)
grpc.channel_ready_future(self.channel)
self.task_servicer.set_channel(self.channel)
self.data_servicer.set_channel(self.channel)
self.logs_servicer.set_channel(self.channel)
def start(self) -> None:
self.server.start()
def stop(self, grace: int) -> None:
self.server.stop(grace)
@contextlib.contextmanager
def start_middleman_server(on_log_response=None,
on_data_response=None,
on_task_request=None,
on_task_response=None):
"""
Helper context that starts a middleman server listening on port 10011,
and a ray client server on port 50051.
"""
ray._inside_client_test = True
server = ray_client_server.serve("localhost:50051")
middleman = None
try:
middleman = MiddlemanServer(
listen_addr="localhost:10011",
real_addr="localhost:50051",
on_log_response=on_log_response,
on_data_response=on_data_response,
on_task_request=on_task_response,
on_task_response=on_task_request)
middleman.start()
ray.init("ray://localhost:10011")
yield middleman, server
finally:
ray._inside_client_test = False
ray.util.disconnect()
server.stop(0)
if middleman:
middleman.stop(0)
def test_disconnect_during_get():
"""
Disconnect the proxy and the client in the middle of a long running get
"""
@ray.remote
def slow_result():
time.sleep(20)
return 12345
def disconnect(middleman):
time.sleep(3)
middleman.reset_channel()
with start_middleman_server() as (middleman, _):
disconnect_thread = threading.Thread(
target=disconnect, args=(middleman, ))
disconnect_thread.start()
result = ray.get(slow_result.remote())
assert result == 12345
disconnect_thread.join()
def test_valid_actor_state():
"""
Repeatedly inject errors in the middle of mutating actor calls. Check
at the end that the final state of the actor is consistent with what
we would expect had the disconnects not occurred.
"""
@ray.remote
class IncrActor:
def __init__(self):
self.val = 0
def incr(self):
self.val += 1
return self.val
i = 0
def fail_every_seven(_):
# Inject an error every seventh time this method is called
nonlocal i
i += 1
if i % 7 == 0:
raise RuntimeError
with start_middleman_server(
on_data_response=fail_every_seven,
on_task_request=fail_every_seven,
on_task_response=fail_every_seven):
actor = IncrActor.remote()
for _ in range(100):
ref = actor.incr.remote()
assert ray.get(ref) == 100
def test_valid_actor_state_2():
"""
Do a full disconnect (cancel channel) every 11 requests. Failure
happens:
- before request sent: request never reaches server
- before response received: response never reaches server
- while get's are being processed
"""
@ray.remote
class IncrActor:
def __init__(self):
self.val = 0
def incr(self):
self.val += 1
return self.val
i = 0
with start_middleman_server() as (middleman, _):
def fail_every_eleven(_):
nonlocal i
i += 1
if i % 11 == 0:
middleman.reset_channel()
middleman.data_servicer.on_response = fail_every_eleven
middleman.task_servicer.on_request = fail_every_eleven
middleman.task_servicer.on_response = fail_every_eleven
actor = IncrActor.remote()
for _ in range(100):
ref = actor.incr.remote()
assert ray.get(ref) == 100
@pytest.mark.skipif(sys.platform == "win32", reason="Flaky on windows")
def test_noisy_puts():
"""
Randomly kills the data channel with 10% chance when receiving response
(requests made it to server, responses dropped) and checks that final
result is still consistent
"""
random.seed(12345)
with start_middleman_server() as (middleman, _):
def fail_randomly(response: ray_client_pb2.DataResponse):
if random.random() < 0.1:
raise RuntimeError
middleman.data_servicer.on_response = fail_randomly
refs = [ray.put(i * 123) for i in range(500)]
results = ray.get(refs)
for i, result in enumerate(results):
assert result == i * 123
def test_client_reconnect_grace_period():
"""
Tests that the client gives up attempting to reconnect the channel
after the grace period expires.
"""
# Lower grace period to 5 seconds to save time
with patch.dict(os.environ, {"RAY_CLIENT_RECONNECT_GRACE_PERIOD": "5"}), \
start_middleman_server() as (middleman, _):
assert ray.get(ray.put(42)) == 42
# Close channel
middleman.channel.close()
start_time = time.time()
with pytest.raises(ConnectionError):
ray.get(ray.put(42))
# Connection error should have been raised within a reasonable
# amount of time. Set to significantly higher than 5 seconds
# to account for reconnect backoff timing
assert time.time() - start_time < 20
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,10 +1,16 @@
from ray.util.client.ray_client_helpers import ray_start_client_server
from unittest.mock import Mock
from unittest.mock import Mock, patch
import pytest
import os
import time
def test_dataclient_disconnect_on_request():
with ray_start_client_server() as ray:
# Client can't signal graceful shutdown to server after unrecoverable
# error. Lower grace period so we don't have to sleep as long before
# checking new connection data.
with patch.dict(os.environ, {"RAY_CLIENT_RECONNECT_GRACE_PERIOD": "5"}), \
ray_start_client_server() as ray:
assert ray.is_connected()
@ray.remote
@ -20,13 +26,18 @@ def test_dataclient_disconnect_on_request():
assert not ray.is_connected()
# Test that a new connection can be made
time.sleep(5) # Give server time to clean up old connection
connection_data = ray.connect("localhost:50051")
assert connection_data["num_clients"] == 1
assert ray.get(f.remote()) == 42
def test_dataclient_disconnect_before_request():
with ray_start_client_server() as ray:
# Client can't signal graceful shutdown to server after unrecoverable
# error. Lower grace period so we don't have to sleep as long before
# checking new connection data.
with patch.dict(os.environ, {"RAY_CLIENT_RECONNECT_GRACE_PERIOD": "5"}), \
ray_start_client_server() as ray:
assert ray.is_connected()
@ray.remote
@ -48,6 +59,7 @@ def test_dataclient_disconnect_before_request():
assert not ray.is_connected()
# Test that a new connection can be made
time.sleep(5) # Give server time to clean up old connection
connection_data = ray.connect("localhost:50051")
assert connection_data["num_clients"] == 1
assert ray.get(f.remote()) == 42

View file

@ -0,0 +1,218 @@
from ray.util.client.common import (_id_is_newer, ResponseCache,
OrderedResponseCache, INT32_MAX)
import threading
import time
import pytest
def test_id_is_newer():
"""
Sanity checks the logic for ID is newer. In general, we would expect
that higher IDs are newer than lower IDs, for example 25 can be assumed
to be newer than 24.
Since IDs roll over at INT32_MAX (~2**31), we should check for weird
behavior there. In particular, we would expect an ID like `11` to be
newer than the ID `2**31` since it's likely that the counter rolled
over.
"""
# Common cases -- higher IDs normally considered newer
assert _id_is_newer(30, 29)
assert _id_is_newer(12345, 12344)
assert not _id_is_newer(12344, 12345)
assert not _id_is_newer(5678, 5678)
# Check behavior near max int boundary
assert _id_is_newer(INT32_MAX, INT32_MAX - 1)
assert _id_is_newer(INT32_MAX - 1, INT32_MAX - 2)
# Low IDs are assumed newer than higher ones if it looks like rollover has
# occurred
assert _id_is_newer(0, INT32_MAX - 4)
assert _id_is_newer(1001, INT32_MAX - 123)
assert not _id_is_newer(INT32_MAX, 123)
def test_response_cache_complete_response():
"""
Test basic check/update logic of cache, and that nothing blocks
"""
cache = ResponseCache()
cache.check_cache(123, 15) # shouldn't block
cache.update_cache(123, 15, "abcdef")
assert cache.check_cache(123, 15) == "abcdef"
def test_ordered_response_cache_complete_response():
"""
Test basic check/update logic of ordered cache, and that nothing blocks
"""
cache = OrderedResponseCache()
cache.check_cache(15) # shouldn't block
cache.update_cache(15, "vwxyz")
assert cache.check_cache(15) == "vwxyz"
def test_response_cache_incomplete_response():
"""
Tests case where a cache entry is populated after a long time. Any new
threads attempting to access that entry should sleep until the response
is ready.
"""
cache = ResponseCache()
def populate_cache():
time.sleep(2)
cache.update_cache(123, 15, "abcdef")
cache.check_cache(123, 15) # shouldn't block
t = threading.Thread(target=populate_cache, args=())
t.start()
# Should block until other thread populates cache
assert cache.check_cache(123, 15) == "abcdef"
t.join()
def test_ordered_response_cache_incomplete_response():
"""
Tests case where an ordered cache entry is populated after a long time. Any
new threads attempting to access that entry should sleep until the response
is ready.
"""
cache = OrderedResponseCache()
def populate_cache():
time.sleep(2)
cache.update_cache(15, "vwxyz")
cache.check_cache(15) # shouldn't block
t = threading.Thread(target=populate_cache, args=())
t.start()
# Should block until other thread populates cache
assert cache.check_cache(15) == "vwxyz"
t.join()
def test_ordered_response_cache_cleanup():
"""
Tests that the cleanup method of ordered cache works as expected, in
particular that all entries <= the passed ID are cleared from the cache.
"""
cache = OrderedResponseCache()
for i in range(1, 21):
assert cache.check_cache(i) is None
cache.update_cache(i, str(i))
assert len(cache.cache) == 20
for i in range(1, 21):
assert cache.check_cache(i) == str(i)
# Expected: clean up all entries up to and including entry 10
cache.cleanup(10)
assert len(cache.cache) == 10
with pytest.raises(RuntimeError):
# Attempting to access value that has already been cleaned up
cache.check_cache(10)
for i in range(21, 31):
# Check that more entries can be inserted
assert cache.check_cache(i) is None
cache.update_cache(i, str(i))
# Cleanup everything
cache.cleanup(30)
assert len(cache.cache) == 0
with pytest.raises(RuntimeError):
cache.check_cache(30)
# Cleanup requests received out of order are tolerated
cache.cleanup(27)
cache.cleanup(23)
def test_response_cache_update_while_waiting():
"""
Tests that an error is thrown when a cache entry is updated with the
response for a different request than what was originally being
checked for.
"""
# Error when awaiting cache to update, but entry is cleaned up
cache = ResponseCache()
assert cache.check_cache(16, 123) is None
def cleanup_cache():
time.sleep(2)
cache.check_cache(16, 124)
cache.update_cache(16, 124, "asdf")
t = threading.Thread(target=cleanup_cache, args=())
t.start()
with pytest.raises(RuntimeError):
cache.check_cache(16, 123)
t.join()
def test_ordered_response_cache_cleanup_while_waiting():
"""
Tests that an error is thrown when an ordered cache entry is updated with
the response for a different request than what was originally being
checked for.
"""
# Error when awaiting cache to update, but entry is cleaned up
cache = OrderedResponseCache()
assert cache.check_cache(123) is None
def cleanup_cache():
time.sleep(2)
cache.cleanup(123)
t = threading.Thread(target=cleanup_cache, args=())
t.start()
with pytest.raises(RuntimeError):
cache.check_cache(123)
t.join()
def test_response_cache_cleanup():
"""
Checks that the response cache replaces old entries for a given thread
with new entries as they come in, instead of creating new entries
(possibly wasting memory on unneeded entries)
"""
# Check that the response cache cleans up previous entries for a given
# thread properly.
cache = ResponseCache()
cache.check_cache(16, 123)
cache.update_cache(16, 123, "Some response")
assert len(cache.cache) == 1
cache.check_cache(16, 124)
cache.update_cache(16, 124, "Second response")
assert len(cache.cache) == 1 # Should reuse entry for thread 16
assert cache.check_cache(16, 124) == "Second response"
def test_response_cache_invalidate():
"""
Check that ordered response cache invalidate works as expected
"""
cache = OrderedResponseCache()
e = RuntimeError("SomeError")
# No pending entries, cache should be valid
assert not cache.invalidate(e)
# No entry for 123 yet
assert cache.check_cache(123) is None
# this should invalidate the entry for 123
assert cache.invalidate(e)
assert cache.check_cache(123) == e
assert cache.invalidate(e)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
# This version string is incremented to indicate breaking changes in the
# protocol that require upgrading the client version.
CURRENT_PROTOCOL_VERSION = "2021-08-26"
CURRENT_PROTOCOL_VERSION = "2021-09-02"
class _ClientContext:

View file

@ -13,10 +13,12 @@ from ray.util.inspect import is_cython, is_function_or_method
import json
import logging
import threading
from collections import OrderedDict
from typing import Any
from typing import List
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Union
logger = logging.getLogger(__name__)
@ -25,6 +27,21 @@ logger = logging.getLogger(__name__)
# number of simultaneous in-flight requests.
INT32_MAX = (2**31) - 1
# gRPC status codes that the client shouldn't attempt to recover from
# Resource exhausted: Server is low on resources, or has hit the max number
# of client connections
# Invalid argument: Reserved for application errors
# Not found: Set if the client is attempting to reconnect to a session that
# does not exist
# Failed precondition: Reserverd for application errors
# Aborted: Set when an error is serialized into the details of the context,
# signals that error should be deserialized on the client side
GRPC_UNRECOVERABLE_ERRORS = (grpc.StatusCode.RESOURCE_EXHAUSTED,
grpc.StatusCode.INVALID_ARGUMENT,
grpc.StatusCode.NOT_FOUND,
grpc.StatusCode.FAILED_PRECONDITION,
grpc.StatusCode.ABORTED)
# TODO: Instead of just making the max message size large, the right thing to
# do is to split up the bytes representation of serialized data into multiple
# messages and reconstruct them on either end. That said, since clients are
@ -385,6 +402,13 @@ class ClientServerHandle:
logs_servicer: ray_client_pb2_grpc.RayletLogStreamerServicer
grpc_server: grpc.Server
def stop(self, grace: int) -> None:
# The data servicer might be sleeping while waiting for clients to
# reconnect. Signal that they no longer have to sleep and can exit
# immediately, since the RPC server is stopped.
self.grpc_server.stop(grace)
self.data_servicer.stopped.set()
# Add a hook for all the cases that previously
# expected simply a gRPC server
def __getattr__(self, attr):
@ -402,3 +426,241 @@ def _get_client_id_from_context(context: Any) -> str:
logger.error("Client connecting with no client_id")
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
return client_id
def _propagate_error_in_context(e: Exception, context: Any) -> bool:
"""
Encode an error into the context of an RPC response. Returns True
if the error can be recovered from, false otherwise
"""
try:
if isinstance(e, grpc.RpcError):
# RPC error, propagate directly by copying details into context
context.set_code(e.code())
context.set_details(e.details())
return e.code() not in GRPC_UNRECOVERABLE_ERRORS
except Exception:
# Extra precaution -- if encoding the RPC directly fails fallback
# to treating it as a regular error
pass
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details(str(e))
return False
def _id_is_newer(id1: int, id2: int) -> bool:
"""
We should only replace cache entries with the responses for newer IDs.
Most of the time newer IDs will be the ones with higher value, except when
the req_id counter rolls over. We check for this case by checking the
distance between the two IDs. If the distance is significant, then it's
likely that the req_id counter rolled over, and the smaller id should
still be used to replace the one in cache.
"""
diff = abs(id2 - id1)
if diff > (INT32_MAX // 2):
# Rollover likely occurred. In this case the smaller ID is newer
return id1 < id2
return id1 > id2
class ResponseCache:
"""
Cache for blocking method calls. Needed to prevent retried requests from
being applied multiple times on the server, for example when the client
disconnects. This is used to cache requests/responses sent through
unary-unary RPCs to the RayletServicer.
Note that no clean up logic is used, the last response for each thread
will always be remembered, so at most the cache will hold N entries,
where N is the number of threads on the client side. This relies on the
assumption that a thread will not make a new blocking request until it has
received a response for a previous one, at which point it's safe to
overwrite the old response.
The high level logic is:
1. Before making a call, check the cache for the current thread.
2. If present in the cache, check the request id of the cached
response.
a. If it matches the current request_id, then the request has been
received before and we shouldn't re-attempt the logic. Wait for
the response to become available in the cache, and then return it
b. If it doesn't match, then this is a new request and we can
proceed with calling the real stub. While the response is still
being generated, temporarily keep (req_id, None) in the cache.
Once the call is finished, update the cache entry with the
new (req_id, response) pair. Notify other threads that may
have been waiting for the response to be prepared.
"""
def __init__(self):
self.cv = threading.Condition()
self.cache: Dict[int, Tuple[int, Any]] = {}
def check_cache(self, thread_id: int, request_id: int) -> Optional[Any]:
"""
Check the cache for a given thread, and see if the entry in the cache
matches the current request_id. Returns None if the request_id has
not been seen yet, otherwise returns the cached result.
Throws an error if the placeholder in the cache doesn't match the
request_id -- this means that a new request evicted the old value in
the cache, and that the RPC for `request_id` is redundant and the
result can be discarded, i.e.:
1. Request A is sent (A1)
2. Channel disconnects
3. Request A is resent (A2)
4. A1 is received
5. A2 is received, waits for A1 to finish
6. A1 finishes and is sent back to client
7. Request B is sent
8. Request B overwrites cache entry
9. A2 wakes up extremely late, but cache is now invalid
In practice this is VERY unlikely to happen, but the error can at
least serve as a sanity check or catch invalid request id's.
"""
with self.cv:
if thread_id in self.cache:
cached_request_id, cached_resp = self.cache[thread_id]
if cached_request_id == request_id:
while cached_resp is None:
# The call was started, but the response hasn't yet
# been added to the cache. Let go of the lock and
# wait until the response is ready.
self.cv.wait()
cached_request_id, cached_resp = self.cache[thread_id]
if cached_request_id != request_id:
raise RuntimeError(
"Cached response doesn't match the id of the "
"original request. This might happen if this "
"request was received out of order. The "
"result of the caller is no longer needed. "
f"({request_id} != {cached_request_id})")
return cached_resp
if not _id_is_newer(request_id, cached_request_id):
raise RuntimeError(
"Attempting to replace newer cache entry with older "
"one. This might happen if this request was received "
"out of order. The result of the caller is no "
f"longer needed. ({request_id} != {cached_request_id}")
self.cache[thread_id] = (request_id, None)
return None
def update_cache(self, thread_id: int, request_id: int,
response: Any) -> None:
"""
Inserts `response` into the cache for `request_id`.
"""
with self.cv:
cached_request_id, cached_resp = self.cache[thread_id]
if cached_request_id != request_id or cached_resp is not None:
# The cache was overwritten by a newer requester between
# our call to check_cache and our call to update it.
# This can't happen if the assumption that the cached requests
# are all blocking on the client side, so if you encounter
# this, check if any async requests are being cached.
raise RuntimeError(
"Attempting to update the cache, but placeholder's "
"do not match the current request_id. This might happen "
"if this request was received out of order. The result "
f"of the caller is no longer needed. ({request_id} != "
f"{cached_request_id})")
self.cache[thread_id] = (request_id, response)
self.cv.notify_all()
class OrderedResponseCache:
"""
Cache for streaming RPCs, i.e. the DataServicer. Relies on explicit
ack's from the client to determine when it can clean up cache entries.
"""
def __init__(self):
self.last_received = 0
self.cv = threading.Condition()
self.cache: Dict[int, Any] = OrderedDict()
def check_cache(self, req_id: int) -> Optional[Any]:
"""
Check the cache for a given thread, and see if the entry in the cache
matches the current request_id. Returns None if the request_id has
not been seen yet, otherwise returns the cached result.
"""
with self.cv:
if _id_is_newer(self.last_received,
req_id) or self.last_received == req_id:
# Request is for an id that has already been cleared from
# cache/acknowledged.
raise RuntimeError(
"Attempting to accesss a cache entry that has already "
"cleaned up. The client has already acknowledged "
f"receiving this response. ({req_id}, "
f"{self.last_received})")
if req_id in self.cache:
cached_resp = self.cache[req_id]
while cached_resp is None:
# The call was started, but the response hasn't yet been
# added to the cache. Let go of the lock and wait until
# the response is ready
self.cv.wait()
if req_id not in self.cache:
raise RuntimeError(
"Cache entry was removed. This likely means that "
"the result of this call is no longer needed.")
cached_resp = self.cache[req_id]
return cached_resp
self.cache[req_id] = None
return None
def update_cache(self, req_id: int, resp: Any) -> None:
"""
Inserts `response` into the cache for `request_id`.
"""
with self.cv:
self.cv.notify_all()
if req_id not in self.cache:
raise RuntimeError(
"Attempting to update the cache, but placeholder is "
"missing. This might happen on a redundant call to "
f"update_cache. ({req_id})")
self.cache[req_id] = resp
def invalidate(self, e: Exception) -> bool:
"""
Invalidate any partially populated cache entries, replacing their
placeholders with the passed in exception. Useful to prevent a thread
from waiting indefinitely on a failed call.
Returns True if the cache contains an error, False otherwise
"""
with self.cv:
invalid = False
for req_id in self.cache:
if self.cache[req_id] is None:
self.cache[req_id] = e
if isinstance(self.cache[req_id], Exception):
invalid = True
self.cv.notify_all()
return invalid
def cleanup(self, last_received: int) -> None:
"""
Cleanup all of the cached requests up to last_received. Assumes that
the cache entries were inserted in ascending order.
"""
with self.cv:
if _id_is_newer(last_received, self.last_received):
self.last_received = last_received
to_remove = []
for req_id in self.cache:
if _id_is_newer(last_received,
req_id) or last_received == req_id:
to_remove.append(req_id)
else:
break
for req_id in to_remove:
del self.cache[req_id]
self.cv.notify_all()

View file

@ -6,6 +6,7 @@ import queue
import threading
import grpc
from collections import OrderedDict
from typing import Any, Callable, Dict, TYPE_CHECKING, Optional, Union
import ray.core.generated.ray_client_pb2 as ray_client_pb2
@ -20,6 +21,9 @@ logger = logging.getLogger(__name__)
ResponseCallable = Callable[[Union[ray_client_pb2.DataResponse, Exception]],
None]
# Send an acknowledge on every 32nd response received
ACKNOWLEDGE_BATCH_SIZE = 32
class DataClient:
def __init__(self, client_worker: "Worker", client_id: str,
@ -36,9 +40,13 @@ class DataClient:
self._metadata = metadata
self.data_thread = self._start_datathread()
# Track outstanding requests to resend in case of disconnection
self.outstanding_requests: Dict[int, Any] = OrderedDict()
# Serialize access to all mutable internal states: self.request_queue,
# self.ready_data, self.asyncio_waiting_data,
# self._in_shutdown, self._req_id and calling self._next_id().
# self._in_shutdown, self._req_id, self.outstanding_requests and
# calling self._next_id()
self.lock = threading.Lock()
# Waiting for response or shutdown.
@ -51,6 +59,8 @@ class DataClient:
self.asyncio_waiting_data: Dict[int, ResponseCallable] = {}
self._in_shutdown = False
self._req_id = 0
self._last_exception = None
self._acknowledge_counter = 0
self.data_thread.start()
@ -73,17 +83,32 @@ class DataClient:
daemon=True)
def _data_main(self) -> None:
stub = ray_client_pb2_grpc.RayletDataStreamerStub(
self.client_worker.channel)
resp_stream = stub.Datapath(
iter(self.request_queue.get, None),
metadata=self._metadata,
wait_for_ready=True)
reconnecting = False
try:
for response in resp_stream:
self._process_response(response)
except grpc.RpcError as e:
self._process_rpc_error(e)
while not self.client_worker._in_shutdown:
stub = ray_client_pb2_grpc.RayletDataStreamerStub(
self.client_worker.channel)
metadata = self._metadata + \
[("reconnecting", str(reconnecting))]
resp_stream = stub.Datapath(
iter(self.request_queue.get, None),
metadata=metadata,
wait_for_ready=True)
try:
for response in resp_stream:
self._process_response(response)
return
except grpc.RpcError as e:
reconnecting = self._can_reconnect(e)
if not reconnecting:
self._last_exception = e
return
self._reconnect_channel()
except Exception as e:
self._last_exception = e
finally:
logger.info("Shutting down data channel")
self._shutdown()
def _process_response(self, response: Any) -> None:
"""
@ -101,54 +126,114 @@ class DataClient:
# is accessed without holding self.lock. Holding the
# lock shouldn't be necessary either.
callback = self.asyncio_waiting_data.pop(response.req_id)
callback(response)
if callback:
callback(response)
except Exception:
logger.exception("Callback error:")
with self.lock:
# Update outstanding requests
if response.req_id in self.outstanding_requests:
del self.outstanding_requests[response.req_id]
# Acknowledge response
self._acknowledge(response.req_id)
else:
with self.lock:
self.ready_data[response.req_id] = response
self.cv.notify_all()
def _process_rpc_error(self, e: grpc.RpcError):
def _can_reconnect(self, e: grpc.RpcError) -> bool:
"""
Processes RPC errors that occur while reading from data stream.
Returns True if the error can be recovered from, False otherwise.
"""
self._shutdown(e)
if not self.client_worker._can_reconnect(e):
logger.info("Unrecoverable error in data channel.")
logger.debug(e)
return False
logger.debug("Recoverable error in data channel.")
logger.debug(e)
return True
if e.code() == grpc.StatusCode.CANCELLED:
# Gracefully shutting down
logger.info("Cancelling data channel")
elif e.code() in (grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.RESOURCE_EXHAUSTED):
# TODO(barakmich): The server may have
# dropped. In theory, we can retry, as per
# https://grpc.github.io/grpc/core/md_doc_statuscodes.html but
# in practice we may need to think about the correct semantics
# here.
logger.info("Server disconnected from data channel")
else:
logger.exception("Got Error from data channel -- shutting down:")
def _shutdown(self, e: grpc.RpcError) -> None:
def _shutdown(self) -> None:
"""
Shutdown the data channel
"""
with self.lock:
self._in_shutdown = True
self._last_exception = e
self.cv.notify_all()
callbacks = self.asyncio_waiting_data.values()
self.asyncio_waiting_data = {}
# Abort async requests with the error.
err = ConnectionError("Failed during this or a previous request. "
f"Exception that broke the connection: {e}")
if self._last_exception:
# Abort async requests with the error.
err = ConnectionError(
"Failed during this or a previous request. Exception that "
f"broke the connection: {self._last_exception}")
else:
err = ConnectionError(
"Request cannot be fulfilled because the data client has "
"disconnected.")
for callback in callbacks:
callback(err)
if callback:
callback(err)
# Since self._in_shutdown is set to True, no new item
# will be added to self.asyncio_waiting_data
def _acknowledge(self, req_id: int) -> None:
"""
Puts an acknowledge request on the request queue periodically.
Lock should be held before calling this. Used when an async or
blocking response is received.
"""
if not self.client_worker._reconnect_enabled:
# Skip ACKs if reconnect isn't enabled
return
assert self.lock.locked()
self._acknowledge_counter += 1
if self._acknowledge_counter % ACKNOWLEDGE_BATCH_SIZE == 0:
self.request_queue.put(
ray_client_pb2.DataRequest(
acknowledge=ray_client_pb2.AcknowledgeRequest(
req_id=req_id)))
def _reconnect_channel(self) -> None:
"""
Attempts to reconnect the gRPC channel and resend outstanding
requests. First, the server is pinged to see if the current channel
still works. If the ping fails, then the current channel is closed
and replaced with a new one.
Once a working channel is available, a new request queue is made
and filled with any outstanding requests to be resent to the server.
"""
try:
# Ping the server to see if the current channel is reuseable, for
# example if gRPC reconnected the channel on its own or if the
# RPC error was transient and the channel is still open
ping_succeeded = self.client_worker.ping_server(timeout=5)
except grpc.RpcError:
ping_succeeded = False
if not ping_succeeded:
# Ping failed, try refreshing the data channel
logger.warning(
"Encountered connection issues in the data channel. "
"Attempting to reconnect.")
try:
self.client_worker._connect_channel(reconnecting=True)
except ConnectionError:
logger.warning("Failed to reconnect the data channel")
raise
logger.info("Reconnection succeeded!")
# Recreate the request queue, and resend outstanding requests
with self.lock:
self.request_queue = queue.Queue()
for request in self.outstanding_requests.values():
# Resend outstanding requests
self.request_queue.put(request)
def close(self) -> None:
thread = None
with self.lock:
@ -157,6 +242,12 @@ class DataClient:
self.cv.notify_all()
# Add sentinel to terminate streaming RPC.
if self.request_queue is not None:
# Intentional shutdown, tell server it can clean up the
# connection immediately and ignore the reconnect grace period.
cleanup_request = ray_client_pb2.DataRequest(
connection_cleanup=ray_client_pb2.ConnectionCleanupRequest(
))
self.request_queue.put(cleanup_request)
self.request_queue.put(None)
if self.data_thread is not None:
thread = self.data_thread
@ -171,6 +262,7 @@ class DataClient:
req_id = self._next_id()
req.req_id = req_id
self.request_queue.put(req)
self.outstanding_requests[req_id] = req
self.cv.wait_for(
lambda: req_id in self.ready_data or self._in_shutdown)
@ -178,6 +270,8 @@ class DataClient:
data = self.ready_data[req_id]
del self.ready_data[req_id]
del self.outstanding_requests[req_id]
self._acknowledge(req_id)
return data
@ -188,8 +282,8 @@ class DataClient:
self._check_shutdown()
req_id = self._next_id()
req.req_id = req_id
if callback:
self.asyncio_waiting_data[req_id] = callback
self.asyncio_waiting_data[req_id] = callback
self.outstanding_requests[req_id] = req
self.request_queue.put(req)
# Must hold self.lock when calling this function.
@ -210,11 +304,10 @@ class DataClient:
self.lock.acquire()
last_exception = getattr(self, "_last_exception", None)
if last_exception is not None:
if self._last_exception is not None:
msg = ("Request can't be sent because the Ray client has already "
"been disconnected due to an error. Last exception: "
f"{last_exception}")
f"{self._last_exception}")
else:
msg = ("Request can't be sent because the Ray client has already "
"been disconnected.")

View file

@ -5,6 +5,7 @@ import sys
import logging
import queue
import threading
import time
import grpc
from typing import TYPE_CHECKING
@ -12,6 +13,8 @@ from typing import TYPE_CHECKING
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.util.debug import log_once
if TYPE_CHECKING:
from ray.util.client.worker import Worker
@ -35,42 +38,62 @@ class LogstreamClient:
self.request_queue = queue.Queue()
self.log_thread = self._start_logthread()
self.log_thread.start()
self.last_req = None
def _start_logthread(self) -> threading.Thread:
return threading.Thread(target=self._log_main, args=(), daemon=True)
def _log_main(self) -> None:
stub = ray_client_pb2_grpc.RayletLogStreamerStub(
self.client_worker.channel)
log_stream = stub.Logstream(
iter(self.request_queue.get, None), metadata=self._metadata)
try:
for record in log_stream:
if record.level < 0:
self.stdstream(level=record.level, msg=record.msg)
self.log(level=record.level, msg=record.msg)
except grpc.RpcError as e:
self._process_rpc_error(e)
reconnecting = False
while not self.client_worker._in_shutdown:
if reconnecting:
# Refresh queue and retry last request
self.request_queue = queue.Queue()
if self.last_req:
self.request_queue.put(self.last_req)
stub = ray_client_pb2_grpc.RayletLogStreamerStub(
self.client_worker.channel)
try:
log_stream = stub.Logstream(
iter(self.request_queue.get, None),
metadata=self._metadata)
except ValueError:
# Trying to use the stub on a cancelled channel will raise
# ValueError. This should only happen when the data client
# is attempting to reset the connection -- sleep and try
# again.
time.sleep(.5)
continue
try:
for record in log_stream:
if record.level < 0:
self.stdstream(level=record.level, msg=record.msg)
self.log(level=record.level, msg=record.msg)
return
except grpc.RpcError as e:
reconnecting = self._process_rpc_error(e)
if not reconnecting:
return
def _process_rpc_error(self, e: grpc.RpcError):
def _process_rpc_error(self, e: grpc.RpcError) -> bool:
"""
Processes RPC errors that occur while reading from data stream.
Returns True if the error can be recovered from, False otherwise.
"""
if e.code() == grpc.StatusCode.CANCELLED:
# Graceful shutdown. We've cancelled our own connection.
logger.info("Cancelling logs channel")
elif e.code() in (grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.RESOURCE_EXHAUSTED):
# TODO(barakmich): The server may have
# dropped. In theory, we can retry, as per
# https://grpc.github.io/grpc/core/md_doc_statuscodes.html but
# in practice we may need to think about the correct semantics
# here.
logger.info("Server disconnected from logs channel")
else:
# Some other, unhandled, gRPC error
logger.exception(
f"Got Error from logger channel -- shutting down: {e}")
if self.client_worker._can_reconnect(e):
if log_once("lost_reconnect_logs"):
logger.warning(
"Log channel is reconnecting. Logs produced while "
"the connection was down can be found on the head "
"node of the cluster in "
"`ray_client_server_[port].out`")
logger.info("Log channel dropped, retrying.")
time.sleep(.5)
return True
logger.info("Shutting down log channel.")
if not self.client_worker._in_shutdown:
logger.exception("Unexpected exception:")
return False
def log(self, level: int, msg: str):
"""Log the message from the log stream.
@ -99,6 +122,7 @@ class LogstreamClient:
req.enabled = True
req.loglevel = level
self.request_queue.put(req)
self.last_req = req
def close(self) -> None:
self.request_queue.put(None)
@ -109,3 +133,4 @@ class LogstreamClient:
req = ray_client_pb2.LogSettingsRequest()
req.enabled = False
self.request_queue.put(req)
self.last_req = req

View file

@ -1,15 +1,20 @@
from collections import defaultdict
import threading
import ray
import logging
import grpc
from queue import Queue
import sys
from typing import Any, Iterator, TYPE_CHECKING, Union
from typing import Any, Dict, Iterator, TYPE_CHECKING, Union
from threading import Lock, Thread
import time
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.util.client.common import CLIENT_SERVER_MAX_THREADS
from ray.util.client.common import (CLIENT_SERVER_MAX_THREADS,
_propagate_error_in_context,
OrderedResponseCache)
from ray.util.client import CURRENT_PROTOCOL_VERSION
from ray.util.debug import log_once
from ray._private.client_mode_hook import disable_client_hook
@ -22,6 +27,37 @@ logger = logging.getLogger(__name__)
QUEUE_JOIN_SECONDS = 10
def _get_reconnecting_from_context(context: Any) -> bool:
"""
Get `reconnecting` from gRPC metadata, or False if missing.
"""
metadata = {k: v for k, v in context.invocation_metadata()}
val = metadata.get("reconnecting")
if val is None or val not in ("True", "False"):
logger.error(
f'Client connecting with invalid value for "reconnecting": {val}, '
"This may be because you have a mismatched client and server "
"version.")
return False
return val == "True"
def _should_cache(req: ray_client_pb2.DataRequest) -> bool:
"""
Returns True if the response should to the given request should be cached,
false otherwise. At the moment the only requests we do not cache are:
- asynchronous gets: These arrive out of order. Skipping caching here
is fine, since repeating an async get is idempotent
- acks: Repeating acks is idempotent
- clean up requests: Also idempotent, and client has likely already
wrapped up the data connection by this point.
"""
req_type = req.WhichOneof("type")
if req_type == "get" and req.get.asynchronous:
return False
return req_type not in ("acknowledge", "connection_cleanup")
def fill_queue(
grpc_input_generator: Iterator[ray_client_pb2.DataRequest],
output_queue:
@ -46,15 +82,30 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
self.basic_service = basic_service
self.clients_lock = Lock()
self.num_clients = 0 # guarded by self.clients_lock
# dictionary mapping client_id's to the last time they connected
self.client_last_seen: Dict[str, float] = {}
# dictionary mapping client_id's to their reconnect grace periods
self.reconnect_grace_periods: Dict[str, float] = {}
# dictionary mapping client_id's to their response cache
self.response_caches: Dict[str, OrderedResponseCache] = defaultdict(
OrderedResponseCache)
# stopped event, useful for signals that the server is shut down
self.stopped = threading.Event()
def Datapath(self, request_iterator, context):
start_time = time.time()
# set to True if client shuts down gracefully
cleanup_requested = False
metadata = {k: v for k, v in context.invocation_metadata()}
client_id = metadata.get("client_id")
if client_id is None:
logger.error("Client connecting with no client_id")
return
logger.debug(f"New data connection from client {client_id}: ")
accepted_connection = self._init(client_id, context)
accepted_connection = self._init(client_id, context, start_time)
response_cache = self.response_caches[client_id]
# Set to False if client requests a reconnect grace period of 0
reconnect_enabled = True
if not accepted_connection:
return
try:
@ -76,11 +127,26 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
continue
assert isinstance(req, ray_client_pb2.DataRequest)
if _should_cache(req) and reconnect_enabled:
cached_resp = response_cache.check_cache(req.req_id)
if isinstance(cached_resp, Exception):
# Cache state is invalid, raise exception
raise cached_resp
if cached_resp is not None:
yield cached_resp
continue
resp = None
req_type = req.WhichOneof("type")
if req_type == "init":
resp_init = self.basic_service.Init(req.init)
resp = ray_client_pb2.DataResponse(init=resp_init, )
with self.clients_lock:
self.reconnect_grace_periods[client_id] = \
req.init.reconnect_grace_period
if req.init.reconnect_grace_period == 0:
reconnect_enabled = False
elif req_type == "get":
if req.get.asynchronous:
get_resp = self.basic_service._async_get_object(
@ -115,23 +181,69 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
req.prep_runtime_env)
resp = ray_client_pb2.DataResponse(
prep_runtime_env=resp_prep)
elif req_type == "connection_cleanup":
cleanup_requested = True
cleanup_resp = ray_client_pb2.ConnectionCleanupResponse()
resp = ray_client_pb2.DataResponse(
connection_cleanup=cleanup_resp)
elif req_type == "acknowledge":
# Clean up acknowledged cache entries
response_cache.cleanup(req.acknowledge.req_id)
continue
else:
raise Exception(f"Unreachable code: Request type "
f"{req_type} not handled in Datapath")
resp.req_id = req.req_id
if _should_cache(req) and reconnect_enabled:
response_cache.update_cache(req.req_id, resp)
yield resp
except grpc.RpcError as e:
logger.debug(f"Closing data channel: {e}")
except Exception as e:
logger.exception("Error in data channel:")
recoverable = _propagate_error_in_context(e, context)
invalid_cache = response_cache.invalidate(e)
if not recoverable or invalid_cache:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
# Connection isn't recoverable, skip cleanup
cleanup_requested = True
finally:
logger.debug(f"Lost data connection from client {client_id}")
self.basic_service.release_all(client_id)
queue_filler_thread.join(QUEUE_JOIN_SECONDS)
if queue_filler_thread.is_alive():
logger.error(
"Queue filler thread failed to join before timeout: {}".
format(QUEUE_JOIN_SECONDS))
cleanup_delay = self.reconnect_grace_periods.get(client_id)
if not cleanup_requested and cleanup_delay is not None:
logger.debug("Cleanup wasn't requested, delaying cleanup by"
f"{cleanup_delay} seconds.")
# Delay cleanup, since client may attempt a reconnect
# Wait on the "stopped" event in case the grpc server is
# stopped and we can clean up earlier.
self.stopped.wait(timeout=cleanup_delay)
else:
logger.debug("Cleanup was requested, cleaning up immediately.")
with self.clients_lock:
# Could fail before client accounting happens
if client_id not in self.client_last_seen:
logger.debug("Connection already cleaned up.")
# Some other connection has already cleaned up this
# this client's session. This can happen if the client
# reconnects and then gracefully shut's down immediately.
return
last_seen = self.client_last_seen[client_id]
if last_seen > start_time:
# The client successfully reconnected and updated
# last seen some time during the grace period
logger.debug("Client reconnected, skipping cleanup")
return
# Either the client shut down gracefully, or the client
# failed to reconnect within the grace period. Clean up
# the connection.
self.basic_service.release_all(client_id)
del self.client_last_seen[client_id]
if client_id in self.reconnect_grace_periods:
del self.reconnect_grace_periods[client_id]
if client_id in self.response_caches:
del self.response_caches[client_id]
self.num_clients -= 1
logger.debug(f"Removed clients. {self.num_clients}")
@ -142,12 +254,13 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
logger.debug("Shutting down ray.")
ray.shutdown()
def _init(self, client_id: str, context: Any):
def _init(self, client_id: str, context: Any, start_time: float):
"""
Checks if resources allow for another client.
Returns a boolean indicating if initialization was successful.
"""
with self.clients_lock:
reconnecting = _get_reconnecting_from_context(context)
threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
if self.num_clients >= threshold:
logger.warning(
@ -162,10 +275,21 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
f"(currently set to {CLIENT_SERVER_MAX_THREADS}).")
context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
return False
self.num_clients += 1
logger.debug(f"Accepted data connection from {client_id}. "
f"Total clients: {self.num_clients}")
if reconnecting and client_id not in self.client_last_seen:
# Client took too long to reconnect, session has been
# cleaned up.
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(
"Attempted to reconnect to a session that has already "
"been cleaned up.")
return False
if client_id in self.client_last_seen:
logger.debug(f"Client {client_id} has reconnected.")
else:
self.num_clients += 1
logger.debug(f"Accepted data connection from {client_id}. "
f"Total clients: {self.num_clients}")
self.client_last_seen[client_id] = start_time
return True
def _build_connection_response(self):

View file

@ -7,7 +7,7 @@ from itertools import chain
import json
import socket
import sys
from threading import Lock, Thread, RLock
from threading import Event, Lock, Thread, RLock
import time
import traceback
from typing import Callable, Dict, List, Optional, Tuple
@ -21,9 +21,10 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import ray.core.generated.runtime_env_agent_pb2 as runtime_env_agent_pb2
import ray.core.generated.runtime_env_agent_pb2_grpc as runtime_env_agent_pb2_grpc # noqa: E501
from ray.util.client.common import (_get_client_id_from_context,
ClientServerHandle,
CLIENT_SERVER_MAX_THREADS, GRPC_OPTIONS)
from ray.util.client.common import (
_get_client_id_from_context, ClientServerHandle, CLIENT_SERVER_MAX_THREADS,
GRPC_OPTIONS, _propagate_error_in_context)
from ray.util.client.server.dataservicer import _get_reconnecting_from_context
from ray._private.client_mode_hook import disable_client_hook
from ray._private.parameter import RayParams
from ray._private.runtime_env import RuntimeEnvContext
@ -384,10 +385,14 @@ class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer):
stub = ray_client_pb2_grpc.RayletDriverStub(chan)
try:
return getattr(stub, method)(
request, metadata=[("client_id", client_id)])
except Exception:
metadata = [("client_id", client_id)]
if context:
metadata = context.invocation_metadata()
return getattr(stub, method)(request, metadata=metadata)
except Exception as e:
# Error while proxying -- propagate the error's context to user
logger.exception(f"Proxying call to {method} failed!")
_propagate_error_in_context(e, context)
def _has_channel_for_request(self, context):
client_id = _get_client_id_from_context(context)
@ -531,7 +536,8 @@ def prepare_runtime_init_req(init_request: ray_client_pb2.DataRequest
new_job_config = ray_client_server_env_prep(job_config)
modified_init_req = ray_client_pb2.InitRequest(
job_config=pickle.dumps(new_job_config),
ray_init_kwargs=init_request.init.ray_init_kwargs)
ray_init_kwargs=init_request.init.ray_init_kwargs,
reconnect_grace_period=init_request.init.reconnect_grace_period)
init_request.init.CopyFrom(modified_init_req)
return (init_request, new_job_config)
@ -540,8 +546,12 @@ def prepare_runtime_init_req(init_request: ray_client_pb2.DataRequest
class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, proxy_manager: ProxyManager):
self.num_clients = 0
# dictionary mapping client_id's to the last time they connected
self.clients_last_seen: Dict[str, float] = {}
self.reconnect_grace_periods: Dict[str, float] = {}
self.clients_lock = Lock()
self.proxy_manager = proxy_manager
self.stopped = Event()
def modify_connection_info_resp(self,
init_resp: ray_client_pb2.DataResponse
@ -560,61 +570,117 @@ class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
return modified_resp
def Datapath(self, request_iterator, context):
cleanup_requested = False
start_time = time.time()
client_id = _get_client_id_from_context(context)
if client_id == "":
return
reconnecting = _get_reconnecting_from_context(context)
# Create Placeholder *before* reading the first request.
server = self.proxy_manager.create_specific_server(client_id)
try:
if reconnecting:
with self.clients_lock:
if client_id not in self.clients_last_seen:
# Client took too long to reconnect, session has already
# been cleaned up
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(
"Attempted to reconnect a session that has already "
"been cleaned up")
return
self.clients_last_seen[client_id] = start_time
server = self.proxy_manager._get_server_for_client(client_id)
channel = self.proxy_manager.get_channel(client_id)
# iterator doesn't need modification on reconnect
new_iter = request_iterator
else:
# Create Placeholder *before* reading the first request.
server = self.proxy_manager.create_specific_server(client_id)
with self.clients_lock:
self.clients_last_seen[client_id] = start_time
self.num_clients += 1
logger.info(f"New data connection from client {client_id}: ")
init_req = next(request_iterator)
try:
modified_init_req, job_config = prepare_runtime_init_req(
init_req)
if not self.proxy_manager.start_specific_server(
client_id, job_config):
logger.error(
f"Server startup failed for client: {client_id}, "
f"using JobConfig: {job_config}!")
raise RuntimeError(
"Starting Ray client server failed. See "
f"ray_client_server_{server.port}.err for detailed "
"logs.")
channel = self.proxy_manager.get_channel(client_id)
if channel is None:
logger.error(f"Channel not found for {client_id}")
raise RuntimeError(
"Proxy failed to Connect to backend! Check "
"`ray_client_server.err` and "
f"`ray_client_server_{server.port}.err` on the head "
"node of the cluster for the relevant logs. "
"By default these are located at "
"/tmp/ray/session_latest/logs.")
stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
except Exception:
init_resp = ray_client_pb2.DataResponse(
init=ray_client_pb2.InitResponse(
ok=False, msg=traceback.format_exc()))
init_resp.req_id = init_req.req_id
yield init_resp
return None
try:
if not reconnecting:
logger.info(f"New data connection from client {client_id}: ")
init_req = next(request_iterator)
with self.clients_lock:
self.reconnect_grace_periods[client_id] = \
init_req.init.reconnect_grace_period
try:
modified_init_req, job_config = prepare_runtime_init_req(
init_req)
if not self.proxy_manager.start_specific_server(
client_id, job_config):
logger.error(
f"Server startup failed for client: {client_id}, "
f"using JobConfig: {job_config}!")
raise RuntimeError(
"Starting Ray client server failed. See "
f"ray_client_server_{server.port}.err for "
"detailed logs.")
channel = self.proxy_manager.get_channel(client_id)
if channel is None:
logger.error(f"Channel not found for {client_id}")
raise RuntimeError(
"Proxy failed to Connect to backend! Check "
"`ray_client_server.err` and "
f"`ray_client_server_{server.port}.err` on the "
"head node of the cluster for the relevant logs. "
"By default these are located at "
"/tmp/ray/session_latest/logs.")
except Exception:
init_resp = ray_client_pb2.DataResponse(
init=ray_client_pb2.InitResponse(
ok=False, msg=traceback.format_exc()))
init_resp.req_id = init_req.req_id
yield init_resp
return None
new_iter = chain([modified_init_req], request_iterator)
resp_stream = stub.Datapath(
new_iter, metadata=[("client_id", client_id)])
new_iter = chain([modified_init_req], request_iterator)
stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
metadata = [("client_id", client_id), ("reconnecting",
str(reconnecting))]
resp_stream = stub.Datapath(new_iter, metadata=metadata)
for resp in resp_stream:
resp_type = resp.WhichOneof("type")
if resp_type == "connection_cleanup":
# Specific server is skipping cleanup, proxier should too
cleanup_requested = True
yield self.modify_connection_info_resp(resp)
except Exception:
except Exception as e:
logger.exception("Proxying Datapath failed!")
# Propogate error through context
recoverable = _propagate_error_in_context(e, context)
if not recoverable:
# Client shouldn't attempt to recover, clean up connection
cleanup_requested = True
finally:
server.set_result(None)
cleanup_delay = self.reconnect_grace_periods.get(client_id)
if not cleanup_requested and cleanup_delay is not None:
# Delay cleanup, since client may attempt a reconnect
# Wait on stopped event in case the server closes and we
# can clean up earlier
self.stopped.wait(timeout=cleanup_delay)
with self.clients_lock:
if client_id not in self.clients_last_seen:
logger.info(f"{client_id} not found. Skipping clean up.")
# Connection has already been cleaned up
return
last_seen = self.clients_last_seen[client_id]
logger.info(
f"{client_id} last started stream at {last_seen}. Current "
f"stream started at {start_time}.")
if last_seen > start_time:
logger.info("Client reconnected. Skipping cleanup.")
# Client has reconnected, don't clean up
return
logger.debug(f"Client detached: {client_id}")
self.num_clients -= 1
del self.clients_last_seen[client_id]
if client_id in self.reconnect_grace_periods:
del self.reconnect_grace_periods[client_id]
server.set_result(None)
class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
@ -641,7 +707,10 @@ class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
time.sleep(LOGSTREAM_RETRY_INTERVAL_SEC)
if channel is None:
context.set_code(grpc.StatusCode.UNAVAILABLE)
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(
"Logstream proxy failed to connect. Channel for client "
f"{client_id} not found.")
return None
stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel)

View file

@ -3,6 +3,7 @@ from concurrent import futures
import grpc
import base64
from collections import defaultdict
import functools
import queue
import pickle
@ -22,7 +23,7 @@ import time
import inspect
import json
from ray.util.client.common import (ClientServerHandle, GRPC_OPTIONS,
CLIENT_SERVER_MAX_THREADS)
CLIENT_SERVER_MAX_THREADS, ResponseCache)
from ray.util.client.server.proxier import serve_proxier
from ray.util.client.server.server_pickler import convert_from_arg
from ray.util.client.server.server_pickler import dumps_from_server
@ -40,6 +41,53 @@ TIMEOUT_FOR_SPECIFIC_SERVER_S = env_integer("TIMEOUT_FOR_SPECIFIC_SERVER_S",
30)
def _use_response_cache(func):
"""
Decorator for gRPC stubs. Before calling the real stubs, checks if there's
an existing entry in the caches. If there is, then return the cached
entry. Otherwise, call the real function and use the real cache
"""
@functools.wraps(func)
def wrapper(self, request, context):
metadata = {k: v for k, v in context.invocation_metadata()}
expected_ids = ("client_id", "thread_id", "req_id")
if any(i not in metadata for i in expected_ids):
# Missing IDs, skip caching and call underlying stub directly
return func(self, request, context)
# Get relevant IDs to check cache
client_id = metadata["client_id"]
thread_id = metadata["thread_id"]
req_id = int(metadata["req_id"])
# Check if response already cached
response_cache = self.response_caches[client_id]
cached_entry = response_cache.check_cache(thread_id, req_id)
if cached_entry is not None:
if isinstance(cached_entry, Exception):
# Original call errored, propogate error
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details(str(cached_entry))
raise cached_entry
return cached_entry
try:
# Response wasn't cached, call underlying stub and cache result
resp = func(self, request, context)
except Exception as e:
# Unexpected error in underlying stub -- update cache and
# propagate to user through context
response_cache.update_cache(thread_id, req_id, e)
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details(str(e))
raise
response_cache.update_cache(thread_id, req_id, resp)
return resp
return wrapper
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
def __init__(self, ray_connect_handler: Callable):
"""Construct a raylet service
@ -60,6 +108,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
self.named_actors = set()
self.state_lock = threading.Lock()
self.ray_connect_handler = ray_connect_handler
self.response_caches: Dict[str, ResponseCache] = defaultdict(
ResponseCache)
def Init(self, request: ray_client_pb2.InitRequest,
context=None) -> ray_client_pb2.InitResponse:
@ -105,6 +155,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
f"current one {current_job_config.runtime_env.uris}")
return ray_client_pb2.InitResponse(ok=True)
@_use_response_cache
def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse:
with disable_client_hook():
already_exists = ray.experimental.internal_kv._internal_kv_put(
@ -116,6 +167,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
value = ray.experimental.internal_kv._internal_kv_get(request.key)
return ray_client_pb2.KVGetResponse(value=value)
@_use_response_cache
def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse:
with disable_client_hook():
ray.experimental.internal_kv._internal_kv_del(request.key)
@ -235,6 +287,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
del self.object_refs[client_id]
if client_id in self.client_side_ref_map:
del self.client_side_ref_map[client_id]
if client_id in self.response_caches:
del self.response_caches[client_id]
logger.debug(f"Released all {count} objects for client {client_id}")
def _release_actors(self, client_id):
@ -252,6 +306,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
logger.debug(f"Released all {count} actors for client: {client_id}")
@_use_response_cache
def Terminate(self, req, context=None):
if req.WhichOneof("terminate_type") == "task_object":
try:
@ -428,6 +483,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
ready_object_ids=ready_object_ids,
remaining_object_ids=remaining_object_ids)
@_use_response_cache
def Schedule(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
logger.debug(

View file

@ -5,7 +5,9 @@ to the server.
import base64
import json
import logging
import os
import time
import threading
import uuid
import warnings
from collections import defaultdict
@ -21,11 +23,13 @@ from ray.cloudpickle.compat import pickle
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.exceptions import GetTimeoutError
from ray.ray_constants import DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD
from ray.util.client.client_pickler import (convert_to_arg, dumps_from_client,
loads_from_server)
from ray.util.client.common import (ClientActorClass, ClientActorHandle,
ClientActorRef, ClientObjectRef,
ClientRemoteFunc, ClientStub, GRPC_OPTIONS)
ClientRemoteFunc, ClientStub, GRPC_OPTIONS,
GRPC_UNRECOVERABLE_ERRORS, INT32_MAX)
from ray.util.client.dataclient import DataClient
from ray.util.client.logsclient import LogstreamClient
from ray.util.debug import log_once
@ -108,7 +112,22 @@ class Worker:
else:
self._credentials = None
self._reconnect_grace_period = DEFAULT_CLIENT_RECONNECT_GRACE_PERIOD
if "RAY_CLIENT_RECONNECT_GRACE_PERIOD" in os.environ:
# Use value in environment variable if available
self._reconnect_grace_period = \
int(os.environ["RAY_CLIENT_RECONNECT_GRACE_PERIOD"])
# Disable retries if grace period is set to 0
self._reconnect_enabled = self._reconnect_grace_period != 0
# Set to True when the connection cannot be recovered and reconnect
# attempts should be stopped
self._in_shutdown = False
# Set to True after initial connection succeeds
self._has_connected = False
self._connect_channel()
self._has_connected = True
# Initialize the streams to finish protocol negotiation.
self.data_client = DataClient(self, self._client_id, self.metadata)
@ -124,10 +143,19 @@ class Worker:
self.total_num_tasks_scheduled = 0
self.total_outbound_message_size_bytes = 0
def _connect_channel(self) -> None:
# Used to create unique IDs for RPCs to the RayletServicer
self._req_id_lock = threading.Lock()
self._req_id = 0
def _connect_channel(self, reconnecting=False) -> None:
"""
Attempts to connect to the server specified by conn_str.
Attempts to connect to the server specified by conn_str. If
reconnecting after an RPC error, cleans up the old channel and
continues to attempt to connect until the grace period is over.
"""
if self.channel is not None:
self.channel.unsubscribe(self._on_channel_state_change)
self.channel.close()
if self._secure:
if self._credentials is not None:
@ -144,11 +172,21 @@ class Worker:
# Retry the connection until the channel responds to something
# looking like a gRPC connection, though it may be a proxy.
start_time = time.time()
conn_attempts = 0
timeout = INITIAL_TIMEOUT_SEC
service_ready = False
while conn_attempts < max(self._connection_retries, 1):
while conn_attempts < max(self._connection_retries, 1) or reconnecting:
conn_attempts += 1
if self._in_shutdown:
# User manually closed the worker before connection finished
break
elapsed_time = time.time() - start_time
if reconnecting and elapsed_time > self._reconnect_grace_period:
self._in_shutdown = True
raise ConnectionError(
"Failed to reconnect within the reconnection grace period "
f"({self._reconnect_grace_period}s)")
try:
# Let gRPC wait for us to see if the channel becomes ready.
# If it throws, we couldn't connect.
@ -176,12 +214,17 @@ class Worker:
# Fallthrough, backoff, and retry at the top of the loop
logger.info("Waiting for Ray to become ready on the server, "
f"retry in {timeout}s...")
timeout = backoff(timeout)
if not reconnecting:
# Don't increase backoff when trying to reconnect --
# we already know the server exists, attempt to reconnect
# as soon as we can
timeout = backoff(timeout)
# If we made it through the loop without service_ready
# it means we've used up our retries and
# should error back to the user.
if not service_ready:
self._in_shutdown = True
if log_once("ray_client_security_groups"):
warnings.warn(
"Ray Client connection timed out. Ensure that "
@ -191,6 +234,73 @@ class Worker:
"more information.")
raise ConnectionError("ray client connection timeout")
def _can_reconnect(self, e: grpc.RpcError) -> bool:
"""
Returns True if the RPC error can be recovered from and a retry is
appropriate, false otherwise.
"""
if not self._reconnect_enabled:
return False
if self._in_shutdown:
# Channel is being shutdown, don't try to reconnect
return False
if e.code() in GRPC_UNRECOVERABLE_ERRORS:
# Unrecoverable error -- These errors are specifically raised
# by the server's application logic
return False
if e.code() == grpc.StatusCode.INTERNAL:
details = e.details()
if details == "Exception serializing request!":
# The client failed tried to send a bad request (for example,
# passing "None" instead of a valid grpc message). Don't
# try to reconnect/retry.
return False
# All other errors can be treated as recoverable
return True
def _call_stub(self, stub_name: str, *args, **kwargs) -> Any:
"""
Calls the stub specified by stub_name (Schedule, WaitObject, etc...).
If a recoverable error occurrs while calling the stub, attempts to
retry the RPC.
"""
while not self._in_shutdown:
try:
return getattr(self.server, stub_name)(*args, **kwargs)
except grpc.RpcError as e:
if self._can_reconnect(e):
time.sleep(.5)
continue
raise
except ValueError:
# Trying to use the stub on a cancelled channel will raise
# ValueError. This should only happen when the data client
# is attempting to reset the connection -- sleep and try
# again.
time.sleep(.5)
continue
raise ConnectionError("Client is shutting down.")
def _add_ids_to_metadata(self, metadata: Any):
"""
Adds a unique req_id and the current thread's identifier to the
metadata. These values are useful for preventing mutating operations
from being replayed on the server side in the event that the client
must retry a requsest.
Args:
metadata - the gRPC metadata to append the IDs to
"""
if not self._reconnect_enabled:
# IDs not needed if the reconnects are disabled
return metadata
thread_id = str(threading.get_ident())
with self._req_id_lock:
self._req_id += 1
if self._req_id > INT32_MAX:
self._req_id = 1
req_id = str(self._req_id)
return metadata + [("thread_id", thread_id), ("req_id", req_id)]
def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity):
logger.debug(f"client gRPC channel state change: {conn_state}")
self._conn_state = conn_state
@ -324,7 +434,7 @@ class Worker:
"client_id": self._client_id,
}
req = ray_client_pb2.WaitRequest(**data)
resp = self.server.WaitObject(req, metadata=self.metadata)
resp = self._call_stub("WaitObject", req, metadata=self.metadata)
if not resp.valid:
# TODO(ameer): improve error/exceptions messages.
raise Exception("Client Wait request failed. Reference invalid?")
@ -350,11 +460,11 @@ class Worker:
self, task: ray_client_pb2.ClientTask) -> List[bytes]:
logger.debug("Scheduling %s" % task)
task.client_id = self._client_id
metadata = self._add_ids_to_metadata(self.metadata)
try:
ticket = self.server.Schedule(task, metadata=self.metadata)
ticket = self._call_stub("Schedule", task, metadata=metadata)
except grpc.RpcError as e:
raise decode_exception(e)
if not ticket.valid:
try:
raise cloudpickle.loads(ticket.error)
@ -412,6 +522,7 @@ class Worker:
self.reference_count[id] += 1
def close(self):
self._in_shutdown = True
self.data_client.close()
self.log_client.close()
if self.channel:
@ -441,7 +552,8 @@ class Worker:
try:
term = ray_client_pb2.TerminateRequest(actor=term_actor)
term.client_id = self._client_id
self.server.Terminate(term, metadata=self.metadata)
metadata = self._add_ids_to_metadata(self.metadata)
self._call_stub("Terminate", term, metadata=metadata)
except grpc.RpcError as e:
raise decode_exception(e)
@ -458,14 +570,18 @@ class Worker:
try:
term = ray_client_pb2.TerminateRequest(task_object=term_object)
term.client_id = self._client_id
self.server.Terminate(term, metadata=self.metadata)
metadata = self._add_ids_to_metadata(self.metadata)
self._call_stub("Terminate", term, metadata=metadata)
except grpc.RpcError as e:
raise decode_exception(e)
def get_cluster_info(self, type: ray_client_pb2.ClusterInfoType.TypeEnum):
def get_cluster_info(self,
type: ray_client_pb2.ClusterInfoType.TypeEnum,
timeout: Optional[float] = None):
req = ray_client_pb2.ClusterInfoRequest()
req.type = type
resp = self.server.ClusterInfo(req, metadata=self.metadata)
resp = self.server.ClusterInfo(
req, timeout=timeout, metadata=self.metadata)
if resp.WhichOneof("response_type") == "resource_table":
# translate from a proto map to a python dict
output_dict = {k: v for k, v in resp.resource_table.table.items()}
@ -476,35 +592,37 @@ class Worker:
def internal_kv_get(self, key: bytes) -> bytes:
req = ray_client_pb2.KVGetRequest(key=key)
resp = self.server.KVGet(req, metadata=self.metadata)
resp = self._call_stub("KVGet", req, metadata=self.metadata)
return resp.value
def internal_kv_exists(self, key: bytes) -> bytes:
req = ray_client_pb2.KVGetRequest(key=key)
resp = self.server.KVGet(req, metadata=self.metadata)
resp = self._call_stub("KVGet", req, metadata=self.metadata)
return resp.value
def internal_kv_put(self, key: bytes, value: bytes,
overwrite: bool) -> bool:
req = ray_client_pb2.KVPutRequest(
key=key, value=value, overwrite=overwrite)
resp = self.server.KVPut(req, metadata=self.metadata)
metadata = self._add_ids_to_metadata(self.metadata)
resp = self._call_stub("KVPut", req, metadata=metadata)
return resp.already_exists
def internal_kv_del(self, key: bytes) -> None:
req = ray_client_pb2.KVDelRequest(key=key)
self.server.KVDel(req, metadata=self.metadata)
metadata = self._add_ids_to_metadata(self.metadata)
self._call_stub("KVDel", req, metadata=metadata)
def internal_kv_list(self, prefix: bytes) -> bytes:
req = ray_client_pb2.KVListRequest(prefix=prefix)
return self.server.KVList(req, metadata=self.metadata).keys
return self._call_stub("KVList", req, metadata=self.metadata).keys
def list_named_actors(self, all_namespaces: bool) -> List[Dict[str, str]]:
req = ray_client_pb2.ClientListNamedActorsRequest(
all_namespaces=all_namespaces)
return json.loads(
self.server.ListNamedActors(req,
metadata=self.metadata).actors_json)
self._call_stub("ListNamedActors", req,
metadata=self.metadata).actors_json)
def is_initialized(self) -> bool:
if self.server is not None:
@ -512,7 +630,7 @@ class Worker:
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
return False
def ping_server(self) -> bool:
def ping_server(self, timeout=None) -> bool:
"""Simple health check.
Piggybacks the IS_INITIALIZED call to check if the server provides
@ -520,12 +638,13 @@ class Worker:
"""
if self.server is not None:
logger.debug("Pinging server.")
result = self.get_cluster_info(ray_client_pb2.ClusterInfoType.PING)
result = self.get_cluster_info(
ray_client_pb2.ClusterInfoType.PING, timeout=timeout)
return result is not None
return False
def is_connected(self) -> bool:
return self._conn_state == grpc.ChannelConnectivity.READY
return not self._in_shutdown and self._has_connected
def _server_init(self,
job_config: JobConfig,
@ -549,7 +668,8 @@ class Worker:
response = self.data_client.Init(
ray_client_pb2.InitRequest(
job_config=serialized_job_config,
ray_init_kwargs=json.dumps(ray_init_kwargs)))
ray_init_kwargs=json.dumps(ray_init_kwargs),
reconnect_grace_period=self._reconnect_grace_period))
if not response.ok:
raise ConnectionAbortedError(
f"Initialization failure from server:\n{response.msg}")

View file

@ -259,6 +259,7 @@ message InitRequest {
// job_config of ray.init
bytes job_config = 1;
string ray_init_kwargs = 2;
int32 reconnect_grace_period = 3;
}
message InitResponse {
@ -339,6 +340,20 @@ message ConnectionInfoResponse {
string protocol_version = 5;
}
message ConnectionCleanupRequest {
// Explicitly request that connection is cleaned up for graceful shutdown
}
message ConnectionCleanupResponse {
// Acknowledge cleanup request
}
message AcknowledgeRequest {
// Used to acknowledge that all requests up to the given req_id have been
// received
int32 req_id = 1;
}
message DataRequest {
// An incrementing counter of request IDs on the Datapath,
// to match requests with responses asynchronously.
@ -350,6 +365,8 @@ message DataRequest {
ConnectionInfoRequest connection_info = 5;
InitRequest init = 6;
PrepRuntimeEnvRequest prep_runtime_env = 7;
ConnectionCleanupRequest connection_cleanup = 8;
AcknowledgeRequest acknowledge = 9;
}
}
@ -363,6 +380,7 @@ message DataResponse {
ConnectionInfoResponse connection_info = 5;
InitResponse init = 6;
PrepRuntimeEnvResponse prep_runtime_env = 7;
ConnectionCleanupResponse connection_cleanup = 8;
}
}