[client] Chunk PutRequests (#22327)

Why are these changes needed?
Data from PutRequests is chunked into 64MiB messages over the datastream, to avoid the 2GiB message size limit from gRPC. This will allow users to transfer objects larger than 2GiB over the network.

Proto changes
Put requests now have fields for chunk_id to identify which chunk data belongs to, total_chunks to identify the total number of chunks in the object, and total_size for total size in bytes of the object (useful for raising warnings).

PutObject is still unary-unary. The dataservicer handles reassembling the chunks before passing the result to the underlying RayletServicer.

Dataclient changes
If a put request is inserted into the request queue, self._requests will chunk it lazily. Doing this lazily is important since inserting all of the chunks onto the request queue immediately would double the amount of memory needed to handle a large request. This also guarantees that the chunks of a given putrequest will be contiguous

Dataservicer changes
The dataservicer now maintains some state to track received chunks. Once all chunks for a putrequest are received, the combined chunks are passed to the raylet servicer.
This commit is contained in:
Chris K. W 2022-02-23 08:21:25 -08:00 committed by GitHub
parent a20748f83a
commit 3371e78d2e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 198 additions and 17 deletions

View file

@ -7,6 +7,7 @@ import queue
import threading import threading
import _thread import _thread
from unittest.mock import patch from unittest.mock import patch
import numpy as np
import ray.util.client.server.server as ray_client_server import ray.util.client.server.server as ray_client_server
from ray.tests.client_test_utils import create_remote_signal_actor from ray.tests.client_test_utils import create_remote_signal_actor
@ -780,5 +781,20 @@ def test_object_ref_release(call_ray_start):
assert all(v > 0 for v in ref_cnt.values()) assert all(v > 0 for v in ref_cnt.values())
def test_empty_objects(ray_start_regular_shared):
"""
Tests that client works with "empty" objects. Sanity check, since put requests
will fail if the serialized version of an object consists of zero bytes.
"""
objects = [0, b"", "", [], np.array(()), {}, set(), None]
with ray_start_client_server() as ray:
for obj in objects:
ref = ray.put(obj)
if isinstance(obj, np.ndarray):
assert np.array_equal(ray.get(ref), obj)
else:
assert ray.get(ref) == obj
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__])) sys.exit(pytest.main(["-v", __file__]))

View file

