mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[serve] Application-level batching initial commit (#14610)
This commit is contained in:
parent
7b3102dd32
commit
9cf328d616
4 changed files with 548 additions and 68 deletions
|
@ -1,5 +1,6 @@
|
|||
from ray.serve.api import (accept_batch, Client, connect, start,
|
||||
get_replica_context)
|
||||
from ray.serve.batching import batch
|
||||
from ray.serve.config import BackendConfig, HTTPOptions
|
||||
|
||||
# Mute the warning because Serve sometimes intentionally calls
|
||||
|
@ -8,6 +9,6 @@ import ray.worker
|
|||
ray.worker.blocking_get_inside_async_warned = True
|
||||
|
||||
__all__ = [
|
||||
"accept_batch", "BackendConfig", "connect", "Client", "start",
|
||||
"accept_batch", "BackendConfig", "batch", "connect", "Client", "start",
|
||||
"HTTPOptions", "get_replica_context"
|
||||
]
|
||||
|
|
|
@ -13,6 +13,7 @@ import ray
|
|||
from ray.actor import ActorHandle
|
||||
from ray._private.async_compat import sync_to_async
|
||||
|
||||
from ray.serve.batching import _BatchQueue
|
||||
from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
|
||||
unpack_future, import_attr)
|
||||
from ray.serve.exceptions import RayServeException
|
||||
|
@ -30,71 +31,6 @@ from ray.exceptions import RayTaskError
|
|||
logger = _get_logger()
|
||||
|
||||
|
||||
class BatchQueue:
|
||||
def __init__(self, max_batch_size: int, timeout_s: float) -> None:
|
||||
self.queue = asyncio.Queue()
|
||||
self.full_batch_event = asyncio.Event()
|
||||
self.max_batch_size = max_batch_size
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def set_config(self, max_batch_size: int, timeout_s: float) -> None:
|
||||
self.max_batch_size = max_batch_size
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def put(self, request: Query) -> None:
|
||||
self.queue.put_nowait(request)
|
||||
# Signal when the full batch is ready. The event will be reset
|
||||
# in wait_for_batch.
|
||||
if self.queue.qsize() == self.max_batch_size:
|
||||
self.full_batch_event.set()
|
||||
|
||||
def qsize(self) -> int:
|
||||
return self.queue.qsize()
|
||||
|
||||
async def wait_for_batch(self) -> List[Query]:
|
||||
"""Wait for batch respecting self.max_batch_size and self.timeout_s.
|
||||
|
||||
Returns a batch of up to self.max_batch_size items, waiting for up
|
||||
to self.timeout_s for a full batch. After the timeout, returns as many
|
||||
items as are ready.
|
||||
|
||||
Always returns a batch with at least one item - will block
|
||||
indefinitely until an item comes in.
|
||||
"""
|
||||
curr_timeout = self.timeout_s
|
||||
batch = []
|
||||
while len(batch) == 0:
|
||||
loop_start = time.time()
|
||||
|
||||
# If the timeout is 0, wait for any item to be available on the
|
||||
# queue.
|
||||
if curr_timeout == 0:
|
||||
batch.append(await self.queue.get())
|
||||
# If the timeout is nonzero, wait for either the timeout to occur
|
||||
# or the max batch size to be ready.
|
||||
else:
|
||||
try:
|
||||
await asyncio.wait_for(self.full_batch_event.wait(),
|
||||
curr_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
# Pull up to the max_batch_size requests off the queue.
|
||||
while len(batch) < self.max_batch_size and not self.queue.empty():
|
||||
batch.append(self.queue.get_nowait())
|
||||
|
||||
# Reset the event if there are fewer than max_batch_size requests
|
||||
# in the queue.
|
||||
if (self.queue.qsize() < self.max_batch_size
|
||||
and self.full_batch_event.is_set()):
|
||||
self.full_batch_event.clear()
|
||||
|
||||
# Adjust the timeout based on the time spent in this iteration.
|
||||
curr_timeout = max(0, curr_timeout - (time.time() - loop_start))
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def create_backend_replica(backend_def: Union[Callable, Type[Callable], str]):
|
||||
"""Creates a replica class wrapping the provided function or class.
|
||||
|
||||
|
@ -185,8 +121,8 @@ class RayServeReplica:
|
|||
self.is_function = is_function
|
||||
|
||||
self.config = backend_config
|
||||
self.batch_queue = BatchQueue(self.config.max_batch_size or 1,
|
||||
self.config.batch_wait_timeout)
|
||||
self.batch_queue = _BatchQueue(self.config.max_batch_size or 1,
|
||||
self.config.batch_wait_timeout)
|
||||
self.reconfigure(self.config.user_config)
|
||||
|
||||
self.num_ongoing_requests = 0
|
||||
|
|
282
python/ray/serve/batching.py
Normal file
282
python/ray/serve/batching.py
Normal file
|
@ -0,0 +1,282 @@
|
|||
import asyncio
|
||||
from functools import wraps
|
||||
from inspect import iscoroutinefunction
|
||||
import time
|
||||
from typing import Any, Callable, List, Optional, overload, Tuple, TypeVar
|
||||
|
||||
from ray.serve.exceptions import RayServeException
|
||||
|
||||
|
||||
class _BatchQueue:
|
||||
def __init__(self,
|
||||
max_batch_size: int,
|
||||
timeout_s: float,
|
||||
handle_batch_func: Optional[Callable] = None) -> None:
|
||||
"""Async queue that accepts individual items and returns batches.
|
||||
|
||||
Respects max_batch_size and timeout_s; a batch will be returned when
|
||||
max_batch_size elements are available or the timeout has passed since
|
||||
the previous get.
|
||||
|
||||
If handle_batch_func is passed in, a background coroutine will run to
|
||||
poll from the queue and call handle_batch_func on the results.
|
||||
|
||||
Arguments:
|
||||
max_batch_size (int): max number of elements to return in a batch.
|
||||
timeout_s (float): time to wait before returning an incomplete
|
||||
batch.
|
||||
handle_batch_func(Optional[Callable]): callback to run in the
|
||||
background to handle batches if provided.
|
||||
"""
|
||||
self.queue = asyncio.Queue()
|
||||
self.full_batch_event = asyncio.Event()
|
||||
self.max_batch_size = max_batch_size
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
self._handle_batch_task = None
|
||||
if handle_batch_func is not None:
|
||||
self._handle_batch_task = asyncio.get_event_loop().create_task(
|
||||
self._handle_batches(handle_batch_func))
|
||||
|
||||
def set_config(self, max_batch_size: int, timeout_s: float) -> None:
|
||||
self.max_batch_size = max_batch_size
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def put(self, request: Tuple[Any, asyncio.Future]) -> None:
|
||||
self.queue.put_nowait(request)
|
||||
# Signal when the full batch is ready. The event will be reset
|
||||
# in wait_for_batch.
|
||||
if self.queue.qsize() == self.max_batch_size:
|
||||
self.full_batch_event.set()
|
||||
|
||||
def qsize(self) -> int:
|
||||
return self.queue.qsize()
|
||||
|
||||
async def wait_for_batch(self) -> List[Any]:
|
||||
"""Wait for batch respecting self.max_batch_size and self.timeout_s.
|
||||
|
||||
Returns a batch of up to self.max_batch_size items, waiting for up
|
||||
to self.timeout_s for a full batch. After the timeout, returns as many
|
||||
items as are ready.
|
||||
|
||||
Always returns a batch with at least one item - will block
|
||||
indefinitely until an item comes in.
|
||||
"""
|
||||
curr_timeout = self.timeout_s
|
||||
batch = []
|
||||
while len(batch) == 0:
|
||||
loop_start = time.time()
|
||||
|
||||
# If the timeout is 0, wait for any item to be available on the
|
||||
# queue.
|
||||
if curr_timeout == 0:
|
||||
batch.append(await self.queue.get())
|
||||
# If the timeout is nonzero, wait for either the timeout to occur
|
||||
# or the max batch size to be ready.
|
||||
else:
|
||||
try:
|
||||
await asyncio.wait_for(self.full_batch_event.wait(),
|
||||
curr_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
# Pull up to the max_batch_size requests off the queue.
|
||||
while len(batch) < self.max_batch_size and not self.queue.empty():
|
||||
batch.append(self.queue.get_nowait())
|
||||
|
||||
# Reset the event if there are fewer than max_batch_size requests
|
||||
# in the queue.
|
||||
if (self.queue.qsize() < self.max_batch_size
|
||||
and self.full_batch_event.is_set()):
|
||||
self.full_batch_event.clear()
|
||||
|
||||
# Adjust the timeout based on the time spent in this iteration.
|
||||
curr_timeout = max(0, curr_timeout - (time.time() - loop_start))
|
||||
|
||||
return batch
|
||||
|
||||
async def _handle_batches(self, func):
|
||||
while True:
|
||||
batch = await self.wait_for_batch()
|
||||
assert len(batch) > 0
|
||||
self_arg = batch[0][0]
|
||||
args = [item[1] for item in batch]
|
||||
futures = [item[2] for item in batch]
|
||||
|
||||
try:
|
||||
# Method call.
|
||||
if self_arg is not None:
|
||||
results = await func(self_arg, args)
|
||||
# Normal function call.
|
||||
else:
|
||||
results = await func(args)
|
||||
|
||||
if len(results) != len(batch):
|
||||
raise RayServeException(
|
||||
"Batched function doesn't preserve batch size. "
|
||||
f"The input list has length {len(batch)} but the "
|
||||
f"returned list has length {len(results)}.")
|
||||
|
||||
for i, result in enumerate(results):
|
||||
futures[i].set_result(result)
|
||||
except Exception as e:
|
||||
for future in futures:
|
||||
future.set_exception(e)
|
||||
|
||||
def __del__(self):
|
||||
if (self._handle_batch_task is None
|
||||
or not asyncio.get_event_loop().is_running()):
|
||||
return
|
||||
|
||||
# TODO(edoakes): although we try to gracefully shutdown here, it still
|
||||
# causes some errors when the process exits due to the asyncio loop
|
||||
# already being destroyed.
|
||||
self._handle_batch_task.cancel()
|
||||
|
||||
|
||||
def extract_self_if_method_call(args: List[Any],
|
||||
func: Callable) -> Optional[object]:
|
||||
"""Check if this is a method rather than a function.
|
||||
|
||||
Does this by checking to see if `func` is the attribute of the first
|
||||
(`self`) argument under `func.__name__`. Unfortunately, this is the most
|
||||
robust solution to this I was able to find. It would also be preferable
|
||||
to do this check when the decorator runs, rather than when the method is.
|
||||
|
||||
Returns the `self` object if it's a method call, else None.
|
||||
|
||||
Arguments:
|
||||
args (List[Any]): arguments to the function/method call.
|
||||
func (Callable): the unbound function that was called.
|
||||
"""
|
||||
if len(args) > 0:
|
||||
method = getattr(args[0], func.__name__, False)
|
||||
if method:
|
||||
wrapped = getattr(method, "__wrapped__", False)
|
||||
if wrapped and wrapped == func:
|
||||
return args.pop(0)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
F = TypeVar("F", bound=Callable[[List[T]], List[R]])
|
||||
G = TypeVar("G", bound=Callable[[T], R])
|
||||
|
||||
|
||||
# Normal decorator use case (called with no arguments).
|
||||
@overload
|
||||
def batch(func: F) -> G:
|
||||
pass
|
||||
|
||||
|
||||
# "Decorator factory" use case (called with arguments).
|
||||
@overload
|
||||
def batch(max_batch_size: int = 10,
|
||||
batch_wait_timeout_s: float = 0.1) -> Callable[[F], G]:
|
||||
pass
|
||||
|
||||
|
||||
def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.1):
|
||||
"""Converts a function to asynchronously handle batches.
|
||||
|
||||
The function can be a standalone function or a class method, and must
|
||||
take a list of objects as its sole argument and return a list of the
|
||||
same length.
|
||||
|
||||
When invoked, the caller passes a single object. These will be batched
|
||||
and executed asynchronously once there is a batch of `max_batch_size`
|
||||
or `batch_wait_timeout_s` has elapsed, whichever occurs first.
|
||||
|
||||
Example:
|
||||
@serve.batch(max_batch_size=50, batch_wait_timeout_s=0.5)
|
||||
async def handle_batch(self, batch: List[str]):
|
||||
return [s.lower() for s in batch]
|
||||
|
||||
async def handle_single(self, s: str):
|
||||
# Will return s.lower().
|
||||
return await handle_batch(s)
|
||||
|
||||
Arguments:
|
||||
max_batch_size (int): the maximum batch size that will be executed in
|
||||
one call to the underlying function.
|
||||
batch_wait_timeout_s (float): the maximum duration to wait for
|
||||
`max_batch_size` elements before running the underlying function.
|
||||
"""
|
||||
# `_func` will be None in the case when the decorator is parametrized.
|
||||
# See the comment at the end of this function for a detailed explanation.
|
||||
if _func is not None:
|
||||
if not callable(_func):
|
||||
raise TypeError("@serve.batch can only be used to "
|
||||
"decorate functions or methods.")
|
||||
|
||||
if not iscoroutinefunction(_func):
|
||||
raise TypeError(
|
||||
"Functions decorated with @serve.batch must be 'async def'")
|
||||
|
||||
if not isinstance(max_batch_size, int):
|
||||
if isinstance(max_batch_size, float) and max_batch_size.is_integer():
|
||||
max_batch_size = int(max_batch_size)
|
||||
else:
|
||||
raise TypeError("max_batch_size must be integer >= 1")
|
||||
|
||||
if max_batch_size < 1:
|
||||
raise ValueError("max_batch_size must be an integer >= 1")
|
||||
|
||||
if not isinstance(batch_wait_timeout_s, (float, int)):
|
||||
raise TypeError("batch_wait_timeout_s must be a float >= 0")
|
||||
|
||||
if batch_wait_timeout_s < 0:
|
||||
raise ValueError("batch_wait_timeout_s must be a float >= 0")
|
||||
|
||||
def _batch_decorator(_func):
|
||||
@wraps(_func)
|
||||
async def batch_wrapper(*args, **kwargs):
|
||||
args = list(args)
|
||||
self = extract_self_if_method_call(args, _func)
|
||||
|
||||
if len(args) != 1:
|
||||
raise ValueError("@serve.batch functions can only take a "
|
||||
"single argument as input")
|
||||
|
||||
if len(kwargs) != 0:
|
||||
raise ValueError(
|
||||
"@serve.batch functions do not support kwargs")
|
||||
|
||||
if self is None:
|
||||
# For functions, inject the batch queue as an
|
||||
# attribute of the function.
|
||||
batch_queue_object = _func
|
||||
else:
|
||||
# For methods, inject the batch queue as an
|
||||
# attribute of the object.
|
||||
batch_queue_object = self
|
||||
|
||||
# The first time the function runs, we lazily construct the batch
|
||||
# queue and inject it under a custom attribute name. On subsequent
|
||||
# runs, we just get a reference to the attribute.
|
||||
batch_queue_attr = f"__serve_batch_queue_{_func.__name__}"
|
||||
if not hasattr(batch_queue_object, batch_queue_attr):
|
||||
batch_queue = _BatchQueue(max_batch_size, batch_wait_timeout_s,
|
||||
_func)
|
||||
setattr(batch_queue_object, batch_queue_attr, batch_queue)
|
||||
else:
|
||||
batch_queue = getattr(batch_queue_object, batch_queue_attr)
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
batch_queue.put((self, args[0], future))
|
||||
|
||||
# This will raise if the underlying call raised an exception.
|
||||
return await future
|
||||
|
||||
return batch_wrapper
|
||||
|
||||
# Unfortunately, this is required to handle both non-parametrized
|
||||
# (@serve.batch) and parametrized (@serve.batch(**kwargs)) usage.
|
||||
# In the former case, `serve.batch` will be called with the underlying
|
||||
# function as the sole argument. In the latter case, it will first be
|
||||
# called with **kwargs, then the result of that call will be called
|
||||
# with the underlying function as the sole argument (i.e., it must be a
|
||||
# "decorator factory.").
|
||||
return _batch_decorator(_func) if callable(_func) else _batch_decorator
|
|
@ -1,3 +1,5 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
|
@ -58,6 +60,265 @@ def test_batching_exception(serve_instance):
|
|||
assert ray.get(handle.remote(temp=1))
|
||||
|
||||
|
||||
def test_app_level_batching(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class BatchingExample:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
@serve.batch(max_batch_size=5, batch_wait_timeout_s=1)
|
||||
async def handle_batch(self, requests):
|
||||
self.count += 1
|
||||
batch_size = len(requests)
|
||||
return [self.count] * batch_size
|
||||
|
||||
async def __call__(self, request):
|
||||
return await self.handle_batch(request)
|
||||
|
||||
# set the max batch size
|
||||
client.create_backend("counter:v11", BatchingExample)
|
||||
client.create_endpoint(
|
||||
"counter1", backend="counter:v11", route="/increment2")
|
||||
|
||||
future_list = []
|
||||
handle = client.get_handle("counter1")
|
||||
for _ in range(20):
|
||||
f = handle.remote(temp=1)
|
||||
future_list.append(f)
|
||||
|
||||
counter_result = ray.get(future_list)
|
||||
# since count is only updated per batch of queries
|
||||
# If there atleast one __call__ fn call with batch size greater than 1
|
||||
# counter result will always be less than 20
|
||||
assert max(counter_result) < 20
|
||||
|
||||
|
||||
def test_app_level_batching_exception(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class NoListReturned:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
@serve.batch(max_batch_size=5)
|
||||
async def handle_batch(self, requests):
|
||||
return len(requests)
|
||||
|
||||
async def __call__(self, request):
|
||||
return await self.handle_batch(request)
|
||||
|
||||
# Set the max batch size.
|
||||
client.create_backend("exception:v1", NoListReturned)
|
||||
client.create_endpoint("exception-test", backend="exception:v1")
|
||||
|
||||
handle = client.get_handle("exception-test")
|
||||
with pytest.raises(ray.exceptions.RayTaskError):
|
||||
assert ray.get(handle.remote(temp=1))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_validation():
|
||||
@serve.batch
|
||||
async def function():
|
||||
pass
|
||||
|
||||
@serve.batch(max_batch_size=10, batch_wait_timeout_s=1.5)
|
||||
async def function2():
|
||||
pass
|
||||
|
||||
class Class():
|
||||
@serve.batch
|
||||
async def method(self):
|
||||
pass
|
||||
|
||||
class Class2():
|
||||
@serve.batch(max_batch_size=10, batch_wait_timeout_s=1.5)
|
||||
async def method(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError, match="async def"):
|
||||
|
||||
@serve.batch
|
||||
def non_async_function():
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError, match="async def"):
|
||||
|
||||
class NotAsync:
|
||||
@serve.batch
|
||||
def method(self, requests):
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
class ZeroBatch:
|
||||
@serve.batch(max_batch_size=0)
|
||||
async def method(self, requests):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class FloatNonIntBatch:
|
||||
@serve.batch(max_batch_size=1.1)
|
||||
async def method(self, requests):
|
||||
pass
|
||||
|
||||
class FloatIntegerBatch:
|
||||
@serve.batch(max_batch_size=1.0)
|
||||
async def method(self, requests):
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
class NegativeTimeout:
|
||||
@serve.batch(batch_wait_timeout_s=-0.1)
|
||||
async def method(self, requests):
|
||||
pass
|
||||
|
||||
class FloatZeroTimeout:
|
||||
@serve.batch(batch_wait_timeout_s=0.0)
|
||||
async def method(self, requests):
|
||||
pass
|
||||
|
||||
class IntZeroTimeout:
|
||||
@serve.batch(batch_wait_timeout_s=0)
|
||||
async def method(self, requests):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class NonTimeout:
|
||||
@serve.batch(batch_wait_timeout_s="a")
|
||||
async def method(self, requests):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("use_class", [True, False])
|
||||
async def test_batch_size_one_long_timeout(use_class):
|
||||
@serve.batch(max_batch_size=1, batch_wait_timeout_s=1000)
|
||||
async def long_timeout(requests):
|
||||
if "raise" in requests:
|
||||
1 / 0
|
||||
return requests
|
||||
|
||||
class LongTimeout:
|
||||
@serve.batch(max_batch_size=1, batch_wait_timeout_s=1000)
|
||||
async def long_timeout(self, requests):
|
||||
if "raise" in requests:
|
||||
1 / 0
|
||||
return requests
|
||||
|
||||
cls = LongTimeout()
|
||||
|
||||
async def call(arg):
|
||||
if use_class:
|
||||
return await cls.long_timeout(arg)
|
||||
else:
|
||||
return await long_timeout(arg)
|
||||
|
||||
assert await call("hi") == "hi"
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
await call("raise")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("use_class", [True, False])
|
||||
async def test_batch_size_multiple_zero_timeout(use_class):
|
||||
@serve.batch(max_batch_size=2, batch_wait_timeout_s=0)
|
||||
async def zero_timeout(requests):
|
||||
await asyncio.sleep(1)
|
||||
if "raise" in requests:
|
||||
1 / 0
|
||||
return requests
|
||||
|
||||
class ZeroTimeout:
|
||||
@serve.batch(max_batch_size=2, batch_wait_timeout_s=0)
|
||||
async def zero_timeout(self, requests):
|
||||
await asyncio.sleep(1)
|
||||
if "raise" in requests:
|
||||
1 / 0
|
||||
return requests
|
||||
|
||||
cls = ZeroTimeout()
|
||||
|
||||
async def call(arg):
|
||||
if use_class:
|
||||
return await cls.zero_timeout(arg)
|
||||
else:
|
||||
return await zero_timeout(arg)
|
||||
|
||||
assert await call("hi") == "hi"
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
await call("raise")
|
||||
|
||||
# Check that 2 requests will be executed together if available.
|
||||
# The first should cause a size-one batch to be executed, then
|
||||
# the next two should be executed together (signaled by both
|
||||
# having the exception).
|
||||
t1 = asyncio.get_event_loop().create_task(call("hi1"))
|
||||
await asyncio.sleep(0.5)
|
||||
t2 = asyncio.get_event_loop().create_task(call("hi2"))
|
||||
t3 = asyncio.get_event_loop().create_task(call("raise"))
|
||||
|
||||
assert await t1 == "hi1"
|
||||
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
await t2
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
await t3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("use_class", [True, False])
|
||||
async def test_batch_size_multiple_long_timeout(use_class):
|
||||
@serve.batch(max_batch_size=3, batch_wait_timeout_s=1000)
|
||||
async def long_timeout(requests):
|
||||
if "raise" in requests:
|
||||
1 / 0
|
||||
return requests
|
||||
|
||||
class LongTimeout:
|
||||
@serve.batch(max_batch_size=3, batch_wait_timeout_s=1000)
|
||||
async def long_timeout(self, requests):
|
||||
if "raise" in requests:
|
||||
1 / 0
|
||||
return requests
|
||||
|
||||
cls = LongTimeout()
|
||||
|
||||
async def call(arg):
|
||||
if use_class:
|
||||
return await cls.long_timeout(arg)
|
||||
else:
|
||||
return await long_timeout(arg)
|
||||
|
||||
t1 = asyncio.get_event_loop().create_task(call("hi1"))
|
||||
t2 = asyncio.get_event_loop().create_task(call("hi2"))
|
||||
done, pending = await asyncio.wait([t1, t2], timeout=0.1)
|
||||
assert len(done) == 0
|
||||
t3 = asyncio.get_event_loop().create_task(call("hi3"))
|
||||
done, pending = await asyncio.wait([t1, t2, t3], timeout=100)
|
||||
assert set(done) == {t1, t2, t3}
|
||||
assert [t1.result(), t2.result(), t3.result()] == ["hi1", "hi2", "hi3"]
|
||||
|
||||
t1 = asyncio.get_event_loop().create_task(call("hi1"))
|
||||
t2 = asyncio.get_event_loop().create_task(call("raise"))
|
||||
done, pending = await asyncio.wait([t1, t2], timeout=0.1)
|
||||
assert len(done) == 0
|
||||
t3 = asyncio.get_event_loop().create_task(call("hi3"))
|
||||
done, pending = await asyncio.wait([t1, t2, t3], timeout=100)
|
||||
assert set(done) == {t1, t2, t3}
|
||||
assert all(isinstance(t.exception(), ZeroDivisionError) for t in done)
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
t1.result()
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
t2.result()
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
t3.result()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
|
Loading…
Add table
Reference in a new issue