[Serve] Batching in Worker Replica (#8709)

This commit is contained in:
Simon Mo 2020-06-09 11:29:16 -07:00 committed by GitHub
parent f007bfb4cf
commit 6c3062906f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 671 additions and 94 deletions

View file

@ -5,21 +5,101 @@ py_library(
srcs = glob(["**/*.py"], exclude=["tests/*.py"]), srcs = glob(["**/*.py"], exclude=["tests/*.py"]),
) )
# This test aggregates all serve tests and run them in a single session serve_tests_srcs = glob(["tests/*.py"],
# similar to `pytest .` exclude=["tests/test_nonblocking.py",
# Serve tests need to run in a single session because starting and stopping "tests/test_master_crashes.py",
# serve cluster take a large chunk of time. All serve tests use a shared "tests/test_serve.py",
# cluster. ])
py_test( py_test(
name = "test_serve", name = "test_api",
size = "medium", size = "medium",
srcs = glob(["tests/*.py"], srcs = serve_tests_srcs,
exclude=["tests/test_nonblocking.py",
"tests/test_master_crashes.py"]),
tags = ["exclusive"], tags = ["exclusive"],
deps = [":serve_lib"], deps = [":serve_lib"],
) )
py_test(
name = "test_backend_worker",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_config",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_failure",
size = "medium",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_handle",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_kv_store",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_metric",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_persistence",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_router",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
py_test(
name = "test_util",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
# Runs test_api and test_failure with injected failures in the master actor. # Runs test_api and test_failure with injected failures in the master actor.
# TODO(edoakes): reenable this once we're using GCS actor fault tolerance. # TODO(edoakes): reenable this once we're using GCS actor fault tolerance.
# py_test( # py_test(
@ -97,6 +177,14 @@ py_test(
deps = [":serve_lib"] deps = [":serve_lib"]
) )
py_test(
name = "snippet_model_composition",
size = "small",
srcs = glob(["examples/doc/*.py"]),
tags = ["exclusive"],
deps = [":serve_lib"]
)
# Disable the deployment tutorial test because it requires # Disable the deployment tutorial test because it requires
# ray start --head in the background. # ray start --head in the background.
# py_test( # py_test(

View file

@ -236,7 +236,8 @@ def create_backend(backend_tag,
replica_config = ReplicaConfig( replica_config = ReplicaConfig(
func_or_class, *actor_init_args, ray_actor_options=ray_actor_options) func_or_class, *actor_init_args, ray_actor_options=ray_actor_options)
backend_config = BackendConfig(config, replica_config.accepts_batches) backend_config = BackendConfig(config, replica_config.accepts_batches,
replica_config.is_blocking)
ray.get( ray.get(
master_actor.create_backend.remote(backend_tag, backend_config, master_actor.create_backend.remote(backend_tag, backend_config,

View file

@ -1,20 +1,61 @@
import asyncio
import traceback import traceback
import inspect import inspect
from collections.abc import Iterable from collections.abc import Iterable
from collections import defaultdict
from itertools import groupby
from operator import attrgetter
import time
import ray import ray
from ray.async_compat import sync_to_async
from ray import serve from ray import serve
from ray.serve import context as serve_context from ray.serve import context as serve_context
from ray.serve.context import FakeFlaskRequest from ray.serve.context import FakeFlaskRequest
from collections import defaultdict from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
from ray.serve.utils import (parse_request_item, _get_logger) unpack_future)
from ray.serve.exceptions import RayServeException from ray.serve.exceptions import RayServeException
from ray.serve.metric import MetricClient from ray.serve.metric import MetricClient
from ray.async_compat import sync_to_async from ray.serve.config import BackendConfig
from ray.serve.router import Query
logger = _get_logger() logger = _get_logger()
class WaitableQueue(asyncio.Queue):
async def wait_for_batch(self, num_items: int, timeout_s: float):
"""Wait up to num_items in the queue given timeout_s.
This method will block indefinitely for the first item. Therefore, it
guarantees to return at least one item.
"""
assert num_items >= 1
# Wait for the first value without timeout. We will return at least
# one item. Additionally this help the caller context switch on empty
# queue.
start_waiting = time.time()
batch = [
await self.get(),
]
# Adjust the timeout to account for the time waiting for first item.
time_remaining = timeout_s - (time.time() - start_waiting)
time_remaining = max(0, time_remaining)
# Wait for the remaining batch with the timeout
if num_items > 1:
done_set, not_done_set = await asyncio.wait(
[self.get() for _ in range(num_items - 1)],
timeout=time_remaining)
for task in done_set:
batch.append(task.result())
for task in not_done_set:
task.cancel()
return batch
def create_backend_worker(func_or_class): def create_backend_worker(func_or_class):
"""Creates a worker class wrapping the provided function or class.""" """Creates a worker class wrapping the provided function or class."""
@ -30,8 +71,10 @@ def create_backend_worker(func_or_class):
backend_tag, backend_tag,
replica_tag, replica_tag,
init_args, init_args,
backend_config: BackendConfig,
instance_name=None): instance_name=None):
serve.init(name=instance_name) serve.init(name=instance_name)
if is_function: if is_function:
_callable = func_or_class _callable = func_or_class
else: else:
@ -42,11 +85,15 @@ def create_backend_worker(func_or_class):
metric_client = MetricClient( metric_client = MetricClient(
metric_exporter, default_labels={"backend": backend_tag}) metric_exporter, default_labels={"backend": backend_tag})
self.backend = RayServeWorker(backend_tag, replica_tag, _callable, self.backend = RayServeWorker(backend_tag, replica_tag, _callable,
is_function, metric_client) backend_config, is_function,
metric_client)
async def handle_request(self, request): async def handle_request(self, request):
return await self.backend.handle_request(request) return await self.backend.handle_request(request)
def update_config(self, new_config: BackendConfig):
return self.backend.update_config(new_config)
def ready(self): def ready(self):
pass pass
@ -75,13 +122,16 @@ def ensure_async(func):
class RayServeWorker: class RayServeWorker:
"""Handles requests with the provided callable.""" """Handles requests with the provided callable."""
def __init__(self, name, replica_tag, _callable, is_function, def __init__(self, name, replica_tag, _callable,
metric_client): backend_config: BackendConfig, is_function, metric_client):
self.name = name self.name = name
self.replica_tag = replica_tag self.replica_tag = replica_tag
self.callable = _callable self.callable = _callable
self.is_function = is_function self.is_function = is_function
self.config = backend_config
self.query_queue = WaitableQueue()
self.metric_client = metric_client self.metric_client = metric_client
self.request_counter = self.metric_client.new_counter( self.request_counter = self.metric_client.new_counter(
"backend_request_counter", "backend_request_counter",
@ -101,6 +151,9 @@ class RayServeWorker:
self.restart_counter.labels(replica_tag=self.replica_tag).add() self.restart_counter.labels(replica_tag=self.replica_tag).add()
self.loop_task = asyncio.get_event_loop().create_task(self.main_loop())
self.config_updated = asyncio.Event()
def get_runner_method(self, request_item): def get_runner_method(self, request_item):
method_name = request_item.call_method method_name = request_item.call_method
if not hasattr(self.callable, method_name): if not hasattr(self.callable, method_name):
@ -108,6 +161,8 @@ class RayServeWorker:
"which is specified in the request. " "which is specified in the request. "
"The available methods are {}".format( "The available methods are {}".format(
method_name, dir(self.callable))) method_name, dir(self.callable)))
if self.is_function:
return self.callable
return getattr(self.callable, method_name) return getattr(self.callable, method_name)
def has_positional_args(self, f): def has_positional_args(self, f):
@ -124,6 +179,12 @@ class RayServeWorker:
return True return True
return False return False
def _reset_context(self):
# NOTE(simon): context management won't work in async mode because
# many concurrent queries might be running at the same time.
serve_context.web = None
serve_context.batch_size = None
async def invoke_single(self, request_item): async def invoke_single(self, request_item):
args, kwargs, is_web_context = parse_request_item(request_item) args, kwargs, is_web_context = parse_request_item(request_item)
serve_context.web = is_web_context serve_context.web = is_web_context
@ -137,24 +198,12 @@ class RayServeWorker:
except Exception as e: except Exception as e:
result = wrap_to_ray_error(e) result = wrap_to_ray_error(e)
self.error_counter.add() self.error_counter.add()
finally:
self._reset_context()
return result return result
async def invoke_batch(self, request_item_list): async def invoke_batch(self, request_item_list):
# TODO(alind) : create no-http services. The enqueues
# from such services will always be TaskContext.Python.
# Assumption : all the requests in a bacth
# have same serve context.
# For batching kwargs are modified as follows -
# kwargs [Python Context] : key,val
# kwargs_list : key, [val1,val2, ... , valn]
# or
# args[Web Context] : val
# args_list : [val1,val2, ...... , valn]
# where n (current batch size) <= max_batch_size of a backend
arg_list = [] arg_list = []
kwargs_list = defaultdict(list) kwargs_list = defaultdict(list)
context_flags = set() context_flags = set()
@ -222,22 +271,53 @@ class RayServeWorker:
"results with length equal to the batch size" "results with length equal to the batch size"
".".format(batch_size, len(result_list))) ".".format(batch_size, len(result_list)))
raise RayServeException(error_message) raise RayServeException(error_message)
self._reset_context()
return result_list return result_list
except Exception as e: except Exception as e:
wrapped_exception = wrap_to_ray_error(e) wrapped_exception = wrap_to_ray_error(e)
self.error_counter.add() self.error_counter.add()
self._reset_context()
return [wrapped_exception for _ in range(batch_size)] return [wrapped_exception for _ in range(batch_size)]
async def handle_request(self, request): async def main_loop(self):
# check if work_item is a list or not while True:
# if it is list: then batching supported # NOTE(simon): There's an issue when user updated batch size and
if not isinstance(request, list): # batch wait timeout during the execution, these values will not be
result = await self.invoke_single(request) # updated until after the current iteration.
else: batch = await self.query_queue.wait_for_batch(
result = await self.invoke_batch(request) num_items=self.config.max_batch_size or 1,
timeout_s=self.config.batch_wait_timeout)
# re-assign to default values all_evaluated_futures = []
serve_context.web = False
serve_context.batch_size = None
return result if not self.config.accepts_batches:
query = batch[0]
evaluated = asyncio.ensure_future(self.invoke_single(query))
all_evaluated_futures = [evaluated]
chain_future(evaluated, query.async_future)
else:
get_call_method = attrgetter("call_method")
sorted_batch = sorted(batch, key=get_call_method)
for _, group in groupby(sorted_batch, key=get_call_method):
group = sorted(group)
evaluated = asyncio.ensure_future(self.invoke_batch(group))
all_evaluated_futures.append(evaluated)
result_futures = [q.async_future for q in group]
chain_future(
unpack_future(evaluated, len(group)), result_futures)
if self.config.is_blocking:
# We use asyncio.wait here so if the result is exception,
# it will not be raised.
await asyncio.wait(all_evaluated_futures)
def update_config(self, new_config: BackendConfig):
self.config = new_config
self.config_updated.set()
async def handle_request(self, request: Query):
assert not isinstance(request, list)
logger.debug("Worker {} got request {}".format(self.name, request))
request.async_future = asyncio.get_event_loop().create_future()
self.query_queue.put_nowait(request)
return await request.async_future

View file

@ -1,5 +1,7 @@
import inspect import inspect
from ray.serve.constants import ASYNC_CONCURRENCY
def _callable_accepts_batch(func_or_class): def _callable_accepts_batch(func_or_class):
if inspect.isfunction(func_or_class): if inspect.isfunction(func_or_class):
@ -8,15 +10,46 @@ def _callable_accepts_batch(func_or_class):
return hasattr(func_or_class.__call__, "_serve_accept_batch") return hasattr(func_or_class.__call__, "_serve_accept_batch")
def _callable_is_blocking(func_or_class):
if inspect.isfunction(func_or_class):
return not inspect.iscoroutinefunction(func_or_class)
elif inspect.isclass(func_or_class):
return not inspect.iscoroutinefunction(func_or_class.__call__)
class BackendConfig: class BackendConfig:
def __init__(self, config_dict, accepts_batches=False): def __init__(self, config_dict, accepts_batches=False, is_blocking=True):
assert isinstance(config_dict, dict) assert isinstance(config_dict, dict)
# Make a copy so that we don't modify the input dict. # Make a copy so that we don't modify the input dict.
config_dict = config_dict.copy() config_dict = config_dict.copy()
self.accepts_batches = accepts_batches self.accepts_batches = accepts_batches
self.is_blocking = is_blocking
self.num_replicas = config_dict.pop("num_replicas", 1) self.num_replicas = config_dict.pop("num_replicas", 1)
self.max_batch_size = config_dict.pop("max_batch_size", None) self.max_batch_size = config_dict.pop("max_batch_size", None)
self.batch_wait_timeout = config_dict.pop("batch_wait_timeout", 0)
self.max_concurrent_queries = config_dict.pop("max_concurrent_queries",
None)
if self.max_concurrent_queries is None:
# Model serving mode: if the servable is blocking and the wait
# timeout is default zero seconds, then we keep the existing
# behavior to allow at most max batch size queries.
if self.is_blocking and self.batch_wait_timeout == 0:
self.max_concurrent_queries = self.max_batch_size or 1
# Pipeline/async mode: if the servable is not blocking,
# router should just keep pushing queries to the worker
# replicas until a high limit.
if not self.is_blocking:
self.max_concurrent_queries = ASYNC_CONCURRENCY
# Batch inference mode: user specifies non zero timeout to wait for
# full batch. We will use 2*max_batch_size to perform double
# buffering to keep the replica busy.
if self.max_batch_size is not None and self.batch_wait_timeout > 0:
self.max_concurrent_queries = 2 * self.max_batch_size
if len(config_dict) != 0: if len(config_dict) != 0:
raise ValueError("Unknown options in backend config: {}".format( raise ValueError("Unknown options in backend config: {}".format(
list(config_dict.keys()))) list(config_dict.keys())))
@ -64,6 +97,7 @@ class ReplicaConfig:
ray_actor_options=None): ray_actor_options=None):
self.func_or_class = func_or_class self.func_or_class = func_or_class
self.accepts_batches = _callable_accepts_batch(func_or_class) self.accepts_batches = _callable_accepts_batch(func_or_class)
self.is_blocking = _callable_is_blocking(func_or_class)
self.actor_init_args = list(actor_init_args) self.actor_init_args = list(actor_init_args)
if ray_actor_options is None: if ray_actor_options is None:
self.ray_actor_options = {} self.ray_actor_options = {}

