[Serve] [1/3 Wrappers] Allow @serve.batch to accept args and kwargs (#22913)

This commit is contained in:
Simon Mo 2022-03-09 09:15:57 -08:00 committed by GitHub
parent 15601ed79b
commit 77ead01b65
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 20 deletions

View file

@ -2,11 +2,46 @@ import asyncio
from functools import wraps
from inspect import iscoroutinefunction
import time
from typing import Any, Callable, List, Optional, overload, Tuple, TypeVar
from typing import Any, Callable, Dict, List, Optional, overload, Tuple, TypeVar
from dataclasses import dataclass
from ray._private.signature import extract_signature, flatten_args, recover_args
from ray.serve.exceptions import RayServeException
@dataclass
class SingleRequest:
self_arg: Optional[Any]
flattened_args: List[Any]
future: asyncio.Future
def _batch_args_kwargs(
list_of_flattened_args: List[List[Any]],
) -> Tuple[Tuple[Any], Dict[Any, Any]]:
"""Batch a list of flatten args and returns regular args and kwargs"""
# Ray's flatten arg format is a list with alternating key and values
# e.g. args=(1, 2), kwargs={"key": "val"} got turned into
# [None, 1, None, 2, "key", "val"]
arg_lengths = {len(args) for args in list_of_flattened_args}
assert (
len(arg_lengths) == 1
), "All batch requests should have the same number of parameters."
arg_length = arg_lengths.pop()
batched_flattened_args = []
for idx in range(arg_length):
if idx % 2 == 0:
batched_flattened_args.append(list_of_flattened_args[0][idx])
else:
batched_flattened_args.append(
[item[idx] for item in list_of_flattened_args]
)
return recover_args(batched_flattened_args)
class _BatchQueue:
def __init__(
self,
@ -30,7 +65,7 @@ class _BatchQueue:
handle_batch_func(Optional[Callable]): callback to run in the
background to handle batches if provided.
"""
self.queue = asyncio.Queue()
self.queue: asyncio.Queue[SingleRequest] = asyncio.Queue()
self.full_batch_event = asyncio.Event()
self.max_batch_size = max_batch_size
self.timeout_s = timeout_s
@ -41,7 +76,7 @@ class _BatchQueue:
self._handle_batches(handle_batch_func)
)
def put(self, request: Tuple[Any, asyncio.Future]) -> None:
def put(self, request: Tuple[SingleRequest, asyncio.Future]) -> None:
self.queue.put_nowait(request)
# Signal when the full batch is ready. The event will be reset
# in wait_for_batch.
@ -94,19 +129,19 @@ class _BatchQueue:
async def _handle_batches(self, func):
while True:
batch = await self.wait_for_batch()
batch: List[SingleRequest] = await self.wait_for_batch()
assert len(batch) > 0
self_arg = batch[0][0]
args = [item[1] for item in batch]
futures = [item[2] for item in batch]
self_arg = batch[0].self_arg
args, kwargs = _batch_args_kwargs([item.flattened_args for item in batch])
futures = [item.future for item in batch]
try:
# Method call.
if self_arg is not None:
results = await func(self_arg, args)
results = await func(self_arg, *args, **kwargs)
# Normal function call.
else:
results = await func(args)
results = await func(*args, **kwargs)
if len(results) != len(batch):
raise RayServeException(
@ -150,7 +185,7 @@ def extract_self_if_method_call(args: List[Any], func: Callable) -> Optional[obj
if method:
wrapped = getattr(method, "__wrapped__", False)
if wrapped and wrapped == func:
return args.pop(0)
return args[0]
return None
@ -230,16 +265,8 @@ def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.0):
def _batch_decorator(_func):
@wraps(_func)
async def batch_wrapper(*args, **kwargs):
args = list(args)
self = extract_self_if_method_call(args, _func)
if len(args) != 1:
raise ValueError(
"@serve.batch functions can only take a " "single argument as input"
)
if len(kwargs) != 0:
raise ValueError("@serve.batch functions do not support kwargs")
flattened_args: List = flatten_args(extract_signature(_func), args, kwargs)
if self is None:
# For functions, inject the batch queue as an
@ -249,6 +276,8 @@ def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.0):
# For methods, inject the batch queue as an
# attribute of the object.
batch_queue_object = self
# Trim the self argument from methods
flattened_args = flattened_args[2:]
# The first time the function runs, we lazily construct the batch
# queue and inject it under a custom attribute name. On subsequent
@ -261,7 +290,7 @@ def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.0):
batch_queue = getattr(batch_queue_object, batch_queue_attr)
future = asyncio.get_event_loop().create_future()
batch_queue.put((self, args[0], future))
batch_queue.put(SingleRequest(self, flattened_args, future))
# This will raise if the underlying call raised an exception.
return await future

View file

@ -259,6 +259,39 @@ async def test_batch_size_multiple_long_timeout(use_class):
t3.result()
@pytest.mark.asyncio
@pytest.mark.parametrize("mode", ["args", "kwargs", "mixed", "out-of-order"])
@pytest.mark.parametrize("use_class", [True, False])
async def test_batch_args_kwargs(mode, use_class):
if use_class:
class MultipleArgs:
@serve.batch(max_batch_size=2, batch_wait_timeout_s=1000)
async def method(self, key1, key2):
return [(key1[i], key2[i]) for i in range(len(key1))]
instance = MultipleArgs()
func = instance.method
else:
@serve.batch(max_batch_size=2, batch_wait_timeout_s=1000)
async def func(key1, key2):
return [(key1[i], key2[i]) for i in range(len(key1))]
if mode == "args":
coros = [func("hi1", "hi2"), func("hi3", "hi4")]
elif mode == "kwargs":
coros = [func(key1="hi1", key2="hi2"), func(key1="hi3", key2="hi4")]
elif mode == "mixed":
coros = [func("hi1", key2="hi2"), func("hi3", key2="hi4")]
elif mode == "out-of-order":
coros = [func(key2="hi2", key1="hi1"), func(key2="hi4", key1="hi3")]
result = await asyncio.gather(*coros)
assert result == [("hi1", "hi2"), ("hi3", "hi4")]
if __name__ == "__main__":
import sys