[Client] Asyncio Client, Sync gRPC Server (#15488)

This commit is contained in:
Ian Rodney 2021-04-27 08:41:10 -07:00 committed by GitHub
parent 643cf4c755
commit 4db696d365
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 260 additions and 58 deletions

View file

@ -212,6 +212,7 @@ py_test_module_list(
"test_basic.py",
"test_basic_2.py",
"test_basic_3.py",
"test_asyncio.py"
],
size = "medium",
extra_srcs = SRCS,

View file

@ -155,7 +155,10 @@ async def test_asyncio_get(ray_start_regular_shared, event_loop):
with pytest.raises(ray.exceptions.RayTaskError):
await actor.throw_error.remote().as_future()
kill_actor_and_wait_for_failure(actor)
# Wrap in Remote Function to work with Ray client.
kill_actor_ref = ray.remote(kill_actor_and_wait_for_failure).remote(actor)
ray.get(kill_actor_ref)
with pytest.raises(ray.exceptions.RayActorError):
await actor.echo.remote(1)
@ -256,6 +259,7 @@ def test_async_function_errored(ray_start_regular_shared):
ray.get(ref)
@pytest.mark.asyncio
async def test_async_obj_unhandled_errors(ray_start_regular_shared):
@ray.remote
def f():

View file

@ -22,5 +22,23 @@ def test_rllib_integration(ray_start_regular_shared):
rock_paper_scissors_multiagent.main()
@pytest.mark.asyncio
async def test_serve_handle(ray_start_regular_shared):
with ray_start_client_server() as ray:
from ray import serve
_explicitly_enable_client_mode()
serve.start(detached=True)
def hello(request):
return "hello"
serve.create_backend("my_backend", hello, config={"num_replicas": 1})
serve.create_endpoint(
"my_endpoint", backend="my_backend", route="/hello")
handle = serve.get_handle("my_endpoint")
assert ray.get(handle.remote()) == "hello"
assert await handle.remote() == "hello"
if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))

View file

@ -8,7 +8,7 @@ logger = logging.getLogger(__name__)
# This version string is incremented to indicate breaking changes in the
# protocol that require upgrading the client version.
CURRENT_PROTOCOL_VERSION = "2021-04-09"
CURRENT_PROTOCOL_VERSION = "2021-04-19"
class RayAPIStub:

View file

@ -6,7 +6,7 @@ import json
import logging
from ray.util.client.runtime_context import ClientWorkerPropertyAPI
from typing import Any, List, Optional, TYPE_CHECKING
from typing import Any, Callable, List, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from ray.actor import ActorClass
from ray.remote_function import RemoteFunction
@ -317,3 +317,7 @@ class ClientAPI:
"available within Ray remote functions and is not yet "
"implemented in the client API.".format(key))
return self.__getattribute__(key)
def _register_callback(self, ref: "ClientObjectRef",
callback: Callable[["DataResponse"], None]) -> None:
self.worker.register_callback(ref, callback)

View file

@ -2,13 +2,16 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray.util.client import ray
from ray.util.client.options import validate_options
import uuid
import asyncio
import concurrent.futures
import os
import uuid
import inspect
from ray.util.inspect import is_cython
import json
import threading
from typing import Any
from typing import Callable
from typing import List
from typing import Dict
from typing import Optional
@ -60,7 +63,50 @@ class ClientBaseRef:
class ClientObjectRef(ClientBaseRef):
pass
def __await__(self):
return self.as_future().__await__()
def as_future(self) -> asyncio.Future:
return asyncio.wrap_future(self.future())
def future(self) -> concurrent.futures.Future:
fut = concurrent.futures.Future()
def set_value(data: Any) -> None:
"""Schedules a callback to set the exception or result
in the Future."""
if isinstance(data, Exception):
fut.set_exception(data)
else:
fut.set_result(data)
self._on_completed(set_value)
# Prevent this object ref from being released.
fut.object_ref = self
return fut
def _on_completed(self, py_callback: Callable[[Any], None]) -> None:
"""Register a callback that will be called after Object is ready.
If the ObjectRef is already ready, the callback will be called soon.
The callback should take the result as the only argument. The result
can be an exception object in case of task error.
"""
from ray.util.client.client_pickler import loads_from_server
def deserialize_obj(resp: ray_client_pb2.DataResponse) -> None:
"""Converts from a GetResponse proto to a python object."""
obj = resp.get
data = None
if not obj.valid:
data = loads_from_server(resp.get.error)
else:
data = loads_from_server(resp.get.data)
py_callback(data)
ray._register_callback(self, deserialize_obj)
class ClientActorRef(ClientBaseRef):

