mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Serve] Batching in Worker Replica (#8709)
This commit is contained in:
parent
f007bfb4cf
commit
6c3062906f
20 changed files with 671 additions and 94 deletions
|
@ -5,21 +5,101 @@ py_library(
|
|||
srcs = glob(["**/*.py"], exclude=["tests/*.py"]),
|
||||
)
|
||||
|
||||
# This test aggregates all serve tests and run them in a single session
|
||||
# similar to `pytest .`
|
||||
# Serve tests need to run in a single session because starting and stopping
|
||||
# serve cluster take a large chunk of time. All serve tests use a shared
|
||||
# cluster.
|
||||
serve_tests_srcs = glob(["tests/*.py"],
|
||||
exclude=["tests/test_nonblocking.py",
|
||||
"tests/test_master_crashes.py",
|
||||
"tests/test_serve.py",
|
||||
])
|
||||
|
||||
py_test(
|
||||
name = "test_serve",
|
||||
name = "test_api",
|
||||
size = "medium",
|
||||
srcs = glob(["tests/*.py"],
|
||||
exclude=["tests/test_nonblocking.py",
|
||||
"tests/test_master_crashes.py"]),
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_backend_worker",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
|
||||
py_test(
|
||||
name = "test_config",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
|
||||
py_test(
|
||||
name = "test_failure",
|
||||
size = "medium",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
|
||||
py_test(
|
||||
name = "test_handle",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
|
||||
py_test(
|
||||
name = "test_kv_store",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
|
||||
py_test(
|
||||
name = "test_metric",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
|
||||
py_test(
|
||||
name = "test_persistence",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
|
||||
py_test(
|
||||
name = "test_router",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
|
||||
py_test(
|
||||
name = "test_util",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
|
||||
# Runs test_api and test_failure with injected failures in the master actor.
|
||||
# TODO(edoakes): reenable this once we're using GCS actor fault tolerance.
|
||||
# py_test(
|
||||
|
@ -97,6 +177,14 @@ py_test(
|
|||
deps = [":serve_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "snippet_model_composition",
|
||||
size = "small",
|
||||
srcs = glob(["examples/doc/*.py"]),
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"]
|
||||
)
|
||||
|
||||
# Disable the deployment tutorial test because it requires
|
||||
# ray start --head in the background.
|
||||
# py_test(
|
||||
|
|
|
@ -236,7 +236,8 @@ def create_backend(backend_tag,
|
|||
|
||||
replica_config = ReplicaConfig(
|
||||
func_or_class, *actor_init_args, ray_actor_options=ray_actor_options)
|
||||
backend_config = BackendConfig(config, replica_config.accepts_batches)
|
||||
backend_config = BackendConfig(config, replica_config.accepts_batches,
|
||||
replica_config.is_blocking)
|
||||
|
||||
ray.get(
|
||||
master_actor.create_backend.remote(backend_tag, backend_config,
|
||||
|
|
|
@ -1,20 +1,61 @@
|
|||
import asyncio
|
||||
import traceback
|
||||
import inspect
|
||||
from collections.abc import Iterable
|
||||
from collections import defaultdict
|
||||
from itertools import groupby
|
||||
from operator import attrgetter
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.async_compat import sync_to_async
|
||||
|
||||
from ray import serve
|
||||
from ray.serve import context as serve_context
|
||||
from ray.serve.context import FakeFlaskRequest
|
||||
from collections import defaultdict
|
||||
from ray.serve.utils import (parse_request_item, _get_logger)
|
||||
from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
|
||||
unpack_future)
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.metric import MetricClient
|
||||
from ray.async_compat import sync_to_async
|
||||
from ray.serve.config import BackendConfig
|
||||
from ray.serve.router import Query
|
||||
|
||||
logger = _get_logger()
|
||||
|
||||
|
||||
class WaitableQueue(asyncio.Queue):
|
||||
async def wait_for_batch(self, num_items: int, timeout_s: float):
|
||||
"""Wait up to num_items in the queue given timeout_s.
|
||||
|
||||
This method will block indefinitely for the first item. Therefore, it
|
||||
guarantees to return at least one item.
|
||||
"""
|
||||
|
||||
assert num_items >= 1
|
||||
# Wait for the first value without timeout. We will return at least
|
||||
# one item. Additionally this help the caller context switch on empty
|
||||
# queue.
|
||||
start_waiting = time.time()
|
||||
batch = [
|
||||
await self.get(),
|
||||
]
|
||||
|
||||
# Adjust the timeout to account for the time waiting for first item.
|
||||
time_remaining = timeout_s - (time.time() - start_waiting)
|
||||
time_remaining = max(0, time_remaining)
|
||||
|
||||
# Wait for the remaining batch with the timeout
|
||||
if num_items > 1:
|
||||
done_set, not_done_set = await asyncio.wait(
|
||||
[self.get() for _ in range(num_items - 1)],
|
||||
timeout=time_remaining)
|
||||
for task in done_set:
|
||||
batch.append(task.result())
|
||||
for task in not_done_set:
|
||||
task.cancel()
|
||||
return batch
|
||||
|
||||
|
||||
def create_backend_worker(func_or_class):
|
||||
"""Creates a worker class wrapping the provided function or class."""
|
||||
|
||||
|
@ -30,8 +71,10 @@ def create_backend_worker(func_or_class):
|
|||
backend_tag,
|
||||
replica_tag,
|
||||
init_args,
|
||||
backend_config: BackendConfig,
|
||||
instance_name=None):
|
||||
serve.init(name=instance_name)
|
||||
|
||||
if is_function:
|
||||
_callable = func_or_class
|
||||
else:
|
||||
|
@ -42,11 +85,15 @@ def create_backend_worker(func_or_class):
|
|||
metric_client = MetricClient(
|
||||
metric_exporter, default_labels={"backend": backend_tag})
|
||||
self.backend = RayServeWorker(backend_tag, replica_tag, _callable,
|
||||
is_function, metric_client)
|
||||
backend_config, is_function,
|
||||
metric_client)
|
||||
|
||||
async def handle_request(self, request):
|
||||
return await self.backend.handle_request(request)
|
||||
|
||||
def update_config(self, new_config: BackendConfig):
|
||||
return self.backend.update_config(new_config)
|
||||
|
||||
def ready(self):
|
||||
pass
|
||||
|
||||
|
@ -75,13 +122,16 @@ def ensure_async(func):
|
|||
class RayServeWorker:
|
||||
"""Handles requests with the provided callable."""
|
||||
|
||||
def __init__(self, name, replica_tag, _callable, is_function,
|
||||
metric_client):
|
||||
def __init__(self, name, replica_tag, _callable,
|
||||
backend_config: BackendConfig, is_function, metric_client):
|
||||
self.name = name
|
||||
self.replica_tag = replica_tag
|
||||
self.callable = _callable
|
||||
self.is_function = is_function
|
||||
|
||||
self.config = backend_config
|
||||
self.query_queue = WaitableQueue()
|
||||
|
||||
self.metric_client = metric_client
|
||||
self.request_counter = self.metric_client.new_counter(
|
||||
"backend_request_counter",
|
||||
|
@ -101,6 +151,9 @@ class RayServeWorker:
|
|||
|
||||
self.restart_counter.labels(replica_tag=self.replica_tag).add()
|
||||
|
||||
self.loop_task = asyncio.get_event_loop().create_task(self.main_loop())
|
||||
self.config_updated = asyncio.Event()
|
||||
|
||||
def get_runner_method(self, request_item):
|
||||
method_name = request_item.call_method
|
||||
if not hasattr(self.callable, method_name):
|
||||
|
@ -108,6 +161,8 @@ class RayServeWorker:
|
|||
"which is specified in the request. "
|
||||
"The available methods are {}".format(
|
||||
method_name, dir(self.callable)))
|
||||
if self.is_function:
|
||||
return self.callable
|
||||
return getattr(self.callable, method_name)
|
||||
|
||||
def has_positional_args(self, f):
|
||||
|
@ -124,6 +179,12 @@ class RayServeWorker:
|
|||
return True
|
||||
return False
|
||||
|
||||
def _reset_context(self):
|
||||
# NOTE(simon): context management won't work in async mode because
|
||||
# many concurrent queries might be running at the same time.
|
||||
serve_context.web = None
|
||||
serve_context.batch_size = None
|
||||
|
||||
async def invoke_single(self, request_item):
|
||||
args, kwargs, is_web_context = parse_request_item(request_item)
|
||||
serve_context.web = is_web_context
|
||||
|
@ -137,24 +198,12 @@ class RayServeWorker:
|
|||
except Exception as e:
|
||||
result = wrap_to_ray_error(e)
|
||||
self.error_counter.add()
|
||||
finally:
|
||||
self._reset_context()
|
||||
|
||||
return result
|
||||
|
||||
async def invoke_batch(self, request_item_list):
|
||||
# TODO(alind) : create no-http services. The enqueues
|
||||
# from such services will always be TaskContext.Python.
|
||||
|
||||
# Assumption : all the requests in a bacth
|
||||
# have same serve context.
|
||||
|
||||
# For batching kwargs are modified as follows -
|
||||
# kwargs [Python Context] : key,val
|
||||
# kwargs_list : key, [val1,val2, ... , valn]
|
||||
# or
|
||||
# args[Web Context] : val
|
||||
# args_list : [val1,val2, ...... , valn]
|
||||
# where n (current batch size) <= max_batch_size of a backend
|
||||
|
||||
arg_list = []
|
||||
kwargs_list = defaultdict(list)
|
||||
context_flags = set()
|
||||
|
@ -222,22 +271,53 @@ class RayServeWorker:
|
|||
"results with length equal to the batch size"
|
||||
".".format(batch_size, len(result_list)))
|
||||
raise RayServeException(error_message)
|
||||
self._reset_context()
|
||||
return result_list
|
||||
except Exception as e:
|
||||
wrapped_exception = wrap_to_ray_error(e)
|
||||
self.error_counter.add()
|
||||
self._reset_context()
|
||||
return [wrapped_exception for _ in range(batch_size)]
|
||||
|
||||
async def handle_request(self, request):
|
||||
# check if work_item is a list or not
|
||||
# if it is list: then batching supported
|
||||
if not isinstance(request, list):
|
||||
result = await self.invoke_single(request)
|
||||
else:
|
||||
result = await self.invoke_batch(request)
|
||||
async def main_loop(self):
|
||||
while True:
|
||||
# NOTE(simon): There's an issue when user updated batch size and
|
||||
# batch wait timeout during the execution, these values will not be
|
||||
# updated until after the current iteration.
|
||||
batch = await self.query_queue.wait_for_batch(
|
||||
num_items=self.config.max_batch_size or 1,
|
||||
timeout_s=self.config.batch_wait_timeout)
|
||||
|
||||
# re-assign to default values
|
||||
serve_context.web = False
|
||||
serve_context.batch_size = None
|
||||
all_evaluated_futures = []
|
||||
|
||||
return result
|
||||
if not self.config.accepts_batches:
|
||||
query = batch[0]
|
||||
evaluated = asyncio.ensure_future(self.invoke_single(query))
|
||||
all_evaluated_futures = [evaluated]
|
||||
chain_future(evaluated, query.async_future)
|
||||
else:
|
||||
get_call_method = attrgetter("call_method")
|
||||
sorted_batch = sorted(batch, key=get_call_method)
|
||||
for _, group in groupby(sorted_batch, key=get_call_method):
|
||||
group = sorted(group)
|
||||
evaluated = asyncio.ensure_future(self.invoke_batch(group))
|
||||
all_evaluated_futures.append(evaluated)
|
||||
result_futures = [q.async_future for q in group]
|
||||
chain_future(
|
||||
unpack_future(evaluated, len(group)), result_futures)
|
||||
|
||||
if self.config.is_blocking:
|
||||
# We use asyncio.wait here so if the result is exception,
|
||||
# it will not be raised.
|
||||
await asyncio.wait(all_evaluated_futures)
|
||||
|
||||
def update_config(self, new_config: BackendConfig):
|
||||
self.config = new_config
|
||||
self.config_updated.set()
|
||||
|
||||
async def handle_request(self, request: Query):
|
||||
assert not isinstance(request, list)
|
||||
logger.debug("Worker {} got request {}".format(self.name, request))
|
||||
request.async_future = asyncio.get_event_loop().create_future()
|
||||
self.query_queue.put_nowait(request)
|
||||
return await request.async_future
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import inspect
|
||||
|
||||
from ray.serve.constants import ASYNC_CONCURRENCY
|
||||
|
||||
|
||||
def _callable_accepts_batch(func_or_class):
|
||||
if inspect.isfunction(func_or_class):
|
||||
|
@ -8,15 +10,46 @@ def _callable_accepts_batch(func_or_class):
|
|||
return hasattr(func_or_class.__call__, "_serve_accept_batch")
|
||||
|
||||
|
||||
def _callable_is_blocking(func_or_class):
|
||||
if inspect.isfunction(func_or_class):
|
||||
return not inspect.iscoroutinefunction(func_or_class)
|
||||
elif inspect.isclass(func_or_class):
|
||||
return not inspect.iscoroutinefunction(func_or_class.__call__)
|
||||
|
||||
|
||||
class BackendConfig:
|
||||
def __init__(self, config_dict, accepts_batches=False):
|
||||
def __init__(self, config_dict, accepts_batches=False, is_blocking=True):
|
||||
assert isinstance(config_dict, dict)
|
||||
# Make a copy so that we don't modify the input dict.
|
||||
config_dict = config_dict.copy()
|
||||
|
||||
self.accepts_batches = accepts_batches
|
||||
self.is_blocking = is_blocking
|
||||
self.num_replicas = config_dict.pop("num_replicas", 1)
|
||||
self.max_batch_size = config_dict.pop("max_batch_size", None)
|
||||
self.batch_wait_timeout = config_dict.pop("batch_wait_timeout", 0)
|
||||
self.max_concurrent_queries = config_dict.pop("max_concurrent_queries",
|
||||
None)
|
||||
|
||||
if self.max_concurrent_queries is None:
|
||||
# Model serving mode: if the servable is blocking and the wait
|
||||
# timeout is default zero seconds, then we keep the existing
|
||||
# behavior to allow at most max batch size queries.
|
||||
if self.is_blocking and self.batch_wait_timeout == 0:
|
||||
self.max_concurrent_queries = self.max_batch_size or 1
|
||||
|
||||
# Pipeline/async mode: if the servable is not blocking,
|
||||
# router should just keep pushing queries to the worker
|
||||
# replicas until a high limit.
|
||||
if not self.is_blocking:
|
||||
self.max_concurrent_queries = ASYNC_CONCURRENCY
|
||||
|
||||
# Batch inference mode: user specifies non zero timeout to wait for
|
||||
# full batch. We will use 2*max_batch_size to perform double
|
||||
# buffering to keep the replica busy.
|
||||
if self.max_batch_size is not None and self.batch_wait_timeout > 0:
|
||||
self.max_concurrent_queries = 2 * self.max_batch_size
|
||||
|
||||
if len(config_dict) != 0:
|
||||
raise ValueError("Unknown options in backend config: {}".format(
|
||||
list(config_dict.keys())))
|
||||
|
@ -64,6 +97,7 @@ class ReplicaConfig:
|
|||
ray_actor_options=None):
|
||||
self.func_or_class = func_or_class
|
||||
self.accepts_batches = _callable_accepts_batch(func_or_class)
|
||||
self.is_blocking = _callable_is_blocking(func_or_class)
|
||||
self.actor_init_args = list(actor_init_args)
|
||||
if ray_actor_options is None:
|
||||
self.ray_actor_options = {}
|
||||
|
|
57
python/ray/serve/examples/doc/snippet_model_composition.py
Normal file
57
python/ray/serve/examples/doc/snippet_model_composition.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
from random import random
|
||||
|
||||
import requests
|
||||
|
||||
from ray import serve
|
||||
|
||||
serve.init()
|
||||
|
||||
|
||||
def model_one(_unused_flask_request, data=None):
|
||||
print("Model 1 called with data ", data)
|
||||
return random()
|
||||
|
||||
|
||||
def model_two(_unused_flask_request, data=None):
|
||||
print("Model 2 called with data ", data)
|
||||
return data
|
||||
|
||||
|
||||
class ComposedModel:
|
||||
def __init__(self):
|
||||
self.model_one = serve.get_handle("model_one")
|
||||
self.model_two = serve.get_handle("model_two")
|
||||
|
||||
async def __call__(self, flask_request):
|
||||
data = flask_request.data
|
||||
|
||||
score = await self.model_one.remote(data=data)
|
||||
if score > 0.5:
|
||||
result = await self.model_two.remote(data=data)
|
||||
result = {"model_used": 2, "score": score}
|
||||
else:
|
||||
result = {"model_used": 1, "score": score}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
serve.create_backend("model_one", model_one)
|
||||
serve.create_endpoint("model_one", backend="model_one")
|
||||
|
||||
serve.create_backend("model_two", model_two)
|
||||
serve.create_endpoint("model_two", backend="model_two")
|
||||
|
||||
serve.create_backend(
|
||||
"composed_backend", ComposedModel, config={"max_concurrent_queries": 10})
|
||||
serve.create_endpoint(
|
||||
"composed", backend="composed_backend", route="/composed")
|
||||
|
||||
for _ in range(5):
|
||||
resp = requests.get("http://127.0.0.1:8000/composed", data="hey!")
|
||||
print(resp.json())
|
||||
# Output
|
||||
# {'model_used': 2, 'score': 0.6250189863595503}
|
||||
# {'model_used': 1, 'score': 0.03146855349621436}
|
||||
# {'model_used': 2, 'score': 0.6916977560006987}
|
||||
# {'model_used': 2, 'score': 0.8169693450866928}
|
||||
# {'model_used': 2, 'score': 0.9540681979573862}
|
|
@ -258,6 +258,7 @@ class ServeMaster:
|
|||
for backend, (_, backend_config, _) in self.backends.items():
|
||||
await self.router.set_backend_config.remote(
|
||||
backend, backend_config)
|
||||
await self.broadcast_backend_config(backend)
|
||||
|
||||
# Push configuration state to the HTTP proxy.
|
||||
await self.http_proxy.set_route_table.remote(self.routes)
|
||||
|
@ -314,6 +315,7 @@ class ServeMaster:
|
|||
backend_tag,
|
||||
replica_tag,
|
||||
replica_config.actor_init_args,
|
||||
backend_config,
|
||||
instance_name=self.instance_name)
|
||||
# TODO(edoakes): we should probably have a timeout here.
|
||||
await worker_handle.ready.remote()
|
||||
|
@ -602,6 +604,7 @@ class ServeMaster:
|
|||
# (particularly for max-batch-size).
|
||||
await self.router.set_backend_config.remote(
|
||||
backend_tag, backend_config)
|
||||
await self.broadcast_backend_config(backend_tag)
|
||||
|
||||
async def delete_backend(self, backend_tag):
|
||||
async with self.write_lock:
|
||||
|
@ -664,6 +667,22 @@ class ServeMaster:
|
|||
await self._start_pending_replicas()
|
||||
await self._stop_pending_replicas()
|
||||
|
||||
await self.broadcast_backend_config(backend_tag)
|
||||
|
||||
async def broadcast_backend_config(self, backend_tag):
|
||||
_, backend_config, _ = self.backends[backend_tag]
|
||||
broadcast_futures = []
|
||||
for replica_tag in self.replicas[backend_tag]:
|
||||
try:
|
||||
replica = ray.get_actor(replica_tag)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
future = replica.update_config.remote(backend_config).as_future()
|
||||
broadcast_futures.append(future)
|
||||
if len(broadcast_futures) > 0:
|
||||
await asyncio.gather(*broadcast_futures)
|
||||
|
||||
def get_backend_config(self, backend_tag):
|
||||
"""Get the current config for the specified backend."""
|
||||
assert (backend_tag in self.backends
|
||||
|
|
|
@ -13,7 +13,7 @@ import ray
|
|||
from ray import serve
|
||||
from ray.serve.metric import MetricClient
|
||||
from ray.serve.policy import RandomEndpointPolicy
|
||||
from ray.serve.utils import logger
|
||||
from ray.serve.utils import logger, chain_future
|
||||
|
||||
|
||||
class Query:
|
||||
|
@ -58,10 +58,6 @@ class Query:
|
|||
def __lt__(self, other):
|
||||
return self.request_slo_ms < other.request_slo_ms
|
||||
|
||||
def __repr__(self):
|
||||
return "<Query args={} kwargs={}>".format(self.request_args,
|
||||
self.request_kwargs)
|
||||
|
||||
|
||||
def _make_future_unwrapper(client_futures: List[asyncio.Future],
|
||||
host_future: asyncio.Future):
|
||||
|
@ -117,6 +113,8 @@ class Router:
|
|||
self.backend_info = dict()
|
||||
# replica tag -> worker_handle
|
||||
self.replicas = dict()
|
||||
# replica_tag -> concurrent queries counter
|
||||
self.queries_counter = defaultdict(lambda: 0)
|
||||
|
||||
# -- Synchronization -- #
|
||||
|
||||
|
@ -126,7 +124,7 @@ class Router:
|
|||
# an operation holding the only query and the other flush operation
|
||||
# holding the only idle replica. Additionally, allowing only one flush
|
||||
# operation at a time simplifies design overhead for custom queuing and
|
||||
# batching polcies.
|
||||
# batching policies.
|
||||
self.flush_lock = asyncio.Lock()
|
||||
|
||||
# -- State Restoration -- #
|
||||
|
@ -215,11 +213,15 @@ class Router:
|
|||
await self.mark_worker_idle(backend_tag, backend_replica_tag)
|
||||
|
||||
async def mark_worker_idle(self, backend_tag, backend_replica_tag):
|
||||
logger.debug(
|
||||
"Marking backend with tag {} as idle.".format(backend_replica_tag))
|
||||
if backend_replica_tag not in self.replicas:
|
||||
return
|
||||
|
||||
async with self.flush_lock:
|
||||
self.worker_queues[backend_tag].appendleft(backend_replica_tag)
|
||||
# NOTE(simon): This is a O(n) operation where n=len(worker_queue)
|
||||
if backend_replica_tag not in self.worker_queues[backend_tag]:
|
||||
self.worker_queues[backend_tag].appendleft(backend_replica_tag)
|
||||
self.flush_backend_queues([backend_tag])
|
||||
|
||||
async def remove_worker(self, backend_tag, replica_tag):
|
||||
|
@ -299,12 +301,11 @@ class Router:
|
|||
"queue size {} and worker queue size {}".format(
|
||||
backend, len(buffer_queue), len(worker_queue)))
|
||||
|
||||
max_batch_size = None
|
||||
if backend in self.backend_info:
|
||||
max_batch_size = self.backend_info[backend].max_batch_size
|
||||
|
||||
self._assign_query_to_worker(backend, buffer_queue, worker_queue,
|
||||
max_batch_size)
|
||||
self._assign_query_to_worker(
|
||||
backend,
|
||||
buffer_queue,
|
||||
worker_queue,
|
||||
)
|
||||
|
||||
async def _do_query(self, backend, backend_replica_tag, req):
|
||||
# If the worker died, this will be a RayActorError. Just return it and
|
||||
|
@ -317,16 +318,13 @@ class Router:
|
|||
except RayTaskError as error:
|
||||
self.num_error_backend_request.labels(backend=backend).add()
|
||||
result = error
|
||||
self.queries_counter[backend_replica_tag] -= 1
|
||||
await self.mark_worker_idle(backend, backend_replica_tag)
|
||||
logger.debug("Got result in {:.2f}s".format(time.time() - start))
|
||||
return result
|
||||
|
||||
def _assign_query_to_worker(self,
|
||||
backend,
|
||||
buffer_queue,
|
||||
worker_queue,
|
||||
max_batch_size=None):
|
||||
|
||||
def _assign_query_to_worker(self, backend, buffer_queue, worker_queue):
|
||||
overloaded_replicas = set()
|
||||
while len(buffer_queue) and len(worker_queue):
|
||||
backend_replica_tag = worker_queue.pop()
|
||||
|
||||
|
@ -334,27 +332,30 @@ class Router:
|
|||
if backend_replica_tag not in self.replicas:
|
||||
continue
|
||||
|
||||
if max_batch_size is None: # No batching
|
||||
request = buffer_queue.pop(0)
|
||||
future = asyncio.get_event_loop().create_task(
|
||||
self._do_query(backend, backend_replica_tag, request))
|
||||
# chaining satisfies request.async_future with future result.
|
||||
asyncio.futures._chain_future(future, request.async_future)
|
||||
else:
|
||||
real_batch_size = min(len(buffer_queue), max_batch_size)
|
||||
requests = [
|
||||
buffer_queue.pop(0) for _ in range(real_batch_size)
|
||||
]
|
||||
# We have reached the end of the worker queue where all replicas
|
||||
# are overloaded.
|
||||
if backend_replica_tag in overloaded_replicas:
|
||||
break
|
||||
|
||||
# split requests by method type
|
||||
requests_group = defaultdict(list)
|
||||
for request in requests:
|
||||
requests_group[request.call_method].append(request)
|
||||
# This replica has too many in flight and processing queries.
|
||||
max_queries = 1
|
||||
if backend in self.backend_info:
|
||||
max_queries = self.backend_info[backend].max_concurrent_queries
|
||||
curr_queries = self.queries_counter[backend_replica_tag]
|
||||
if curr_queries >= max_queries:
|
||||
# Put the worker back to the queue.
|
||||
worker_queue.appendleft(backend_replica_tag)
|
||||
overloaded_replicas.add(backend_replica_tag)
|
||||
logger.debug(
|
||||
"Skipping backend {} because it has {} in flight "
|
||||
"requests which exceeded the concurrency limit.".format(
|
||||
backend, curr_queries))
|
||||
continue
|
||||
|
||||
for group in requests_group.values():
|
||||
future = asyncio.get_event_loop().create_task(
|
||||
self._do_query(backend, backend_replica_tag, group))
|
||||
future.add_done_callback(
|
||||
_make_future_unwrapper(
|
||||
client_futures=[req.async_future for req in group],
|
||||
host_future=future))
|
||||
request = buffer_queue.pop(0)
|
||||
self.queries_counter[backend_replica_tag] += 1
|
||||
future = asyncio.get_event_loop().create_task(
|
||||
self._do_query(backend, backend_replica_tag, request))
|
||||
chain_future(future, request.async_future)
|
||||
|
||||
worker_queue.appendleft(backend_replica_tag)
|
||||
|
|
|
@ -507,3 +507,8 @@ def test_endpoint_input_validation(serve_instance):
|
|||
with pytest.raises(TypeError):
|
||||
serve.create_endpoint("endpoint", backend=2)
|
||||
serve.create_endpoint("endpoint", backend="backend")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -15,7 +15,10 @@ from ray.serve.exceptions import RayServeException
|
|||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def setup_worker(name, func_or_class, init_args=None):
|
||||
def setup_worker(name,
|
||||
func_or_class,
|
||||
init_args=None,
|
||||
backend_config=BackendConfig({})):
|
||||
if init_args is None:
|
||||
init_args = ()
|
||||
|
||||
|
@ -23,7 +26,7 @@ def setup_worker(name, func_or_class, init_args=None):
|
|||
class WorkerActor:
|
||||
def __init__(self):
|
||||
self.worker = create_backend_worker(func_or_class)(
|
||||
name, name + ":tag", init_args)
|
||||
name, name + ":tag", init_args, backend_config)
|
||||
|
||||
def ready(self):
|
||||
pass
|
||||
|
@ -31,6 +34,9 @@ def setup_worker(name, func_or_class, init_args=None):
|
|||
async def handle_request(self, *args, **kwargs):
|
||||
return await self.worker.handle_request(*args, **kwargs)
|
||||
|
||||
def update_config(self, new_config):
|
||||
return self.worker.update_config(new_config)
|
||||
|
||||
worker = WorkerActor.remote()
|
||||
ray.get(worker.ready.remote())
|
||||
return worker
|
||||
|
@ -165,14 +171,16 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
|||
CONSUMER_NAME = "runner"
|
||||
PRODUCER_NAME = "producer"
|
||||
|
||||
worker = setup_worker(CONSUMER_NAME, Batcher)
|
||||
backend_config = BackendConfig(
|
||||
{
|
||||
"max_batch_size": 4,
|
||||
"batch_wait_timeout": 2
|
||||
}, accepts_batches=True)
|
||||
worker = setup_worker(
|
||||
CONSUMER_NAME, Batcher, backend_config=backend_config)
|
||||
|
||||
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
||||
await q.set_backend_config.remote(
|
||||
CONSUMER_NAME,
|
||||
BackendConfig({
|
||||
"max_batch_size": 10
|
||||
}, accepts_batches=True))
|
||||
await q.set_backend_config.remote(CONSUMER_NAME, backend_config)
|
||||
|
||||
def make_request_param(call_method):
|
||||
return RequestMetadata(
|
||||
|
@ -200,3 +208,77 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
|||
np_array = make_request_param("return_np_array")
|
||||
result_np_value = await q.enqueue_request.remote(np_array)
|
||||
assert isinstance(result_np_value, np.int32)
|
||||
|
||||
|
||||
async def test_task_runner_perform_batch(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
|
||||
def batcher(*args, **kwargs):
|
||||
return [serve.context.batch_size] * serve.context.batch_size
|
||||
|
||||
CONSUMER_NAME = "runner"
|
||||
PRODUCER_NAME = "producer"
|
||||
|
||||
config = BackendConfig(
|
||||
{
|
||||
"max_batch_size": 2,
|
||||
"batch_wait_timeout": 10
|
||||
}, accepts_batches=True)
|
||||
|
||||
worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
await q.set_backend_config.remote(CONSUMER_NAME, config)
|
||||
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
||||
|
||||
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
|
||||
|
||||
my_batch_sizes = await asyncio.gather(
|
||||
*[q.enqueue_request.remote(query_param) for _ in range(3)])
|
||||
assert my_batch_sizes == [2, 2, 1]
|
||||
|
||||
|
||||
async def test_task_runner_perform_async(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
|
||||
@ray.remote
|
||||
class Barrier:
|
||||
def __init__(self, release_on):
|
||||
self.release_on = release_on
|
||||
self.current_waiters = 0
|
||||
self.event = asyncio.Event()
|
||||
|
||||
async def wait(self):
|
||||
self.current_waiters += 1
|
||||
if self.current_waiters == self.release_on:
|
||||
self.event.set()
|
||||
else:
|
||||
await self.event.wait()
|
||||
|
||||
barrier = Barrier.remote(release_on=10)
|
||||
|
||||
async def wait_and_go(*args, **kwargs):
|
||||
await barrier.wait.remote()
|
||||
return "done!"
|
||||
|
||||
CONSUMER_NAME = "runner"
|
||||
PRODUCER_NAME = "producer"
|
||||
|
||||
config = BackendConfig({"max_concurrent_queries": 10}, is_blocking=False)
|
||||
|
||||
worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
await q.set_backend_config.remote(CONSUMER_NAME, config)
|
||||
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
||||
|
||||
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
|
||||
|
||||
done, not_done = await asyncio.wait(
|
||||
[q.enqueue_request.remote(query_param) for _ in range(10)], timeout=10)
|
||||
assert len(done) == 10
|
||||
for item in done:
|
||||
await item == "done!"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -132,3 +132,8 @@ def test_replica_config_validation():
|
|||
ReplicaConfig(Class, ray_actor_options={"detached": None})
|
||||
with pytest.raises(ValueError):
|
||||
ReplicaConfig(Class, ray_actor_options={"max_restarts": None})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -241,3 +241,9 @@ def test_worker_replica_failure(serve_instance):
|
|||
break
|
||||
except TimeoutError:
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -33,3 +33,9 @@ def test_handle_in_endpoint(serve_instance):
|
|||
methods=["GET", "POST"])
|
||||
|
||||
assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -38,3 +38,8 @@ def test_ray_internal_kv_collisions(serve_instance):
|
|||
kv2.put("1", b"-1")
|
||||
assert kv2.get("1") == b"-1"
|
||||
assert kv1.get("1") == b"1"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -195,3 +195,8 @@ async def test_system_metric_endpoints(serve_instance):
|
|||
print("Metric not correct, retrying...")
|
||||
if not success:
|
||||
test_metric_endpoint()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -20,4 +20,4 @@ def test_nonblocking():
|
|||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -33,3 +33,9 @@ serve.create_endpoint("driver", backend="driver", route="/driver")
|
|||
assert ray.get(handle.remote()) == "OK!"
|
||||
|
||||
os.remove(path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -7,6 +7,8 @@ import ray
|
|||
from ray.serve.router import Router
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.utils import get_random_letters
|
||||
from ray.test_utils import SignalActor
|
||||
from ray.serve.config import BackendConfig
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
@ -172,3 +174,68 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
|
|||
calls = await runner.get_all_calls.remote()
|
||||
for call in calls:
|
||||
assert call.request_args[0] in runner_shard_keys[i]
|
||||
|
||||
|
||||
async def test_router_use_max_concurrency(serve_instance):
|
||||
signal = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
class MockWorker:
|
||||
async def handle_request(self, request):
|
||||
await signal.wait.remote()
|
||||
return "DONE"
|
||||
|
||||
def ready(self):
|
||||
pass
|
||||
|
||||
class VisibleRouter(Router):
|
||||
def get_queues(self):
|
||||
return self.queries_counter, self.backend_queues
|
||||
|
||||
worker = MockWorker.remote()
|
||||
q = ray.remote(VisibleRouter).remote()
|
||||
BACKEND_NAME = "max-concurrent-test"
|
||||
config = BackendConfig({"max_concurrent_queries": 1})
|
||||
await q.set_traffic.remote("svc", {BACKEND_NAME: 1.0})
|
||||
await q.add_new_worker.remote(BACKEND_NAME, "replica-tag", worker)
|
||||
await q.set_backend_config.remote(BACKEND_NAME, config)
|
||||
|
||||
# We send over two queries
|
||||
first_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||
second_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||
|
||||
# Neither queries should be available
|
||||
with pytest.raises(ray.exceptions.RayTimeoutError):
|
||||
ray.get([first_query, second_query], timeout=0.2)
|
||||
|
||||
# Let's retrieve the router internal state
|
||||
queries_counter, backend_queues = await q.get_queues.remote()
|
||||
# There should be just one inflight request
|
||||
assert queries_counter["max-concurrent-test:replica-tag"] == 1
|
||||
# The second query is buffered
|
||||
assert len(backend_queues["max-concurrent-test"]) == 1
|
||||
|
||||
# Let's unblock the first query
|
||||
await signal.send.remote(clear=True)
|
||||
assert await first_query == "DONE"
|
||||
|
||||
# The internal state of router should have changed.
|
||||
queries_counter, backend_queues = await q.get_queues.remote()
|
||||
# There should still be one inflight request
|
||||
assert queries_counter["max-concurrent-test:replica-tag"] == 1
|
||||
# But there shouldn't be any queries in the queue
|
||||
assert len(backend_queues["max-concurrent-test"]) == 0
|
||||
|
||||
# Unblocking the second query
|
||||
await signal.send.remote(clear=True)
|
||||
assert await second_query == "DONE"
|
||||
|
||||
# Checking the internal state of the router one more time
|
||||
queries_counter, backend_queues = await q.get_queues.remote()
|
||||
assert queries_counter["max-concurrent-test:replica-tag"] == 0
|
||||
assert len(backend_queues["max-concurrent-test"]) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import asyncio
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from ray.serve.utils import ServeEncoder
|
||||
from ray.serve.utils import ServeEncoder, chain_future, unpack_future
|
||||
|
||||
|
||||
def test_bytes_encoder():
|
||||
|
@ -20,3 +22,50 @@ def test_numpy_encoding():
|
|||
assert json.loads(json.dumps(floats, cls=ServeEncoder)) == data
|
||||
assert json.loads(json.dumps(ints, cls=ServeEncoder)) == data
|
||||
assert json.loads(json.dumps(uints, cls=ServeEncoder)) == data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_future_chaining():
|
||||
def make():
|
||||
return asyncio.get_event_loop().create_future()
|
||||
|
||||
# Test 1 -> 1 chaining
|
||||
fut1, fut2 = make(), make()
|
||||
chain_future(fut1, fut2)
|
||||
fut1.set_result(1)
|
||||
assert await fut2 == 1
|
||||
|
||||
# Test 1 -> 1 chaining with exception
|
||||
fut1, fut2 = make(), make()
|
||||
chain_future(fut1, fut2)
|
||||
fut1.set_exception(ValueError(""))
|
||||
with pytest.raises(ValueError):
|
||||
await fut2
|
||||
|
||||
# Test many -> many chaining
|
||||
src_futs = [make() for _ in range(4)]
|
||||
dst_futs = [make() for _ in range(4)]
|
||||
chain_future(src_futs, dst_futs)
|
||||
[fut.set_result(i) for i, fut in enumerate(src_futs)]
|
||||
for i, fut in enumerate(dst_futs):
|
||||
assert await fut == i
|
||||
|
||||
# Test 1 -> many unwrapping
|
||||
batched_future = make()
|
||||
single_futures = unpack_future(batched_future, 4)
|
||||
batched_future.set_result(list(range(4)))
|
||||
for i, fut in enumerate(single_futures):
|
||||
assert await fut == i
|
||||
|
||||
# Test 1 -> many unwrapping with exception
|
||||
batched_future = make()
|
||||
single_futures = unpack_future(batched_future, 4)
|
||||
batched_future.set_exception(ValueError(""))
|
||||
for future in single_futures:
|
||||
with pytest.raises(ValueError):
|
||||
await future
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import asyncio
|
||||
from functools import singledispatch
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
from typing import List
|
||||
import io
|
||||
import os
|
||||
|
||||
|
@ -114,3 +117,59 @@ def format_actor_name(actor_name, instance_name=None):
|
|||
return actor_name
|
||||
else:
|
||||
return "{}:{}".format(instance_name, actor_name)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def chain_future(src, dst):
|
||||
"""Base method for chaining futures together.
|
||||
|
||||
Chaining futures means the output from source future(s) are written as the
|
||||
results of the destination future(s). This method can work with the
|
||||
following inputs:
|
||||
- src: Future, dst: Future
|
||||
- src: List[Future], dst: List[Future]
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@chain_future.register(asyncio.Future)
|
||||
def _chain_future_single(src: asyncio.Future, dst: asyncio.Future):
|
||||
asyncio.futures._chain_future(src, dst)
|
||||
|
||||
|
||||
@chain_future.register(list)
|
||||
def _chain_future_list(src: List[asyncio.Future], dst: List[asyncio.Future]):
|
||||
if len(src) != len(dst):
|
||||
raise ValueError(
|
||||
"Source and destination list doesn't have the same length. "
|
||||
"Source: {}. Destination: {}.".foramt(len(src), len(dst)))
|
||||
|
||||
for s, d in zip(src, dst):
|
||||
chain_future(s, d)
|
||||
|
||||
|
||||
def unpack_future(src: asyncio.Future, num_items: int) -> List[asyncio.Future]:
|
||||
"""Unpack the result of source future to num_items futures.
|
||||
|
||||
This function takes in a Future and splits its result into many futures. If
|
||||
the result of the source future is an exception, then all destination
|
||||
futures will have the same exception.
|
||||
"""
|
||||
dest_futures = [
|
||||
asyncio.get_event_loop().create_future() for _ in range(num_items)
|
||||
]
|
||||
|
||||
def unwrap_callback(fut: asyncio.Future):
|
||||
exception = fut.exception()
|
||||
if exception is not None:
|
||||
[f.set_exception(exception) for f in dest_futures]
|
||||
return
|
||||
|
||||
result = fut.result()
|
||||
assert len(result) == num_items
|
||||
for item, future in zip(result, dest_futures):
|
||||
future.set_result(item)
|
||||
|
||||
src.add_done_callback(unwrap_callback)
|
||||
|
||||
return dest_futures
|
||||
|
|
|
@ -232,8 +232,10 @@ class SignalActor:
|
|||
def __init__(self):
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
def send(self):
|
||||
def send(self, clear=False):
|
||||
self.ready_event.set()
|
||||
if clear:
|
||||
self.ready_event.clear()
|
||||
|
||||
async def wait(self, should_wait=True):
|
||||
if should_wait:
|
||||
|
|
Loading…
Add table
Reference in a new issue