View file

@ -0,0 +1,57 @@
from random import random
import requests
from ray import serve
serve.init()
def model_one(_unused_flask_request, data=None):
print("Model 1 called with data ", data)
return random()
def model_two(_unused_flask_request, data=None):
print("Model 2 called with data ", data)
return data
class ComposedModel:
def __init__(self):
self.model_one = serve.get_handle("model_one")
self.model_two = serve.get_handle("model_two")
async def __call__(self, flask_request):
data = flask_request.data
score = await self.model_one.remote(data=data)
if score > 0.5:
result = await self.model_two.remote(data=data)
result = {"model_used": 2, "score": score}
else:
result = {"model_used": 1, "score": score}
return result
serve.create_backend("model_one", model_one)
serve.create_endpoint("model_one", backend="model_one")
serve.create_backend("model_two", model_two)
serve.create_endpoint("model_two", backend="model_two")
serve.create_backend(
"composed_backend", ComposedModel, config={"max_concurrent_queries": 10})
serve.create_endpoint(
"composed", backend="composed_backend", route="/composed")
for _ in range(5):
resp = requests.get("http://127.0.0.1:8000/composed", data="hey!")
print(resp.json())
# Output
# {'model_used': 2, 'score': 0.6250189863595503}
# {'model_used': 1, 'score': 0.03146855349621436}
# {'model_used': 2, 'score': 0.6916977560006987}
# {'model_used': 2, 'score': 0.8169693450866928}
# {'model_used': 2, 'score': 0.9540681979573862}