@ -4,6 +4,7 @@ import os
import threading import threading
import sys import sys
import grpc import grpc
import numpy as np
import time import time
import random import random
@ -33,7 +34,9 @@ class MiddlemanDataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
errors between a client and server pair. errors between a client and server pair.
""" """
def __init__(self, on_response: Optional[Hook] = None): def __init__(
self, on_response: Optional[Hook] = None, on_request: Optional[Hook] = None
):
""" """
Args: Args:
on_response: Optional hook to inject errors before sending back a on_response: Optional hook to inject errors before sending back a
@ -41,14 +44,21 @@ class MiddlemanDataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
""" """
self.stub = None self.stub = None
self.on_response = on_response self.on_response = on_response
self.on_request = on_request
def set_channel(self, channel: grpc.Channel) -> None: def set_channel(self, channel: grpc.Channel) -> None:
self.stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel) self.stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
def _requests(self, request_iterator):
for req in request_iterator:
if self.on_request:
self.on_request(req)
yield req
def Datapath(self, request_iterator, context): def Datapath(self, request_iterator, context):
try: try:
for response in self.stub.Datapath( for response in self.stub.Datapath(
request_iterator, metadata=context.invocation_metadata() self._requests(request_iterator), metadata=context.invocation_metadata()
): ):
if self.on_response: if self.on_response:
self.on_response(response) self.on_response(response)
@ -189,6 +199,7 @@ class MiddlemanServer:
listen_addr: str, listen_addr: str,
real_addr, real_addr,
on_log_response: Optional[Hook] = None, on_log_response: Optional[Hook] = None,
on_data_request: Optional[Hook] = None,
on_data_response: Optional[Hook] = None, on_data_response: Optional[Hook] = None,
on_task_request: Optional[Hook] = None, on_task_request: Optional[Hook] = None,
on_task_response: Optional[Hook] = None, on_task_response: Optional[Hook] = None,
@ -215,7 +226,9 @@ class MiddlemanServer:
self.task_servicer = MiddlemanRayletServicer( self.task_servicer = MiddlemanRayletServicer(
on_response=on_task_response, on_request=on_task_request on_response=on_task_response, on_request=on_task_request
) )
self.data_servicer = MiddlemanDataServicer(on_response=on_data_response) self.data_servicer = MiddlemanDataServicer(
on_response=on_data_response, on_request=on_data_request
)
self.logs_servicer = MiddlemanLogServicer(on_response=on_log_response) self.logs_servicer = MiddlemanLogServicer(on_response=on_log_response)
ray_client_pb2_grpc.add_RayletDriverServicer_to_server( ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
self.task_servicer, self.server self.task_servicer, self.server
@ -253,6 +266,7 @@ class MiddlemanServer:
@contextlib.contextmanager @contextlib.contextmanager
def start_middleman_server( def start_middleman_server(
on_log_response=None, on_log_response=None,
on_data_request=None,
on_data_response=None, on_data_response=None,
on_task_request=None, on_task_request=None,
on_task_response=None, on_task_response=None,
@ -269,9 +283,10 @@ def start_middleman_server(
listen_addr="localhost:10011", listen_addr="localhost:10011",
real_addr="localhost:50051", real_addr="localhost:50051",
on_log_response=on_log_response, on_log_response=on_log_response,
on_data_request=on_data_request,
on_data_response=on_data_response, on_data_response=on_data_response,
on_task_request=on_task_response, on_task_request=on_task_request,
on_task_response=on_task_request, on_task_response=on_task_response,
) )
middleman.start() middleman.start()
ray.init("ray://localhost:10011") ray.init("ray://localhost:10011")
@ -319,6 +334,30 @@ def test_disconnect_during_get():
disconnect_thread.join() disconnect_thread.join()
def test_disconnect_during_large_put():
"""
Disconnect during a large (multi-chunk) put.
"""
i = 0
started = False
def fail_halfway(_):
# Inject an error halfway through the object transfer
nonlocal i, started
if not started:
return
i += 1
if i == 8:
raise RuntimeError
with start_middleman_server(on_data_request=fail_halfway):
started = True
objref = ray.put(np.random.random((1024, 1024, 128)))
assert i > 8 # Check that the failure was injected
result = ray.get(objref)
assert result.shape == (1024, 1024, 128)
def test_valid_actor_state(): def test_valid_actor_state():
""" """
Repeatedly inject errors in the middle of mutating actor calls. Check Repeatedly inject errors in the middle of mutating actor calls. Check

View file

@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
# This version string is incremented to indicate breaking changes in the # This version string is incremented to indicate breaking changes in the
# protocol that require upgrading the client version. # protocol that require upgrading the client version.
CURRENT_PROTOCOL_VERSION = "2021-12-07" CURRENT_PROTOCOL_VERSION = "2022-02-22"
class _ClientContext: class _ClientContext:

View file

@ -84,6 +84,12 @@ GRPC_OPTIONS = [
CLIENT_SERVER_MAX_THREADS = float(os.getenv("RAY_CLIENT_SERVER_MAX_THREADS", 100)) CLIENT_SERVER_MAX_THREADS = float(os.getenv("RAY_CLIENT_SERVER_MAX_THREADS", 100))
# Large objects are chunked into 64 MiB messages
OBJECT_TRANSFER_CHUNK_SIZE = 64 * 2 ** 20
# Warn the user if the object being transferred is larger than 2 GiB
OBJECT_TRANSFER_WARNING_SIZE = 2 * 2 ** 30
class ClientObjectRef(raylet.ObjectRef): class ClientObjectRef(raylet.ObjectRef):
def __init__(self, id: Union[bytes, Future]): def __init__(self, id: Union[bytes, Future]):

View file

@ -1,9 +1,12 @@
"""This file implements a threaded stream controller to abstract a data stream """This file implements a threaded stream controller to abstract a data stream
back to the ray clientserver. back to the ray clientserver.
""" """
import math
import logging import logging
import queue import queue
import threading import threading
import warnings
import grpc import grpc
from collections import OrderedDict from collections import OrderedDict
@ -11,7 +14,12 @@ from typing import Any, Callable, Dict, TYPE_CHECKING, Optional, Union
import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.util.client.common import INT32_MAX from ray.util.client.common import (
INT32_MAX,
OBJECT_TRANSFER_CHUNK_SIZE,
OBJECT_TRANSFER_WARNING_SIZE,
)
from ray.util.debug import log_once
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.client.worker import Worker from ray.util.client.worker import Worker
@ -24,6 +32,43 @@ ResponseCallable = Callable[[Union[ray_client_pb2.DataResponse, Exception]], Non
ACKNOWLEDGE_BATCH_SIZE = 32 ACKNOWLEDGE_BATCH_SIZE = 32
def chunk_put(req: ray_client_pb2.DataRequest):
"""
Chunks a put request. Doing this lazily is important for large objects,
since taking slices of bytes objects does a copy. This means if we
immediately materialized every chunk of a large object and inserted them
into the result_queue, we would effectively double the memory needed
on the client to handle the put.
"""
total_size = len(req.put.data)
assert total_size > 0, "Cannot chunk object with missing data"
if total_size >= OBJECT_TRANSFER_WARNING_SIZE and log_once(
"client_object_put_size_warning"
):
size_gb = total_size / 2 ** 30
warnings.warn(
"Ray Client is attempting to send a "
f"{size_gb:.2f} GiB object over the network, which may "
"be slow. Consider serializing the object and using a remote "
"URI to transfer via S3 or Google Cloud Storage instead. "
"Documentation for doing this can be found here: "
"https://docs.ray.io/en/latest/handling-dependencies.html#remote-uris",
UserWarning,
)
total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE)
for chunk_id in range(0, total_chunks):
start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE
end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE)
chunk = ray_client_pb2.PutRequest(
client_ref_id=req.put.client_ref_id,
data=req.put.data[start:end],
chunk_id=chunk_id,
total_chunks=total_chunks,
total_size=total_size,
)
yield ray_client_pb2.DataRequest(req_id=req.req_id, put=chunk)
class DataClient: class DataClient:
def __init__(self, client_worker: "Worker", client_id: str, metadata: list): def __init__(self, client_worker: "Worker", client_id: str, metadata: list):
"""Initializes a thread-safe datapath over a Ray Client gRPC channel. """Initializes a thread-safe datapath over a Ray Client gRPC channel.
@ -81,6 +126,19 @@ class DataClient:
daemon=True, daemon=True,
) )
# A helper that takes requests from queue. If the request wraps a PutRequest,
# lazily chunks and yields the request. Otherwise, yields the request directly.
def _requests(self):
while True:
req = self.request_queue.get()
if req is None:
# Stop when client signals shutdown.
return
if req.WhichOneof("type") == "put":
yield from chunk_put(req)
else:
yield req
def _data_main(self) -> None: def _data_main(self) -> None:
reconnecting = False reconnecting = False
try: try:
@ -90,7 +148,7 @@ class DataClient:
) )
metadata = self._metadata + [("reconnecting", str(reconnecting))] metadata = self._metadata + [("reconnecting", str(reconnecting))]
resp_stream = stub.Datapath( resp_stream = stub.Datapath(
iter(self.request_queue.get, None), self._requests(),
metadata=metadata, metadata=metadata,
wait_for_ready=True, wait_for_ready=True,
) )

