[Serve] Batching in Worker Replica (#8709)

This commit is contained in:
Simon Mo 2020-06-09 11:29:16 -07:00 committed by GitHub
parent f007bfb4cf
commit 6c3062906f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 671 additions and 94 deletions

View file

@ -5,21 +5,101 @@ py_library(
srcs = glob(["**/*.py"], exclude=["tests/*.py"]),
)
# This test aggregates all serve tests and run them in a single session
# similar to `pytest .`
# Serve tests need to run in a single session because starting and stopping
# serve cluster take a large chunk of time. All serve tests use a shared
# cluster.
serve_tests_srcs = glob(["tests/*.py"],
exclude=["tests/test_nonblocking.py",
"tests/test_master_crashes.py",
"tests/test_serve.py",
])
py_test(
name = "test_serve",
name = "test_api",
size = "medium",
srcs = glob(["tests/*.py"],
exclude=["tests/test_nonblocking.py",
"tests/test_master_crashes.py"]),
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_backend_worker",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_config",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_failure",
size = "medium",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_handle",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_kv_store",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_metric",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_persistence",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_router",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_util",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
# Runs test_api and test_failure with injected failures in the master actor.
# TODO(edoakes): reenable this once we're using GCS actor fault tolerance.
# py_test(
@ -97,6 +177,14 @@ py_test(
deps = [":serve_lib"]
)
py_test(
name = "snippet_model_composition",
size = "small",
srcs = glob(["examples/doc/*.py"]),
tags = ["exclusive"],
deps = [":serve_lib"]
)
# Disable the deployment tutorial test because it requires
# ray start --head in the background.
# py_test(

View file

@ -236,7 +236,8 @@ def create_backend(backend_tag,
replica_config = ReplicaConfig(
func_or_class, *actor_init_args, ray_actor_options=ray_actor_options)
backend_config = BackendConfig(config, replica_config.accepts_batches)
backend_config = BackendConfig(config, replica_config.accepts_batches,
replica_config.is_blocking)
ray.get(
master_actor.create_backend.remote(backend_tag, backend_config,

View file

@ -1,20 +1,61 @@
import asyncio
import traceback
import inspect
from collections.abc import Iterable
from collections import defaultdict
from itertools import groupby
from operator import attrgetter
import time
import ray
from ray.async_compat import sync_to_async
from ray import serve
from ray.serve import context as serve_context
from ray.serve.context import FakeFlaskRequest
from collections import defaultdict
from ray.serve.utils import (parse_request_item, _get_logger)
from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
unpack_future)
from ray.serve.exceptions import RayServeException
from ray.serve.metric import MetricClient
from ray.async_compat import sync_to_async
from ray.serve.config import BackendConfig
from ray.serve.router import Query
logger = _get_logger()
class WaitableQueue(asyncio.Queue):
async def wait_for_batch(self, num_items: int, timeout_s: float):
"""Wait up to num_items in the queue given timeout_s.
This method will block indefinitely for the first item. Therefore, it
guarantees to return at least one item.
"""
assert num_items >= 1
# Wait for the first value without timeout. We will return at least
# one item. Additionally this help the caller context switch on empty
# queue.
start_waiting = time.time()
batch = [
await self.get(),
]
# Adjust the timeout to account for the time waiting for first item.
time_remaining = timeout_s - (time.time() - start_waiting)
time_remaining = max(0, time_remaining)
# Wait for the remaining batch with the timeout
if num_items > 1:
done_set, not_done_set = await asyncio.wait(
[self.get() for _ in range(num_items - 1)],
timeout=time_remaining)
for task in done_set:
batch.append(task.result())
for task in not_done_set:
task.cancel()
return batch
def create_backend_worker(func_or_class):
"""Creates a worker class wrapping the provided function or class."""
@ -30,8 +71,10 @@ def create_backend_worker(func_or_class):
backend_tag,
replica_tag,
init_args,
backend_config: BackendConfig,
instance_name=None):
serve.init(name=instance_name)
if is_function:
_callable = func_or_class
else:
@ -42,11 +85,15 @@ def create_backend_worker(func_or_class):
metric_client = MetricClient(
metric_exporter, default_labels={"backend": backend_tag})
self.backend = RayServeWorker(backend_tag, replica_tag, _callable,
is_function, metric_client)
backend_config, is_function,
metric_client)
async def handle_request(self, request):
return await self.backend.handle_request(request)
def update_config(self, new_config: BackendConfig):
return self.backend.update_config(new_config)
def ready(self):
pass
@ -75,13 +122,16 @@ def ensure_async(func):
class RayServeWorker:
"""Handles requests with the provided callable."""
def __init__(self, name, replica_tag, _callable, is_function,
metric_client):
def __init__(self, name, replica_tag, _callable,
backend_config: BackendConfig, is_function, metric_client):
self.name = name
self.replica_tag = replica_tag
self.callable = _callable
self.is_function = is_function
self.config = backend_config
self.query_queue = WaitableQueue()
self.metric_client = metric_client
self.request_counter = self.metric_client.new_counter(
"backend_request_counter",
@ -101,6 +151,9 @@ class RayServeWorker:
self.restart_counter.labels(replica_tag=self.replica_tag).add()
self.loop_task = asyncio.get_event_loop().create_task(self.main_loop())
self.config_updated = asyncio.Event()
def get_runner_method(self, request_item):
method_name = request_item.call_method
if not hasattr(self.callable, method_name):
@ -108,6 +161,8 @@ class RayServeWorker:
"which is specified in the request. "
"The available methods are {}".format(
method_name, dir(self.callable)))
if self.is_function:
return self.callable
return getattr(self.callable, method_name)
def has_positional_args(self, f):
@ -124,6 +179,12 @@ class RayServeWorker:
return True
return False
def _reset_context(self):
# NOTE(simon): context management won't work in async mode because
# many concurrent queries might be running at the same time.
serve_context.web = None
serve_context.batch_size = None
async def invoke_single(self, request_item):
args, kwargs, is_web_context = parse_request_item(request_item)
serve_context.web = is_web_context
@ -137,24 +198,12 @@ class RayServeWorker:
except Exception as e:
result = wrap_to_ray_error(e)
self.error_counter.add()
finally:
self._reset_context()
return result
async def invoke_batch(self, request_item_list):
# TODO(alind) : create no-http services. The enqueues
# from such services will always be TaskContext.Python.
# Assumption : all the requests in a bacth
# have same serve context.
# For batching kwargs are modified as follows -
# kwargs [Python Context] : key,val
# kwargs_list : key, [val1,val2, ... , valn]
# or
# args[Web Context] : val
# args_list : [val1,val2, ...... , valn]
# where n (current batch size) <= max_batch_size of a backend
arg_list = []
kwargs_list = defaultdict(list)
context_flags = set()
@ -222,22 +271,53 @@ class RayServeWorker:
"results with length equal to the batch size"
".".format(batch_size, len(result_list)))
raise RayServeException(error_message)
self._reset_context()
return result_list
except Exception as e:
wrapped_exception = wrap_to_ray_error(e)
self.error_counter.add()
self._reset_context()
return [wrapped_exception for _ in range(batch_size)]
async def handle_request(self, request):
# check if work_item is a list or not
# if it is list: then batching supported
if not isinstance(request, list):
result = await self.invoke_single(request)
else:
result = await self.invoke_batch(request)
async def main_loop(self):
while True:
# NOTE(simon): There's an issue when user updated batch size and
# batch wait timeout during the execution, these values will not be
# updated until after the current iteration.
batch = await self.query_queue.wait_for_batch(
num_items=self.config.max_batch_size or 1,
timeout_s=self.config.batch_wait_timeout)
# re-assign to default values
serve_context.web = False
serve_context.batch_size = None
all_evaluated_futures = []
return result
if not self.config.accepts_batches:
query = batch[0]
evaluated = asyncio.ensure_future(self.invoke_single(query))
all_evaluated_futures = [evaluated]
chain_future(evaluated, query.async_future)
else:
get_call_method = attrgetter("call_method")
sorted_batch = sorted(batch, key=get_call_method)
for _, group in groupby(sorted_batch, key=get_call_method):
group = sorted(group)
evaluated = asyncio.ensure_future(self.invoke_batch(group))
all_evaluated_futures.append(evaluated)
result_futures = [q.async_future for q in group]
chain_future(
unpack_future(evaluated, len(group)), result_futures)
if self.config.is_blocking:
# We use asyncio.wait here so if the result is exception,
# it will not be raised.
await asyncio.wait(all_evaluated_futures)
def update_config(self, new_config: BackendConfig):
self.config = new_config
self.config_updated.set()
async def handle_request(self, request: Query):
assert not isinstance(request, list)
logger.debug("Worker {} got request {}".format(self.name, request))
request.async_future = asyncio.get_event_loop().create_future()
self.query_queue.put_nowait(request)
return await request.async_future

View file

@ -1,5 +1,7 @@
import inspect
from ray.serve.constants import ASYNC_CONCURRENCY
def _callable_accepts_batch(func_or_class):
if inspect.isfunction(func_or_class):
@ -8,15 +10,46 @@ def _callable_accepts_batch(func_or_class):
return hasattr(func_or_class.__call__, "_serve_accept_batch")
def _callable_is_blocking(func_or_class):
if inspect.isfunction(func_or_class):
return not inspect.iscoroutinefunction(func_or_class)
elif inspect.isclass(func_or_class):
return not inspect.iscoroutinefunction(func_or_class.__call__)
class BackendConfig:
def __init__(self, config_dict, accepts_batches=False):
def __init__(self, config_dict, accepts_batches=False, is_blocking=True):
assert isinstance(config_dict, dict)
# Make a copy so that we don't modify the input dict.
config_dict = config_dict.copy()
self.accepts_batches = accepts_batches
self.is_blocking = is_blocking
self.num_replicas = config_dict.pop("num_replicas", 1)
self.max_batch_size = config_dict.pop("max_batch_size", None)
self.batch_wait_timeout = config_dict.pop("batch_wait_timeout", 0)
self.max_concurrent_queries = config_dict.pop("max_concurrent_queries",
None)
if self.max_concurrent_queries is None:
# Model serving mode: if the servable is blocking and the wait
# timeout is default zero seconds, then we keep the existing
# behavior to allow at most max batch size queries.
if self.is_blocking and self.batch_wait_timeout == 0:
self.max_concurrent_queries = self.max_batch_size or 1
# Pipeline/async mode: if the servable is not blocking,
# router should just keep pushing queries to the worker
# replicas until a high limit.
if not self.is_blocking:
self.max_concurrent_queries = ASYNC_CONCURRENCY
# Batch inference mode: user specifies non zero timeout to wait for
# full batch. We will use 2*max_batch_size to perform double
# buffering to keep the replica busy.
if self.max_batch_size is not None and self.batch_wait_timeout > 0:
self.max_concurrent_queries = 2 * self.max_batch_size
if len(config_dict) != 0:
raise ValueError("Unknown options in backend config: {}".format(
list(config_dict.keys())))
@ -64,6 +97,7 @@ class ReplicaConfig:
ray_actor_options=None):
self.func_or_class = func_or_class
self.accepts_batches = _callable_accepts_batch(func_or_class)
self.is_blocking = _callable_is_blocking(func_or_class)
self.actor_init_args = list(actor_init_args)
if ray_actor_options is None:
self.ray_actor_options = {}

