Revert "Revert "[Client] chunked get requests (#22455)"" (#23261)

* revert revertchunkedgets

* exit early if all chunks received, tighter exception handler for stream in proxy
This commit is contained in:
Chris K. W 2022-03-17 16:24:30 -07:00 committed by GitHub
parent f74ad24901
commit 6416c65505
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 303 additions and 35 deletions

View file

@ -1,4 +1,5 @@
from concurrent import futures
import asyncio
import contextlib
import os
import threading
@ -135,7 +136,8 @@ class MiddlemanRayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
context.set_code(e.code())
context.set_details(e.details())
raise
if self.on_response:
if self.on_response and method != "GetObject":
# GetObject streams response, handle on_response separately
self.on_response(response)
return response
@ -169,7 +171,10 @@ class MiddlemanRayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
return self._call_inner_function(req, context, "Terminate")
def GetObject(self, request, context=None):
return self._call_inner_function(request, context, "GetObject")
for response in self._call_inner_function(request, context, "GetObject"):
if self.on_response:
self.on_response(response)
yield response
def PutObject(
self, request: ray_client_pb2.PutRequest, context=None
@ -334,6 +339,73 @@ def test_disconnect_during_get():
disconnect_thread.join()
def test_disconnects_during_large_get():
"""
Disconnect repeatedly during a large (multi-chunk) get.
"""
i = 0
started = False
def fail_every_three(_):
# Inject an error every third time this method is called
nonlocal i, started
if not started:
return
i += 1
if i % 3 == 0:
raise RuntimeError
@ray.remote
def large_result():
# 1024x1024x128 float64 matrix (1024 MiB). With 64MiB chunk size,
# it will take at least 16 chunks to transfer this object. Since
# the failure is injected every 3 chunks, this transfer can only
# work if the chunked get request retries at the last received chunk
# (instead of starting from the beginning each retry)
return np.random.random((1024, 1024, 128))
with start_middleman_server(on_task_response=fail_every_three):
started = True
result = ray.get(large_result.remote())
assert result.shape == (1024, 1024, 128)
def test_disconnects_during_large_async_get():
"""
Disconnect repeatedly during a large (multi-chunk) async get.
"""
i = 0
started = False
def fail_every_three(_):
# Inject an error every third time this method is called
nonlocal i, started
if not started:
return
i += 1
if i % 3 == 0:
raise RuntimeError
@ray.remote
def large_result():
# 1024x1024x128 float64 matrix (1024 MiB). With 64MiB chunk size,
# it will take at least 16 chunks to transfer this object. Since
# the failure is injected every 3 chunks, this transfer can only
# work if the chunked get request retries at the last received chunk
# (instead of starting from the beginning each retry)
return np.random.random((1024, 1024, 128))
with start_middleman_server(on_data_response=fail_every_three):
started = True
async def get_large_result():
return await large_result.remote()
loop = asyncio.get_event_loop()
result = loop.run_until_complete(get_large_result())
assert result.shape == (1024, 1024, 128)
def test_disconnect_during_large_put():
"""
Disconnect during a large (multi-chunk) put.

View file

@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
# This version string is incremented to indicate breaking changes in the
# protocol that require upgrading the client version.
CURRENT_PROTOCOL_VERSION = "2022-02-22"
CURRENT_PROTOCOL_VERSION = "2022-03-16"
class _ClientContext:

View file

@ -172,6 +172,8 @@ class ClientObjectRef(raylet.ObjectRef):
if isinstance(resp, Exception):
data = resp
elif isinstance(resp, bytearray):
data = loads_from_server(resp)
else:
obj = resp.get
data = None

View file

@ -6,7 +6,6 @@ import logging
import queue
import threading
import warnings
import grpc
from collections import OrderedDict
@ -69,6 +68,83 @@ def chunk_put(req: ray_client_pb2.DataRequest):
yield ray_client_pb2.DataRequest(req_id=req.req_id, put=chunk)
class ChunkCollector:
"""
This object collects chunks from async get requests via __call__, and
calls the underlying callback when the object is fully received, or if an
exception while retrieving the object occurs.
This is not used in synchronous gets (synchronous gets interact with the
raylet servicer directly, not through the datapath).
__call__ returns true once the underlying call back has been called.
"""
def __init__(self, callback: ResponseCallable, request: ray_client_pb2.DataRequest):
# Bytearray containing data received so far
self.data = bytearray()
# The callback that will be called once all data is received
self.callback = callback
# The id of the last chunk we've received, or -1 if haven't seen any yet
self.last_seen_chunk = -1
# The GetRequest that initiated the transfer. start_chunk_id will be
# updated as chunks are received to avoid re-requesting chunks that
# we've already received.
self.request = request
def __call__(self, response: Union[ray_client_pb2.DataResponse, Exception]) -> bool:
if isinstance(response, Exception):
self.callback(response)
return True
get_resp = response.get
if not get_resp.valid:
self.callback(response)
return True
if get_resp.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once(
"client_object_transfer_size_warning"
):
size_gb = get_resp.total_size / 2 ** 30
warnings.warn(
"Ray Client is attempting to retrieve a "
f"{size_gb:.2f} GiB object over the network, which may "
"be slow. Consider serializing the object to a file and "
"using rsync or S3 instead.",
UserWarning,
)
chunk_data = get_resp.data
chunk_id = get_resp.chunk_id
if chunk_id == self.last_seen_chunk + 1:
self.data.extend(chunk_data)
self.last_seen_chunk = chunk_id
# If we disconnect partway through, restart the get request
# at the first chunk we haven't seen
self.request.get.start_chunk_id = self.last_seen_chunk + 1
elif chunk_id > self.last_seen_chunk + 1:
# A chunk was skipped. This shouldn't happen in practice since
# grpc guarantees that chunks will arrive in order.
msg = (
f"Received chunk {chunk_id} when we expected "
f"{self.last_seen_chunk + 1} for request {response.req_id}"
)
logger.warning(msg)
self.callback(RuntimeError(msg))
return True
else:
# We received a chunk that've already seen before. Ignore, since
# it should already be appended to self.data.
logger.debug(
f"Received a repeated chunk {chunk_id} "
f"from request {response.req_id}."
)
if get_resp.chunk_id == get_resp.total_chunks - 1:
self.callback(self.data)
return True
else:
# Not done yet
return False
class DataClient:
def __init__(self, client_worker: "Worker", client_id: str, metadata: list):
"""Initializes a thread-safe datapath over a Ray Client gRPC channel.
@ -177,20 +253,25 @@ class DataClient:
logger.debug(f"Got unawaited response {response}")
return
if response.req_id in self.asyncio_waiting_data:
can_remove = True
try:
# NOTE: calling self.asyncio_waiting_data.pop() results
# in the destructor of ClientObjectRef running, which
# calls ReleaseObject(). So self.asyncio_waiting_data
# is accessed without holding self.lock. Holding the
# lock shouldn't be necessary either.
callback = self.asyncio_waiting_data.pop(response.req_id)
if callback:
callback = self.asyncio_waiting_data[response.req_id]
if isinstance(callback, ChunkCollector):
can_remove = callback(response)
elif callback:
callback(response)
if can_remove:
# NOTE: calling del self.asyncio_waiting_data results
# in the destructor of ClientObjectRef running, which
# calls ReleaseObject(). So self.asyncio_waiting_data
# is accessed without holding self.lock. Holding the
# lock shouldn't be necessary either.
del self.asyncio_waiting_data[response.req_id]
except Exception:
logger.exception("Callback error:")
with self.lock:
# Update outstanding requests
if response.req_id in self.outstanding_requests:
if response.req_id in self.outstanding_requests and can_remove:
del self.outstanding_requests[response.req_id]
# Acknowledge response
self._acknowledge(response.req_id)
@ -428,7 +509,8 @@ class DataClient:
datareq = ray_client_pb2.DataRequest(
get=request,
)
self._async_send(datareq, callback)
collector = ChunkCollector(callback=callback, request=datareq)
self._async_send(datareq, collector)
# TODO: convert PutObject to async
def PutObject(

View file

@ -525,7 +525,12 @@ class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer):
return self._call_inner_function(req, context, "Terminate")
def GetObject(self, request, context=None):
return self._call_inner_function(request, context, "GetObject")
try:
yield from self._call_inner_function(request, context, "GetObject")
except Exception as e:
# Error while iterating over response from GetObject stream
logger.exception("Proxying call to GetObject failed!")
_propagate_error_in_context(e, context)
def PutObject(
self, request: ray_client_pb2.PutRequest, context=None

View file

@ -5,6 +5,7 @@ import grpc
import base64
from collections import defaultdict
import functools
import math
import queue
import pickle
@ -28,6 +29,7 @@ from ray.util.client.common import (
ClientServerHandle,
GRPC_OPTIONS,
CLIENT_SERVER_MAX_THREADS,
OBJECT_TRANSFER_CHUNK_SIZE,
ResponseCache,
)
from ray import ray_constants
@ -379,20 +381,38 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
with disable_client_hook():
def send_get_response(result: Any) -> None:
"""Pushes a GetResponse to the main DataPath loop to send
"""Pushes GetResponses to the main DataPath loop to send
to the client. This is called when the object is ready
on the server side."""
try:
serialized = dumps_from_server(result, client_id, self)
get_resp = ray_client_pb2.GetResponse(
valid=True, data=serialized
total_size = len(serialized)
assert total_size > 0, "Serialized object cannot be zero bytes"
total_chunks = math.ceil(
total_size / OBJECT_TRANSFER_CHUNK_SIZE
)
for chunk_id in range(request.start_chunk_id, total_chunks):
start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
end = min(
total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE
)
get_resp = ray_client_pb2.GetResponse(
valid=True,
data=serialized[start:end],
chunk_id=chunk_id,
total_chunks=total_chunks,
total_size=total_size,
)
chunk_resp = ray_client_pb2.DataResponse(
get=get_resp, req_id=req_id
)
result_queue.put(chunk_resp)
except Exception as exc:
get_resp = ray_client_pb2.GetResponse(
valid=False, error=cloudpickle.dumps(exc)
)
resp = ray_client_pb2.DataResponse(get=get_resp, req_id=req_id)
result_queue.put(resp)
resp = ray_client_pb2.DataResponse(get=get_resp, req_id=req_id)
result_queue.put(resp)
ref._on_completed(send_get_response)
return None
@ -404,13 +424,14 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
metadata = {k: v for k, v in context.invocation_metadata()}
client_id = metadata.get("client_id")
if client_id is None:
return ray_client_pb2.GetResponse(
yield ray_client_pb2.GetResponse(
valid=False,
error=cloudpickle.dumps(
ValueError("client_id is not specified in request metadata")
),
)
return self._get_object(request, client_id)
else:
yield from self._get_object(request, client_id)
def _get_object(self, request: ray_client_pb2.GetRequest, client_id: str):
objectrefs = []
@ -419,7 +440,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
if ref:
objectrefs.append(ref)
else:
return ray_client_pb2.GetResponse(
yield ray_client_pb2.GetResponse(
valid=False,
error=cloudpickle.dumps(
ValueError(
@ -428,14 +449,28 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
)
),
)
return
try:
logger.debug("get: %s" % objectrefs)
with disable_client_hook():
items = ray.get(objectrefs, timeout=request.timeout)
except Exception as e:
return ray_client_pb2.GetResponse(valid=False, error=cloudpickle.dumps(e))
yield ray_client_pb2.GetResponse(valid=False, error=cloudpickle.dumps(e))
return
serialized = dumps_from_server(items, client_id, self)
return ray_client_pb2.GetResponse(valid=True, data=serialized)
total_size = len(serialized)
assert total_size > 0, "Serialized object cannot be zero bytes"
total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE)
for chunk_id in range(request.start_chunk_id, total_chunks):
start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE)
yield ray_client_pb2.GetResponse(
valid=True,
data=serialized[start:end],
chunk_id=chunk_id,
total_chunks=total_chunks,
total_size=total_size,
)
def PutObject(
self, request: ray_client_pb2.PutRequest, context=None

View file

@ -41,6 +41,7 @@ from ray.util.client.common import (
GRPC_OPTIONS,
GRPC_UNRECOVERABLE_ERRORS,
INT32_MAX,
OBJECT_TRANSFER_WARNING_SIZE,
)
from ray.util.client.dataclient import DataClient
from ray.util.client.logsclient import LogstreamClient
@ -309,6 +310,54 @@ class Worker:
continue
raise ConnectionError("Client is shutting down.")
def _get_object_iterator(
self, req: ray_client_pb2.GetRequest, *args, **kwargs
) -> Any:
"""
Calls the stub for GetObject on the underlying server stub. If a
recoverable error occurs while streaming the response, attempts
to retry the get starting from the first chunk that hasn't been
received.
"""
last_seen_chunk = -1
while not self._in_shutdown:
# If we disconnect partway through, restart the get request
# at the first chunk we haven't seen
req.start_chunk_id = last_seen_chunk + 1
try:
for chunk in self.server.GetObject(req, *args, **kwargs):
if chunk.chunk_id <= last_seen_chunk:
# Ignore repeat chunks
logger.debug(
f"Received a repeated chunk {chunk.chunk_id} "
f"from request {req.req_id}."
)
continue
if last_seen_chunk + 1 != chunk.chunk_id:
raise RuntimeError(
f"Received chunk {chunk.chunk_id} when we expected "
f"{self.last_seen_chunk + 1}"
)
last_seen_chunk = chunk.chunk_id
yield chunk
if last_seen_chunk == chunk.total_chunks - 1:
# We've yielded the last chunk, exit early
return
return
except grpc.RpcError as e:
if self._can_reconnect(e):
time.sleep(0.5)
continue
raise
except ValueError:
# Trying to use the stub on a cancelled channel will raise
# ValueError. This should only happen when the data client
# is attempting to reset the connection -- sleep and try
# again.
time.sleep(0.5)
continue
raise ConnectionError("Client is shutting down.")
def _add_ids_to_metadata(self, metadata: Any):
"""
Adds a unique req_id and the current thread's identifier to the
@ -399,18 +448,33 @@ class Worker:
def _get(self, ref: List[ClientObjectRef], timeout: float):
req = ray_client_pb2.GetRequest(ids=[r.id for r in ref], timeout=timeout)
data = bytearray()
try:
resp = self._call_stub("GetObject", req, metadata=self.metadata)
resp = self._get_object_iterator(req, metadata=self.metadata)
for chunk in resp:
if not chunk.valid:
try:
err = cloudpickle.loads(chunk.error)
except (pickle.UnpicklingError, TypeError):
logger.exception("Failed to deserialize {}".format(chunk.error))
raise
raise err
if chunk.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once(
"client_object_transfer_size_warning"
):
size_gb = chunk.total_size / 2 ** 30
warnings.warn(
"Ray Client is attempting to retrieve a "
f"{size_gb:.2f} GiB object over the network, which may "
"be slow. Consider serializing the object to a file "
"and using S3 or rsync instead.",
UserWarning,
stacklevel=5,
)
data.extend(chunk.data)
except grpc.RpcError as e:
raise decode_exception(e)
if not resp.valid:
try:
err = cloudpickle.loads(resp.error)
except (pickle.UnpicklingError, TypeError):
logger.exception("Failed to deserialize {}".format(resp.error))
raise
raise err
return loads_from_server(resp.data)
return loads_from_server(data)
def put(self, val, *, client_ref_id: bytes = None):
if isinstance(val, ClientObjectRef):

View file

@ -128,6 +128,9 @@ message GetRequest {
float timeout = 2;
// Whether to schedule this as a callback on the server side.
bool asynchronous = 3;
// The chunk_id to start retrieving data from, in case the request is interrupted
// after partial retrieval by a disconnect
int32 start_chunk_id = 5;
// Deprecated fields.
bytes id = 1 [deprecated = true];
@ -141,6 +144,12 @@ message GetResponse {
bytes data = 2;
// An error blob (for example, an exception) on failure.
bytes error = 3;
// Identifies which chunk the data belongs to
int32 chunk_id = 4;
// Total number of chunks
int32 total_chunks = 5;
// Total size in bytes of the data being retrieved
uint64 total_size = 6;
}
// Waits for data to be ready on the server, with a timeout.
@ -293,8 +302,7 @@ service RayletDriver {
}
rpc PrepRuntimeEnv(PrepRuntimeEnvRequest) returns (PrepRuntimeEnvResponse) {
}
rpc GetObject(GetRequest) returns (GetResponse) {
}
rpc GetObject(GetRequest) returns (stream GetResponse) {}
rpc PutObject(PutRequest) returns (PutResponse) {
}
rpc WaitObject(WaitRequest) returns (WaitResponse) {