View file

@ -6,8 +6,7 @@ import queue
import threading
import grpc
from typing import Any
from typing import Dict
from typing import Any, Callable, Dict, Optional
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
@ -18,6 +17,8 @@ logger = logging.getLogger(__name__)
# number of simultaneous in-flight requests.
INT32_MAX = (2**31) - 1
ResponseCallable = Callable[[ray_client_pb2.DataResponse], None]
class DataClient:
def __init__(self, channel: "grpc._channel.Channel", client_id: str,
@ -35,6 +36,10 @@ class DataClient:
self.ready_data: Dict[int, Any] = {}
self.cv = threading.Condition()
self.lock = threading.RLock()
# NOTE: Dictionary insertion is guaranteed to complete before lookup
# and/or removal because of synchronization via the request_queue.
self.asyncio_waiting_data: Dict[int, ResponseCallable] = {}
self._req_id = 0
self._client_id = client_id
self._metadata = metadata
@ -66,9 +71,16 @@ class DataClient:
# This is not being waited for.
logger.debug(f"Got unawaited response {response}")
continue
with self.cv:
self.ready_data[response.req_id] = response
self.cv.notify_all()
if response.req_id in self.asyncio_waiting_data:
callback = self.asyncio_waiting_data.pop(response.req_id)
try:
callback(response)
except Exception:
logger.exception("Callback error:")
else:
with self.cv:
self.ready_data[response.req_id] = response
self.cv.notify_all()
except grpc.RpcError as e:
with self.cv:
self._in_shutdown = True
@ -112,9 +124,13 @@ class DataClient:
del self.ready_data[req_id]
return data
def _async_send(self, req: ray_client_pb2.DataRequest) -> None:
def _async_send(self,
req: ray_client_pb2.DataRequest,
callback: Optional[ResponseCallable] = None) -> None:
req_id = self._next_id()
req.req_id = req_id
if callback:
self.asyncio_waiting_data[req_id] = callback
self.request_queue.put(req)
def Init(self, request: ray_client_pb2.InitRequest,
@ -143,6 +159,13 @@ class DataClient:
resp = self._blocking_send(datareq)
return resp.get
def RegisterGetCallback(self,
request: ray_client_pb2.GetRequest,
callback: ResponseCallable,
context=None) -> None:
datareq = ray_client_pb2.DataRequest(get=request, )
self._async_send(datareq, callback)
def PutObject(self, request: ray_client_pb2.PutRequest,
context=None) -> ray_client_pb2.PutResponse:
datareq = ray_client_pb2.DataRequest(put=request, )

View file

@ -1,10 +1,11 @@
import ray
import logging
import grpc
from queue import Queue
import sys
from typing import TYPE_CHECKING
from threading import Lock
from typing import Any, Iterator, TYPE_CHECKING, Union
from threading import Lock, Thread
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
@ -18,6 +19,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
QUEUE_JOIN_SECONDS = 5
def fill_queue(
grpc_input_generator: Iterator[ray_client_pb2.DataRequest],
output_queue:
"Queue[Union[ray_client_pb2.DataRequest, ray_client_pb2.DataResponse]]"
) -> None:
"""
Pushes incoming requests to a shared output_queue.
"""
try:
for req in grpc_input_generator:
output_queue.put(req)
except grpc.RpcError as e:
logger.debug("closing dataservicer reader thread "
f"grpc error reading request_iterator: {e}")
finally:
# Set the sentinel value for the output_queue
output_queue.put(None)
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, basic_service: "RayletServicer"):
@ -28,53 +50,75 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def Datapath(self, request_iterator, context):
metadata = {k: v for k, v in context.invocation_metadata()}
client_id = metadata["client_id"]
accepted_connection = False
if client_id == "":
logger.error("Client connecting with no client_id")
return
logger.debug(f"New data connection from client {client_id}: ")
accepted_connection = self._init(client_id, context)
if not accepted_connection:
return
try:
for req in request_iterator:
request_queue = Queue()
queue_filler_thread = Thread(
target=fill_queue,
daemon=True,
args=(request_iterator, request_queue))
queue_filler_thread.start()
"""For non `async get` requests, this loop yields immediately
For `async get` requests, this loop:
1) does not yield, it just continues
2) When the result is ready, it yields
"""
for req in iter(request_queue.get, None):
if isinstance(req, ray_client_pb2.DataResponse):
# Early shortcut if this is the result of an async get.
yield req
continue
assert isinstance(req, ray_client_pb2.DataRequest)
resp = None
req_type = req.WhichOneof("type")
if req_type == "init":
resp = self._init(req.init, client_id)
if resp is None:
context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
return
logger.debug(f"Accepted data connection from {client_id}. "
f"Total clients: {self.num_clients}")
accepted_connection = True
else:
assert accepted_connection
if req_type == "get":
resp_init = self.basic_service.Init(req.init)
resp = ray_client_pb2.DataResponse(init=resp_init, )
elif req_type == "get":
get_resp = None
if req.get.asynchronous:
get_resp = self.basic_service._async_get_object(
req.get, client_id, req.req_id, request_queue)
if get_resp is None:
# Skip sending a response for this request and
# continue to the next requst. The response for
# this request will be sent when the object is
# ready.
continue
else:
get_resp = self.basic_service._get_object(
req.get, client_id)
resp = ray_client_pb2.DataResponse(get=get_resp)
elif req_type == "put":
put_resp = self.basic_service._put_object(
req.put, client_id)
resp = ray_client_pb2.DataResponse(put=put_resp)
elif req_type == "release":
released = []
for rel_id in req.release.ids:
rel = self.basic_service.release(client_id, rel_id)
released.append(rel)
resp = ray_client_pb2.DataResponse(get=get_resp)
elif req_type == "put":
put_resp = self.basic_service._put_object(
req.put, client_id)
resp = ray_client_pb2.DataResponse(put=put_resp)
elif req_type == "release":
released = []
for rel_id in req.release.ids:
rel = self.basic_service.release(client_id, rel_id)
released.append(rel)
resp = ray_client_pb2.DataResponse(
release=ray_client_pb2.ReleaseResponse(ok=released))
elif req_type == "connection_info":
resp = ray_client_pb2.DataResponse(
connection_info=self._build_connection_response())
elif req_type == "prep_runtime_env":
with self.clients_lock:
resp_prep = self.basic_service.PrepRuntimeEnv(
req.prep_runtime_env)
resp = ray_client_pb2.DataResponse(
release=ray_client_pb2.ReleaseResponse(
ok=released))
elif req_type == "connection_info":
resp = ray_client_pb2.DataResponse(
connection_info=self._build_connection_response())
elif req_type == "prep_runtime_env":
with self.clients_lock:
resp_prep = self.basic_service.PrepRuntimeEnv(
req.prep_runtime_env)
resp = ray_client_pb2.DataResponse(
prep_runtime_env=resp_prep)
else:
raise Exception(f"Unreachable code: Request type "
f"{req_type} not handled in Datapath")
prep_runtime_env=resp_prep)
else:
raise Exception(f"Unreachable code: Request type "
f"{req_type} not handled in Datapath")
resp.req_id = req.req_id
yield resp
except grpc.RpcError as e:
@ -82,12 +126,15 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
finally:
logger.debug(f"Lost data connection from client {client_id}")
self.basic_service.release_all(client_id)
queue_filler_thread.join(QUEUE_JOIN_SECONDS)
if queue_filler_thread.is_alive():
logger.error(
"Queue filler thread failed to join before timeout: {}".
format(QUEUE_JOIN_SECONDS))
with self.clients_lock:
if accepted_connection:
# Could fail before client accounting happens
self.num_clients -= 1
logger.debug(f"Removed clients. {self.num_clients}")
# Could fail before client accounting happens
self.num_clients -= 1
logger.debug(f"Removed clients. {self.num_clients}")
# It's important to keep the Ray shutdown
# within this locked context or else Ray could hang.
@ -96,7 +143,11 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
logger.debug("Shutting down ray.")
ray.shutdown()
def _init(self, req_init, client_id):
def _init(self, client_id: str, context: Any):
"""
Checks if resources allow for another client.
Returns a boolean indicating if initialization was successful.
"""
with self.clients_lock:
threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
if self.num_clients >= threshold:
@ -110,10 +161,13 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
"threshold by setting the "
"RAY_CLIENT_SERVER_MAX_THREADS env var "
f"(currently set to {CLIENT_SERVER_MAX_THREADS}).")
return None
resp_init = self.basic_service.Init(req_init)
context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
return False
self.num_clients += 1
return ray_client_pb2.DataResponse(init=resp_init, )
logger.debug(f"Accepted data connection from {client_id}. "
f"Total clients: {self.num_clients}")
return True
def _build_connection_response(self):
with self.clients_lock:

