mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
* revert revertchunkedgets * exit early if all chunks received, tighter exception handler for stream in proxy
This commit is contained in:
parent
f74ad24901
commit
6416c65505
8 changed files with 303 additions and 35 deletions
|
@ -1,4 +1,5 @@
|
|||
from concurrent import futures
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import threading
|
||||
|
@ -135,7 +136,8 @@ class MiddlemanRayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
context.set_code(e.code())
|
||||
context.set_details(e.details())
|
||||
raise
|
||||
if self.on_response:
|
||||
if self.on_response and method != "GetObject":
|
||||
# GetObject streams response, handle on_response separately
|
||||
self.on_response(response)
|
||||
return response
|
||||
|
||||
|
@ -169,7 +171,10 @@ class MiddlemanRayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
return self._call_inner_function(req, context, "Terminate")
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
return self._call_inner_function(request, context, "GetObject")
|
||||
for response in self._call_inner_function(request, context, "GetObject"):
|
||||
if self.on_response:
|
||||
self.on_response(response)
|
||||
yield response
|
||||
|
||||
def PutObject(
|
||||
self, request: ray_client_pb2.PutRequest, context=None
|
||||
|
@ -334,6 +339,73 @@ def test_disconnect_during_get():
|
|||
disconnect_thread.join()
|
||||
|
||||
|
||||
def test_disconnects_during_large_get():
|
||||
"""
|
||||
Disconnect repeatedly during a large (multi-chunk) get.
|
||||
"""
|
||||
i = 0
|
||||
started = False
|
||||
|
||||
def fail_every_three(_):
|
||||
# Inject an error every third time this method is called
|
||||
nonlocal i, started
|
||||
if not started:
|
||||
return
|
||||
i += 1
|
||||
if i % 3 == 0:
|
||||
raise RuntimeError
|
||||
|
||||
@ray.remote
|
||||
def large_result():
|
||||
# 1024x1024x128 float64 matrix (1024 MiB). With 64MiB chunk size,
|
||||
# it will take at least 16 chunks to transfer this object. Since
|
||||
# the failure is injected every 3 chunks, this transfer can only
|
||||
# work if the chunked get request retries at the last received chunk
|
||||
# (instead of starting from the beginning each retry)
|
||||
return np.random.random((1024, 1024, 128))
|
||||
|
||||
with start_middleman_server(on_task_response=fail_every_three):
|
||||
started = True
|
||||
result = ray.get(large_result.remote())
|
||||
assert result.shape == (1024, 1024, 128)
|
||||
|
||||
|
||||
def test_disconnects_during_large_async_get():
|
||||
"""
|
||||
Disconnect repeatedly during a large (multi-chunk) async get.
|
||||
"""
|
||||
i = 0
|
||||
started = False
|
||||
|
||||
def fail_every_three(_):
|
||||
# Inject an error every third time this method is called
|
||||
nonlocal i, started
|
||||
if not started:
|
||||
return
|
||||
i += 1
|
||||
if i % 3 == 0:
|
||||
raise RuntimeError
|
||||
|
||||
@ray.remote
|
||||
def large_result():
|
||||
# 1024x1024x128 float64 matrix (1024 MiB). With 64MiB chunk size,
|
||||
# it will take at least 16 chunks to transfer this object. Since
|
||||
# the failure is injected every 3 chunks, this transfer can only
|
||||
# work if the chunked get request retries at the last received chunk
|
||||
# (instead of starting from the beginning each retry)
|
||||
return np.random.random((1024, 1024, 128))
|
||||
|
||||
with start_middleman_server(on_data_response=fail_every_three):
|
||||
started = True
|
||||
|
||||
async def get_large_result():
|
||||
return await large_result.remote()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
result = loop.run_until_complete(get_large_result())
|
||||
assert result.shape == (1024, 1024, 128)
|
||||
|
||||
|
||||
def test_disconnect_during_large_put():
|
||||
"""
|
||||
Disconnect during a large (multi-chunk) put.
|
||||
|
|
|
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# This version string is incremented to indicate breaking changes in the
|
||||
# protocol that require upgrading the client version.
|
||||
CURRENT_PROTOCOL_VERSION = "2022-02-22"
|
||||
CURRENT_PROTOCOL_VERSION = "2022-03-16"
|
||||
|
||||
|
||||
class _ClientContext:
|
||||
|
|
|
@ -172,6 +172,8 @@ class ClientObjectRef(raylet.ObjectRef):
|
|||
|
||||
if isinstance(resp, Exception):
|
||||
data = resp
|
||||
elif isinstance(resp, bytearray):
|
||||
data = loads_from_server(resp)
|
||||
else:
|
||||
obj = resp.get
|
||||
data = None
|
||||
|
|
|
@ -6,7 +6,6 @@ import logging
|
|||
import queue
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
import grpc
|
||||
|
||||
from collections import OrderedDict
|
||||
|
@ -69,6 +68,83 @@ def chunk_put(req: ray_client_pb2.DataRequest):
|
|||
yield ray_client_pb2.DataRequest(req_id=req.req_id, put=chunk)
|
||||
|
||||
|
||||
class ChunkCollector:
|
||||
"""
|
||||
This object collects chunks from async get requests via __call__, and
|
||||
calls the underlying callback when the object is fully received, or if an
|
||||
exception while retrieving the object occurs.
|
||||
|
||||
This is not used in synchronous gets (synchronous gets interact with the
|
||||
raylet servicer directly, not through the datapath).
|
||||
|
||||
__call__ returns true once the underlying call back has been called.
|
||||
"""
|
||||
|
||||
def __init__(self, callback: ResponseCallable, request: ray_client_pb2.DataRequest):
|
||||
# Bytearray containing data received so far
|
||||
self.data = bytearray()
|
||||
# The callback that will be called once all data is received
|
||||
self.callback = callback
|
||||
# The id of the last chunk we've received, or -1 if haven't seen any yet
|
||||
self.last_seen_chunk = -1
|
||||
# The GetRequest that initiated the transfer. start_chunk_id will be
|
||||
# updated as chunks are received to avoid re-requesting chunks that
|
||||
# we've already received.
|
||||
self.request = request
|
||||
|
||||
def __call__(self, response: Union[ray_client_pb2.DataResponse, Exception]) -> bool:
|
||||
if isinstance(response, Exception):
|
||||
self.callback(response)
|
||||
return True
|
||||
get_resp = response.get
|
||||
if not get_resp.valid:
|
||||
self.callback(response)
|
||||
return True
|
||||
if get_resp.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once(
|
||||
"client_object_transfer_size_warning"
|
||||
):
|
||||
size_gb = get_resp.total_size / 2 ** 30
|
||||
warnings.warn(
|
||||
"Ray Client is attempting to retrieve a "
|
||||
f"{size_gb:.2f} GiB object over the network, which may "
|
||||
"be slow. Consider serializing the object to a file and "
|
||||
"using rsync or S3 instead.",
|
||||
UserWarning,
|
||||
)
|
||||
chunk_data = get_resp.data
|
||||
chunk_id = get_resp.chunk_id
|
||||
if chunk_id == self.last_seen_chunk + 1:
|
||||
self.data.extend(chunk_data)
|
||||
self.last_seen_chunk = chunk_id
|
||||
# If we disconnect partway through, restart the get request
|
||||
# at the first chunk we haven't seen
|
||||
self.request.get.start_chunk_id = self.last_seen_chunk + 1
|
||||
elif chunk_id > self.last_seen_chunk + 1:
|
||||
# A chunk was skipped. This shouldn't happen in practice since
|
||||
# grpc guarantees that chunks will arrive in order.
|
||||
msg = (
|
||||
f"Received chunk {chunk_id} when we expected "
|
||||
f"{self.last_seen_chunk + 1} for request {response.req_id}"
|
||||
)
|
||||
logger.warning(msg)
|
||||
self.callback(RuntimeError(msg))
|
||||
return True
|
||||
else:
|
||||
# We received a chunk that've already seen before. Ignore, since
|
||||
# it should already be appended to self.data.
|
||||
logger.debug(
|
||||
f"Received a repeated chunk {chunk_id} "
|
||||
f"from request {response.req_id}."
|
||||
)
|
||||
|
||||
if get_resp.chunk_id == get_resp.total_chunks - 1:
|
||||
self.callback(self.data)
|
||||
return True
|
||||
else:
|
||||
# Not done yet
|
||||
return False
|
||||
|
||||
|
||||
class DataClient:
|
||||
def __init__(self, client_worker: "Worker", client_id: str, metadata: list):
|
||||
"""Initializes a thread-safe datapath over a Ray Client gRPC channel.
|
||||
|
@ -177,20 +253,25 @@ class DataClient:
|
|||
logger.debug(f"Got unawaited response {response}")
|
||||
return
|
||||
if response.req_id in self.asyncio_waiting_data:
|
||||
can_remove = True
|
||||
try:
|
||||
# NOTE: calling self.asyncio_waiting_data.pop() results
|
||||
callback = self.asyncio_waiting_data[response.req_id]
|
||||
if isinstance(callback, ChunkCollector):
|
||||
can_remove = callback(response)
|
||||
elif callback:
|
||||
callback(response)
|
||||
if can_remove:
|
||||
# NOTE: calling del self.asyncio_waiting_data results
|
||||
# in the destructor of ClientObjectRef running, which
|
||||
# calls ReleaseObject(). So self.asyncio_waiting_data
|
||||
# is accessed without holding self.lock. Holding the
|
||||
# lock shouldn't be necessary either.
|
||||
callback = self.asyncio_waiting_data.pop(response.req_id)
|
||||
if callback:
|
||||
callback(response)
|
||||
del self.asyncio_waiting_data[response.req_id]
|
||||
except Exception:
|
||||
logger.exception("Callback error:")
|
||||
with self.lock:
|
||||
# Update outstanding requests
|
||||
if response.req_id in self.outstanding_requests:
|
||||
if response.req_id in self.outstanding_requests and can_remove:
|
||||
del self.outstanding_requests[response.req_id]
|
||||
# Acknowledge response
|
||||
self._acknowledge(response.req_id)
|
||||
|
@ -428,7 +509,8 @@ class DataClient:
|
|||
datareq = ray_client_pb2.DataRequest(
|
||||
get=request,
|
||||
)
|
||||
self._async_send(datareq, callback)
|
||||
collector = ChunkCollector(callback=callback, request=datareq)
|
||||
self._async_send(datareq, collector)
|
||||
|
||||
# TODO: convert PutObject to async
|
||||
def PutObject(
|
||||
|
|
|
@ -525,7 +525,12 @@ class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
return self._call_inner_function(req, context, "Terminate")
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
return self._call_inner_function(request, context, "GetObject")
|
||||
try:
|
||||
yield from self._call_inner_function(request, context, "GetObject")
|
||||
except Exception as e:
|
||||
# Error while iterating over response from GetObject stream
|
||||
logger.exception("Proxying call to GetObject failed!")
|
||||
_propagate_error_in_context(e, context)
|
||||
|
||||
def PutObject(
|
||||
self, request: ray_client_pb2.PutRequest, context=None
|
||||
|
|
|
@ -5,6 +5,7 @@ import grpc
|
|||
import base64
|
||||
from collections import defaultdict
|
||||
import functools
|
||||
import math
|
||||
import queue
|
||||
import pickle
|
||||
|
||||
|
@ -28,6 +29,7 @@ from ray.util.client.common import (
|
|||
ClientServerHandle,
|
||||
GRPC_OPTIONS,
|
||||
CLIENT_SERVER_MAX_THREADS,
|
||||
OBJECT_TRANSFER_CHUNK_SIZE,
|
||||
ResponseCache,
|
||||
)
|
||||
from ray import ray_constants
|
||||
|
@ -379,14 +381,32 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
with disable_client_hook():
|
||||
|
||||
def send_get_response(result: Any) -> None:
|
||||
"""Pushes a GetResponse to the main DataPath loop to send
|
||||
"""Pushes GetResponses to the main DataPath loop to send
|
||||
to the client. This is called when the object is ready
|
||||
on the server side."""
|
||||
try:
|
||||
serialized = dumps_from_server(result, client_id, self)
|
||||
get_resp = ray_client_pb2.GetResponse(
|
||||
valid=True, data=serialized
|
||||
total_size = len(serialized)
|
||||
assert total_size > 0, "Serialized object cannot be zero bytes"
|
||||
total_chunks = math.ceil(
|
||||
total_size / OBJECT_TRANSFER_CHUNK_SIZE
|
||||
)
|
||||
for chunk_id in range(request.start_chunk_id, total_chunks):
|
||||
start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
|
||||
end = min(
|
||||
total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE
|
||||
)
|
||||
get_resp = ray_client_pb2.GetResponse(
|
||||
valid=True,
|
||||
data=serialized[start:end],
|
||||
chunk_id=chunk_id,
|
||||
total_chunks=total_chunks,
|
||||
total_size=total_size,
|
||||
)
|
||||
chunk_resp = ray_client_pb2.DataResponse(
|
||||
get=get_resp, req_id=req_id
|
||||
)
|
||||
result_queue.put(chunk_resp)
|
||||
except Exception as exc:
|
||||
get_resp = ray_client_pb2.GetResponse(
|
||||
valid=False, error=cloudpickle.dumps(exc)
|
||||
|
@ -404,13 +424,14 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||
client_id = metadata.get("client_id")
|
||||
if client_id is None:
|
||||
return ray_client_pb2.GetResponse(
|
||||
yield ray_client_pb2.GetResponse(
|
||||
valid=False,
|
||||
error=cloudpickle.dumps(
|
||||
ValueError("client_id is not specified in request metadata")
|
||||
),
|
||||
)
|
||||
return self._get_object(request, client_id)
|
||||
else:
|
||||
yield from self._get_object(request, client_id)
|
||||
|
||||
def _get_object(self, request: ray_client_pb2.GetRequest, client_id: str):
|
||||
objectrefs = []
|
||||
|
@ -419,7 +440,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
if ref:
|
||||
objectrefs.append(ref)
|
||||
else:
|
||||
return ray_client_pb2.GetResponse(
|
||||
yield ray_client_pb2.GetResponse(
|
||||
valid=False,
|
||||
error=cloudpickle.dumps(
|
||||
ValueError(
|
||||
|
@ -428,14 +449,28 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
)
|
||||
),
|
||||
)
|
||||
return
|
||||
try:
|
||||
logger.debug("get: %s" % objectrefs)
|
||||
with disable_client_hook():
|
||||
items = ray.get(objectrefs, timeout=request.timeout)
|
||||
except Exception as e:
|
||||
return ray_client_pb2.GetResponse(valid=False, error=cloudpickle.dumps(e))
|
||||
yield ray_client_pb2.GetResponse(valid=False, error=cloudpickle.dumps(e))
|
||||
return
|
||||
serialized = dumps_from_server(items, client_id, self)
|
||||
return ray_client_pb2.GetResponse(valid=True, data=serialized)
|
||||
total_size = len(serialized)
|
||||
assert total_size > 0, "Serialized object cannot be zero bytes"
|
||||
total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE)
|
||||
for chunk_id in range(request.start_chunk_id, total_chunks):
|
||||
start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
|
||||
end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE)
|
||||
yield ray_client_pb2.GetResponse(
|
||||
valid=True,
|
||||
data=serialized[start:end],
|
||||
chunk_id=chunk_id,
|
||||
total_chunks=total_chunks,
|
||||
total_size=total_size,
|
||||
)
|
||||
|
||||
def PutObject(
|
||||
self, request: ray_client_pb2.PutRequest, context=None
|
||||
|
|
|
@ -41,6 +41,7 @@ from ray.util.client.common import (
|
|||
GRPC_OPTIONS,
|
||||
GRPC_UNRECOVERABLE_ERRORS,
|
||||
INT32_MAX,
|
||||
OBJECT_TRANSFER_WARNING_SIZE,
|
||||
)
|
||||
from ray.util.client.dataclient import DataClient
|
||||
from ray.util.client.logsclient import LogstreamClient
|
||||
|
@ -309,6 +310,54 @@ class Worker:
|
|||
continue
|
||||
raise ConnectionError("Client is shutting down.")
|
||||
|
||||
def _get_object_iterator(
|
||||
self, req: ray_client_pb2.GetRequest, *args, **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Calls the stub for GetObject on the underlying server stub. If a
|
||||
recoverable error occurs while streaming the response, attempts
|
||||
to retry the get starting from the first chunk that hasn't been
|
||||
received.
|
||||
"""
|
||||
last_seen_chunk = -1
|
||||
while not self._in_shutdown:
|
||||
# If we disconnect partway through, restart the get request
|
||||
# at the first chunk we haven't seen
|
||||
req.start_chunk_id = last_seen_chunk + 1
|
||||
try:
|
||||
for chunk in self.server.GetObject(req, *args, **kwargs):
|
||||
if chunk.chunk_id <= last_seen_chunk:
|
||||
# Ignore repeat chunks
|
||||
logger.debug(
|
||||
f"Received a repeated chunk {chunk.chunk_id} "
|
||||
f"from request {req.req_id}."
|
||||
)
|
||||
continue
|
||||
if last_seen_chunk + 1 != chunk.chunk_id:
|
||||
raise RuntimeError(
|
||||
f"Received chunk {chunk.chunk_id} when we expected "
|
||||
f"{self.last_seen_chunk + 1}"
|
||||
)
|
||||
last_seen_chunk = chunk.chunk_id
|
||||
yield chunk
|
||||
if last_seen_chunk == chunk.total_chunks - 1:
|
||||
# We've yielded the last chunk, exit early
|
||||
return
|
||||
return
|
||||
except grpc.RpcError as e:
|
||||
if self._can_reconnect(e):
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
raise
|
||||
except ValueError:
|
||||
# Trying to use the stub on a cancelled channel will raise
|
||||
# ValueError. This should only happen when the data client
|
||||
# is attempting to reset the connection -- sleep and try
|
||||
# again.
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
raise ConnectionError("Client is shutting down.")
|
||||
|
||||
def _add_ids_to_metadata(self, metadata: Any):
|
||||
"""
|
||||
Adds a unique req_id and the current thread's identifier to the
|
||||
|
@ -399,18 +448,33 @@ class Worker:
|
|||
|
||||
def _get(self, ref: List[ClientObjectRef], timeout: float):
|
||||
req = ray_client_pb2.GetRequest(ids=[r.id for r in ref], timeout=timeout)
|
||||
data = bytearray()
|
||||
try:
|
||||
resp = self._call_stub("GetObject", req, metadata=self.metadata)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e)
|
||||
if not resp.valid:
|
||||
resp = self._get_object_iterator(req, metadata=self.metadata)
|
||||
for chunk in resp:
|
||||
if not chunk.valid:
|
||||
try:
|
||||
err = cloudpickle.loads(resp.error)
|
||||
err = cloudpickle.loads(chunk.error)
|
||||
except (pickle.UnpicklingError, TypeError):
|
||||
logger.exception("Failed to deserialize {}".format(resp.error))
|
||||
logger.exception("Failed to deserialize {}".format(chunk.error))
|
||||
raise
|
||||
raise err
|
||||
return loads_from_server(resp.data)
|
||||
if chunk.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once(
|
||||
"client_object_transfer_size_warning"
|
||||
):
|
||||
size_gb = chunk.total_size / 2 ** 30
|
||||
warnings.warn(
|
||||
"Ray Client is attempting to retrieve a "
|
||||
f"{size_gb:.2f} GiB object over the network, which may "
|
||||
"be slow. Consider serializing the object to a file "
|
||||
"and using S3 or rsync instead.",
|
||||
UserWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
data.extend(chunk.data)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e)
|
||||
return loads_from_server(data)
|
||||
|
||||
def put(self, val, *, client_ref_id: bytes = None):
|
||||
if isinstance(val, ClientObjectRef):
|
||||
|
|
|
@ -128,6 +128,9 @@ message GetRequest {
|
|||
float timeout = 2;
|
||||
// Whether to schedule this as a callback on the server side.
|
||||
bool asynchronous = 3;
|
||||
// The chunk_id to start retrieving data from, in case the request is interrupted
|
||||
// after partial retrieval by a disconnect
|
||||
int32 start_chunk_id = 5;
|
||||
|
||||
// Deprecated fields.
|
||||
bytes id = 1 [deprecated = true];
|
||||
|
@ -141,6 +144,12 @@ message GetResponse {
|
|||
bytes data = 2;
|
||||
// An error blob (for example, an exception) on failure.
|
||||
bytes error = 3;
|
||||
// Identifies which chunk the data belongs to
|
||||
int32 chunk_id = 4;
|
||||
// Total number of chunks
|
||||
int32 total_chunks = 5;
|
||||
// Total size in bytes of the data being retrieved
|
||||
uint64 total_size = 6;
|
||||
}
|
||||
|
||||
// Waits for data to be ready on the server, with a timeout.
|
||||
|
@ -293,8 +302,7 @@ service RayletDriver {
|
|||
}
|
||||
rpc PrepRuntimeEnv(PrepRuntimeEnvRequest) returns (PrepRuntimeEnvResponse) {
|
||||
}
|
||||
rpc GetObject(GetRequest) returns (GetResponse) {
|
||||
}
|
||||
rpc GetObject(GetRequest) returns (stream GetResponse) {}
|
||||
rpc PutObject(PutRequest) returns (PutResponse) {
|
||||
}
|
||||
rpc WaitObject(WaitRequest) returns (WaitResponse) {
|
||||
|
|
Loading…
Add table
Reference in a new issue