View file

@ -0,0 +1,57 @@
from random import random
import requests
from ray import serve
serve.init()
def model_one(_unused_flask_request, data=None):
print("Model 1 called with data ", data)
return random()
def model_two(_unused_flask_request, data=None):
print("Model 2 called with data ", data)
return data
class ComposedModel:
def __init__(self):
self.model_one = serve.get_handle("model_one")
self.model_two = serve.get_handle("model_two")
async def __call__(self, flask_request):
data = flask_request.data
score = await self.model_one.remote(data=data)
if score > 0.5:
result = await self.model_two.remote(data=data)
result = {"model_used": 2, "score": score}
else:
result = {"model_used": 1, "score": score}
return result
serve.create_backend("model_one", model_one)
serve.create_endpoint("model_one", backend="model_one")
serve.create_backend("model_two", model_two)
serve.create_endpoint("model_two", backend="model_two")
serve.create_backend(
"composed_backend", ComposedModel, config={"max_concurrent_queries": 10})
serve.create_endpoint(
"composed", backend="composed_backend", route="/composed")
for _ in range(5):
resp = requests.get("http://127.0.0.1:8000/composed", data="hey!")
print(resp.json())
# Output
# {'model_used': 2, 'score': 0.6250189863595503}
# {'model_used': 1, 'score': 0.03146855349621436}
# {'model_used': 2, 'score': 0.6916977560006987}
# {'model_used': 2, 'score': 0.8169693450866928}
# {'model_used': 2, 'score': 0.9540681979573862}

