mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Serve] Batching in Worker Replica (#8709)
This commit is contained in:
parent
f007bfb4cf
commit
6c3062906f
20 changed files with 671 additions and 94 deletions
|
@ -5,21 +5,101 @@ py_library(
|
||||||
srcs = glob(["**/*.py"], exclude=["tests/*.py"]),
|
srcs = glob(["**/*.py"], exclude=["tests/*.py"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# This test aggregates all serve tests and run them in a single session
|
serve_tests_srcs = glob(["tests/*.py"],
|
||||||
# similar to `pytest .`
|
exclude=["tests/test_nonblocking.py",
|
||||||
# Serve tests need to run in a single session because starting and stopping
|
"tests/test_master_crashes.py",
|
||||||
# serve cluster take a large chunk of time. All serve tests use a shared
|
"tests/test_serve.py",
|
||||||
# cluster.
|
])
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "test_serve",
|
name = "test_api",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = glob(["tests/*.py"],
|
srcs = serve_tests_srcs,
|
||||||
exclude=["tests/test_nonblocking.py",
|
|
||||||
"tests/test_master_crashes.py"]),
|
|
||||||
tags = ["exclusive"],
|
tags = ["exclusive"],
|
||||||
deps = [":serve_lib"],
|
deps = [":serve_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_backend_worker",
|
||||||
|
size = "small",
|
||||||
|
srcs = serve_tests_srcs,
|
||||||
|
tags = ["exclusive"],
|
||||||
|
deps = [":serve_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_config",
|
||||||
|
size = "small",
|
||||||
|
srcs = serve_tests_srcs,
|
||||||
|
tags = ["exclusive"],
|
||||||
|
deps = [":serve_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_failure",
|
||||||
|
size = "medium",
|
||||||
|
srcs = serve_tests_srcs,
|
||||||
|
tags = ["exclusive"],
|
||||||
|
deps = [":serve_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_handle",
|
||||||
|
size = "small",
|
||||||
|
srcs = serve_tests_srcs,
|
||||||
|
tags = ["exclusive"],
|
||||||
|
deps = [":serve_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_kv_store",
|
||||||
|
size = "small",
|
||||||
|
srcs = serve_tests_srcs,
|
||||||
|
tags = ["exclusive"],
|
||||||
|
deps = [":serve_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_metric",
|
||||||
|
size = "small",
|
||||||
|
srcs = serve_tests_srcs,
|
||||||
|
tags = ["exclusive"],
|
||||||
|
deps = [":serve_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_persistence",
|
||||||
|
size = "small",
|
||||||
|
srcs = serve_tests_srcs,
|
||||||
|
tags = ["exclusive"],
|
||||||
|
deps = [":serve_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_router",
|
||||||
|
size = "small",
|
||||||
|
srcs = serve_tests_srcs,
|
||||||
|
tags = ["exclusive"],
|
||||||
|
deps = [":serve_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_util",
|
||||||
|
size = "small",
|
||||||
|
srcs = serve_tests_srcs,
|
||||||
|
tags = ["exclusive"],
|
||||||
|
deps = [":serve_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Runs test_api and test_failure with injected failures in the master actor.
|
# Runs test_api and test_failure with injected failures in the master actor.
|
||||||
# TODO(edoakes): reenable this once we're using GCS actor fault tolerance.
|
# TODO(edoakes): reenable this once we're using GCS actor fault tolerance.
|
||||||
# py_test(
|
# py_test(
|
||||||
|
@ -97,6 +177,14 @@ py_test(
|
||||||
deps = [":serve_lib"]
|
deps = [":serve_lib"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "snippet_model_composition",
|
||||||
|
size = "small",
|
||||||
|
srcs = glob(["examples/doc/*.py"]),
|
||||||
|
tags = ["exclusive"],
|
||||||
|
deps = [":serve_lib"]
|
||||||
|
)
|
||||||
|
|
||||||
# Disable the deployment tutorial test because it requires
|
# Disable the deployment tutorial test because it requires
|
||||||
# ray start --head in the background.
|
# ray start --head in the background.
|
||||||
# py_test(
|
# py_test(
|
||||||
|
|
|
@ -236,7 +236,8 @@ def create_backend(backend_tag,
|
||||||
|
|
||||||
replica_config = ReplicaConfig(
|
replica_config = ReplicaConfig(
|
||||||
func_or_class, *actor_init_args, ray_actor_options=ray_actor_options)
|
func_or_class, *actor_init_args, ray_actor_options=ray_actor_options)
|
||||||
backend_config = BackendConfig(config, replica_config.accepts_batches)
|
backend_config = BackendConfig(config, replica_config.accepts_batches,
|
||||||
|
replica_config.is_blocking)
|
||||||
|
|
||||||
ray.get(
|
ray.get(
|
||||||
master_actor.create_backend.remote(backend_tag, backend_config,
|
master_actor.create_backend.remote(backend_tag, backend_config,
|
||||||
|
|
|
@ -1,20 +1,61 @@
|
||||||
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
import inspect
|
import inspect
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from collections import defaultdict
|
||||||
|
from itertools import groupby
|
||||||
|
from operator import attrgetter
|
||||||
|
import time
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
from ray.async_compat import sync_to_async
|
||||||
|
|
||||||
from ray import serve
|
from ray import serve
|
||||||
from ray.serve import context as serve_context
|
from ray.serve import context as serve_context
|
||||||
from ray.serve.context import FakeFlaskRequest
|
from ray.serve.context import FakeFlaskRequest
|
||||||
from collections import defaultdict
|
from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
|
||||||
from ray.serve.utils import (parse_request_item, _get_logger)
|
unpack_future)
|
||||||
from ray.serve.exceptions import RayServeException
|
from ray.serve.exceptions import RayServeException
|
||||||
from ray.serve.metric import MetricClient
|
from ray.serve.metric import MetricClient
|
||||||
from ray.async_compat import sync_to_async
|
from ray.serve.config import BackendConfig
|
||||||
|
from ray.serve.router import Query
|
||||||
|
|
||||||
logger = _get_logger()
|
logger = _get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class WaitableQueue(asyncio.Queue):
|
||||||
|
async def wait_for_batch(self, num_items: int, timeout_s: float):
|
||||||
|
"""Wait up to num_items in the queue given timeout_s.
|
||||||
|
|
||||||
|
This method will block indefinitely for the first item. Therefore, it
|
||||||
|
guarantees to return at least one item.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert num_items >= 1
|
||||||
|
# Wait for the first value without timeout. We will return at least
|
||||||
|
# one item. Additionally this help the caller context switch on empty
|
||||||
|
# queue.
|
||||||
|
start_waiting = time.time()
|
||||||
|
batch = [
|
||||||
|
await self.get(),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Adjust the timeout to account for the time waiting for first item.
|
||||||
|
time_remaining = timeout_s - (time.time() - start_waiting)
|
||||||
|
time_remaining = max(0, time_remaining)
|
||||||
|
|
||||||
|
# Wait for the remaining batch with the timeout
|
||||||
|
if num_items > 1:
|
||||||
|
done_set, not_done_set = await asyncio.wait(
|
||||||
|
[self.get() for _ in range(num_items - 1)],
|
||||||
|
timeout=time_remaining)
|
||||||
|
for task in done_set:
|
||||||
|
batch.append(task.result())
|
||||||
|
for task in not_done_set:
|
||||||
|
task.cancel()
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def create_backend_worker(func_or_class):
|
def create_backend_worker(func_or_class):
|
||||||
"""Creates a worker class wrapping the provided function or class."""
|
"""Creates a worker class wrapping the provided function or class."""
|
||||||
|
|
||||||
|
@ -30,8 +71,10 @@ def create_backend_worker(func_or_class):
|
||||||
backend_tag,
|
backend_tag,
|
||||||
replica_tag,
|
replica_tag,
|
||||||
init_args,
|
init_args,
|
||||||
|
backend_config: BackendConfig,
|
||||||
instance_name=None):
|
instance_name=None):
|
||||||
serve.init(name=instance_name)
|
serve.init(name=instance_name)
|
||||||
|
|
||||||
if is_function:
|
if is_function:
|
||||||
_callable = func_or_class
|
_callable = func_or_class
|
||||||
else:
|
else:
|
||||||
|
@ -42,11 +85,15 @@ def create_backend_worker(func_or_class):
|
||||||
metric_client = MetricClient(
|
metric_client = MetricClient(
|
||||||
metric_exporter, default_labels={"backend": backend_tag})
|
metric_exporter, default_labels={"backend": backend_tag})
|
||||||
self.backend = RayServeWorker(backend_tag, replica_tag, _callable,
|
self.backend = RayServeWorker(backend_tag, replica_tag, _callable,
|
||||||
is_function, metric_client)
|
backend_config, is_function,
|
||||||
|
metric_client)
|
||||||
|
|
||||||
async def handle_request(self, request):
|
async def handle_request(self, request):
|
||||||
return await self.backend.handle_request(request)
|
return await self.backend.handle_request(request)
|
||||||
|
|
||||||
|
def update_config(self, new_config: BackendConfig):
|
||||||
|
return self.backend.update_config(new_config)
|
||||||
|
|
||||||
def ready(self):
|
def ready(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -75,13 +122,16 @@ def ensure_async(func):
|
||||||
class RayServeWorker:
|
class RayServeWorker:
|
||||||
"""Handles requests with the provided callable."""
|
"""Handles requests with the provided callable."""
|
||||||
|
|
||||||
def __init__(self, name, replica_tag, _callable, is_function,
|
def __init__(self, name, replica_tag, _callable,
|
||||||
metric_client):
|
backend_config: BackendConfig, is_function, metric_client):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.replica_tag = replica_tag
|
self.replica_tag = replica_tag
|
||||||
self.callable = _callable
|
self.callable = _callable
|
||||||
self.is_function = is_function
|
self.is_function = is_function
|
||||||
|
|
||||||
|
self.config = backend_config
|
||||||
|
self.query_queue = WaitableQueue()
|
||||||
|
|
||||||
self.metric_client = metric_client
|
self.metric_client = metric_client
|
||||||
self.request_counter = self.metric_client.new_counter(
|
self.request_counter = self.metric_client.new_counter(
|
||||||
"backend_request_counter",
|
"backend_request_counter",
|
||||||
|
@ -101,6 +151,9 @@ class RayServeWorker:
|
||||||
|
|
||||||
self.restart_counter.labels(replica_tag=self.replica_tag).add()
|
self.restart_counter.labels(replica_tag=self.replica_tag).add()
|
||||||
|
|
||||||
|
self.loop_task = asyncio.get_event_loop().create_task(self.main_loop())
|
||||||
|
self.config_updated = asyncio.Event()
|
||||||
|
|
||||||
def get_runner_method(self, request_item):
|
def get_runner_method(self, request_item):
|
||||||
method_name = request_item.call_method
|
method_name = request_item.call_method
|
||||||
if not hasattr(self.callable, method_name):
|
if not hasattr(self.callable, method_name):
|
||||||
|
@ -108,6 +161,8 @@ class RayServeWorker:
|
||||||
"which is specified in the request. "
|
"which is specified in the request. "
|
||||||
"The available methods are {}".format(
|
"The available methods are {}".format(
|
||||||
method_name, dir(self.callable)))
|
method_name, dir(self.callable)))
|
||||||
|
if self.is_function:
|
||||||
|
return self.callable
|
||||||
return getattr(self.callable, method_name)
|
return getattr(self.callable, method_name)
|
||||||
|
|
||||||
def has_positional_args(self, f):
|
def has_positional_args(self, f):
|
||||||
|
@ -124,6 +179,12 @@ class RayServeWorker:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _reset_context(self):
|
||||||
|
# NOTE(simon): context management won't work in async mode because
|
||||||
|
# many concurrent queries might be running at the same time.
|
||||||
|
serve_context.web = None
|
||||||
|
serve_context.batch_size = None
|
||||||
|
|
||||||
async def invoke_single(self, request_item):
|
async def invoke_single(self, request_item):
|
||||||
args, kwargs, is_web_context = parse_request_item(request_item)
|
args, kwargs, is_web_context = parse_request_item(request_item)
|
||||||
serve_context.web = is_web_context
|
serve_context.web = is_web_context
|
||||||
|
@ -137,24 +198,12 @@ class RayServeWorker:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result = wrap_to_ray_error(e)
|
result = wrap_to_ray_error(e)
|
||||||
self.error_counter.add()
|
self.error_counter.add()
|
||||||
|
finally:
|
||||||
|
self._reset_context()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def invoke_batch(self, request_item_list):
|
async def invoke_batch(self, request_item_list):
|
||||||
# TODO(alind) : create no-http services. The enqueues
|
|
||||||
# from such services will always be TaskContext.Python.
|
|
||||||
|
|
||||||
# Assumption : all the requests in a bacth
|
|
||||||
# have same serve context.
|
|
||||||
|
|
||||||
# For batching kwargs are modified as follows -
|
|
||||||
# kwargs [Python Context] : key,val
|
|
||||||
# kwargs_list : key, [val1,val2, ... , valn]
|
|
||||||
# or
|
|
||||||
# args[Web Context] : val
|
|
||||||
# args_list : [val1,val2, ...... , valn]
|
|
||||||
# where n (current batch size) <= max_batch_size of a backend
|
|
||||||
|
|
||||||
arg_list = []
|
arg_list = []
|
||||||
kwargs_list = defaultdict(list)
|
kwargs_list = defaultdict(list)
|
||||||
context_flags = set()
|
context_flags = set()
|
||||||
|
@ -222,22 +271,53 @@ class RayServeWorker:
|
||||||
"results with length equal to the batch size"
|
"results with length equal to the batch size"
|
||||||
".".format(batch_size, len(result_list)))
|
".".format(batch_size, len(result_list)))
|
||||||
raise RayServeException(error_message)
|
raise RayServeException(error_message)
|
||||||
|
self._reset_context()
|
||||||
return result_list
|
return result_list
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
wrapped_exception = wrap_to_ray_error(e)
|
wrapped_exception = wrap_to_ray_error(e)
|
||||||
self.error_counter.add()
|
self.error_counter.add()
|
||||||
|
self._reset_context()
|
||||||
return [wrapped_exception for _ in range(batch_size)]
|
return [wrapped_exception for _ in range(batch_size)]
|
||||||
|
|
||||||
async def handle_request(self, request):
|
async def main_loop(self):
|
||||||
# check if work_item is a list or not
|
while True:
|
||||||
# if it is list: then batching supported
|
# NOTE(simon): There's an issue when user updated batch size and
|
||||||
if not isinstance(request, list):
|
# batch wait timeout during the execution, these values will not be
|
||||||
result = await self.invoke_single(request)
|
# updated until after the current iteration.
|
||||||
else:
|
batch = await self.query_queue.wait_for_batch(
|
||||||
result = await self.invoke_batch(request)
|
num_items=self.config.max_batch_size or 1,
|
||||||
|
timeout_s=self.config.batch_wait_timeout)
|
||||||
|
|
||||||
# re-assign to default values
|
all_evaluated_futures = []
|
||||||
serve_context.web = False
|
|
||||||
serve_context.batch_size = None
|
|
||||||
|
|
||||||
return result
|
if not self.config.accepts_batches:
|
||||||
|
query = batch[0]
|
||||||
|
evaluated = asyncio.ensure_future(self.invoke_single(query))
|
||||||
|
all_evaluated_futures = [evaluated]
|
||||||
|
chain_future(evaluated, query.async_future)
|
||||||
|
else:
|
||||||
|
get_call_method = attrgetter("call_method")
|
||||||
|
sorted_batch = sorted(batch, key=get_call_method)
|
||||||
|
for _, group in groupby(sorted_batch, key=get_call_method):
|
||||||
|
group = sorted(group)
|
||||||
|
evaluated = asyncio.ensure_future(self.invoke_batch(group))
|
||||||
|
all_evaluated_futures.append(evaluated)
|
||||||
|
result_futures = [q.async_future for q in group]
|
||||||
|
chain_future(
|
||||||
|
unpack_future(evaluated, len(group)), result_futures)
|
||||||
|
|
||||||
|
if self.config.is_blocking:
|
||||||
|
# We use asyncio.wait here so if the result is exception,
|
||||||
|
# it will not be raised.
|
||||||
|
await asyncio.wait(all_evaluated_futures)
|
||||||
|
|
||||||
|
def update_config(self, new_config: BackendConfig):
|
||||||
|
self.config = new_config
|
||||||
|
self.config_updated.set()
|
||||||
|
|
||||||
|
async def handle_request(self, request: Query):
|
||||||
|
assert not isinstance(request, list)
|
||||||
|
logger.debug("Worker {} got request {}".format(self.name, request))
|
||||||
|
request.async_future = asyncio.get_event_loop().create_future()
|
||||||
|
self.query_queue.put_nowait(request)
|
||||||
|
return await request.async_future
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
from ray.serve.constants import ASYNC_CONCURRENCY
|
||||||
|
|
||||||
|
|
||||||
def _callable_accepts_batch(func_or_class):
|
def _callable_accepts_batch(func_or_class):
|
||||||
if inspect.isfunction(func_or_class):
|
if inspect.isfunction(func_or_class):
|
||||||
|
@ -8,15 +10,46 @@ def _callable_accepts_batch(func_or_class):
|
||||||
return hasattr(func_or_class.__call__, "_serve_accept_batch")
|
return hasattr(func_or_class.__call__, "_serve_accept_batch")
|
||||||
|
|
||||||
|
|
||||||
|
def _callable_is_blocking(func_or_class):
|
||||||
|
if inspect.isfunction(func_or_class):
|
||||||
|
return not inspect.iscoroutinefunction(func_or_class)
|
||||||
|
elif inspect.isclass(func_or_class):
|
||||||
|
return not inspect.iscoroutinefunction(func_or_class.__call__)
|
||||||
|
|
||||||
|
|
||||||
class BackendConfig:
|
class BackendConfig:
|
||||||
def __init__(self, config_dict, accepts_batches=False):
|
def __init__(self, config_dict, accepts_batches=False, is_blocking=True):
|
||||||
assert isinstance(config_dict, dict)
|
assert isinstance(config_dict, dict)
|
||||||
# Make a copy so that we don't modify the input dict.
|
# Make a copy so that we don't modify the input dict.
|
||||||
config_dict = config_dict.copy()
|
config_dict = config_dict.copy()
|
||||||
|
|
||||||
self.accepts_batches = accepts_batches
|
self.accepts_batches = accepts_batches
|
||||||
|
self.is_blocking = is_blocking
|
||||||
self.num_replicas = config_dict.pop("num_replicas", 1)
|
self.num_replicas = config_dict.pop("num_replicas", 1)
|
||||||
self.max_batch_size = config_dict.pop("max_batch_size", None)
|
self.max_batch_size = config_dict.pop("max_batch_size", None)
|
||||||
|
self.batch_wait_timeout = config_dict.pop("batch_wait_timeout", 0)
|
||||||
|
self.max_concurrent_queries = config_dict.pop("max_concurrent_queries",
|
||||||
|
None)
|
||||||
|
|
||||||
|
if self.max_concurrent_queries is None:
|
||||||
|
# Model serving mode: if the servable is blocking and the wait
|
||||||
|
# timeout is default zero seconds, then we keep the existing
|
||||||
|
# behavior to allow at most max batch size queries.
|
||||||
|
if self.is_blocking and self.batch_wait_timeout == 0:
|
||||||
|
self.max_concurrent_queries = self.max_batch_size or 1
|
||||||
|
|
||||||
|
# Pipeline/async mode: if the servable is not blocking,
|
||||||
|
# router should just keep pushing queries to the worker
|
||||||
|
# replicas until a high limit.
|
||||||
|
if not self.is_blocking:
|
||||||
|
self.max_concurrent_queries = ASYNC_CONCURRENCY
|
||||||
|
|
||||||
|
# Batch inference mode: user specifies non zero timeout to wait for
|
||||||
|
# full batch. We will use 2*max_batch_size to perform double
|
||||||
|
# buffering to keep the replica busy.
|
||||||
|
if self.max_batch_size is not None and self.batch_wait_timeout > 0:
|
||||||
|
self.max_concurrent_queries = 2 * self.max_batch_size
|
||||||
|
|
||||||
if len(config_dict) != 0:
|
if len(config_dict) != 0:
|
||||||
raise ValueError("Unknown options in backend config: {}".format(
|
raise ValueError("Unknown options in backend config: {}".format(
|
||||||
list(config_dict.keys())))
|
list(config_dict.keys())))
|
||||||
|
@ -64,6 +97,7 @@ class ReplicaConfig:
|
||||||
ray_actor_options=None):
|
ray_actor_options=None):
|
||||||
self.func_or_class = func_or_class
|
self.func_or_class = func_or_class
|
||||||
self.accepts_batches = _callable_accepts_batch(func_or_class)
|
self.accepts_batches = _callable_accepts_batch(func_or_class)
|
||||||
|
self.is_blocking = _callable_is_blocking(func_or_class)
|
||||||
self.actor_init_args = list(actor_init_args)
|
self.actor_init_args = list(actor_init_args)
|
||||||
if ray_actor_options is None:
|
if ray_actor_options is None:
|
||||||
self.ray_actor_options = {}
|
self.ray_actor_options = {}
|
||||||
|
|
57
python/ray/serve/examples/doc/snippet_model_composition.py
Normal file
57
python/ray/serve/examples/doc/snippet_model_composition.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
from random import random
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from ray import serve
|
||||||
|
|
||||||
|
serve.init()
|
||||||
|
|
||||||
|
|
||||||
|
def model_one(_unused_flask_request, data=None):
|
||||||
|
print("Model 1 called with data ", data)
|
||||||
|
return random()
|
||||||
|
|
||||||
|
|
||||||
|
def model_two(_unused_flask_request, data=None):
|
||||||
|
print("Model 2 called with data ", data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class ComposedModel:
|
||||||
|
def __init__(self):
|
||||||
|
self.model_one = serve.get_handle("model_one")
|
||||||
|
self.model_two = serve.get_handle("model_two")
|
||||||
|
|
||||||
|
async def __call__(self, flask_request):
|
||||||
|
data = flask_request.data
|
||||||
|
|
||||||
|
score = await self.model_one.remote(data=data)
|
||||||
|
if score > 0.5:
|
||||||
|
result = await self.model_two.remote(data=data)
|
||||||
|
result = {"model_used": 2, "score": score}
|
||||||
|
else:
|
||||||
|
result = {"model_used": 1, "score": score}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
serve.create_backend("model_one", model_one)
|
||||||
|
serve.create_endpoint("model_one", backend="model_one")
|
||||||
|
|
||||||
|
serve.create_backend("model_two", model_two)
|
||||||
|
serve.create_endpoint("model_two", backend="model_two")
|
||||||
|
|
||||||
|
serve.create_backend(
|
||||||
|
"composed_backend", ComposedModel, config={"max_concurrent_queries": 10})
|
||||||
|
serve.create_endpoint(
|
||||||
|
"composed", backend="composed_backend", route="/composed")
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
resp = requests.get("http://127.0.0.1:8000/composed", data="hey!")
|
||||||
|
print(resp.json())
|
||||||
|
# Output
|
||||||
|
# {'model_used': 2, 'score': 0.6250189863595503}
|
||||||
|
# {'model_used': 1, 'score': 0.03146855349621436}
|
||||||
|
# {'model_used': 2, 'score': 0.6916977560006987}
|
||||||
|
# {'model_used': 2, 'score': 0.8169693450866928}
|
||||||
|
# {'model_used': 2, 'score': 0.9540681979573862}
|
|
@ -258,6 +258,7 @@ class ServeMaster:
|
||||||
for backend, (_, backend_config, _) in self.backends.items():
|
for backend, (_, backend_config, _) in self.backends.items():
|
||||||
await self.router.set_backend_config.remote(
|
await self.router.set_backend_config.remote(
|
||||||
backend, backend_config)
|
backend, backend_config)
|
||||||
|
await self.broadcast_backend_config(backend)
|
||||||
|
|
||||||
# Push configuration state to the HTTP proxy.
|
# Push configuration state to the HTTP proxy.
|
||||||
await self.http_proxy.set_route_table.remote(self.routes)
|
await self.http_proxy.set_route_table.remote(self.routes)
|
||||||
|
@ -314,6 +315,7 @@ class ServeMaster:
|
||||||
backend_tag,
|
backend_tag,
|
||||||
replica_tag,
|
replica_tag,
|
||||||
replica_config.actor_init_args,
|
replica_config.actor_init_args,
|
||||||
|
backend_config,
|
||||||
instance_name=self.instance_name)
|
instance_name=self.instance_name)
|
||||||
# TODO(edoakes): we should probably have a timeout here.
|
# TODO(edoakes): we should probably have a timeout here.
|
||||||
await worker_handle.ready.remote()
|
await worker_handle.ready.remote()
|
||||||
|
@ -602,6 +604,7 @@ class ServeMaster:
|
||||||
# (particularly for max-batch-size).
|
# (particularly for max-batch-size).
|
||||||
await self.router.set_backend_config.remote(
|
await self.router.set_backend_config.remote(
|
||||||
backend_tag, backend_config)
|
backend_tag, backend_config)
|
||||||
|
await self.broadcast_backend_config(backend_tag)
|
||||||
|
|
||||||
async def delete_backend(self, backend_tag):
|
async def delete_backend(self, backend_tag):
|
||||||
async with self.write_lock:
|
async with self.write_lock:
|
||||||
|
@ -664,6 +667,22 @@ class ServeMaster:
|
||||||
await self._start_pending_replicas()
|
await self._start_pending_replicas()
|
||||||
await self._stop_pending_replicas()
|
await self._stop_pending_replicas()
|
||||||
|
|
||||||
|
await self.broadcast_backend_config(backend_tag)
|
||||||
|
|
||||||
|
async def broadcast_backend_config(self, backend_tag):
|
||||||
|
_, backend_config, _ = self.backends[backend_tag]
|
||||||
|
broadcast_futures = []
|
||||||
|
for replica_tag in self.replicas[backend_tag]:
|
||||||
|
try:
|
||||||
|
replica = ray.get_actor(replica_tag)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
future = replica.update_config.remote(backend_config).as_future()
|
||||||
|
broadcast_futures.append(future)
|
||||||
|
if len(broadcast_futures) > 0:
|
||||||
|
await asyncio.gather(*broadcast_futures)
|
||||||
|
|
||||||
def get_backend_config(self, backend_tag):
|
def get_backend_config(self, backend_tag):
|
||||||
"""Get the current config for the specified backend."""
|
"""Get the current config for the specified backend."""
|
||||||
assert (backend_tag in self.backends
|
assert (backend_tag in self.backends
|
||||||
|
|
|
@ -13,7 +13,7 @@ import ray
|
||||||
from ray import serve
|
from ray import serve
|
||||||
from ray.serve.metric import MetricClient
|
from ray.serve.metric import MetricClient
|
||||||
from ray.serve.policy import RandomEndpointPolicy
|
from ray.serve.policy import RandomEndpointPolicy
|
||||||
from ray.serve.utils import logger
|
from ray.serve.utils import logger, chain_future
|
||||||
|
|
||||||
|
|
||||||
class Query:
|
class Query:
|
||||||
|
@ -58,10 +58,6 @@ class Query:
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
return self.request_slo_ms < other.request_slo_ms
|
return self.request_slo_ms < other.request_slo_ms
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<Query args={} kwargs={}>".format(self.request_args,
|
|
||||||
self.request_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_future_unwrapper(client_futures: List[asyncio.Future],
|
def _make_future_unwrapper(client_futures: List[asyncio.Future],
|
||||||
host_future: asyncio.Future):
|
host_future: asyncio.Future):
|
||||||
|
@ -117,6 +113,8 @@ class Router:
|
||||||
self.backend_info = dict()
|
self.backend_info = dict()
|
||||||
# replica tag -> worker_handle
|
# replica tag -> worker_handle
|
||||||
self.replicas = dict()
|
self.replicas = dict()
|
||||||
|
# replica_tag -> concurrent queries counter
|
||||||
|
self.queries_counter = defaultdict(lambda: 0)
|
||||||
|
|
||||||
# -- Synchronization -- #
|
# -- Synchronization -- #
|
||||||
|
|
||||||
|
@ -126,7 +124,7 @@ class Router:
|
||||||
# an operation holding the only query and the other flush operation
|
# an operation holding the only query and the other flush operation
|
||||||
# holding the only idle replica. Additionally, allowing only one flush
|
# holding the only idle replica. Additionally, allowing only one flush
|
||||||
# operation at a time simplifies design overhead for custom queuing and
|
# operation at a time simplifies design overhead for custom queuing and
|
||||||
# batching polcies.
|
# batching policies.
|
||||||
self.flush_lock = asyncio.Lock()
|
self.flush_lock = asyncio.Lock()
|
||||||
|
|
||||||
# -- State Restoration -- #
|
# -- State Restoration -- #
|
||||||
|
@ -215,11 +213,15 @@ class Router:
|
||||||
await self.mark_worker_idle(backend_tag, backend_replica_tag)
|
await self.mark_worker_idle(backend_tag, backend_replica_tag)
|
||||||
|
|
||||||
async def mark_worker_idle(self, backend_tag, backend_replica_tag):
|
async def mark_worker_idle(self, backend_tag, backend_replica_tag):
|
||||||
|
logger.debug(
|
||||||
|
"Marking backend with tag {} as idle.".format(backend_replica_tag))
|
||||||
if backend_replica_tag not in self.replicas:
|
if backend_replica_tag not in self.replicas:
|
||||||
return
|
return
|
||||||
|
|
||||||
async with self.flush_lock:
|
async with self.flush_lock:
|
||||||
self.worker_queues[backend_tag].appendleft(backend_replica_tag)
|
# NOTE(simon): This is a O(n) operation where n=len(worker_queue)
|
||||||
|
if backend_replica_tag not in self.worker_queues[backend_tag]:
|
||||||
|
self.worker_queues[backend_tag].appendleft(backend_replica_tag)
|
||||||
self.flush_backend_queues([backend_tag])
|
self.flush_backend_queues([backend_tag])
|
||||||
|
|
||||||
async def remove_worker(self, backend_tag, replica_tag):
|
async def remove_worker(self, backend_tag, replica_tag):
|
||||||
|
@ -299,12 +301,11 @@ class Router:
|
||||||
"queue size {} and worker queue size {}".format(
|
"queue size {} and worker queue size {}".format(
|
||||||
backend, len(buffer_queue), len(worker_queue)))
|
backend, len(buffer_queue), len(worker_queue)))
|
||||||
|
|
||||||
max_batch_size = None
|
self._assign_query_to_worker(
|
||||||
if backend in self.backend_info:
|
backend,
|
||||||
max_batch_size = self.backend_info[backend].max_batch_size
|
buffer_queue,
|
||||||
|
worker_queue,
|
||||||
self._assign_query_to_worker(backend, buffer_queue, worker_queue,
|
)
|
||||||
max_batch_size)
|
|
||||||
|
|
||||||
async def _do_query(self, backend, backend_replica_tag, req):
|
async def _do_query(self, backend, backend_replica_tag, req):
|
||||||
# If the worker died, this will be a RayActorError. Just return it and
|
# If the worker died, this will be a RayActorError. Just return it and
|
||||||
|
@ -317,16 +318,13 @@ class Router:
|
||||||
except RayTaskError as error:
|
except RayTaskError as error:
|
||||||
self.num_error_backend_request.labels(backend=backend).add()
|
self.num_error_backend_request.labels(backend=backend).add()
|
||||||
result = error
|
result = error
|
||||||
|
self.queries_counter[backend_replica_tag] -= 1
|
||||||
await self.mark_worker_idle(backend, backend_replica_tag)
|
await self.mark_worker_idle(backend, backend_replica_tag)
|
||||||
logger.debug("Got result in {:.2f}s".format(time.time() - start))
|
logger.debug("Got result in {:.2f}s".format(time.time() - start))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _assign_query_to_worker(self,
|
def _assign_query_to_worker(self, backend, buffer_queue, worker_queue):
|
||||||
backend,
|
overloaded_replicas = set()
|
||||||
buffer_queue,
|
|
||||||
worker_queue,
|
|
||||||
max_batch_size=None):
|
|
||||||
|
|
||||||
while len(buffer_queue) and len(worker_queue):
|
while len(buffer_queue) and len(worker_queue):
|
||||||
backend_replica_tag = worker_queue.pop()
|
backend_replica_tag = worker_queue.pop()
|
||||||
|
|
||||||
|
@ -334,27 +332,30 @@ class Router:
|
||||||
if backend_replica_tag not in self.replicas:
|
if backend_replica_tag not in self.replicas:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if max_batch_size is None: # No batching
|
# We have reached the end of the worker queue where all replicas
|
||||||
request = buffer_queue.pop(0)
|
# are overloaded.
|
||||||
future = asyncio.get_event_loop().create_task(
|
if backend_replica_tag in overloaded_replicas:
|
||||||
self._do_query(backend, backend_replica_tag, request))
|
break
|
||||||
# chaining satisfies request.async_future with future result.
|
|
||||||
asyncio.futures._chain_future(future, request.async_future)
|
|
||||||
else:
|
|
||||||
real_batch_size = min(len(buffer_queue), max_batch_size)
|
|
||||||
requests = [
|
|
||||||
buffer_queue.pop(0) for _ in range(real_batch_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
# split requests by method type
|
# This replica has too many in flight and processing queries.
|
||||||
requests_group = defaultdict(list)
|
max_queries = 1
|
||||||
for request in requests:
|
if backend in self.backend_info:
|
||||||
requests_group[request.call_method].append(request)
|
max_queries = self.backend_info[backend].max_concurrent_queries
|
||||||
|
curr_queries = self.queries_counter[backend_replica_tag]
|
||||||
|
if curr_queries >= max_queries:
|
||||||
|
# Put the worker back to the queue.
|
||||||
|
worker_queue.appendleft(backend_replica_tag)
|
||||||
|
overloaded_replicas.add(backend_replica_tag)
|
||||||
|
logger.debug(
|
||||||
|
"Skipping backend {} because it has {} in flight "
|
||||||
|
"requests which exceeded the concurrency limit.".format(
|
||||||
|
backend, curr_queries))
|
||||||
|
continue
|
||||||
|
|
||||||
for group in requests_group.values():
|
request = buffer_queue.pop(0)
|
||||||
future = asyncio.get_event_loop().create_task(
|
self.queries_counter[backend_replica_tag] += 1
|
||||||
self._do_query(backend, backend_replica_tag, group))
|
future = asyncio.get_event_loop().create_task(
|
||||||
future.add_done_callback(
|
self._do_query(backend, backend_replica_tag, request))
|
||||||
_make_future_unwrapper(
|
chain_future(future, request.async_future)
|
||||||
client_futures=[req.async_future for req in group],
|
|
||||||
host_future=future))
|
worker_queue.appendleft(backend_replica_tag)
|
||||||
|
|
|
@ -507,3 +507,8 @@ def test_endpoint_input_validation(serve_instance):
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
serve.create_endpoint("endpoint", backend=2)
|
serve.create_endpoint("endpoint", backend=2)
|
||||||
serve.create_endpoint("endpoint", backend="backend")
|
serve.create_endpoint("endpoint", backend="backend")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||||
|
|
|
@ -15,7 +15,10 @@ from ray.serve.exceptions import RayServeException
|
||||||
pytestmark = pytest.mark.asyncio
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
def setup_worker(name, func_or_class, init_args=None):
|
def setup_worker(name,
|
||||||
|
func_or_class,
|
||||||
|
init_args=None,
|
||||||
|
backend_config=BackendConfig({})):
|
||||||
if init_args is None:
|
if init_args is None:
|
||||||
init_args = ()
|
init_args = ()
|
||||||
|
|
||||||
|
@ -23,7 +26,7 @@ def setup_worker(name, func_or_class, init_args=None):
|
||||||
class WorkerActor:
|
class WorkerActor:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.worker = create_backend_worker(func_or_class)(
|
self.worker = create_backend_worker(func_or_class)(
|
||||||
name, name + ":tag", init_args)
|
name, name + ":tag", init_args, backend_config)
|
||||||
|
|
||||||
def ready(self):
|
def ready(self):
|
||||||
pass
|
pass
|
||||||
|
@ -31,6 +34,9 @@ def setup_worker(name, func_or_class, init_args=None):
|
||||||
async def handle_request(self, *args, **kwargs):
|
async def handle_request(self, *args, **kwargs):
|
||||||
return await self.worker.handle_request(*args, **kwargs)
|
return await self.worker.handle_request(*args, **kwargs)
|
||||||
|
|
||||||
|
def update_config(self, new_config):
|
||||||
|
return self.worker.update_config(new_config)
|
||||||
|
|
||||||
worker = WorkerActor.remote()
|
worker = WorkerActor.remote()
|
||||||
ray.get(worker.ready.remote())
|
ray.get(worker.ready.remote())
|
||||||
return worker
|
return worker
|
||||||
|
@ -165,14 +171,16 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||||
CONSUMER_NAME = "runner"
|
CONSUMER_NAME = "runner"
|
||||||
PRODUCER_NAME = "producer"
|
PRODUCER_NAME = "producer"
|
||||||
|
|
||||||
worker = setup_worker(CONSUMER_NAME, Batcher)
|
backend_config = BackendConfig(
|
||||||
|
{
|
||||||
|
"max_batch_size": 4,
|
||||||
|
"batch_wait_timeout": 2
|
||||||
|
}, accepts_batches=True)
|
||||||
|
worker = setup_worker(
|
||||||
|
CONSUMER_NAME, Batcher, backend_config=backend_config)
|
||||||
|
|
||||||
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
||||||
await q.set_backend_config.remote(
|
await q.set_backend_config.remote(CONSUMER_NAME, backend_config)
|
||||||
CONSUMER_NAME,
|
|
||||||
BackendConfig({
|
|
||||||
"max_batch_size": 10
|
|
||||||
}, accepts_batches=True))
|
|
||||||
|
|
||||||
def make_request_param(call_method):
|
def make_request_param(call_method):
|
||||||
return RequestMetadata(
|
return RequestMetadata(
|
||||||
|
@ -200,3 +208,77 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||||
np_array = make_request_param("return_np_array")
|
np_array = make_request_param("return_np_array")
|
||||||
result_np_value = await q.enqueue_request.remote(np_array)
|
result_np_value = await q.enqueue_request.remote(np_array)
|
||||||
assert isinstance(result_np_value, np.int32)
|
assert isinstance(result_np_value, np.int32)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_task_runner_perform_batch(serve_instance):
|
||||||
|
q = ray.remote(Router).remote()
|
||||||
|
|
||||||
|
def batcher(*args, **kwargs):
|
||||||
|
return [serve.context.batch_size] * serve.context.batch_size
|
||||||
|
|
||||||
|
CONSUMER_NAME = "runner"
|
||||||
|
PRODUCER_NAME = "producer"
|
||||||
|
|
||||||
|
config = BackendConfig(
|
||||||
|
{
|
||||||
|
"max_batch_size": 2,
|
||||||
|
"batch_wait_timeout": 10
|
||||||
|
}, accepts_batches=True)
|
||||||
|
|
||||||
|
worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config)
|
||||||
|
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||||
|
await q.set_backend_config.remote(CONSUMER_NAME, config)
|
||||||
|
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
||||||
|
|
||||||
|
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
|
||||||
|
|
||||||
|
my_batch_sizes = await asyncio.gather(
|
||||||
|
*[q.enqueue_request.remote(query_param) for _ in range(3)])
|
||||||
|
assert my_batch_sizes == [2, 2, 1]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_task_runner_perform_async(serve_instance):
|
||||||
|
q = ray.remote(Router).remote()
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class Barrier:
|
||||||
|
def __init__(self, release_on):
|
||||||
|
self.release_on = release_on
|
||||||
|
self.current_waiters = 0
|
||||||
|
self.event = asyncio.Event()
|
||||||
|
|
||||||
|
async def wait(self):
|
||||||
|
self.current_waiters += 1
|
||||||
|
if self.current_waiters == self.release_on:
|
||||||
|
self.event.set()
|
||||||
|
else:
|
||||||
|
await self.event.wait()
|
||||||
|
|
||||||
|
barrier = Barrier.remote(release_on=10)
|
||||||
|
|
||||||
|
async def wait_and_go(*args, **kwargs):
|
||||||
|
await barrier.wait.remote()
|
||||||
|
return "done!"
|
||||||
|
|
||||||
|
CONSUMER_NAME = "runner"
|
||||||
|
PRODUCER_NAME = "producer"
|
||||||
|
|
||||||
|
config = BackendConfig({"max_concurrent_queries": 10}, is_blocking=False)
|
||||||
|
|
||||||
|
worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config)
|
||||||
|
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||||
|
await q.set_backend_config.remote(CONSUMER_NAME, config)
|
||||||
|
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
||||||
|
|
||||||
|
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
|
||||||
|
|
||||||
|
done, not_done = await asyncio.wait(
|
||||||
|
[q.enqueue_request.remote(query_param) for _ in range(10)], timeout=10)
|
||||||
|
assert len(done) == 10
|
||||||
|
for item in done:
|
||||||
|
await item == "done!"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||||
|
|
|
@ -132,3 +132,8 @@ def test_replica_config_validation():
|
||||||
ReplicaConfig(Class, ray_actor_options={"detached": None})
|
ReplicaConfig(Class, ray_actor_options={"detached": None})
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ReplicaConfig(Class, ray_actor_options={"max_restarts": None})
|
ReplicaConfig(Class, ray_actor_options={"max_restarts": None})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||||
|
|
|
@ -241,3 +241,9 @@ def test_worker_replica_failure(serve_instance):
|
||||||
break
|
break
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||||
|
|
|
@ -33,3 +33,9 @@ def test_handle_in_endpoint(serve_instance):
|
||||||
methods=["GET", "POST"])
|
methods=["GET", "POST"])
|
||||||
|
|
||||||
assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello"
|
assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||||
|
|
|
@ -38,3 +38,8 @@ def test_ray_internal_kv_collisions(serve_instance):
|
||||||
kv2.put("1", b"-1")
|
kv2.put("1", b"-1")
|
||||||
assert kv2.get("1") == b"-1"
|
assert kv2.get("1") == b"-1"
|
||||||
assert kv1.get("1") == b"1"
|
assert kv1.get("1") == b"1"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||||
|
|
|
@ -195,3 +195,8 @@ async def test_system_metric_endpoints(serve_instance):
|
||||||
print("Metric not correct, retrying...")
|
print("Metric not correct, retrying...")
|
||||||
if not success:
|
if not success:
|
||||||
test_metric_endpoint()
|
test_metric_endpoint()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||||
|
|
|
@ -20,4 +20,4 @@ def test_nonblocking():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
sys.exit(pytest.main(["-v", __file__]))
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||||
|
|
|
@ -33,3 +33,9 @@ serve.create_endpoint("driver", backend="driver", route="/driver")
|
||||||
assert ray.get(handle.remote()) == "OK!"
|
assert ray.get(handle.remote()) == "OK!"
|
||||||
|
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||||
|
|
|
@ -7,6 +7,8 @@ import ray
|
||||||
from ray.serve.router import Router
|
from ray.serve.router import Router
|
||||||
from ray.serve.request_params import RequestMetadata
|
from ray.serve.request_params import RequestMetadata
|
||||||
from ray.serve.utils import get_random_letters
|
from ray.serve.utils import get_random_letters
|
||||||
|
from ray.test_utils import SignalActor
|
||||||
|
from ray.serve.config import BackendConfig
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
@ -172,3 +174,68 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
|
||||||
calls = await runner.get_all_calls.remote()
|
calls = await runner.get_all_calls.remote()
|
||||||
for call in calls:
|
for call in calls:
|
||||||
assert call.request_args[0] in runner_shard_keys[i]
|
assert call.request_args[0] in runner_shard_keys[i]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_router_use_max_concurrency(serve_instance):
|
||||||
|
signal = SignalActor.remote()
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class MockWorker:
|
||||||
|
async def handle_request(self, request):
|
||||||
|
await signal.wait.remote()
|
||||||
|
return "DONE"
|
||||||
|
|
||||||
|
def ready(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class VisibleRouter(Router):
|
||||||
|
def get_queues(self):
|
||||||
|
return self.queries_counter, self.backend_queues
|
||||||
|
|
||||||
|
worker = MockWorker.remote()
|
||||||
|
q = ray.remote(VisibleRouter).remote()
|
||||||
|
BACKEND_NAME = "max-concurrent-test"
|
||||||
|
config = BackendConfig({"max_concurrent_queries": 1})
|
||||||
|
await q.set_traffic.remote("svc", {BACKEND_NAME: 1.0})
|
||||||
|
await q.add_new_worker.remote(BACKEND_NAME, "replica-tag", worker)
|
||||||
|
await q.set_backend_config.remote(BACKEND_NAME, config)
|
||||||
|
|
||||||
|
# We send over two queries
|
||||||
|
first_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||||
|
second_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||||
|
|
||||||
|
# Neither queries should be available
|
||||||
|
with pytest.raises(ray.exceptions.RayTimeoutError):
|
||||||
|
ray.get([first_query, second_query], timeout=0.2)
|
||||||
|
|
||||||
|
# Let's retrieve the router internal state
|
||||||
|
queries_counter, backend_queues = await q.get_queues.remote()
|
||||||
|
# There should be just one inflight request
|
||||||
|
assert queries_counter["max-concurrent-test:replica-tag"] == 1
|
||||||
|
# The second query is buffered
|
||||||
|
assert len(backend_queues["max-concurrent-test"]) == 1
|
||||||
|
|
||||||
|
# Let's unblock the first query
|
||||||
|
await signal.send.remote(clear=True)
|
||||||
|
assert await first_query == "DONE"
|
||||||
|
|
||||||
|
# The internal state of router should have changed.
|
||||||
|
queries_counter, backend_queues = await q.get_queues.remote()
|
||||||
|
# There should still be one inflight request
|
||||||
|
assert queries_counter["max-concurrent-test:replica-tag"] == 1
|
||||||
|
# But there shouldn't be any queries in the queue
|
||||||
|
assert len(backend_queues["max-concurrent-test"]) == 0
|
||||||
|
|
||||||
|
# Unblocking the second query
|
||||||
|
await signal.send.remote(clear=True)
|
||||||
|
assert await second_query == "DONE"
|
||||||
|
|
||||||
|
# Checking the internal state of the router one more time
|
||||||
|
queries_counter, backend_queues = await q.get_queues.remote()
|
||||||
|
assert queries_counter["max-concurrent-test:replica-tag"] == 0
|
||||||
|
assert len(backend_queues["max-concurrent-test"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
from ray.serve.utils import ServeEncoder
|
from ray.serve.utils import ServeEncoder, chain_future, unpack_future
|
||||||
|
|
||||||
|
|
||||||
def test_bytes_encoder():
|
def test_bytes_encoder():
|
||||||
|
@ -20,3 +22,50 @@ def test_numpy_encoding():
|
||||||
assert json.loads(json.dumps(floats, cls=ServeEncoder)) == data
|
assert json.loads(json.dumps(floats, cls=ServeEncoder)) == data
|
||||||
assert json.loads(json.dumps(ints, cls=ServeEncoder)) == data
|
assert json.loads(json.dumps(ints, cls=ServeEncoder)) == data
|
||||||
assert json.loads(json.dumps(uints, cls=ServeEncoder)) == data
|
assert json.loads(json.dumps(uints, cls=ServeEncoder)) == data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_future_chaining():
|
||||||
|
def make():
|
||||||
|
return asyncio.get_event_loop().create_future()
|
||||||
|
|
||||||
|
# Test 1 -> 1 chaining
|
||||||
|
fut1, fut2 = make(), make()
|
||||||
|
chain_future(fut1, fut2)
|
||||||
|
fut1.set_result(1)
|
||||||
|
assert await fut2 == 1
|
||||||
|
|
||||||
|
# Test 1 -> 1 chaining with exception
|
||||||
|
fut1, fut2 = make(), make()
|
||||||
|
chain_future(fut1, fut2)
|
||||||
|
fut1.set_exception(ValueError(""))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await fut2
|
||||||
|
|
||||||
|
# Test many -> many chaining
|
||||||
|
src_futs = [make() for _ in range(4)]
|
||||||
|
dst_futs = [make() for _ in range(4)]
|
||||||
|
chain_future(src_futs, dst_futs)
|
||||||
|
[fut.set_result(i) for i, fut in enumerate(src_futs)]
|
||||||
|
for i, fut in enumerate(dst_futs):
|
||||||
|
assert await fut == i
|
||||||
|
|
||||||
|
# Test 1 -> many unwrapping
|
||||||
|
batched_future = make()
|
||||||
|
single_futures = unpack_future(batched_future, 4)
|
||||||
|
batched_future.set_result(list(range(4)))
|
||||||
|
for i, fut in enumerate(single_futures):
|
||||||
|
assert await fut == i
|
||||||
|
|
||||||
|
# Test 1 -> many unwrapping with exception
|
||||||
|
batched_future = make()
|
||||||
|
single_futures = unpack_future(batched_future, 4)
|
||||||
|
batched_future.set_exception(ValueError(""))
|
||||||
|
for future in single_futures:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await future
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
|
import asyncio
|
||||||
|
from functools import singledispatch
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
import time
|
import time
|
||||||
|
from typing import List
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
@ -114,3 +117,59 @@ def format_actor_name(actor_name, instance_name=None):
|
||||||
return actor_name
|
return actor_name
|
||||||
else:
|
else:
|
||||||
return "{}:{}".format(instance_name, actor_name)
|
return "{}:{}".format(instance_name, actor_name)
|
||||||
|
|
||||||
|
|
||||||
|
@singledispatch
|
||||||
|
def chain_future(src, dst):
|
||||||
|
"""Base method for chaining futures together.
|
||||||
|
|
||||||
|
Chaining futures means the output from source future(s) are written as the
|
||||||
|
results of the destination future(s). This method can work with the
|
||||||
|
following inputs:
|
||||||
|
- src: Future, dst: Future
|
||||||
|
- src: List[Future], dst: List[Future]
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@chain_future.register(asyncio.Future)
|
||||||
|
def _chain_future_single(src: asyncio.Future, dst: asyncio.Future):
|
||||||
|
asyncio.futures._chain_future(src, dst)
|
||||||
|
|
||||||
|
|
||||||
|
@chain_future.register(list)
|
||||||
|
def _chain_future_list(src: List[asyncio.Future], dst: List[asyncio.Future]):
|
||||||
|
if len(src) != len(dst):
|
||||||
|
raise ValueError(
|
||||||
|
"Source and destination list doesn't have the same length. "
|
||||||
|
"Source: {}. Destination: {}.".foramt(len(src), len(dst)))
|
||||||
|
|
||||||
|
for s, d in zip(src, dst):
|
||||||
|
chain_future(s, d)
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_future(src: asyncio.Future, num_items: int) -> List[asyncio.Future]:
|
||||||
|
"""Unpack the result of source future to num_items futures.
|
||||||
|
|
||||||
|
This function takes in a Future and splits its result into many futures. If
|
||||||
|
the result of the source future is an exception, then all destination
|
||||||
|
futures will have the same exception.
|
||||||
|
"""
|
||||||
|
dest_futures = [
|
||||||
|
asyncio.get_event_loop().create_future() for _ in range(num_items)
|
||||||
|
]
|
||||||
|
|
||||||
|
def unwrap_callback(fut: asyncio.Future):
|
||||||
|
exception = fut.exception()
|
||||||
|
if exception is not None:
|
||||||
|
[f.set_exception(exception) for f in dest_futures]
|
||||||
|
return
|
||||||
|
|
||||||
|
result = fut.result()
|
||||||
|
assert len(result) == num_items
|
||||||
|
for item, future in zip(result, dest_futures):
|
||||||
|
future.set_result(item)
|
||||||
|
|
||||||
|
src.add_done_callback(unwrap_callback)
|
||||||
|
|
||||||
|
return dest_futures
|
||||||
|
|
|
@ -232,8 +232,10 @@ class SignalActor:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.ready_event = asyncio.Event()
|
self.ready_event = asyncio.Event()
|
||||||
|
|
||||||
def send(self):
|
def send(self, clear=False):
|
||||||
self.ready_event.set()
|
self.ready_event.set()
|
||||||
|
if clear:
|
||||||
|
self.ready_event.clear()
|
||||||
|
|
||||||
async def wait(self, should_wait=True):
|
async def wait(self, should_wait=True):
|
||||||
if should_wait:
|
if should_wait:
|
||||||
|
|
Loading…
Add table
Reference in a new issue