From 9cf328d616325f2c4fb84619349f0726947046f7 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 11 Mar 2021 21:16:08 -0600 Subject: [PATCH] [serve] Application-level batching initial commit (#14610) --- python/ray/serve/__init__.py | 3 +- python/ray/serve/backend_worker.py | 70 +----- python/ray/serve/batching.py | 282 ++++++++++++++++++++++++ python/ray/serve/tests/test_batching.py | 261 ++++++++++++++++++++++ 4 files changed, 548 insertions(+), 68 deletions(-) create mode 100644 python/ray/serve/batching.py diff --git a/python/ray/serve/__init__.py b/python/ray/serve/__init__.py index 20586bbcc..fada301a1 100644 --- a/python/ray/serve/__init__.py +++ b/python/ray/serve/__init__.py @@ -1,5 +1,6 @@ from ray.serve.api import (accept_batch, Client, connect, start, get_replica_context) +from ray.serve.batching import batch from ray.serve.config import BackendConfig, HTTPOptions # Mute the warning because Serve sometimes intentionally calls @@ -8,6 +9,6 @@ import ray.worker ray.worker.blocking_get_inside_async_warned = True __all__ = [ - "accept_batch", "BackendConfig", "connect", "Client", "start", + "accept_batch", "BackendConfig", "batch", "connect", "Client", "start", "HTTPOptions", "get_replica_context" ] diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index b82f8017b..d5c23d65c 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -13,6 +13,7 @@ import ray from ray.actor import ActorHandle from ray._private.async_compat import sync_to_async +from ray.serve.batching import _BatchQueue from ray.serve.utils import (parse_request_item, _get_logger, chain_future, unpack_future, import_attr) from ray.serve.exceptions import RayServeException @@ -30,71 +31,6 @@ from ray.exceptions import RayTaskError logger = _get_logger() -class BatchQueue: - def __init__(self, max_batch_size: int, timeout_s: float) -> None: - self.queue = asyncio.Queue() - self.full_batch_event = asyncio.Event() - self.max_batch_size = max_batch_size - self.timeout_s = timeout_s - - def set_config(self, max_batch_size: int, timeout_s: float) -> None: - self.max_batch_size = max_batch_size - self.timeout_s = timeout_s - - def put(self, request: Query) -> None: - self.queue.put_nowait(request) - # Signal when the full batch is ready. The event will be reset - # in wait_for_batch. - if self.queue.qsize() == self.max_batch_size: - self.full_batch_event.set() - - def qsize(self) -> int: - return self.queue.qsize() - - async def wait_for_batch(self) -> List[Query]: - """Wait for batch respecting self.max_batch_size and self.timeout_s. - - Returns a batch of up to self.max_batch_size items, waiting for up - to self.timeout_s for a full batch. After the timeout, returns as many - items as are ready. - - Always returns a batch with at least one item - will block - indefinitely until an item comes in. - """ - curr_timeout = self.timeout_s - batch = [] - while len(batch) == 0: - loop_start = time.time() - - # If the timeout is 0, wait for any item to be available on the - # queue. - if curr_timeout == 0: - batch.append(await self.queue.get()) - # If the timeout is nonzero, wait for either the timeout to occur - # or the max batch size to be ready. - else: - try: - await asyncio.wait_for(self.full_batch_event.wait(), - curr_timeout) - except asyncio.TimeoutError: - pass - - # Pull up to the max_batch_size requests off the queue. - while len(batch) < self.max_batch_size and not self.queue.empty(): - batch.append(self.queue.get_nowait()) - - # Reset the event if there are fewer than max_batch_size requests - # in the queue. - if (self.queue.qsize() < self.max_batch_size - and self.full_batch_event.is_set()): - self.full_batch_event.clear() - - # Adjust the timeout based on the time spent in this iteration. - curr_timeout = max(0, curr_timeout - (time.time() - loop_start)) - - return batch - - def create_backend_replica(backend_def: Union[Callable, Type[Callable], str]): """Creates a replica class wrapping the provided function or class. @@ -185,8 +121,8 @@ class RayServeReplica: self.is_function = is_function self.config = backend_config - self.batch_queue = BatchQueue(self.config.max_batch_size or 1, - self.config.batch_wait_timeout) + self.batch_queue = _BatchQueue(self.config.max_batch_size or 1, + self.config.batch_wait_timeout) self.reconfigure(self.config.user_config) self.num_ongoing_requests = 0 diff --git a/python/ray/serve/batching.py b/python/ray/serve/batching.py new file mode 100644 index 000000000..bae5c7f6a --- /dev/null +++ b/python/ray/serve/batching.py @@ -0,0 +1,282 @@ +import asyncio +from functools import wraps +from inspect import iscoroutinefunction +import time +from typing import Any, Callable, List, Optional, overload, Tuple, TypeVar + +from ray.serve.exceptions import RayServeException + + +class _BatchQueue: + def __init__(self, + max_batch_size: int, + timeout_s: float, + handle_batch_func: Optional[Callable] = None) -> None: + """Async queue that accepts individual items and returns batches. + + Respects max_batch_size and timeout_s; a batch will be returned when + max_batch_size elements are available or the timeout has passed since + the previous get. + + If handle_batch_func is passed in, a background coroutine will run to + poll from the queue and call handle_batch_func on the results. + + Arguments: + max_batch_size (int): max number of elements to return in a batch. + timeout_s (float): time to wait before returning an incomplete + batch. + handle_batch_func(Optional[Callable]): callback to run in the + background to handle batches if provided. + """ + self.queue = asyncio.Queue() + self.full_batch_event = asyncio.Event() + self.max_batch_size = max_batch_size + self.timeout_s = timeout_s + + self._handle_batch_task = None + if handle_batch_func is not None: + self._handle_batch_task = asyncio.get_event_loop().create_task( + self._handle_batches(handle_batch_func)) + + def set_config(self, max_batch_size: int, timeout_s: float) -> None: + self.max_batch_size = max_batch_size + self.timeout_s = timeout_s + + def put(self, request: Tuple[Any, asyncio.Future]) -> None: + self.queue.put_nowait(request) + # Signal when the full batch is ready. The event will be reset + # in wait_for_batch. + if self.queue.qsize() == self.max_batch_size: + self.full_batch_event.set() + + def qsize(self) -> int: + return self.queue.qsize() + + async def wait_for_batch(self) -> List[Any]: + """Wait for batch respecting self.max_batch_size and self.timeout_s. + + Returns a batch of up to self.max_batch_size items, waiting for up + to self.timeout_s for a full batch. After the timeout, returns as many + items as are ready. + + Always returns a batch with at least one item - will block + indefinitely until an item comes in. + """ + curr_timeout = self.timeout_s + batch = [] + while len(batch) == 0: + loop_start = time.time() + + # If the timeout is 0, wait for any item to be available on the + # queue. + if curr_timeout == 0: + batch.append(await self.queue.get()) + # If the timeout is nonzero, wait for either the timeout to occur + # or the max batch size to be ready. + else: + try: + await asyncio.wait_for(self.full_batch_event.wait(), + curr_timeout) + except asyncio.TimeoutError: + pass + + # Pull up to the max_batch_size requests off the queue. + while len(batch) < self.max_batch_size and not self.queue.empty(): + batch.append(self.queue.get_nowait()) + + # Reset the event if there are fewer than max_batch_size requests + # in the queue. + if (self.queue.qsize() < self.max_batch_size + and self.full_batch_event.is_set()): + self.full_batch_event.clear() + + # Adjust the timeout based on the time spent in this iteration. + curr_timeout = max(0, curr_timeout - (time.time() - loop_start)) + + return batch + + async def _handle_batches(self, func): + while True: + batch = await self.wait_for_batch() + assert len(batch) > 0 + self_arg = batch[0][0] + args = [item[1] for item in batch] + futures = [item[2] for item in batch] + + try: + # Method call. + if self_arg is not None: + results = await func(self_arg, args) + # Normal function call. + else: + results = await func(args) + + if len(results) != len(batch): + raise RayServeException( + "Batched function doesn't preserve batch size. " + f"The input list has length {len(batch)} but the " + f"returned list has length {len(results)}.") + + for i, result in enumerate(results): + futures[i].set_result(result) + except Exception as e: + for future in futures: + future.set_exception(e) + + def __del__(self): + if (self._handle_batch_task is None + or not asyncio.get_event_loop().is_running()): + return + + # TODO(edoakes): although we try to gracefully shutdown here, it still + # causes some errors when the process exits due to the asyncio loop + # already being destroyed. + self._handle_batch_task.cancel() + + +def extract_self_if_method_call(args: List[Any], + func: Callable) -> Optional[object]: + """Check if this is a method rather than a function. + + Does this by checking to see if `func` is the attribute of the first + (`self`) argument under `func.__name__`. Unfortunately, this is the most + robust solution to this I was able to find. It would also be preferable + to do this check when the decorator runs, rather than when the method is. + + Returns the `self` object if it's a method call, else None. + + Arguments: + args (List[Any]): arguments to the function/method call. + func (Callable): the unbound function that was called. + """ + if len(args) > 0: + method = getattr(args[0], func.__name__, False) + if method: + wrapped = getattr(method, "__wrapped__", False) + if wrapped and wrapped == func: + return args.pop(0) + + return None + + +T = TypeVar("T") +R = TypeVar("R") +F = TypeVar("F", bound=Callable[[List[T]], List[R]]) +G = TypeVar("G", bound=Callable[[T], R]) + + +# Normal decorator use case (called with no arguments). +@overload +def batch(func: F) -> G: + pass + + +# "Decorator factory" use case (called with arguments). +@overload +def batch(max_batch_size: int = 10, + batch_wait_timeout_s: float = 0.1) -> Callable[[F], G]: + pass + + +def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.1): + """Converts a function to asynchronously handle batches. + + The function can be a standalone function or a class method, and must + take a list of objects as its sole argument and return a list of the + same length. + + When invoked, the caller passes a single object. These will be batched + and executed asynchronously once there is a batch of `max_batch_size` + or `batch_wait_timeout_s` has elapsed, whichever occurs first. + + Example: + @serve.batch(max_batch_size=50, batch_wait_timeout_s=0.5) + async def handle_batch(self, batch: List[str]): + return [s.lower() for s in batch] + + async def handle_single(self, s: str): + # Will return s.lower(). + return await handle_batch(s) + + Arguments: + max_batch_size (int): the maximum batch size that will be executed in + one call to the underlying function. + batch_wait_timeout_s (float): the maximum duration to wait for + `max_batch_size` elements before running the underlying function. + """ + # `_func` will be None in the case when the decorator is parametrized. + # See the comment at the end of this function for a detailed explanation. + if _func is not None: + if not callable(_func): + raise TypeError("@serve.batch can only be used to " + "decorate functions or methods.") + + if not iscoroutinefunction(_func): + raise TypeError( + "Functions decorated with @serve.batch must be 'async def'") + + if not isinstance(max_batch_size, int): + if isinstance(max_batch_size, float) and max_batch_size.is_integer(): + max_batch_size = int(max_batch_size) + else: + raise TypeError("max_batch_size must be integer >= 1") + + if max_batch_size < 1: + raise ValueError("max_batch_size must be an integer >= 1") + + if not isinstance(batch_wait_timeout_s, (float, int)): + raise TypeError("batch_wait_timeout_s must be a float >= 0") + + if batch_wait_timeout_s < 0: + raise ValueError("batch_wait_timeout_s must be a float >= 0") + + def _batch_decorator(_func): + @wraps(_func) + async def batch_wrapper(*args, **kwargs): + args = list(args) + self = extract_self_if_method_call(args, _func) + + if len(args) != 1: + raise ValueError("@serve.batch functions can only take a " + "single argument as input") + + if len(kwargs) != 0: + raise ValueError( + "@serve.batch functions do not support kwargs") + + if self is None: + # For functions, inject the batch queue as an + # attribute of the function. + batch_queue_object = _func + else: + # For methods, inject the batch queue as an + # attribute of the object. + batch_queue_object = self + + # The first time the function runs, we lazily construct the batch + # queue and inject it under a custom attribute name. On subsequent + # runs, we just get a reference to the attribute. + batch_queue_attr = f"__serve_batch_queue_{_func.__name__}" + if not hasattr(batch_queue_object, batch_queue_attr): + batch_queue = _BatchQueue(max_batch_size, batch_wait_timeout_s, + _func) + setattr(batch_queue_object, batch_queue_attr, batch_queue) + else: + batch_queue = getattr(batch_queue_object, batch_queue_attr) + + future = asyncio.get_event_loop().create_future() + batch_queue.put((self, args[0], future)) + + # This will raise if the underlying call raised an exception. + return await future + + return batch_wrapper + + # Unfortunately, this is required to handle both non-parametrized + # (@serve.batch) and parametrized (@serve.batch(**kwargs)) usage. + # In the former case, `serve.batch` will be called with the underlying + # function as the sole argument. In the latter case, it will first be + # called with **kwargs, then the result of that call will be called + # with the underlying function as the sole argument (i.e., it must be a + # "decorator factory."). + return _batch_decorator(_func) if callable(_func) else _batch_decorator diff --git a/python/ray/serve/tests/test_batching.py b/python/ray/serve/tests/test_batching.py index d4c646a5c..9d5c4d38d 100644 --- a/python/ray/serve/tests/test_batching.py +++ b/python/ray/serve/tests/test_batching.py @@ -1,3 +1,5 @@ +import asyncio + import pytest import ray @@ -58,6 +60,265 @@ def test_batching_exception(serve_instance): assert ray.get(handle.remote(temp=1)) +def test_app_level_batching(serve_instance): + client = serve_instance + + class BatchingExample: + def __init__(self): + self.count = 0 + + @serve.batch(max_batch_size=5, batch_wait_timeout_s=1) + async def handle_batch(self, requests): + self.count += 1 + batch_size = len(requests) + return [self.count] * batch_size + + async def __call__(self, request): + return await self.handle_batch(request) + + # set the max batch size + client.create_backend("counter:v11", BatchingExample) + client.create_endpoint( + "counter1", backend="counter:v11", route="/increment2") + + future_list = [] + handle = client.get_handle("counter1") + for _ in range(20): + f = handle.remote(temp=1) + future_list.append(f) + + counter_result = ray.get(future_list) + # since count is only updated per batch of queries + # If there atleast one __call__ fn call with batch size greater than 1 + # counter result will always be less than 20 + assert max(counter_result) < 20 + + +def test_app_level_batching_exception(serve_instance): + client = serve_instance + + class NoListReturned: + def __init__(self): + self.count = 0 + + @serve.batch(max_batch_size=5) + async def handle_batch(self, requests): + return len(requests) + + async def __call__(self, request): + return await self.handle_batch(request) + + # Set the max batch size. + client.create_backend("exception:v1", NoListReturned) + client.create_endpoint("exception-test", backend="exception:v1") + + handle = client.get_handle("exception-test") + with pytest.raises(ray.exceptions.RayTaskError): + assert ray.get(handle.remote(temp=1)) + + +@pytest.mark.asyncio +async def test_decorator_validation(): + @serve.batch + async def function(): + pass + + @serve.batch(max_batch_size=10, batch_wait_timeout_s=1.5) + async def function2(): + pass + + class Class(): + @serve.batch + async def method(self): + pass + + class Class2(): + @serve.batch(max_batch_size=10, batch_wait_timeout_s=1.5) + async def method(self): + pass + + with pytest.raises(TypeError, match="async def"): + + @serve.batch + def non_async_function(): + pass + + with pytest.raises(TypeError, match="async def"): + + class NotAsync: + @serve.batch + def method(self, requests): + pass + + with pytest.raises(ValueError): + + class ZeroBatch: + @serve.batch(max_batch_size=0) + async def method(self, requests): + pass + + with pytest.raises(TypeError): + + class FloatNonIntBatch: + @serve.batch(max_batch_size=1.1) + async def method(self, requests): + pass + + class FloatIntegerBatch: + @serve.batch(max_batch_size=1.0) + async def method(self, requests): + pass + + with pytest.raises(ValueError): + + class NegativeTimeout: + @serve.batch(batch_wait_timeout_s=-0.1) + async def method(self, requests): + pass + + class FloatZeroTimeout: + @serve.batch(batch_wait_timeout_s=0.0) + async def method(self, requests): + pass + + class IntZeroTimeout: + @serve.batch(batch_wait_timeout_s=0) + async def method(self, requests): + pass + + with pytest.raises(TypeError): + + class NonTimeout: + @serve.batch(batch_wait_timeout_s="a") + async def method(self, requests): + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_class", [True, False]) +async def test_batch_size_one_long_timeout(use_class): + @serve.batch(max_batch_size=1, batch_wait_timeout_s=1000) + async def long_timeout(requests): + if "raise" in requests: + 1 / 0 + return requests + + class LongTimeout: + @serve.batch(max_batch_size=1, batch_wait_timeout_s=1000) + async def long_timeout(self, requests): + if "raise" in requests: + 1 / 0 + return requests + + cls = LongTimeout() + + async def call(arg): + if use_class: + return await cls.long_timeout(arg) + else: + return await long_timeout(arg) + + assert await call("hi") == "hi" + with pytest.raises(ZeroDivisionError): + await call("raise") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_class", [True, False]) +async def test_batch_size_multiple_zero_timeout(use_class): + @serve.batch(max_batch_size=2, batch_wait_timeout_s=0) + async def zero_timeout(requests): + await asyncio.sleep(1) + if "raise" in requests: + 1 / 0 + return requests + + class ZeroTimeout: + @serve.batch(max_batch_size=2, batch_wait_timeout_s=0) + async def zero_timeout(self, requests): + await asyncio.sleep(1) + if "raise" in requests: + 1 / 0 + return requests + + cls = ZeroTimeout() + + async def call(arg): + if use_class: + return await cls.zero_timeout(arg) + else: + return await zero_timeout(arg) + + assert await call("hi") == "hi" + with pytest.raises(ZeroDivisionError): + await call("raise") + + # Check that 2 requests will be executed together if available. + # The first should cause a size-one batch to be executed, then + # the next two should be executed together (signaled by both + # having the exception). + t1 = asyncio.get_event_loop().create_task(call("hi1")) + await asyncio.sleep(0.5) + t2 = asyncio.get_event_loop().create_task(call("hi2")) + t3 = asyncio.get_event_loop().create_task(call("raise")) + + assert await t1 == "hi1" + + with pytest.raises(ZeroDivisionError): + await t2 + with pytest.raises(ZeroDivisionError): + await t3 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_class", [True, False]) +async def test_batch_size_multiple_long_timeout(use_class): + @serve.batch(max_batch_size=3, batch_wait_timeout_s=1000) + async def long_timeout(requests): + if "raise" in requests: + 1 / 0 + return requests + + class LongTimeout: + @serve.batch(max_batch_size=3, batch_wait_timeout_s=1000) + async def long_timeout(self, requests): + if "raise" in requests: + 1 / 0 + return requests + + cls = LongTimeout() + + async def call(arg): + if use_class: + return await cls.long_timeout(arg) + else: + return await long_timeout(arg) + + t1 = asyncio.get_event_loop().create_task(call("hi1")) + t2 = asyncio.get_event_loop().create_task(call("hi2")) + done, pending = await asyncio.wait([t1, t2], timeout=0.1) + assert len(done) == 0 + t3 = asyncio.get_event_loop().create_task(call("hi3")) + done, pending = await asyncio.wait([t1, t2, t3], timeout=100) + assert set(done) == {t1, t2, t3} + assert [t1.result(), t2.result(), t3.result()] == ["hi1", "hi2", "hi3"] + + t1 = asyncio.get_event_loop().create_task(call("hi1")) + t2 = asyncio.get_event_loop().create_task(call("raise")) + done, pending = await asyncio.wait([t1, t2], timeout=0.1) + assert len(done) == 0 + t3 = asyncio.get_event_loop().create_task(call("hi3")) + done, pending = await asyncio.wait([t1, t2, t3], timeout=100) + assert set(done) == {t1, t2, t3} + assert all(isinstance(t.exception(), ZeroDivisionError) for t in done) + with pytest.raises(ZeroDivisionError): + t1.result() + with pytest.raises(ZeroDivisionError): + t2.result() + with pytest.raises(ZeroDivisionError): + t3.result() + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", "-s", __file__]))