View file

@ -5,6 +5,7 @@ import base64
from collections import defaultdict
from dataclasses import dataclass
import os
import queue
import sys
import threading
@ -253,6 +254,48 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
"terminate_type")
return ray_client_pb2.TerminateResponse(ok=True)
def _async_get_object(
self,
request,
client_id: str,
req_id: int,
result_queue: queue.Queue,
context=None) -> Optional[ray_client_pb2.GetResponse]:
"""Attempts to schedule a callback to push the GetResponse to the
main loop when the desired object is ready. If there is some failure
in scheduling, a GetResponse will be immediately returned.
"""
if request.id not in self.object_refs[client_id]:
return ray_client_pb2.GetResponse(valid=False)
try:
object_ref = self.object_refs[client_id][request.id]
logger.debug("async get: %s" % object_ref)
with disable_client_hook():
def send_get_response(result: Any) -> None:
"""Pushes a GetResponse to the main DataPath loop to send
to the client. This is called when the object is ready
on the server side."""
try:
serialized = dumps_from_server(result, client_id, self)
get_resp = ray_client_pb2.GetResponse(
valid=True, data=serialized)
except Exception as e:
get_resp = ray_client_pb2.GetResponse(
valid=False, error=cloudpickle.dumps(e))
resp = ray_client_pb2.DataResponse(
get=get_resp, req_id=req_id)
resp.req_id = req_id
result_queue.put(resp)
object_ref._on_completed(send_get_response)
return None
except Exception as e:
return ray_client_pb2.GetResponse(
valid=False, error=cloudpickle.dumps(e))
def GetObject(self, request, context=None):
return self._get_object(request, "", context)

View file

@ -9,6 +9,7 @@ import time
import uuid
from collections import defaultdict
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Tuple
@ -166,6 +167,12 @@ class Worker:
"protocol_version": data.protocol_version,
}
def register_callback(
self, ref: ClientObjectRef,
callback: Callable[[ray_client_pb2.DataResponse], None]) -> None:
req = ray_client_pb2.GetRequest(id=ref.id, asynchronous=True)
self.data_client.RegisterGetCallback(req, callback)
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
to_get = []
single = False

View file

@ -117,6 +117,8 @@ message GetRequest {
bytes id = 1;
// Length of time to wait for data to be available, in seconds. Zero is no timeout.
float timeout = 2;
// Whether to schedule this as a callback on the server side.
bool asynchronous = 3;
}
message GetResponse {