View file

@ -258,6 +258,7 @@ class ServeMaster:
for backend, (_, backend_config, _) in self.backends.items():
await self.router.set_backend_config.remote(
backend, backend_config)
await self.broadcast_backend_config(backend)
# Push configuration state to the HTTP proxy.
await self.http_proxy.set_route_table.remote(self.routes)
@ -314,6 +315,7 @@ class ServeMaster:
backend_tag,
replica_tag,
replica_config.actor_init_args,
backend_config,
instance_name=self.instance_name)
# TODO(edoakes): we should probably have a timeout here.
await worker_handle.ready.remote()
@ -602,6 +604,7 @@ class ServeMaster:
# (particularly for max-batch-size).
await self.router.set_backend_config.remote(
backend_tag, backend_config)
await self.broadcast_backend_config(backend_tag)
async def delete_backend(self, backend_tag):
async with self.write_lock:
@ -664,6 +667,22 @@ class ServeMaster:
await self._start_pending_replicas()
await self._stop_pending_replicas()
await self.broadcast_backend_config(backend_tag)
async def broadcast_backend_config(self, backend_tag):
_, backend_config, _ = self.backends[backend_tag]
broadcast_futures = []
for replica_tag in self.replicas[backend_tag]:
try:
replica = ray.get_actor(replica_tag)
except ValueError:
continue
future = replica.update_config.remote(backend_config).as_future()
broadcast_futures.append(future)
if len(broadcast_futures) > 0:
await asyncio.gather(*broadcast_futures)
def get_backend_config(self, backend_tag):
"""Get the current config for the specified backend."""
assert (backend_tag in self.backends

View file

@ -13,7 +13,7 @@ import ray
from ray import serve
from ray.serve.metric import MetricClient
from ray.serve.policy import RandomEndpointPolicy
from ray.serve.utils import logger
from ray.serve.utils import logger, chain_future
class Query:
@ -58,10 +58,6 @@ class Query:
def __lt__(self, other):
return self.request_slo_ms < other.request_slo_ms
def __repr__(self):
return "<Query args={} kwargs={}>".format(self.request_args,
self.request_kwargs)
def _make_future_unwrapper(client_futures: List[asyncio.Future],
host_future: asyncio.Future):
@ -117,6 +113,8 @@ class Router:
self.backend_info = dict()
# replica tag -> worker_handle
self.replicas = dict()
# replica_tag -> concurrent queries counter
self.queries_counter = defaultdict(lambda: 0)
# -- Synchronization -- #
@ -126,7 +124,7 @@ class Router:
# an operation holding the only query and the other flush operation
# holding the only idle replica. Additionally, allowing only one flush
# operation at a time simplifies design overhead for custom queuing and
# batching polcies.
# batching policies.
self.flush_lock = asyncio.Lock()
# -- State Restoration -- #
@ -215,11 +213,15 @@ class Router:
await self.mark_worker_idle(backend_tag, backend_replica_tag)
async def mark_worker_idle(self, backend_tag, backend_replica_tag):
logger.debug(
"Marking backend with tag {} as idle.".format(backend_replica_tag))
if backend_replica_tag not in self.replicas:
return
async with self.flush_lock:
self.worker_queues[backend_tag].appendleft(backend_replica_tag)
# NOTE(simon): This is a O(n) operation where n=len(worker_queue)
if backend_replica_tag not in self.worker_queues[backend_tag]:
self.worker_queues[backend_tag].appendleft(backend_replica_tag)
self.flush_backend_queues([backend_tag])
async def remove_worker(self, backend_tag, replica_tag):
@ -299,12 +301,11 @@ class Router:
"queue size {} and worker queue size {}".format(
backend, len(buffer_queue), len(worker_queue)))
max_batch_size = None
if backend in self.backend_info:
max_batch_size = self.backend_info[backend].max_batch_size
self._assign_query_to_worker(backend, buffer_queue, worker_queue,
max_batch_size)
self._assign_query_to_worker(
backend,
buffer_queue,
worker_queue,
)
async def _do_query(self, backend, backend_replica_tag, req):
# If the worker died, this will be a RayActorError. Just return it and
@ -317,16 +318,13 @@ class Router:
except RayTaskError as error:
self.num_error_backend_request.labels(backend=backend).add()
result = error
self.queries_counter[backend_replica_tag] -= 1
await self.mark_worker_idle(backend, backend_replica_tag)
logger.debug("Got result in {:.2f}s".format(time.time() - start))
return result
def _assign_query_to_worker(self,
backend,
buffer_queue,
worker_queue,
max_batch_size=None):
def _assign_query_to_worker(self, backend, buffer_queue, worker_queue):
overloaded_replicas = set()
while len(buffer_queue) and len(worker_queue):
backend_replica_tag = worker_queue.pop()
@ -334,27 +332,30 @@ class Router:
if backend_replica_tag not in self.replicas:
continue
if max_batch_size is None: # No batching
request = buffer_queue.pop(0)
future = asyncio.get_event_loop().create_task(
self._do_query(backend, backend_replica_tag, request))
# chaining satisfies request.async_future with future result.
asyncio.futures._chain_future(future, request.async_future)
else:
real_batch_size = min(len(buffer_queue), max_batch_size)
requests = [
buffer_queue.pop(0) for _ in range(real_batch_size)
]
# We have reached the end of the worker queue where all replicas
# are overloaded.
if backend_replica_tag in overloaded_replicas:
break
# split requests by method type
requests_group = defaultdict(list)
for request in requests:
requests_group[request.call_method].append(request)
# This replica has too many in flight and processing queries.
max_queries = 1
if backend in self.backend_info:
max_queries = self.backend_info[backend].max_concurrent_queries
curr_queries = self.queries_counter[backend_replica_tag]
if curr_queries >= max_queries:
# Put the worker back to the queue.
worker_queue.appendleft(backend_replica_tag)
overloaded_replicas.add(backend_replica_tag)
logger.debug(
"Skipping backend {} because it has {} in flight "
"requests which exceeded the concurrency limit.".format(
backend, curr_queries))
continue
for group in requests_group.values():
future = asyncio.get_event_loop().create_task(
self._do_query(backend, backend_replica_tag, group))
future.add_done_callback(
_make_future_unwrapper(
client_futures=[req.async_future for req in group],
host_future=future))
request = buffer_queue.pop(0)
self.queries_counter[backend_replica_tag] += 1
future = asyncio.get_event_loop().create_task(
self._do_query(backend, backend_replica_tag, request))
chain_future(future, request.async_future)
worker_queue.appendleft(backend_replica_tag)

