From 4db696d36557b551970418385bceaf51210df54f Mon Sep 17 00:00:00 2001 From: Ian Rodney Date: Tue, 27 Apr 2021 08:41:10 -0700 Subject: [PATCH] [Client] Asyncio Client, Sync gRPC Server (#15488) --- python/ray/tests/BUILD | 1 + python/ray/tests/test_asyncio.py | 6 +- .../tests/test_client_library_integration.py | 18 +++ python/ray/util/client/__init__.py | 2 +- python/ray/util/client/api.py | 6 +- python/ray/util/client/common.py | 50 +++++- python/ray/util/client/dataclient.py | 35 ++++- python/ray/util/client/server/dataservicer.py | 148 ++++++++++++------ python/ray/util/client/server/server.py | 43 +++++ python/ray/util/client/worker.py | 7 + src/ray/protobuf/ray_client.proto | 2 + 11 files changed, 260 insertions(+), 58 deletions(-) diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 3f81f785a..0f400d7ff 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -212,6 +212,7 @@ py_test_module_list( "test_basic.py", "test_basic_2.py", "test_basic_3.py", + "test_asyncio.py" ], size = "medium", extra_srcs = SRCS, diff --git a/python/ray/tests/test_asyncio.py b/python/ray/tests/test_asyncio.py index c43b7e491..107e9b358 100644 --- a/python/ray/tests/test_asyncio.py +++ b/python/ray/tests/test_asyncio.py @@ -155,7 +155,10 @@ async def test_asyncio_get(ray_start_regular_shared, event_loop): with pytest.raises(ray.exceptions.RayTaskError): await actor.throw_error.remote().as_future() - kill_actor_and_wait_for_failure(actor) + # Wrap in Remote Function to work with Ray client. + kill_actor_ref = ray.remote(kill_actor_and_wait_for_failure).remote(actor) + ray.get(kill_actor_ref) + with pytest.raises(ray.exceptions.RayActorError): await actor.echo.remote(1) @@ -256,6 +259,7 @@ def test_async_function_errored(ray_start_regular_shared): ray.get(ref) +@pytest.mark.asyncio async def test_async_obj_unhandled_errors(ray_start_regular_shared): @ray.remote def f(): diff --git a/python/ray/tests/test_client_library_integration.py b/python/ray/tests/test_client_library_integration.py index 6560d8a5c..0ddaa1d91 100644 --- a/python/ray/tests/test_client_library_integration.py +++ b/python/ray/tests/test_client_library_integration.py @@ -22,5 +22,23 @@ def test_rllib_integration(ray_start_regular_shared): rock_paper_scissors_multiagent.main() +@pytest.mark.asyncio +async def test_serve_handle(ray_start_regular_shared): + with ray_start_client_server() as ray: + from ray import serve + _explicitly_enable_client_mode() + serve.start(detached=True) + + def hello(request): + return "hello" + + serve.create_backend("my_backend", hello, config={"num_replicas": 1}) + serve.create_endpoint( + "my_endpoint", backend="my_backend", route="/hello") + handle = serve.get_handle("my_endpoint") + assert ray.get(handle.remote()) == "hello" + assert await handle.remote() == "hello" + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index fb6ab43ef..ee64b9e38 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) # This version string is incremented to indicate breaking changes in the # protocol that require upgrading the client version. -CURRENT_PROTOCOL_VERSION = "2021-04-09" +CURRENT_PROTOCOL_VERSION = "2021-04-19" class RayAPIStub: diff --git a/python/ray/util/client/api.py b/python/ray/util/client/api.py index ad5cc4880..3be9b3b18 100644 --- a/python/ray/util/client/api.py +++ b/python/ray/util/client/api.py @@ -6,7 +6,7 @@ import json import logging from ray.util.client.runtime_context import ClientWorkerPropertyAPI -from typing import Any, List, Optional, TYPE_CHECKING +from typing import Any, Callable, List, Optional, TYPE_CHECKING if TYPE_CHECKING: from ray.actor import ActorClass from ray.remote_function import RemoteFunction @@ -317,3 +317,7 @@ class ClientAPI: "available within Ray remote functions and is not yet " "implemented in the client API.".format(key)) return self.__getattribute__(key) + + def _register_callback(self, ref: "ClientObjectRef", + callback: Callable[["DataResponse"], None]) -> None: + self.worker.register_callback(ref, callback) diff --git a/python/ray/util/client/common.py b/python/ray/util/client/common.py index 16ec11042..ae77b4c08 100644 --- a/python/ray/util/client/common.py +++ b/python/ray/util/client/common.py @@ -2,13 +2,16 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 from ray.util.client import ray from ray.util.client.options import validate_options -import uuid +import asyncio +import concurrent.futures import os +import uuid import inspect from ray.util.inspect import is_cython import json import threading from typing import Any +from typing import Callable from typing import List from typing import Dict from typing import Optional @@ -60,7 +63,50 @@ class ClientBaseRef: class ClientObjectRef(ClientBaseRef): - pass + def __await__(self): + return self.as_future().__await__() + + def as_future(self) -> asyncio.Future: + return asyncio.wrap_future(self.future()) + + def future(self) -> concurrent.futures.Future: + fut = concurrent.futures.Future() + + def set_value(data: Any) -> None: + """Schedules a callback to set the exception or result + in the Future.""" + + if isinstance(data, Exception): + fut.set_exception(data) + else: + fut.set_result(data) + + self._on_completed(set_value) + + # Prevent this object ref from being released. + fut.object_ref = self + return fut + + def _on_completed(self, py_callback: Callable[[Any], None]) -> None: + """Register a callback that will be called after Object is ready. + If the ObjectRef is already ready, the callback will be called soon. + The callback should take the result as the only argument. The result + can be an exception object in case of task error. + """ + from ray.util.client.client_pickler import loads_from_server + + def deserialize_obj(resp: ray_client_pb2.DataResponse) -> None: + """Converts from a GetResponse proto to a python object.""" + obj = resp.get + data = None + if not obj.valid: + data = loads_from_server(resp.get.error) + else: + data = loads_from_server(resp.get.data) + + py_callback(data) + + ray._register_callback(self, deserialize_obj) class ClientActorRef(ClientBaseRef): diff --git a/python/ray/util/client/dataclient.py b/python/ray/util/client/dataclient.py index 9fcda555f..14192a4f8 100644 --- a/python/ray/util/client/dataclient.py +++ b/python/ray/util/client/dataclient.py @@ -6,8 +6,7 @@ import queue import threading import grpc -from typing import Any -from typing import Dict +from typing import Any, Callable, Dict, Optional import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc @@ -18,6 +17,8 @@ logger = logging.getLogger(__name__) # number of simultaneous in-flight requests. INT32_MAX = (2**31) - 1 +ResponseCallable = Callable[[ray_client_pb2.DataResponse], None] + class DataClient: def __init__(self, channel: "grpc._channel.Channel", client_id: str, @@ -35,6 +36,10 @@ class DataClient: self.ready_data: Dict[int, Any] = {} self.cv = threading.Condition() self.lock = threading.RLock() + + # NOTE: Dictionary insertion is guaranteed to complete before lookup + # and/or removal because of synchronization via the request_queue. + self.asyncio_waiting_data: Dict[int, ResponseCallable] = {} self._req_id = 0 self._client_id = client_id self._metadata = metadata @@ -66,9 +71,16 @@ class DataClient: # This is not being waited for. logger.debug(f"Got unawaited response {response}") continue - with self.cv: - self.ready_data[response.req_id] = response - self.cv.notify_all() + if response.req_id in self.asyncio_waiting_data: + callback = self.asyncio_waiting_data.pop(response.req_id) + try: + callback(response) + except Exception: + logger.exception("Callback error:") + else: + with self.cv: + self.ready_data[response.req_id] = response + self.cv.notify_all() except grpc.RpcError as e: with self.cv: self._in_shutdown = True @@ -112,9 +124,13 @@ class DataClient: del self.ready_data[req_id] return data - def _async_send(self, req: ray_client_pb2.DataRequest) -> None: + def _async_send(self, + req: ray_client_pb2.DataRequest, + callback: Optional[ResponseCallable] = None) -> None: req_id = self._next_id() req.req_id = req_id + if callback: + self.asyncio_waiting_data[req_id] = callback self.request_queue.put(req) def Init(self, request: ray_client_pb2.InitRequest, @@ -143,6 +159,13 @@ class DataClient: resp = self._blocking_send(datareq) return resp.get + def RegisterGetCallback(self, + request: ray_client_pb2.GetRequest, + callback: ResponseCallable, + context=None) -> None: + datareq = ray_client_pb2.DataRequest(get=request, ) + self._async_send(datareq, callback) + def PutObject(self, request: ray_client_pb2.PutRequest, context=None) -> ray_client_pb2.PutResponse: datareq = ray_client_pb2.DataRequest(put=request, ) diff --git a/python/ray/util/client/server/dataservicer.py b/python/ray/util/client/server/dataservicer.py index e4b5f90dc..550704986 100644 --- a/python/ray/util/client/server/dataservicer.py +++ b/python/ray/util/client/server/dataservicer.py @@ -1,10 +1,11 @@ import ray import logging import grpc +from queue import Queue import sys -from typing import TYPE_CHECKING -from threading import Lock +from typing import Any, Iterator, TYPE_CHECKING, Union +from threading import Lock, Thread import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc @@ -18,6 +19,27 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +QUEUE_JOIN_SECONDS = 5 + + +def fill_queue( + grpc_input_generator: Iterator[ray_client_pb2.DataRequest], + output_queue: + "Queue[Union[ray_client_pb2.DataRequest, ray_client_pb2.DataResponse]]" +) -> None: + """ + Pushes incoming requests to a shared output_queue. + """ + try: + for req in grpc_input_generator: + output_queue.put(req) + except grpc.RpcError as e: + logger.debug("closing dataservicer reader thread " + f"grpc error reading request_iterator: {e}") + finally: + # Set the sentinel value for the output_queue + output_queue.put(None) + class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): def __init__(self, basic_service: "RayletServicer"): @@ -28,53 +50,75 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): def Datapath(self, request_iterator, context): metadata = {k: v for k, v in context.invocation_metadata()} client_id = metadata["client_id"] - accepted_connection = False if client_id == "": logger.error("Client connecting with no client_id") return logger.debug(f"New data connection from client {client_id}: ") + accepted_connection = self._init(client_id, context) + if not accepted_connection: + return try: - for req in request_iterator: + request_queue = Queue() + queue_filler_thread = Thread( + target=fill_queue, + daemon=True, + args=(request_iterator, request_queue)) + queue_filler_thread.start() + """For non `async get` requests, this loop yields immediately + For `async get` requests, this loop: + 1) does not yield, it just continues + 2) When the result is ready, it yields + """ + for req in iter(request_queue.get, None): + if isinstance(req, ray_client_pb2.DataResponse): + # Early shortcut if this is the result of an async get. + yield req + continue + + assert isinstance(req, ray_client_pb2.DataRequest) resp = None req_type = req.WhichOneof("type") if req_type == "init": - resp = self._init(req.init, client_id) - if resp is None: - context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED) - return - logger.debug(f"Accepted data connection from {client_id}. " - f"Total clients: {self.num_clients}") - accepted_connection = True - else: - assert accepted_connection - if req_type == "get": + resp_init = self.basic_service.Init(req.init) + resp = ray_client_pb2.DataResponse(init=resp_init, ) + elif req_type == "get": + get_resp = None + if req.get.asynchronous: + get_resp = self.basic_service._async_get_object( + req.get, client_id, req.req_id, request_queue) + if get_resp is None: + # Skip sending a response for this request and + # continue to the next requst. The response for + # this request will be sent when the object is + # ready. + continue + else: get_resp = self.basic_service._get_object( req.get, client_id) - resp = ray_client_pb2.DataResponse(get=get_resp) - elif req_type == "put": - put_resp = self.basic_service._put_object( - req.put, client_id) - resp = ray_client_pb2.DataResponse(put=put_resp) - elif req_type == "release": - released = [] - for rel_id in req.release.ids: - rel = self.basic_service.release(client_id, rel_id) - released.append(rel) + resp = ray_client_pb2.DataResponse(get=get_resp) + elif req_type == "put": + put_resp = self.basic_service._put_object( + req.put, client_id) + resp = ray_client_pb2.DataResponse(put=put_resp) + elif req_type == "release": + released = [] + for rel_id in req.release.ids: + rel = self.basic_service.release(client_id, rel_id) + released.append(rel) + resp = ray_client_pb2.DataResponse( + release=ray_client_pb2.ReleaseResponse(ok=released)) + elif req_type == "connection_info": + resp = ray_client_pb2.DataResponse( + connection_info=self._build_connection_response()) + elif req_type == "prep_runtime_env": + with self.clients_lock: + resp_prep = self.basic_service.PrepRuntimeEnv( + req.prep_runtime_env) resp = ray_client_pb2.DataResponse( - release=ray_client_pb2.ReleaseResponse( - ok=released)) - elif req_type == "connection_info": - resp = ray_client_pb2.DataResponse( - connection_info=self._build_connection_response()) - elif req_type == "prep_runtime_env": - with self.clients_lock: - resp_prep = self.basic_service.PrepRuntimeEnv( - req.prep_runtime_env) - resp = ray_client_pb2.DataResponse( - prep_runtime_env=resp_prep) - else: - raise Exception(f"Unreachable code: Request type " - f"{req_type} not handled in Datapath") + prep_runtime_env=resp_prep) + else: + raise Exception(f"Unreachable code: Request type " + f"{req_type} not handled in Datapath") resp.req_id = req.req_id yield resp except grpc.RpcError as e: @@ -82,12 +126,15 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): finally: logger.debug(f"Lost data connection from client {client_id}") self.basic_service.release_all(client_id) - + queue_filler_thread.join(QUEUE_JOIN_SECONDS) + if queue_filler_thread.is_alive(): + logger.error( + "Queue filler thread failed to join before timeout: {}". + format(QUEUE_JOIN_SECONDS)) with self.clients_lock: - if accepted_connection: - # Could fail before client accounting happens - self.num_clients -= 1 - logger.debug(f"Removed clients. {self.num_clients}") + # Could fail before client accounting happens + self.num_clients -= 1 + logger.debug(f"Removed clients. {self.num_clients}") # It's important to keep the Ray shutdown # within this locked context or else Ray could hang. @@ -96,7 +143,11 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): logger.debug("Shutting down ray.") ray.shutdown() - def _init(self, req_init, client_id): + def _init(self, client_id: str, context: Any): + """ + Checks if resources allow for another client. + Returns a boolean indicating if initialization was successful. + """ with self.clients_lock: threshold = int(CLIENT_SERVER_MAX_THREADS / 2) if self.num_clients >= threshold: @@ -110,10 +161,13 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): "threshold by setting the " "RAY_CLIENT_SERVER_MAX_THREADS env var " f"(currently set to {CLIENT_SERVER_MAX_THREADS}).") - return None - resp_init = self.basic_service.Init(req_init) + context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED) + return False self.num_clients += 1 - return ray_client_pb2.DataResponse(init=resp_init, ) + logger.debug(f"Accepted data connection from {client_id}. " + f"Total clients: {self.num_clients}") + + return True def _build_connection_response(self): with self.clients_lock: diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index f72933d68..81671ef94 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -5,6 +5,7 @@ import base64 from collections import defaultdict from dataclasses import dataclass import os +import queue import sys import threading @@ -253,6 +254,48 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): "terminate_type") return ray_client_pb2.TerminateResponse(ok=True) + def _async_get_object( + self, + request, + client_id: str, + req_id: int, + result_queue: queue.Queue, + context=None) -> Optional[ray_client_pb2.GetResponse]: + """Attempts to schedule a callback to push the GetResponse to the + main loop when the desired object is ready. If there is some failure + in scheduling, a GetResponse will be immediately returned. + """ + if request.id not in self.object_refs[client_id]: + return ray_client_pb2.GetResponse(valid=False) + try: + object_ref = self.object_refs[client_id][request.id] + logger.debug("async get: %s" % object_ref) + with disable_client_hook(): + + def send_get_response(result: Any) -> None: + """Pushes a GetResponse to the main DataPath loop to send + to the client. This is called when the object is ready + on the server side.""" + try: + serialized = dumps_from_server(result, client_id, self) + get_resp = ray_client_pb2.GetResponse( + valid=True, data=serialized) + except Exception as e: + get_resp = ray_client_pb2.GetResponse( + valid=False, error=cloudpickle.dumps(e)) + + resp = ray_client_pb2.DataResponse( + get=get_resp, req_id=req_id) + resp.req_id = req_id + + result_queue.put(resp) + + object_ref._on_completed(send_get_response) + return None + except Exception as e: + return ray_client_pb2.GetResponse( + valid=False, error=cloudpickle.dumps(e)) + def GetObject(self, request, context=None): return self._get_object(request, "", context) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index ded6657ce..527fd3c98 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -9,6 +9,7 @@ import time import uuid from collections import defaultdict from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Tuple @@ -166,6 +167,12 @@ class Worker: "protocol_version": data.protocol_version, } + def register_callback( + self, ref: ClientObjectRef, + callback: Callable[[ray_client_pb2.DataResponse], None]) -> None: + req = ray_client_pb2.GetRequest(id=ref.id, asynchronous=True) + self.data_client.RegisterGetCallback(req, callback) + def get(self, vals, *, timeout: Optional[float] = None) -> Any: to_get = [] single = False diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index e34ebe5a3..22ec99733 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -117,6 +117,8 @@ message GetRequest { bytes id = 1; // Length of time to wait for data to be available, in seconds. Zero is no timeout. float timeout = 2; + // Whether to schedule this as a callback on the server side. + bool asynchronous = 3; } message GetResponse {