View file

@ -53,10 +53,14 @@ def _should_cache(req: ray_client_pb2.DataRequest) -> bool:
- acks: Repeating acks is idempotent - acks: Repeating acks is idempotent
- clean up requests: Also idempotent, and client has likely already - clean up requests: Also idempotent, and client has likely already
wrapped up the data connection by this point. wrapped up the data connection by this point.
- puts: We should only cache when we receive the final chunk, since
any earlier chunks won't generate a response
""" """
req_type = req.WhichOneof("type") req_type = req.WhichOneof("type")
if req_type == "get" and req.get.asynchronous: if req_type == "get" and req.get.asynchronous:
return False return False
if req_type == "put":
return req.put.chunk_id == req.put.total_chunks - 1
return req_type not in ("acknowledge", "connection_cleanup") return req_type not in ("acknowledge", "connection_cleanup")
@ -80,6 +84,44 @@ def fill_queue(
output_queue.put(None) output_queue.put(None)
class ChunkCollector:
"""
Helper class for collecting chunks from PutObject calls
"""
def __init__(self):
self.curr_req_id = None
self.last_seen_chunk_id = -1
self.data = bytearray()
def add_chunk(self, req: ray_client_pb2.DataRequest):
if self.curr_req_id is not None and self.curr_req_id != req.req_id:
raise RuntimeError(
"Expected to receive a chunk from request with id "
f"{self.curr_req_id}, but found {req.req_id} instead."
)
self.curr_req_id = req.req_id
chunk = req.put
next_chunk = self.last_seen_chunk_id + 1
if chunk.chunk_id < next_chunk:
# Repeated chunk, ignore
return
if chunk.chunk_id > next_chunk:
raise RuntimeError(
f"A chunk {chunk.chunk_id} of request {req.req_id} was "
"received out of order."
)
elif chunk.chunk_id == self.last_seen_chunk_id + 1:
self.data.extend(chunk.data)
self.last_seen_chunk_id = chunk.chunk_id
return chunk.chunk_id + 1 == chunk.total_chunks
def reset(self):
self.curr_req_id = None
self.last_seen_chunk_id = -1
self.data = bytearray()
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, basic_service: "RayletServicer"): def __init__(self, basic_service: "RayletServicer"):
self.basic_service = basic_service self.basic_service = basic_service
@ -95,6 +137,9 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
) )
# stopped event, useful for signals that the server is shut down # stopped event, useful for signals that the server is shut down
self.stopped = Event() self.stopped = Event()
# Helper for collecting chunks from PutObject calls. Assumes that
# that put requests from different objects aren't interleaved.
self.chunk_collector = ChunkCollector()
def Datapath(self, request_iterator, context): def Datapath(self, request_iterator, context):
start_time = time.time() start_time = time.time()
@ -168,7 +213,13 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
get_resp = self.basic_service._get_object(req.get, client_id) get_resp = self.basic_service._get_object(req.get, client_id)
resp = ray_client_pb2.DataResponse(get=get_resp) resp = ray_client_pb2.DataResponse(get=get_resp)
elif req_type == "put": elif req_type == "put":
put_resp = self.basic_service._put_object(req.put, client_id) if not self.chunk_collector.add_chunk(req):
# Put request still in progress
continue
put_resp = self.basic_service._put_object(
self.chunk_collector.data, req.put.client_ref_id, client_id
)
self.chunk_collector.reset()
resp = ray_client_pb2.DataResponse(put=put_resp) resp = ray_client_pb2.DataResponse(put=put_resp)
elif req_type == "release": elif req_type == "release":
released = [] released = []