View file

@ -507,3 +507,8 @@ def test_endpoint_input_validation(serve_instance):
with pytest.raises(TypeError):
serve.create_endpoint("endpoint", backend=2)
serve.create_endpoint("endpoint", backend="backend")
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -15,7 +15,10 @@ from ray.serve.exceptions import RayServeException
pytestmark = pytest.mark.asyncio
def setup_worker(name, func_or_class, init_args=None):
def setup_worker(name,
func_or_class,
init_args=None,
backend_config=BackendConfig({})):
if init_args is None:
init_args = ()
@ -23,7 +26,7 @@ def setup_worker(name, func_or_class, init_args=None):
class WorkerActor:
def __init__(self):
self.worker = create_backend_worker(func_or_class)(
name, name + ":tag", init_args)
name, name + ":tag", init_args, backend_config)
def ready(self):
pass
@ -31,6 +34,9 @@ def setup_worker(name, func_or_class, init_args=None):
async def handle_request(self, *args, **kwargs):
return await self.worker.handle_request(*args, **kwargs)
def update_config(self, new_config):
return self.worker.update_config(new_config)
worker = WorkerActor.remote()
ray.get(worker.ready.remote())
return worker
@ -165,14 +171,16 @@ async def test_task_runner_custom_method_batch(serve_instance):
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
worker = setup_worker(CONSUMER_NAME, Batcher)
backend_config = BackendConfig(
{
"max_batch_size": 4,
"batch_wait_timeout": 2
}, accepts_batches=True)
worker = setup_worker(
CONSUMER_NAME, Batcher, backend_config=backend_config)
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
await q.set_backend_config.remote(
CONSUMER_NAME,
BackendConfig({
"max_batch_size": 10
}, accepts_batches=True))
await q.set_backend_config.remote(CONSUMER_NAME, backend_config)
def make_request_param(call_method):
return RequestMetadata(
@ -200,3 +208,77 @@ async def test_task_runner_custom_method_batch(serve_instance):
np_array = make_request_param("return_np_array")
result_np_value = await q.enqueue_request.remote(np_array)
assert isinstance(result_np_value, np.int32)
async def test_task_runner_perform_batch(serve_instance):
q = ray.remote(Router).remote()
def batcher(*args, **kwargs):
return [serve.context.batch_size] * serve.context.batch_size
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
config = BackendConfig(
{
"max_batch_size": 2,
"batch_wait_timeout": 10
}, accepts_batches=True)
worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
await q.set_backend_config.remote(CONSUMER_NAME, config)
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
my_batch_sizes = await asyncio.gather(
*[q.enqueue_request.remote(query_param) for _ in range(3)])
assert my_batch_sizes == [2, 2, 1]
async def test_task_runner_perform_async(serve_instance):
q = ray.remote(Router).remote()
@ray.remote
class Barrier:
def __init__(self, release_on):
self.release_on = release_on
self.current_waiters = 0
self.event = asyncio.Event()
async def wait(self):
self.current_waiters += 1
if self.current_waiters == self.release_on:
self.event.set()
else:
await self.event.wait()
barrier = Barrier.remote(release_on=10)
async def wait_and_go(*args, **kwargs):
await barrier.wait.remote()
return "done!"
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
config = BackendConfig({"max_concurrent_queries": 10}, is_blocking=False)
worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
await q.set_backend_config.remote(CONSUMER_NAME, config)
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
done, not_done = await asyncio.wait(
[q.enqueue_request.remote(query_param) for _ in range(10)], timeout=10)
assert len(done) == 10
for item in done:
await item == "done!"
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -132,3 +132,8 @@ def test_replica_config_validation():
ReplicaConfig(Class, ray_actor_options={"detached": None})
with pytest.raises(ValueError):
ReplicaConfig(Class, ray_actor_options={"max_restarts": None})
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -241,3 +241,9 @@ def test_worker_replica_failure(serve_instance):
break
except TimeoutError:
time.sleep(0.1)
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -33,3 +33,9 @@ def test_handle_in_endpoint(serve_instance):
methods=["GET", "POST"])
assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello"
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -38,3 +38,8 @@ def test_ray_internal_kv_collisions(serve_instance):
kv2.put("1", b"-1")
assert kv2.get("1") == b"-1"
assert kv1.get("1") == b"1"
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -195,3 +195,8 @@ async def test_system_metric_endpoints(serve_instance):
print("Metric not correct, retrying...")
if not success:
test_metric_endpoint()
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -20,4 +20,4 @@ def test_nonblocking():
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -33,3 +33,9 @@ serve.create_endpoint("driver", backend="driver", route="/driver")
assert ray.get(handle.remote()) == "OK!"
os.remove(path)
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -7,6 +7,8 @@ import ray
from ray.serve.router import Router
from ray.serve.request_params import RequestMetadata
from ray.serve.utils import get_random_letters
from ray.test_utils import SignalActor
from ray.serve.config import BackendConfig
pytestmark = pytest.mark.asyncio
@ -172,3 +174,68 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
calls = await runner.get_all_calls.remote()
for call in calls:
assert call.request_args[0] in runner_shard_keys[i]
async def test_router_use_max_concurrency(serve_instance):
signal = SignalActor.remote()
@ray.remote
class MockWorker:
async def handle_request(self, request):
await signal.wait.remote()
return "DONE"
def ready(self):
pass
class VisibleRouter(Router):
def get_queues(self):
return self.queries_counter, self.backend_queues
worker = MockWorker.remote()
q = ray.remote(VisibleRouter).remote()
BACKEND_NAME = "max-concurrent-test"
config = BackendConfig({"max_concurrent_queries": 1})
await q.set_traffic.remote("svc", {BACKEND_NAME: 1.0})
await q.add_new_worker.remote(BACKEND_NAME, "replica-tag", worker)
await q.set_backend_config.remote(BACKEND_NAME, config)
# We send over two queries
first_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)
second_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)
# Neither queries should be available
with pytest.raises(ray.exceptions.RayTimeoutError):
ray.get([first_query, second_query], timeout=0.2)
# Let's retrieve the router internal state
queries_counter, backend_queues = await q.get_queues.remote()
# There should be just one inflight request
assert queries_counter["max-concurrent-test:replica-tag"] == 1
# The second query is buffered
assert len(backend_queues["max-concurrent-test"]) == 1
# Let's unblock the first query
await signal.send.remote(clear=True)
assert await first_query == "DONE"
# The internal state of router should have changed.
queries_counter, backend_queues = await q.get_queues.remote()
# There should still be one inflight request
assert queries_counter["max-concurrent-test:replica-tag"] == 1
# But there shouldn't be any queries in the queue
assert len(backend_queues["max-concurrent-test"]) == 0
# Unblocking the second query
await signal.send.remote(clear=True)
assert await second_query == "DONE"
# Checking the internal state of the router one more time
queries_counter, backend_queues = await q.get_queues.remote()
assert queries_counter["max-concurrent-test:replica-tag"] == 0
assert len(backend_queues["max-concurrent-test"]) == 0
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -1,8 +1,10 @@
import asyncio
import json
import numpy as np
import pytest
from ray.serve.utils import ServeEncoder
from ray.serve.utils import ServeEncoder, chain_future, unpack_future
def test_bytes_encoder():
@ -20,3 +22,50 @@ def test_numpy_encoding():
assert json.loads(json.dumps(floats, cls=ServeEncoder)) == data
assert json.loads(json.dumps(ints, cls=ServeEncoder)) == data
assert json.loads(json.dumps(uints, cls=ServeEncoder)) == data
@pytest.mark.asyncio
async def test_future_chaining():
def make():
return asyncio.get_event_loop().create_future()
# Test 1 -> 1 chaining
fut1, fut2 = make(), make()
chain_future(fut1, fut2)
fut1.set_result(1)
assert await fut2 == 1
# Test 1 -> 1 chaining with exception
fut1, fut2 = make(), make()
chain_future(fut1, fut2)
fut1.set_exception(ValueError(""))
with pytest.raises(ValueError):
await fut2
# Test many -> many chaining
src_futs = [make() for _ in range(4)]
dst_futs = [make() for _ in range(4)]
chain_future(src_futs, dst_futs)
[fut.set_result(i) for i, fut in enumerate(src_futs)]
for i, fut in enumerate(dst_futs):
assert await fut == i
# Test 1 -> many unwrapping
batched_future = make()
single_futures = unpack_future(batched_future, 4)
batched_future.set_result(list(range(4)))
for i, fut in enumerate(single_futures):
assert await fut == i
# Test 1 -> many unwrapping with exception
batched_future = make()
single_futures = unpack_future(batched_future, 4)
batched_future.set_exception(ValueError(""))
for future in single_futures:
with pytest.raises(ValueError):
await future
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -1,8 +1,11 @@
import asyncio
from functools import singledispatch
import json
import logging
import random
import string
import time
from typing import List
import io
import os
@ -114,3 +117,59 @@ def format_actor_name(actor_name, instance_name=None):
return actor_name
else:
return "{}:{}".format(instance_name, actor_name)
@singledispatch
def chain_future(src, dst):
"""Base method for chaining futures together.
Chaining futures means the output from source future(s) are written as the
results of the destination future(s). This method can work with the
following inputs:
- src: Future, dst: Future
- src: List[Future], dst: List[Future]
"""
raise NotImplementedError()
@chain_future.register(asyncio.Future)
def _chain_future_single(src: asyncio.Future, dst: asyncio.Future):
asyncio.futures._chain_future(src, dst)
@chain_future.register(list)
def _chain_future_list(src: List[asyncio.Future], dst: List[asyncio.Future]):
if len(src) != len(dst):
raise ValueError(
"Source and destination list doesn't have the same length. "
"Source: {}. Destination: {}.".foramt(len(src), len(dst)))
for s, d in zip(src, dst):
chain_future(s, d)
def unpack_future(src: asyncio.Future, num_items: int) -> List[asyncio.Future]:
"""Unpack the result of source future to num_items futures.
This function takes in a Future and splits its result into many futures. If
the result of the source future is an exception, then all destination
futures will have the same exception.
"""
dest_futures = [
asyncio.get_event_loop().create_future() for _ in range(num_items)
]
def unwrap_callback(fut: asyncio.Future):
exception = fut.exception()
if exception is not None:
[f.set_exception(exception) for f in dest_futures]
return
result = fut.result()
assert len(result) == num_items
for item, future in zip(result, dest_futures):
future.set_result(item)
src.add_done_callback(unwrap_callback)
return dest_futures

View file

@ -232,8 +232,10 @@ class SignalActor:
def __init__(self):
self.ready_event = asyncio.Event()
def send(self):
def send(self, clear=False):
self.ready_event.set()
if clear:
self.ready_event.clear()
async def wait(self, should_wait=True):
if should_wait: