diff --git a/python/ray/serve/batching.py b/python/ray/serve/batching.py index 4d2797d5b..7d2324cd7 100644 --- a/python/ray/serve/batching.py +++ b/python/ray/serve/batching.py @@ -2,11 +2,46 @@ import asyncio from functools import wraps from inspect import iscoroutinefunction import time -from typing import Any, Callable, List, Optional, overload, Tuple, TypeVar +from typing import Any, Callable, Dict, List, Optional, overload, Tuple, TypeVar +from dataclasses import dataclass + +from ray._private.signature import extract_signature, flatten_args, recover_args from ray.serve.exceptions import RayServeException +@dataclass +class SingleRequest: + self_arg: Optional[Any] + flattened_args: List[Any] + future: asyncio.Future + + +def _batch_args_kwargs( + list_of_flattened_args: List[List[Any]], +) -> Tuple[Tuple[Any], Dict[Any, Any]]: + """Batch a list of flatten args and returns regular args and kwargs""" + # Ray's flatten arg format is a list with alternating key and values + # e.g. args=(1, 2), kwargs={"key": "val"} got turned into + # [None, 1, None, 2, "key", "val"] + arg_lengths = {len(args) for args in list_of_flattened_args} + assert ( + len(arg_lengths) == 1 + ), "All batch requests should have the same number of parameters." + arg_length = arg_lengths.pop() + + batched_flattened_args = [] + for idx in range(arg_length): + if idx % 2 == 0: + batched_flattened_args.append(list_of_flattened_args[0][idx]) + else: + batched_flattened_args.append( + [item[idx] for item in list_of_flattened_args] + ) + + return recover_args(batched_flattened_args) + + class _BatchQueue: def __init__( self, @@ -30,7 +65,7 @@ class _BatchQueue: handle_batch_func(Optional[Callable]): callback to run in the background to handle batches if provided. """ - self.queue = asyncio.Queue() + self.queue: asyncio.Queue[SingleRequest] = asyncio.Queue() self.full_batch_event = asyncio.Event() self.max_batch_size = max_batch_size self.timeout_s = timeout_s @@ -41,7 +76,7 @@ class _BatchQueue: self._handle_batches(handle_batch_func) ) - def put(self, request: Tuple[Any, asyncio.Future]) -> None: + def put(self, request: Tuple[SingleRequest, asyncio.Future]) -> None: self.queue.put_nowait(request) # Signal when the full batch is ready. The event will be reset # in wait_for_batch. @@ -94,19 +129,19 @@ class _BatchQueue: async def _handle_batches(self, func): while True: - batch = await self.wait_for_batch() + batch: List[SingleRequest] = await self.wait_for_batch() assert len(batch) > 0 - self_arg = batch[0][0] - args = [item[1] for item in batch] - futures = [item[2] for item in batch] + self_arg = batch[0].self_arg + args, kwargs = _batch_args_kwargs([item.flattened_args for item in batch]) + futures = [item.future for item in batch] try: # Method call. if self_arg is not None: - results = await func(self_arg, args) + results = await func(self_arg, *args, **kwargs) # Normal function call. else: - results = await func(args) + results = await func(*args, **kwargs) if len(results) != len(batch): raise RayServeException( @@ -150,7 +185,7 @@ def extract_self_if_method_call(args: List[Any], func: Callable) -> Optional[obj if method: wrapped = getattr(method, "__wrapped__", False) if wrapped and wrapped == func: - return args.pop(0) + return args[0] return None @@ -230,16 +265,8 @@ def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.0): def _batch_decorator(_func): @wraps(_func) async def batch_wrapper(*args, **kwargs): - args = list(args) self = extract_self_if_method_call(args, _func) - - if len(args) != 1: - raise ValueError( - "@serve.batch functions can only take a " "single argument as input" - ) - - if len(kwargs) != 0: - raise ValueError("@serve.batch functions do not support kwargs") + flattened_args: List = flatten_args(extract_signature(_func), args, kwargs) if self is None: # For functions, inject the batch queue as an @@ -249,6 +276,8 @@ def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.0): # For methods, inject the batch queue as an # attribute of the object. batch_queue_object = self + # Trim the self argument from methods + flattened_args = flattened_args[2:] # The first time the function runs, we lazily construct the batch # queue and inject it under a custom attribute name. On subsequent @@ -261,7 +290,7 @@ def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.0): batch_queue = getattr(batch_queue_object, batch_queue_attr) future = asyncio.get_event_loop().create_future() - batch_queue.put((self, args[0], future)) + batch_queue.put(SingleRequest(self, flattened_args, future)) # This will raise if the underlying call raised an exception. return await future diff --git a/python/ray/serve/tests/test_batching.py b/python/ray/serve/tests/test_batching.py index 4bf351cb7..20a005b90 100644 --- a/python/ray/serve/tests/test_batching.py +++ b/python/ray/serve/tests/test_batching.py @@ -259,6 +259,39 @@ async def test_batch_size_multiple_long_timeout(use_class): t3.result() +@pytest.mark.asyncio +@pytest.mark.parametrize("mode", ["args", "kwargs", "mixed", "out-of-order"]) +@pytest.mark.parametrize("use_class", [True, False]) +async def test_batch_args_kwargs(mode, use_class): + if use_class: + + class MultipleArgs: + @serve.batch(max_batch_size=2, batch_wait_timeout_s=1000) + async def method(self, key1, key2): + return [(key1[i], key2[i]) for i in range(len(key1))] + + instance = MultipleArgs() + func = instance.method + + else: + + @serve.batch(max_batch_size=2, batch_wait_timeout_s=1000) + async def func(key1, key2): + return [(key1[i], key2[i]) for i in range(len(key1))] + + if mode == "args": + coros = [func("hi1", "hi2"), func("hi3", "hi4")] + elif mode == "kwargs": + coros = [func(key1="hi1", key2="hi2"), func(key1="hi3", key2="hi4")] + elif mode == "mixed": + coros = [func("hi1", key2="hi2"), func("hi3", key2="hi4")] + elif mode == "out-of-order": + coros = [func(key2="hi2", key1="hi1"), func(key2="hi4", key1="hi3")] + + result = await asyncio.gather(*coros) + assert result == [("hi1", "hi2"), ("hi3", "hi4")] + + if __name__ == "__main__": import sys