View file

@ -258,6 +258,7 @@ class ServeMaster:
for backend, (_, backend_config, _) in self.backends.items(): for backend, (_, backend_config, _) in self.backends.items():
await self.router.set_backend_config.remote( await self.router.set_backend_config.remote(
backend, backend_config) backend, backend_config)
await self.broadcast_backend_config(backend)
# Push configuration state to the HTTP proxy. # Push configuration state to the HTTP proxy.
await self.http_proxy.set_route_table.remote(self.routes) await self.http_proxy.set_route_table.remote(self.routes)
@ -314,6 +315,7 @@ class ServeMaster:
backend_tag, backend_tag,
replica_tag, replica_tag,
replica_config.actor_init_args, replica_config.actor_init_args,
backend_config,
instance_name=self.instance_name) instance_name=self.instance_name)
# TODO(edoakes): we should probably have a timeout here. # TODO(edoakes): we should probably have a timeout here.
await worker_handle.ready.remote() await worker_handle.ready.remote()
@ -602,6 +604,7 @@ class ServeMaster:
# (particularly for max-batch-size). # (particularly for max-batch-size).
await self.router.set_backend_config.remote( await self.router.set_backend_config.remote(
backend_tag, backend_config) backend_tag, backend_config)
await self.broadcast_backend_config(backend_tag)
async def delete_backend(self, backend_tag): async def delete_backend(self, backend_tag):
async with self.write_lock: async with self.write_lock:
@ -664,6 +667,22 @@ class ServeMaster:
await self._start_pending_replicas() await self._start_pending_replicas()
await self._stop_pending_replicas() await self._stop_pending_replicas()
await self.broadcast_backend_config(backend_tag)
async def broadcast_backend_config(self, backend_tag):
_, backend_config, _ = self.backends[backend_tag]
broadcast_futures = []
for replica_tag in self.replicas[backend_tag]:
try:
replica = ray.get_actor(replica_tag)
except ValueError:
continue
future = replica.update_config.remote(backend_config).as_future()
broadcast_futures.append(future)
if len(broadcast_futures) > 0:
await asyncio.gather(*broadcast_futures)
def get_backend_config(self, backend_tag): def get_backend_config(self, backend_tag):
"""Get the current config for the specified backend.""" """Get the current config for the specified backend."""
assert (backend_tag in self.backends assert (backend_tag in self.backends

View file

@ -13,7 +13,7 @@ import ray
from ray import serve from ray import serve
from ray.serve.metric import MetricClient from ray.serve.metric import MetricClient
from ray.serve.policy import RandomEndpointPolicy from ray.serve.policy import RandomEndpointPolicy
from ray.serve.utils import logger from ray.serve.utils import logger, chain_future
class Query: class Query:
@ -58,10 +58,6 @@ class Query:
def __lt__(self, other): def __lt__(self, other):
return self.request_slo_ms < other.request_slo_ms return self.request_slo_ms < other.request_slo_ms
def __repr__(self):
return "<Query args={} kwargs={}>".format(self.request_args,
self.request_kwargs)
def _make_future_unwrapper(client_futures: List[asyncio.Future], def _make_future_unwrapper(client_futures: List[asyncio.Future],
host_future: asyncio.Future): host_future: asyncio.Future):
@ -117,6 +113,8 @@ class Router:
self.backend_info = dict() self.backend_info = dict()
# replica tag -> worker_handle # replica tag -> worker_handle
self.replicas = dict() self.replicas = dict()
# replica_tag -> concurrent queries counter
self.queries_counter = defaultdict(lambda: 0)
# -- Synchronization -- # # -- Synchronization -- #
@ -126,7 +124,7 @@ class Router:
# an operation holding the only query and the other flush operation # an operation holding the only query and the other flush operation
# holding the only idle replica. Additionally, allowing only one flush # holding the only idle replica. Additionally, allowing only one flush
# operation at a time simplifies design overhead for custom queuing and # operation at a time simplifies design overhead for custom queuing and
# batching polcies. # batching policies.
self.flush_lock = asyncio.Lock() self.flush_lock = asyncio.Lock()
# -- State Restoration -- # # -- State Restoration -- #
@ -215,11 +213,15 @@ class Router:
await self.mark_worker_idle(backend_tag, backend_replica_tag) await self.mark_worker_idle(backend_tag, backend_replica_tag)
async def mark_worker_idle(self, backend_tag, backend_replica_tag): async def mark_worker_idle(self, backend_tag, backend_replica_tag):
logger.debug(
"Marking backend with tag {} as idle.".format(backend_replica_tag))
if backend_replica_tag not in self.replicas: if backend_replica_tag not in self.replicas:
return return
async with self.flush_lock: async with self.flush_lock:
self.worker_queues[backend_tag].appendleft(backend_replica_tag) # NOTE(simon): This is a O(n) operation where n=len(worker_queue)
if backend_replica_tag not in self.worker_queues[backend_tag]:
self.worker_queues[backend_tag].appendleft(backend_replica_tag)
self.flush_backend_queues([backend_tag]) self.flush_backend_queues([backend_tag])
async def remove_worker(self, backend_tag, replica_tag): async def remove_worker(self, backend_tag, replica_tag):
@ -299,12 +301,11 @@ class Router:
"queue size {} and worker queue size {}".format( "queue size {} and worker queue size {}".format(
backend, len(buffer_queue), len(worker_queue))) backend, len(buffer_queue), len(worker_queue)))
max_batch_size = None self._assign_query_to_worker(
if backend in self.backend_info: backend,
max_batch_size = self.backend_info[backend].max_batch_size buffer_queue,
worker_queue,
self._assign_query_to_worker(backend, buffer_queue, worker_queue, )
max_batch_size)
async def _do_query(self, backend, backend_replica_tag, req): async def _do_query(self, backend, backend_replica_tag, req):
# If the worker died, this will be a RayActorError. Just return it and # If the worker died, this will be a RayActorError. Just return it and
@ -317,16 +318,13 @@ class Router:
except RayTaskError as error: except RayTaskError as error:
self.num_error_backend_request.labels(backend=backend).add() self.num_error_backend_request.labels(backend=backend).add()
result = error result = error
self.queries_counter[backend_replica_tag] -= 1
await self.mark_worker_idle(backend, backend_replica_tag) await self.mark_worker_idle(backend, backend_replica_tag)
logger.debug("Got result in {:.2f}s".format(time.time() - start)) logger.debug("Got result in {:.2f}s".format(time.time() - start))
return result return result
def _assign_query_to_worker(self, def _assign_query_to_worker(self, backend, buffer_queue, worker_queue):
backend, overloaded_replicas = set()
buffer_queue,
worker_queue,
max_batch_size=None):
while len(buffer_queue) and len(worker_queue): while len(buffer_queue) and len(worker_queue):
backend_replica_tag = worker_queue.pop() backend_replica_tag = worker_queue.pop()
@ -334,27 +332,30 @@ class Router:
if backend_replica_tag not in self.replicas: if backend_replica_tag not in self.replicas:
continue continue
if max_batch_size is None: # No batching # We have reached the end of the worker queue where all replicas
request = buffer_queue.pop(0) # are overloaded.
future = asyncio.get_event_loop().create_task( if backend_replica_tag in overloaded_replicas:
self._do_query(backend, backend_replica_tag, request)) break
# chaining satisfies request.async_future with future result.
asyncio.futures._chain_future(future, request.async_future)
else:
real_batch_size = min(len(buffer_queue), max_batch_size)
requests = [
buffer_queue.pop(0) for _ in range(real_batch_size)
]
# split requests by method type # This replica has too many in flight and processing queries.
requests_group = defaultdict(list) max_queries = 1
for request in requests: if backend in self.backend_info:
requests_group[request.call_method].append(request) max_queries = self.backend_info[backend].max_concurrent_queries
curr_queries = self.queries_counter[backend_replica_tag]
if curr_queries >= max_queries:
# Put the worker back to the queue.
worker_queue.appendleft(backend_replica_tag)
overloaded_replicas.add(backend_replica_tag)
logger.debug(
"Skipping backend {} because it has {} in flight "
"requests which exceeded the concurrency limit.".format(
backend, curr_queries))
continue
for group in requests_group.values(): request = buffer_queue.pop(0)
future = asyncio.get_event_loop().create_task( self.queries_counter[backend_replica_tag] += 1
self._do_query(backend, backend_replica_tag, group)) future = asyncio.get_event_loop().create_task(
future.add_done_callback( self._do_query(backend, backend_replica_tag, request))
_make_future_unwrapper( chain_future(future, request.async_future)
client_futures=[req.async_future for req in group],
host_future=future)) worker_queue.appendleft(backend_replica_tag)