View file

@ -14,6 +14,7 @@ from typing import Dict
from typing import Set from typing import Set
from typing import Optional from typing import Optional
from typing import Callable from typing import Callable
from typing import Union
from ray import cloudpickle from ray import cloudpickle
from ray.job_config import JobConfig from ray.job_config import JobConfig
import ray import ray
@ -440,21 +441,27 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
self, request: ray_client_pb2.PutRequest, context=None self, request: ray_client_pb2.PutRequest, context=None
) -> ray_client_pb2.PutResponse: ) -> ray_client_pb2.PutResponse:
"""gRPC entrypoint for unary PutObject""" """gRPC entrypoint for unary PutObject"""
return self._put_object(request, "", context) return self._put_object(request.data, request.client_ref_id, "", context)
def _put_object( def _put_object(
self, request: ray_client_pb2.PutRequest, client_id: str, context=None self,
data: Union[bytes, bytearray],
client_ref_id: bytes,
client_id: str,
context=None,
): ):
"""Put an object in the cluster with ray.put() via gRPC. """Put an object in the cluster with ray.put() via gRPC.
Args: Args:
request: PutRequest with pickled data. data: Pickled data. Can either be bytearray if this is called
from the dataservicer, or bytes if called from PutObject.
client_ref_id: The id associated with this object on the client.
client_id: The client who owns this data, for tracking when to client_id: The client who owns this data, for tracking when to
delete this reference. delete this reference.
context: gRPC context. context: gRPC context.
""" """
try: try:
obj = loads_from_client(request.data, self) obj = loads_from_client(data, self)
with disable_client_hook(): with disable_client_hook():
objectref = ray.put(obj) objectref = ray.put(obj)
except Exception as e: except Exception as e:
@ -464,10 +471,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
) )
self.object_refs[client_id][objectref.binary()] = objectref self.object_refs[client_id][objectref.binary()] = objectref
if len(request.client_ref_id) > 0: if len(client_ref_id) > 0:
self.client_side_ref_map[client_id][ self.client_side_ref_map[client_id][client_ref_id] = objectref.binary()
request.client_ref_id
] = objectref.binary()
logger.debug("put: %s" % objectref) logger.debug("put: %s" % objectref)
return ray_client_pb2.PutResponse(id=objectref.binary(), valid=True) return ray_client_pb2.PutResponse(id=objectref.binary(), valid=True)

View file

@ -101,6 +101,12 @@ message PutRequest {
// //
// Empty if no late binding is possible, as in a normal put(). // Empty if no late binding is possible, as in a normal put().
bytes client_ref_id = 2; bytes client_ref_id = 2;
// Identifies which chunk the data belongs to
int32 chunk_id = 3;
// Total number of chunks
int32 total_chunks = 4;
// Total size in bytes of the data being put
int64 total_size = 5;
} }
message PutResponse { message PutResponse {