[serve] Application-level batching initial commit (#14610)

This commit is contained in:
Edward Oakes 2021-03-11 21:16:08 -06:00 committed by GitHub
parent 7b3102dd32
commit 9cf328d616
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 548 additions and 68 deletions

View file

@ -1,5 +1,6 @@
from ray.serve.api import (accept_batch, Client, connect, start,
get_replica_context)
from ray.serve.batching import batch
from ray.serve.config import BackendConfig, HTTPOptions
# Mute the warning because Serve sometimes intentionally calls
@ -8,6 +9,6 @@ import ray.worker
ray.worker.blocking_get_inside_async_warned = True
__all__ = [
"accept_batch", "BackendConfig", "connect", "Client", "start",
"accept_batch", "BackendConfig", "batch", "connect", "Client", "start",
"HTTPOptions", "get_replica_context"
]

View file

@ -13,6 +13,7 @@ import ray
from ray.actor import ActorHandle
from ray._private.async_compat import sync_to_async
from ray.serve.batching import _BatchQueue
from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
unpack_future, import_attr)
from ray.serve.exceptions import RayServeException
@ -30,71 +31,6 @@ from ray.exceptions import RayTaskError
logger = _get_logger()
class BatchQueue:
def __init__(self, max_batch_size: int, timeout_s: float) -> None:
self.queue = asyncio.Queue()
self.full_batch_event = asyncio.Event()
self.max_batch_size = max_batch_size
self.timeout_s = timeout_s
def set_config(self, max_batch_size: int, timeout_s: float) -> None:
self.max_batch_size = max_batch_size
self.timeout_s = timeout_s
def put(self, request: Query) -> None:
self.queue.put_nowait(request)
# Signal when the full batch is ready. The event will be reset
# in wait_for_batch.
if self.queue.qsize() == self.max_batch_size:
self.full_batch_event.set()
def qsize(self) -> int:
return self.queue.qsize()
async def wait_for_batch(self) -> List[Query]:
"""Wait for batch respecting self.max_batch_size and self.timeout_s.
Returns a batch of up to self.max_batch_size items, waiting for up
to self.timeout_s for a full batch. After the timeout, returns as many
items as are ready.
Always returns a batch with at least one item - will block
indefinitely until an item comes in.
"""
curr_timeout = self.timeout_s
batch = []
while len(batch) == 0:
loop_start = time.time()
# If the timeout is 0, wait for any item to be available on the
# queue.
if curr_timeout == 0:
batch.append(await self.queue.get())
# If the timeout is nonzero, wait for either the timeout to occur
# or the max batch size to be ready.
else:
try:
await asyncio.wait_for(self.full_batch_event.wait(),
curr_timeout)
except asyncio.TimeoutError:
pass
# Pull up to the max_batch_size requests off the queue.
while len(batch) < self.max_batch_size and not self.queue.empty():
batch.append(self.queue.get_nowait())
# Reset the event if there are fewer than max_batch_size requests
# in the queue.
if (self.queue.qsize() < self.max_batch_size
and self.full_batch_event.is_set()):
self.full_batch_event.clear()
# Adjust the timeout based on the time spent in this iteration.
curr_timeout = max(0, curr_timeout - (time.time() - loop_start))
return batch
def create_backend_replica(backend_def: Union[Callable, Type[Callable], str]):
"""Creates a replica class wrapping the provided function or class.
@ -185,8 +121,8 @@ class RayServeReplica:
self.is_function = is_function
self.config = backend_config
self.batch_queue = BatchQueue(self.config.max_batch_size or 1,
self.config.batch_wait_timeout)
self.batch_queue = _BatchQueue(self.config.max_batch_size or 1,
self.config.batch_wait_timeout)
self.reconfigure(self.config.user_config)
self.num_ongoing_requests = 0

View file

@ -0,0 +1,282 @@
import asyncio
from functools import wraps
from inspect import iscoroutinefunction
import time
from typing import Any, Callable, List, Optional, overload, Tuple, TypeVar
from ray.serve.exceptions import RayServeException
class _BatchQueue:
def __init__(self,
max_batch_size: int,
timeout_s: float,
handle_batch_func: Optional[Callable] = None) -> None:
"""Async queue that accepts individual items and returns batches.
Respects max_batch_size and timeout_s; a batch will be returned when
max_batch_size elements are available or the timeout has passed since
the previous get.
If handle_batch_func is passed in, a background coroutine will run to
poll from the queue and call handle_batch_func on the results.
Arguments:
max_batch_size (int): max number of elements to return in a batch.
timeout_s (float): time to wait before returning an incomplete
batch.
handle_batch_func(Optional[Callable]): callback to run in the
background to handle batches if provided.
"""
self.queue = asyncio.Queue()
self.full_batch_event = asyncio.Event()
self.max_batch_size = max_batch_size
self.timeout_s = timeout_s
self._handle_batch_task = None
if handle_batch_func is not None:
self._handle_batch_task = asyncio.get_event_loop().create_task(
self._handle_batches(handle_batch_func))
def set_config(self, max_batch_size: int, timeout_s: float) -> None:
self.max_batch_size = max_batch_size
self.timeout_s = timeout_s
def put(self, request: Tuple[Any, asyncio.Future]) -> None:
self.queue.put_nowait(request)
# Signal when the full batch is ready. The event will be reset
# in wait_for_batch.
if self.queue.qsize() == self.max_batch_size:
self.full_batch_event.set()
def qsize(self) -> int:
return self.queue.qsize()
async def wait_for_batch(self) -> List[Any]:
"""Wait for batch respecting self.max_batch_size and self.timeout_s.
Returns a batch of up to self.max_batch_size items, waiting for up
to self.timeout_s for a full batch. After the timeout, returns as many
items as are ready.
Always returns a batch with at least one item - will block
indefinitely until an item comes in.
"""
curr_timeout = self.timeout_s
batch = []
while len(batch) == 0:
loop_start = time.time()
# If the timeout is 0, wait for any item to be available on the
# queue.
if curr_timeout == 0:
batch.append(await self.queue.get())
# If the timeout is nonzero, wait for either the timeout to occur
# or the max batch size to be ready.
else:
try:
await asyncio.wait_for(self.full_batch_event.wait(),
curr_timeout)
except asyncio.TimeoutError:
pass
# Pull up to the max_batch_size requests off the queue.
while len(batch) < self.max_batch_size and not self.queue.empty():
batch.append(self.queue.get_nowait())
# Reset the event if there are fewer than max_batch_size requests
# in the queue.
if (self.queue.qsize() < self.max_batch_size
and self.full_batch_event.is_set()):
self.full_batch_event.clear()
# Adjust the timeout based on the time spent in this iteration.
curr_timeout = max(0, curr_timeout - (time.time() - loop_start))
return batch
async def _handle_batches(self, func):
while True:
batch = await self.wait_for_batch()
assert len(batch) > 0
self_arg = batch[0][0]
args = [item[1] for item in batch]
futures = [item[2] for item in batch]
try:
# Method call.
if self_arg is not None:
results = await func(self_arg, args)
# Normal function call.
else:
results = await func(args)
if len(results) != len(batch):
raise RayServeException(
"Batched function doesn't preserve batch size. "
f"The input list has length {len(batch)} but the "
f"returned list has length {len(results)}.")
for i, result in enumerate(results):
futures[i].set_result(result)
except Exception as e:
for future in futures:
future.set_exception(e)
def __del__(self):
if (self._handle_batch_task is None
or not asyncio.get_event_loop().is_running()):
return
# TODO(edoakes): although we try to gracefully shutdown here, it still
# causes some errors when the process exits due to the asyncio loop
# already being destroyed.
self._handle_batch_task.cancel()
def extract_self_if_method_call(args: List[Any],
func: Callable) -> Optional[object]:
"""Check if this is a method rather than a function.
Does this by checking to see if `func` is the attribute of the first
(`self`) argument under `func.__name__`. Unfortunately, this is the most
robust solution to this I was able to find. It would also be preferable
to do this check when the decorator runs, rather than when the method is.
Returns the `self` object if it's a method call, else None.
Arguments:
args (List[Any]): arguments to the function/method call.
func (Callable): the unbound function that was called.
"""
if len(args) > 0:
method = getattr(args[0], func.__name__, False)
if method:
wrapped = getattr(method, "__wrapped__", False)
if wrapped and wrapped == func:
return args.pop(0)
return None
T = TypeVar("T")
R = TypeVar("R")
F = TypeVar("F", bound=Callable[[List[T]], List[R]])
G = TypeVar("G", bound=Callable[[T], R])
# Normal decorator use case (called with no arguments).
@overload
def batch(func: F) -> G:
pass
# "Decorator factory" use case (called with arguments).
@overload
def batch(max_batch_size: int = 10,
batch_wait_timeout_s: float = 0.1) -> Callable[[F], G]:
pass
def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.1):
"""Converts a function to asynchronously handle batches.
The function can be a standalone function or a class method, and must
take a list of objects as its sole argument and return a list of the
same length.
When invoked, the caller passes a single object. These will be batched
and executed asynchronously once there is a batch of `max_batch_size`
or `batch_wait_timeout_s` has elapsed, whichever occurs first.
Example:
@serve.batch(max_batch_size=50, batch_wait_timeout_s=0.5)
async def handle_batch(self, batch: List[str]):
return [s.lower() for s in batch]
async def handle_single(self, s: str):
# Will return s.lower().
return await handle_batch(s)
Arguments:
max_batch_size (int): the maximum batch size that will be executed in
one call to the underlying function.
batch_wait_timeout_s (float): the maximum duration to wait for
`max_batch_size` elements before running the underlying function.
"""
# `_func` will be None in the case when the decorator is parametrized.
# See the comment at the end of this function for a detailed explanation.
if _func is not None:
if not callable(_func):
raise TypeError("@serve.batch can only be used to "
"decorate functions or methods.")
if not iscoroutinefunction(_func):
raise TypeError(
"Functions decorated with @serve.batch must be 'async def'")
if not isinstance(max_batch_size, int):
if isinstance(max_batch_size, float) and max_batch_size.is_integer():
max_batch_size = int(max_batch_size)
else:
raise TypeError("max_batch_size must be integer >= 1")
if max_batch_size < 1:
raise ValueError("max_batch_size must be an integer >= 1")
if not isinstance(batch_wait_timeout_s, (float, int)):
raise TypeError("batch_wait_timeout_s must be a float >= 0")
if batch_wait_timeout_s < 0:
raise ValueError("batch_wait_timeout_s must be a float >= 0")
def _batch_decorator(_func):
@wraps(_func)
async def batch_wrapper(*args, **kwargs):
args = list(args)
self = extract_self_if_method_call(args, _func)
if len(args) != 1:
raise ValueError("@serve.batch functions can only take a "
"single argument as input")
if len(kwargs) != 0:
raise ValueError(
"@serve.batch functions do not support kwargs")
if self is None:
# For functions, inject the batch queue as an
# attribute of the function.
batch_queue_object = _func
else:
# For methods, inject the batch queue as an
# attribute of the object.
batch_queue_object = self
# The first time the function runs, we lazily construct the batch
# queue and inject it under a custom attribute name. On subsequent
# runs, we just get a reference to the attribute.
batch_queue_attr = f"__serve_batch_queue_{_func.__name__}"
if not hasattr(batch_queue_object, batch_queue_attr):
batch_queue = _BatchQueue(max_batch_size, batch_wait_timeout_s,
_func)
setattr(batch_queue_object, batch_queue_attr, batch_queue)
else:
batch_queue = getattr(batch_queue_object, batch_queue_attr)
future = asyncio.get_event_loop().create_future()
batch_queue.put((self, args[0], future))
# This will raise if the underlying call raised an exception.
return await future
return batch_wrapper
# Unfortunately, this is required to handle both non-parametrized
# (@serve.batch) and parametrized (@serve.batch(**kwargs)) usage.
# In the former case, `serve.batch` will be called with the underlying
# function as the sole argument. In the latter case, it will first be
# called with **kwargs, then the result of that call will be called
# with the underlying function as the sole argument (i.e., it must be a
# "decorator factory.").
return _batch_decorator(_func) if callable(_func) else _batch_decorator

View file

@ -1,3 +1,5 @@
import asyncio
import pytest
import ray
@ -58,6 +60,265 @@ def test_batching_exception(serve_instance):
assert ray.get(handle.remote(temp=1))
def test_app_level_batching(serve_instance):
client = serve_instance
class BatchingExample:
def __init__(self):
self.count = 0
@serve.batch(max_batch_size=5, batch_wait_timeout_s=1)
async def handle_batch(self, requests):
self.count += 1
batch_size = len(requests)
return [self.count] * batch_size
async def __call__(self, request):
return await self.handle_batch(request)
# set the max batch size
client.create_backend("counter:v11", BatchingExample)
client.create_endpoint(
"counter1", backend="counter:v11", route="/increment2")
future_list = []
handle = client.get_handle("counter1")
for _ in range(20):
f = handle.remote(temp=1)
future_list.append(f)
counter_result = ray.get(future_list)
# since count is only updated per batch of queries
# If there atleast one __call__ fn call with batch size greater than 1
# counter result will always be less than 20
assert max(counter_result) < 20
def test_app_level_batching_exception(serve_instance):
client = serve_instance
class NoListReturned:
def __init__(self):
self.count = 0
@serve.batch(max_batch_size=5)
async def handle_batch(self, requests):
return len(requests)
async def __call__(self, request):
return await self.handle_batch(request)
# Set the max batch size.
client.create_backend("exception:v1", NoListReturned)
client.create_endpoint("exception-test", backend="exception:v1")
handle = client.get_handle("exception-test")
with pytest.raises(ray.exceptions.RayTaskError):
assert ray.get(handle.remote(temp=1))
@pytest.mark.asyncio
async def test_decorator_validation():
@serve.batch
async def function():
pass
@serve.batch(max_batch_size=10, batch_wait_timeout_s=1.5)
async def function2():
pass
class Class():
@serve.batch
async def method(self):
pass
class Class2():
@serve.batch(max_batch_size=10, batch_wait_timeout_s=1.5)
async def method(self):
pass
with pytest.raises(TypeError, match="async def"):
@serve.batch
def non_async_function():
pass
with pytest.raises(TypeError, match="async def"):
class NotAsync:
@serve.batch
def method(self, requests):
pass
with pytest.raises(ValueError):
class ZeroBatch:
@serve.batch(max_batch_size=0)
async def method(self, requests):
pass
with pytest.raises(TypeError):
class FloatNonIntBatch:
@serve.batch(max_batch_size=1.1)
async def method(self, requests):
pass
class FloatIntegerBatch:
@serve.batch(max_batch_size=1.0)
async def method(self, requests):
pass
with pytest.raises(ValueError):
class NegativeTimeout:
@serve.batch(batch_wait_timeout_s=-0.1)
async def method(self, requests):
pass
class FloatZeroTimeout:
@serve.batch(batch_wait_timeout_s=0.0)
async def method(self, requests):
pass
class IntZeroTimeout:
@serve.batch(batch_wait_timeout_s=0)
async def method(self, requests):
pass
with pytest.raises(TypeError):
class NonTimeout:
@serve.batch(batch_wait_timeout_s="a")
async def method(self, requests):
pass
@pytest.mark.asyncio
@pytest.mark.parametrize("use_class", [True, False])
async def test_batch_size_one_long_timeout(use_class):
@serve.batch(max_batch_size=1, batch_wait_timeout_s=1000)
async def long_timeout(requests):
if "raise" in requests:
1 / 0
return requests
class LongTimeout:
@serve.batch(max_batch_size=1, batch_wait_timeout_s=1000)
async def long_timeout(self, requests):
if "raise" in requests:
1 / 0
return requests
cls = LongTimeout()
async def call(arg):
if use_class:
return await cls.long_timeout(arg)
else:
return await long_timeout(arg)
assert await call("hi") == "hi"
with pytest.raises(ZeroDivisionError):
await call("raise")
@pytest.mark.asyncio
@pytest.mark.parametrize("use_class", [True, False])
async def test_batch_size_multiple_zero_timeout(use_class):
@serve.batch(max_batch_size=2, batch_wait_timeout_s=0)
async def zero_timeout(requests):
await asyncio.sleep(1)
if "raise" in requests:
1 / 0
return requests
class ZeroTimeout:
@serve.batch(max_batch_size=2, batch_wait_timeout_s=0)
async def zero_timeout(self, requests):
await asyncio.sleep(1)
if "raise" in requests:
1 / 0
return requests
cls = ZeroTimeout()
async def call(arg):
if use_class:
return await cls.zero_timeout(arg)
else:
return await zero_timeout(arg)
assert await call("hi") == "hi"
with pytest.raises(ZeroDivisionError):
await call("raise")
# Check that 2 requests will be executed together if available.
# The first should cause a size-one batch to be executed, then
# the next two should be executed together (signaled by both
# having the exception).
t1 = asyncio.get_event_loop().create_task(call("hi1"))
await asyncio.sleep(0.5)
t2 = asyncio.get_event_loop().create_task(call("hi2"))
t3 = asyncio.get_event_loop().create_task(call("raise"))
assert await t1 == "hi1"
with pytest.raises(ZeroDivisionError):
await t2
with pytest.raises(ZeroDivisionError):
await t3
@pytest.mark.asyncio
@pytest.mark.parametrize("use_class", [True, False])
async def test_batch_size_multiple_long_timeout(use_class):
@serve.batch(max_batch_size=3, batch_wait_timeout_s=1000)
async def long_timeout(requests):
if "raise" in requests:
1 / 0
return requests
class LongTimeout:
@serve.batch(max_batch_size=3, batch_wait_timeout_s=1000)
async def long_timeout(self, requests):
if "raise" in requests:
1 / 0
return requests
cls = LongTimeout()
async def call(arg):
if use_class:
return await cls.long_timeout(arg)
else:
return await long_timeout(arg)
t1 = asyncio.get_event_loop().create_task(call("hi1"))
t2 = asyncio.get_event_loop().create_task(call("hi2"))
done, pending = await asyncio.wait([t1, t2], timeout=0.1)
assert len(done) == 0
t3 = asyncio.get_event_loop().create_task(call("hi3"))
done, pending = await asyncio.wait([t1, t2, t3], timeout=100)
assert set(done) == {t1, t2, t3}
assert [t1.result(), t2.result(), t3.result()] == ["hi1", "hi2", "hi3"]
t1 = asyncio.get_event_loop().create_task(call("hi1"))
t2 = asyncio.get_event_loop().create_task(call("raise"))
done, pending = await asyncio.wait([t1, t2], timeout=0.1)
assert len(done) == 0
t3 = asyncio.get_event_loop().create_task(call("hi3"))
done, pending = await asyncio.wait([t1, t2, t3], timeout=100)
assert set(done) == {t1, t2, t3}
assert all(isinstance(t.exception(), ZeroDivisionError) for t in done)
with pytest.raises(ZeroDivisionError):
t1.result()
with pytest.raises(ZeroDivisionError):
t2.result()
with pytest.raises(ZeroDivisionError):
t3.result()
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))