View file

@ -507,3 +507,8 @@ def test_endpoint_input_validation(serve_instance):
with pytest.raises(TypeError): with pytest.raises(TypeError):
serve.create_endpoint("endpoint", backend=2) serve.create_endpoint("endpoint", backend=2)
serve.create_endpoint("endpoint", backend="backend") serve.create_endpoint("endpoint", backend="backend")
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -15,7 +15,10 @@ from ray.serve.exceptions import RayServeException
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
def setup_worker(name, func_or_class, init_args=None): def setup_worker(name,
func_or_class,
init_args=None,
backend_config=BackendConfig({})):
if init_args is None: if init_args is None:
init_args = () init_args = ()
@ -23,7 +26,7 @@ def setup_worker(name, func_or_class, init_args=None):
class WorkerActor: class WorkerActor:
def __init__(self): def __init__(self):
self.worker = create_backend_worker(func_or_class)( self.worker = create_backend_worker(func_or_class)(
name, name + ":tag", init_args) name, name + ":tag", init_args, backend_config)
def ready(self): def ready(self):
pass pass
@ -31,6 +34,9 @@ def setup_worker(name, func_or_class, init_args=None):
async def handle_request(self, *args, **kwargs): async def handle_request(self, *args, **kwargs):
return await self.worker.handle_request(*args, **kwargs) return await self.worker.handle_request(*args, **kwargs)
def update_config(self, new_config):
return self.worker.update_config(new_config)
worker = WorkerActor.remote() worker = WorkerActor.remote()
ray.get(worker.ready.remote()) ray.get(worker.ready.remote())
return worker return worker
@ -165,14 +171,16 @@ async def test_task_runner_custom_method_batch(serve_instance):
CONSUMER_NAME = "runner" CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer" PRODUCER_NAME = "producer"
worker = setup_worker(CONSUMER_NAME, Batcher) backend_config = BackendConfig(
{
"max_batch_size": 4,
"batch_wait_timeout": 2
}, accepts_batches=True)
worker = setup_worker(
CONSUMER_NAME, Batcher, backend_config=backend_config)
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
await q.set_backend_config.remote( await q.set_backend_config.remote(CONSUMER_NAME, backend_config)
CONSUMER_NAME,
BackendConfig({
"max_batch_size": 10
}, accepts_batches=True))
def make_request_param(call_method): def make_request_param(call_method):
return RequestMetadata( return RequestMetadata(
@ -200,3 +208,77 @@ async def test_task_runner_custom_method_batch(serve_instance):
np_array = make_request_param("return_np_array") np_array = make_request_param("return_np_array")
result_np_value = await q.enqueue_request.remote(np_array) result_np_value = await q.enqueue_request.remote(np_array)
assert isinstance(result_np_value, np.int32) assert isinstance(result_np_value, np.int32)
async def test_task_runner_perform_batch(serve_instance):
q = ray.remote(Router).remote()
def batcher(*args, **kwargs):
return [serve.context.batch_size] * serve.context.batch_size
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
config = BackendConfig(
{
"max_batch_size": 2,
"batch_wait_timeout": 10
}, accepts_batches=True)
worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
await q.set_backend_config.remote(CONSUMER_NAME, config)
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
my_batch_sizes = await asyncio.gather(
*[q.enqueue_request.remote(query_param) for _ in range(3)])
assert my_batch_sizes == [2, 2, 1]
async def test_task_runner_perform_async(serve_instance):
q = ray.remote(Router).remote()
@ray.remote
class Barrier:
def __init__(self, release_on):
self.release_on = release_on
self.current_waiters = 0
self.event = asyncio.Event()
async def wait(self):
self.current_waiters += 1
if self.current_waiters == self.release_on:
self.event.set()
else:
await self.event.wait()
barrier = Barrier.remote(release_on=10)
async def wait_and_go(*args, **kwargs):
await barrier.wait.remote()
return "done!"
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
config = BackendConfig({"max_concurrent_queries": 10}, is_blocking=False)
worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
await q.set_backend_config.remote(CONSUMER_NAME, config)
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
done, not_done = await asyncio.wait(
[q.enqueue_request.remote(query_param) for _ in range(10)], timeout=10)
assert len(done) == 10
for item in done:
await item == "done!"
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -132,3 +132,8 @@ def test_replica_config_validation():
ReplicaConfig(Class, ray_actor_options={"detached": None}) ReplicaConfig(Class, ray_actor_options={"detached": None})
with pytest.raises(ValueError): with pytest.raises(ValueError):
ReplicaConfig(Class, ray_actor_options={"max_restarts": None}) ReplicaConfig(Class, ray_actor_options={"max_restarts": None})
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -241,3 +241,9 @@ def test_worker_replica_failure(serve_instance):
break break
except TimeoutError: except TimeoutError:
time.sleep(0.1) time.sleep(0.1)
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -33,3 +33,9 @@ def test_handle_in_endpoint(serve_instance):
methods=["GET", "POST"]) methods=["GET", "POST"])
assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello" assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello"
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -38,3 +38,8 @@ def test_ray_internal_kv_collisions(serve_instance):
kv2.put("1", b"-1") kv2.put("1", b"-1")
assert kv2.get("1") == b"-1" assert kv2.get("1") == b"-1"
assert kv1.get("1") == b"1" assert kv1.get("1") == b"1"
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -195,3 +195,8 @@ async def test_system_metric_endpoints(serve_instance):
print("Metric not correct, retrying...") print("Metric not correct, retrying...")
if not success: if not success:
test_metric_endpoint() test_metric_endpoint()
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -20,4 +20,4 @@ def test_nonblocking():
if __name__ == "__main__": if __name__ == "__main__":
import pytest import pytest
sys.exit(pytest.main(["-v", __file__])) sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -33,3 +33,9 @@ serve.create_endpoint("driver", backend="driver", route="/driver")
assert ray.get(handle.remote()) == "OK!" assert ray.get(handle.remote()) == "OK!"
os.remove(path) os.remove(path)
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -7,6 +7,8 @@ import ray
from ray.serve.router import Router from ray.serve.router import Router
from ray.serve.request_params import RequestMetadata from ray.serve.request_params import RequestMetadata
from ray.serve.utils import get_random_letters from ray.serve.utils import get_random_letters
from ray.test_utils import SignalActor
from ray.serve.config import BackendConfig
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
@ -172,3 +174,68 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
calls = await runner.get_all_calls.remote() calls = await runner.get_all_calls.remote()
for call in calls: for call in calls:
assert call.request_args[0] in runner_shard_keys[i] assert call.request_args[0] in runner_shard_keys[i]
async def test_router_use_max_concurrency(serve_instance):
signal = SignalActor.remote()
@ray.remote
class MockWorker:
async def handle_request(self, request):
await signal.wait.remote()
return "DONE"
def ready(self):
pass
class VisibleRouter(Router):
def get_queues(self):
return self.queries_counter, self.backend_queues
worker = MockWorker.remote()
q = ray.remote(VisibleRouter).remote()
BACKEND_NAME = "max-concurrent-test"
config = BackendConfig({"max_concurrent_queries": 1})
await q.set_traffic.remote("svc", {BACKEND_NAME: 1.0})
await q.add_new_worker.remote(BACKEND_NAME, "replica-tag", worker)
await q.set_backend_config.remote(BACKEND_NAME, config)
# We send over two queries
first_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)
second_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)
# Neither queries should be available
with pytest.raises(ray.exceptions.RayTimeoutError):
ray.get([first_query, second_query], timeout=0.2)
# Let's retrieve the router internal state
queries_counter, backend_queues = await q.get_queues.remote()
# There should be just one inflight request
assert queries_counter["max-concurrent-test:replica-tag"] == 1
# The second query is buffered
assert len(backend_queues["max-concurrent-test"]) == 1
# Let's unblock the first query
await signal.send.remote(clear=True)
assert await first_query == "DONE"
# The internal state of router should have changed.
queries_counter, backend_queues = await q.get_queues.remote()
# There should still be one inflight request
assert queries_counter["max-concurrent-test:replica-tag"] == 1
# But there shouldn't be any queries in the queue
assert len(backend_queues["max-concurrent-test"]) == 0
# Unblocking the second query
await signal.send.remote(clear=True)
assert await second_query == "DONE"
# Checking the internal state of the router one more time
queries_counter, backend_queues = await q.get_queues.remote()
assert queries_counter["max-concurrent-test:replica-tag"] == 0
assert len(backend_queues["max-concurrent-test"]) == 0
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -1,8 +1,10 @@
import asyncio
import json import json
import numpy as np import numpy as np
import pytest
from ray.serve.utils import ServeEncoder from ray.serve.utils import ServeEncoder, chain_future, unpack_future
def test_bytes_encoder(): def test_bytes_encoder():
@ -20,3 +22,50 @@ def test_numpy_encoding():
assert json.loads(json.dumps(floats, cls=ServeEncoder)) == data assert json.loads(json.dumps(floats, cls=ServeEncoder)) == data
assert json.loads(json.dumps(ints, cls=ServeEncoder)) == data assert json.loads(json.dumps(ints, cls=ServeEncoder)) == data
assert json.loads(json.dumps(uints, cls=ServeEncoder)) == data assert json.loads(json.dumps(uints, cls=ServeEncoder)) == data
@pytest.mark.asyncio
async def test_future_chaining():
def make():
return asyncio.get_event_loop().create_future()
# Test 1 -> 1 chaining
fut1, fut2 = make(), make()
chain_future(fut1, fut2)
fut1.set_result(1)
assert await fut2 == 1
# Test 1 -> 1 chaining with exception
fut1, fut2 = make(), make()
chain_future(fut1, fut2)
fut1.set_exception(ValueError(""))
with pytest.raises(ValueError):
await fut2
# Test many -> many chaining
src_futs = [make() for _ in range(4)]
dst_futs = [make() for _ in range(4)]
chain_future(src_futs, dst_futs)
[fut.set_result(i) for i, fut in enumerate(src_futs)]
for i, fut in enumerate(dst_futs):
assert await fut == i
# Test 1 -> many unwrapping
batched_future = make()
single_futures = unpack_future(batched_future, 4)
batched_future.set_result(list(range(4)))
for i, fut in enumerate(single_futures):
assert await fut == i
# Test 1 -> many unwrapping with exception
batched_future = make()
single_futures = unpack_future(batched_future, 4)
batched_future.set_exception(ValueError(""))
for future in single_futures:
with pytest.raises(ValueError):
await future
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -1,8 +1,11 @@
import asyncio
from functools import singledispatch
import json import json
import logging import logging
import random import random
import string import string
import time import time
from typing import List
import io import io
import os import os
@ -114,3 +117,59 @@ def format_actor_name(actor_name, instance_name=None):
return actor_name return actor_name
else: else:
return "{}:{}".format(instance_name, actor_name) return "{}:{}".format(instance_name, actor_name)
@singledispatch
def chain_future(src, dst):
"""Base method for chaining futures together.
Chaining futures means the output from source future(s) are written as the
results of the destination future(s). This method can work with the
following inputs:
- src: Future, dst: Future
- src: List[Future], dst: List[Future]
"""
raise NotImplementedError()
@chain_future.register(asyncio.Future)
def _chain_future_single(src: asyncio.Future, dst: asyncio.Future):
asyncio.futures._chain_future(src, dst)
@chain_future.register(list)
def _chain_future_list(src: List[asyncio.Future], dst: List[asyncio.Future]):
if len(src) != len(dst):
raise ValueError(
"Source and destination list doesn't have the same length. "
"Source: {}. Destination: {}.".foramt(len(src), len(dst)))
for s, d in zip(src, dst):
chain_future(s, d)
def unpack_future(src: asyncio.Future, num_items: int) -> List[asyncio.Future]:
"""Unpack the result of source future to num_items futures.
This function takes in a Future and splits its result into many futures. If
the result of the source future is an exception, then all destination
futures will have the same exception.
"""
dest_futures = [
asyncio.get_event_loop().create_future() for _ in range(num_items)
]
def unwrap_callback(fut: asyncio.Future):
exception = fut.exception()
if exception is not None:
[f.set_exception(exception) for f in dest_futures]
return
result = fut.result()
assert len(result) == num_items
for item, future in zip(result, dest_futures):
future.set_result(item)
src.add_done_callback(unwrap_callback)
return dest_futures

View file

@ -232,8 +232,10 @@ class SignalActor:
def __init__(self): def __init__(self):
self.ready_event = asyncio.Event() self.ready_event = asyncio.Event()
def send(self): def send(self, clear=False):
self.ready_event.set() self.ready_event.set()
if clear:
self.ready_event.clear()
async def wait(self, should_wait=True): async def wait(self, should_wait=True):
if should_wait: if should_wait: