mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Client] Asyncio Client, Sync gRPC Server (#15488)
This commit is contained in:
parent
643cf4c755
commit
4db696d365
11 changed files with 260 additions and 58 deletions
|
@ -212,6 +212,7 @@ py_test_module_list(
|
|||
"test_basic.py",
|
||||
"test_basic_2.py",
|
||||
"test_basic_3.py",
|
||||
"test_asyncio.py"
|
||||
],
|
||||
size = "medium",
|
||||
extra_srcs = SRCS,
|
||||
|
|
|
@ -155,7 +155,10 @@ async def test_asyncio_get(ray_start_regular_shared, event_loop):
|
|||
with pytest.raises(ray.exceptions.RayTaskError):
|
||||
await actor.throw_error.remote().as_future()
|
||||
|
||||
kill_actor_and_wait_for_failure(actor)
|
||||
# Wrap in Remote Function to work with Ray client.
|
||||
kill_actor_ref = ray.remote(kill_actor_and_wait_for_failure).remote(actor)
|
||||
ray.get(kill_actor_ref)
|
||||
|
||||
with pytest.raises(ray.exceptions.RayActorError):
|
||||
await actor.echo.remote(1)
|
||||
|
||||
|
@ -256,6 +259,7 @@ def test_async_function_errored(ray_start_regular_shared):
|
|||
ray.get(ref)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_obj_unhandled_errors(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
def f():
|
||||
|
|
|
@ -22,5 +22,23 @@ def test_rllib_integration(ray_start_regular_shared):
|
|||
rock_paper_scissors_multiagent.main()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serve_handle(ray_start_regular_shared):
|
||||
with ray_start_client_server() as ray:
|
||||
from ray import serve
|
||||
_explicitly_enable_client_mode()
|
||||
serve.start(detached=True)
|
||||
|
||||
def hello(request):
|
||||
return "hello"
|
||||
|
||||
serve.create_backend("my_backend", hello, config={"num_replicas": 1})
|
||||
serve.create_endpoint(
|
||||
"my_endpoint", backend="my_backend", route="/hello")
|
||||
handle = serve.get_handle("my_endpoint")
|
||||
assert ray.get(handle.remote()) == "hello"
|
||||
assert await handle.remote() == "hello"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-sv", __file__]))
|
||||
|
|
|
@ -8,7 +8,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# This version string is incremented to indicate breaking changes in the
|
||||
# protocol that require upgrading the client version.
|
||||
CURRENT_PROTOCOL_VERSION = "2021-04-09"
|
||||
CURRENT_PROTOCOL_VERSION = "2021-04-19"
|
||||
|
||||
|
||||
class RayAPIStub:
|
||||
|
|
|
@ -6,7 +6,7 @@ import json
|
|||
import logging
|
||||
|
||||
from ray.util.client.runtime_context import ClientWorkerPropertyAPI
|
||||
from typing import Any, List, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, List, Optional, TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from ray.actor import ActorClass
|
||||
from ray.remote_function import RemoteFunction
|
||||
|
@ -317,3 +317,7 @@ class ClientAPI:
|
|||
"available within Ray remote functions and is not yet "
|
||||
"implemented in the client API.".format(key))
|
||||
return self.__getattribute__(key)
|
||||
|
||||
def _register_callback(self, ref: "ClientObjectRef",
|
||||
callback: Callable[["DataResponse"], None]) -> None:
|
||||
self.worker.register_callback(ref, callback)
|
||||
|
|
|
@ -2,13 +2,16 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
|||
from ray.util.client import ray
|
||||
from ray.util.client.options import validate_options
|
||||
|
||||
import uuid
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
import uuid
|
||||
import inspect
|
||||
from ray.util.inspect import is_cython
|
||||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
@ -60,7 +63,50 @@ class ClientBaseRef:
|
|||
|
||||
|
||||
class ClientObjectRef(ClientBaseRef):
|
||||
pass
|
||||
def __await__(self):
|
||||
return self.as_future().__await__()
|
||||
|
||||
def as_future(self) -> asyncio.Future:
|
||||
return asyncio.wrap_future(self.future())
|
||||
|
||||
def future(self) -> concurrent.futures.Future:
|
||||
fut = concurrent.futures.Future()
|
||||
|
||||
def set_value(data: Any) -> None:
|
||||
"""Schedules a callback to set the exception or result
|
||||
in the Future."""
|
||||
|
||||
if isinstance(data, Exception):
|
||||
fut.set_exception(data)
|
||||
else:
|
||||
fut.set_result(data)
|
||||
|
||||
self._on_completed(set_value)
|
||||
|
||||
# Prevent this object ref from being released.
|
||||
fut.object_ref = self
|
||||
return fut
|
||||
|
||||
def _on_completed(self, py_callback: Callable[[Any], None]) -> None:
|
||||
"""Register a callback that will be called after Object is ready.
|
||||
If the ObjectRef is already ready, the callback will be called soon.
|
||||
The callback should take the result as the only argument. The result
|
||||
can be an exception object in case of task error.
|
||||
"""
|
||||
from ray.util.client.client_pickler import loads_from_server
|
||||
|
||||
def deserialize_obj(resp: ray_client_pb2.DataResponse) -> None:
|
||||
"""Converts from a GetResponse proto to a python object."""
|
||||
obj = resp.get
|
||||
data = None
|
||||
if not obj.valid:
|
||||
data = loads_from_server(resp.get.error)
|
||||
else:
|
||||
data = loads_from_server(resp.get.data)
|
||||
|
||||
py_callback(data)
|
||||
|
||||
ray._register_callback(self, deserialize_obj)
|
||||
|
||||
|
||||
class ClientActorRef(ClientBaseRef):
|
||||
|
|
|
@ -6,8 +6,7 @@ import queue
|
|||
import threading
|
||||
import grpc
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
@ -18,6 +17,8 @@ logger = logging.getLogger(__name__)
|
|||
# number of simultaneous in-flight requests.
|
||||
INT32_MAX = (2**31) - 1
|
||||
|
||||
ResponseCallable = Callable[[ray_client_pb2.DataResponse], None]
|
||||
|
||||
|
||||
class DataClient:
|
||||
def __init__(self, channel: "grpc._channel.Channel", client_id: str,
|
||||
|
@ -35,6 +36,10 @@ class DataClient:
|
|||
self.ready_data: Dict[int, Any] = {}
|
||||
self.cv = threading.Condition()
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# NOTE: Dictionary insertion is guaranteed to complete before lookup
|
||||
# and/or removal because of synchronization via the request_queue.
|
||||
self.asyncio_waiting_data: Dict[int, ResponseCallable] = {}
|
||||
self._req_id = 0
|
||||
self._client_id = client_id
|
||||
self._metadata = metadata
|
||||
|
@ -66,9 +71,16 @@ class DataClient:
|
|||
# This is not being waited for.
|
||||
logger.debug(f"Got unawaited response {response}")
|
||||
continue
|
||||
with self.cv:
|
||||
self.ready_data[response.req_id] = response
|
||||
self.cv.notify_all()
|
||||
if response.req_id in self.asyncio_waiting_data:
|
||||
callback = self.asyncio_waiting_data.pop(response.req_id)
|
||||
try:
|
||||
callback(response)
|
||||
except Exception:
|
||||
logger.exception("Callback error:")
|
||||
else:
|
||||
with self.cv:
|
||||
self.ready_data[response.req_id] = response
|
||||
self.cv.notify_all()
|
||||
except grpc.RpcError as e:
|
||||
with self.cv:
|
||||
self._in_shutdown = True
|
||||
|
@ -112,9 +124,13 @@ class DataClient:
|
|||
del self.ready_data[req_id]
|
||||
return data
|
||||
|
||||
def _async_send(self, req: ray_client_pb2.DataRequest) -> None:
|
||||
def _async_send(self,
|
||||
req: ray_client_pb2.DataRequest,
|
||||
callback: Optional[ResponseCallable] = None) -> None:
|
||||
req_id = self._next_id()
|
||||
req.req_id = req_id
|
||||
if callback:
|
||||
self.asyncio_waiting_data[req_id] = callback
|
||||
self.request_queue.put(req)
|
||||
|
||||
def Init(self, request: ray_client_pb2.InitRequest,
|
||||
|
@ -143,6 +159,13 @@ class DataClient:
|
|||
resp = self._blocking_send(datareq)
|
||||
return resp.get
|
||||
|
||||
def RegisterGetCallback(self,
|
||||
request: ray_client_pb2.GetRequest,
|
||||
callback: ResponseCallable,
|
||||
context=None) -> None:
|
||||
datareq = ray_client_pb2.DataRequest(get=request, )
|
||||
self._async_send(datareq, callback)
|
||||
|
||||
def PutObject(self, request: ray_client_pb2.PutRequest,
|
||||
context=None) -> ray_client_pb2.PutResponse:
|
||||
datareq = ray_client_pb2.DataRequest(put=request, )
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import ray
|
||||
import logging
|
||||
import grpc
|
||||
from queue import Queue
|
||||
import sys
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from threading import Lock
|
||||
from typing import Any, Iterator, TYPE_CHECKING, Union
|
||||
from threading import Lock, Thread
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
@ -18,6 +19,27 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QUEUE_JOIN_SECONDS = 5
|
||||
|
||||
|
||||
def fill_queue(
|
||||
grpc_input_generator: Iterator[ray_client_pb2.DataRequest],
|
||||
output_queue:
|
||||
"Queue[Union[ray_client_pb2.DataRequest, ray_client_pb2.DataResponse]]"
|
||||
) -> None:
|
||||
"""
|
||||
Pushes incoming requests to a shared output_queue.
|
||||
"""
|
||||
try:
|
||||
for req in grpc_input_generator:
|
||||
output_queue.put(req)
|
||||
except grpc.RpcError as e:
|
||||
logger.debug("closing dataservicer reader thread "
|
||||
f"grpc error reading request_iterator: {e}")
|
||||
finally:
|
||||
# Set the sentinel value for the output_queue
|
||||
output_queue.put(None)
|
||||
|
||||
|
||||
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
def __init__(self, basic_service: "RayletServicer"):
|
||||
|
@ -28,53 +50,75 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
def Datapath(self, request_iterator, context):
|
||||
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||
client_id = metadata["client_id"]
|
||||
accepted_connection = False
|
||||
if client_id == "":
|
||||
logger.error("Client connecting with no client_id")
|
||||
return
|
||||
logger.debug(f"New data connection from client {client_id}: ")
|
||||
accepted_connection = self._init(client_id, context)
|
||||
if not accepted_connection:
|
||||
return
|
||||
try:
|
||||
for req in request_iterator:
|
||||
request_queue = Queue()
|
||||
queue_filler_thread = Thread(
|
||||
target=fill_queue,
|
||||
daemon=True,
|
||||
args=(request_iterator, request_queue))
|
||||
queue_filler_thread.start()
|
||||
"""For non `async get` requests, this loop yields immediately
|
||||
For `async get` requests, this loop:
|
||||
1) does not yield, it just continues
|
||||
2) When the result is ready, it yields
|
||||
"""
|
||||
for req in iter(request_queue.get, None):
|
||||
if isinstance(req, ray_client_pb2.DataResponse):
|
||||
# Early shortcut if this is the result of an async get.
|
||||
yield req
|
||||
continue
|
||||
|
||||
assert isinstance(req, ray_client_pb2.DataRequest)
|
||||
resp = None
|
||||
req_type = req.WhichOneof("type")
|
||||
if req_type == "init":
|
||||
resp = self._init(req.init, client_id)
|
||||
if resp is None:
|
||||
context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
|
||||
return
|
||||
logger.debug(f"Accepted data connection from {client_id}. "
|
||||
f"Total clients: {self.num_clients}")
|
||||
accepted_connection = True
|
||||
else:
|
||||
assert accepted_connection
|
||||
if req_type == "get":
|
||||
resp_init = self.basic_service.Init(req.init)
|
||||
resp = ray_client_pb2.DataResponse(init=resp_init, )
|
||||
elif req_type == "get":
|
||||
get_resp = None
|
||||
if req.get.asynchronous:
|
||||
get_resp = self.basic_service._async_get_object(
|
||||
req.get, client_id, req.req_id, request_queue)
|
||||
if get_resp is None:
|
||||
# Skip sending a response for this request and
|
||||
# continue to the next requst. The response for
|
||||
# this request will be sent when the object is
|
||||
# ready.
|
||||
continue
|
||||
else:
|
||||
get_resp = self.basic_service._get_object(
|
||||
req.get, client_id)
|
||||
resp = ray_client_pb2.DataResponse(get=get_resp)
|
||||
elif req_type == "put":
|
||||
put_resp = self.basic_service._put_object(
|
||||
req.put, client_id)
|
||||
resp = ray_client_pb2.DataResponse(put=put_resp)
|
||||
elif req_type == "release":
|
||||
released = []
|
||||
for rel_id in req.release.ids:
|
||||
rel = self.basic_service.release(client_id, rel_id)
|
||||
released.append(rel)
|
||||
resp = ray_client_pb2.DataResponse(get=get_resp)
|
||||
elif req_type == "put":
|
||||
put_resp = self.basic_service._put_object(
|
||||
req.put, client_id)
|
||||
resp = ray_client_pb2.DataResponse(put=put_resp)
|
||||
elif req_type == "release":
|
||||
released = []
|
||||
for rel_id in req.release.ids:
|
||||
rel = self.basic_service.release(client_id, rel_id)
|
||||
released.append(rel)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
release=ray_client_pb2.ReleaseResponse(ok=released))
|
||||
elif req_type == "connection_info":
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
connection_info=self._build_connection_response())
|
||||
elif req_type == "prep_runtime_env":
|
||||
with self.clients_lock:
|
||||
resp_prep = self.basic_service.PrepRuntimeEnv(
|
||||
req.prep_runtime_env)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
release=ray_client_pb2.ReleaseResponse(
|
||||
ok=released))
|
||||
elif req_type == "connection_info":
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
connection_info=self._build_connection_response())
|
||||
elif req_type == "prep_runtime_env":
|
||||
with self.clients_lock:
|
||||
resp_prep = self.basic_service.PrepRuntimeEnv(
|
||||
req.prep_runtime_env)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
prep_runtime_env=resp_prep)
|
||||
else:
|
||||
raise Exception(f"Unreachable code: Request type "
|
||||
f"{req_type} not handled in Datapath")
|
||||
prep_runtime_env=resp_prep)
|
||||
else:
|
||||
raise Exception(f"Unreachable code: Request type "
|
||||
f"{req_type} not handled in Datapath")
|
||||
resp.req_id = req.req_id
|
||||
yield resp
|
||||
except grpc.RpcError as e:
|
||||
|
@ -82,12 +126,15 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
finally:
|
||||
logger.debug(f"Lost data connection from client {client_id}")
|
||||
self.basic_service.release_all(client_id)
|
||||
|
||||
queue_filler_thread.join(QUEUE_JOIN_SECONDS)
|
||||
if queue_filler_thread.is_alive():
|
||||
logger.error(
|
||||
"Queue filler thread failed to join before timeout: {}".
|
||||
format(QUEUE_JOIN_SECONDS))
|
||||
with self.clients_lock:
|
||||
if accepted_connection:
|
||||
# Could fail before client accounting happens
|
||||
self.num_clients -= 1
|
||||
logger.debug(f"Removed clients. {self.num_clients}")
|
||||
# Could fail before client accounting happens
|
||||
self.num_clients -= 1
|
||||
logger.debug(f"Removed clients. {self.num_clients}")
|
||||
|
||||
# It's important to keep the Ray shutdown
|
||||
# within this locked context or else Ray could hang.
|
||||
|
@ -96,7 +143,11 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
logger.debug("Shutting down ray.")
|
||||
ray.shutdown()
|
||||
|
||||
def _init(self, req_init, client_id):
|
||||
def _init(self, client_id: str, context: Any):
|
||||
"""
|
||||
Checks if resources allow for another client.
|
||||
Returns a boolean indicating if initialization was successful.
|
||||
"""
|
||||
with self.clients_lock:
|
||||
threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
|
||||
if self.num_clients >= threshold:
|
||||
|
@ -110,10 +161,13 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
"threshold by setting the "
|
||||
"RAY_CLIENT_SERVER_MAX_THREADS env var "
|
||||
f"(currently set to {CLIENT_SERVER_MAX_THREADS}).")
|
||||
return None
|
||||
resp_init = self.basic_service.Init(req_init)
|
||||
context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
|
||||
return False
|
||||
self.num_clients += 1
|
||||
return ray_client_pb2.DataResponse(init=resp_init, )
|
||||
logger.debug(f"Accepted data connection from {client_id}. "
|
||||
f"Total clients: {self.num_clients}")
|
||||
|
||||
return True
|
||||
|
||||
def _build_connection_response(self):
|
||||
with self.clients_lock:
|
||||
|
|
|
@ -5,6 +5,7 @@ import base64
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import queue
|
||||
import sys
|
||||
|
||||
import threading
|
||||
|
@ -253,6 +254,48 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
"terminate_type")
|
||||
return ray_client_pb2.TerminateResponse(ok=True)
|
||||
|
||||
def _async_get_object(
|
||||
self,
|
||||
request,
|
||||
client_id: str,
|
||||
req_id: int,
|
||||
result_queue: queue.Queue,
|
||||
context=None) -> Optional[ray_client_pb2.GetResponse]:
|
||||
"""Attempts to schedule a callback to push the GetResponse to the
|
||||
main loop when the desired object is ready. If there is some failure
|
||||
in scheduling, a GetResponse will be immediately returned.
|
||||
"""
|
||||
if request.id not in self.object_refs[client_id]:
|
||||
return ray_client_pb2.GetResponse(valid=False)
|
||||
try:
|
||||
object_ref = self.object_refs[client_id][request.id]
|
||||
logger.debug("async get: %s" % object_ref)
|
||||
with disable_client_hook():
|
||||
|
||||
def send_get_response(result: Any) -> None:
|
||||
"""Pushes a GetResponse to the main DataPath loop to send
|
||||
to the client. This is called when the object is ready
|
||||
on the server side."""
|
||||
try:
|
||||
serialized = dumps_from_server(result, client_id, self)
|
||||
get_resp = ray_client_pb2.GetResponse(
|
||||
valid=True, data=serialized)
|
||||
except Exception as e:
|
||||
get_resp = ray_client_pb2.GetResponse(
|
||||
valid=False, error=cloudpickle.dumps(e))
|
||||
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
get=get_resp, req_id=req_id)
|
||||
resp.req_id = req_id
|
||||
|
||||
result_queue.put(resp)
|
||||
|
||||
object_ref._on_completed(send_get_response)
|
||||
return None
|
||||
except Exception as e:
|
||||
return ray_client_pb2.GetResponse(
|
||||
valid=False, error=cloudpickle.dumps(e))
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
return self._get_object(request, "", context)
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import time
|
|||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
@ -166,6 +167,12 @@ class Worker:
|
|||
"protocol_version": data.protocol_version,
|
||||
}
|
||||
|
||||
def register_callback(
|
||||
self, ref: ClientObjectRef,
|
||||
callback: Callable[[ray_client_pb2.DataResponse], None]) -> None:
|
||||
req = ray_client_pb2.GetRequest(id=ref.id, asynchronous=True)
|
||||
self.data_client.RegisterGetCallback(req, callback)
|
||||
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
to_get = []
|
||||
single = False
|
||||
|
|
|
@ -117,6 +117,8 @@ message GetRequest {
|
|||
bytes id = 1;
|
||||
// Length of time to wait for data to be available, in seconds. Zero is no timeout.
|
||||
float timeout = 2;
|
||||
// Whether to schedule this as a callback on the server side.
|
||||
bool asynchronous = 3;
|
||||
}
|
||||
|
||||
message GetResponse {
|
||||
|
|
Loading…
Add table
Reference in a new issue