[Serve] [1/3 Wrappers] Allow @serve.batch to accept args and kwargs (#22913)

This commit is contained in:
Simon Mo 2022-03-09 09:15:57 -08:00 committed by GitHub
parent 15601ed79b
commit 77ead01b65
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 20 deletions

View file

@ -2,11 +2,46 @@ import asyncio
from functools import wraps from functools import wraps
from inspect import iscoroutinefunction from inspect import iscoroutinefunction
import time import time
from typing import Any, Callable, List, Optional, overload, Tuple, TypeVar from typing import Any, Callable, Dict, List, Optional, overload, Tuple, TypeVar
from dataclasses import dataclass
from ray._private.signature import extract_signature, flatten_args, recover_args
from ray.serve.exceptions import RayServeException from ray.serve.exceptions import RayServeException
@dataclass
class SingleRequest:
self_arg: Optional[Any]
flattened_args: List[Any]
future: asyncio.Future
def _batch_args_kwargs(
list_of_flattened_args: List[List[Any]],
) -> Tuple[Tuple[Any], Dict[Any, Any]]:
"""Batch a list of flatten args and returns regular args and kwargs"""
# Ray's flatten arg format is a list with alternating key and values
# e.g. args=(1, 2), kwargs={"key": "val"} got turned into
# [None, 1, None, 2, "key", "val"]
arg_lengths = {len(args) for args in list_of_flattened_args}
assert (
len(arg_lengths) == 1
), "All batch requests should have the same number of parameters."
arg_length = arg_lengths.pop()
batched_flattened_args = []
for idx in range(arg_length):
if idx % 2 == 0:
batched_flattened_args.append(list_of_flattened_args[0][idx])
else:
batched_flattened_args.append(
[item[idx] for item in list_of_flattened_args]
)
return recover_args(batched_flattened_args)
class _BatchQueue: class _BatchQueue:
def __init__( def __init__(
self, self,
@ -30,7 +65,7 @@ class _BatchQueue:
handle_batch_func(Optional[Callable]): callback to run in the handle_batch_func(Optional[Callable]): callback to run in the
background to handle batches if provided. background to handle batches if provided.
""" """
self.queue = asyncio.Queue() self.queue: asyncio.Queue[SingleRequest] = asyncio.Queue()
self.full_batch_event = asyncio.Event() self.full_batch_event = asyncio.Event()
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.timeout_s = timeout_s self.timeout_s = timeout_s
@ -41,7 +76,7 @@ class _BatchQueue:
self._handle_batches(handle_batch_func) self._handle_batches(handle_batch_func)
) )
def put(self, request: Tuple[Any, asyncio.Future]) -> None: def put(self, request: Tuple[SingleRequest, asyncio.Future]) -> None:
self.queue.put_nowait(request) self.queue.put_nowait(request)
# Signal when the full batch is ready. The event will be reset # Signal when the full batch is ready. The event will be reset
# in wait_for_batch. # in wait_for_batch.
@ -94,19 +129,19 @@ class _BatchQueue:
async def _handle_batches(self, func): async def _handle_batches(self, func):
while True: while True:
batch = await self.wait_for_batch() batch: List[SingleRequest] = await self.wait_for_batch()
assert len(batch) > 0 assert len(batch) > 0
self_arg = batch[0][0] self_arg = batch[0].self_arg
args = [item[1] for item in batch] args, kwargs = _batch_args_kwargs([item.flattened_args for item in batch])
futures = [item[2] for item in batch] futures = [item.future for item in batch]
try: try:
# Method call. # Method call.
if self_arg is not None: if self_arg is not None:
results = await func(self_arg, args) results = await func(self_arg, *args, **kwargs)
# Normal function call. # Normal function call.
else: else:
results = await func(args) results = await func(*args, **kwargs)
if len(results) != len(batch): if len(results) != len(batch):
raise RayServeException( raise RayServeException(
@ -150,7 +185,7 @@ def extract_self_if_method_call(args: List[Any], func: Callable) -> Optional[obj
if method: if method:
wrapped = getattr(method, "__wrapped__", False) wrapped = getattr(method, "__wrapped__", False)
if wrapped and wrapped == func: if wrapped and wrapped == func:
return args.pop(0) return args[0]
return None return None
@ -230,16 +265,8 @@ def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.0):
def _batch_decorator(_func): def _batch_decorator(_func):
@wraps(_func) @wraps(_func)
async def batch_wrapper(*args, **kwargs): async def batch_wrapper(*args, **kwargs):
args = list(args)
self = extract_self_if_method_call(args, _func) self = extract_self_if_method_call(args, _func)
flattened_args: List = flatten_args(extract_signature(_func), args, kwargs)
if len(args) != 1:
raise ValueError(
"@serve.batch functions can only take a " "single argument as input"
)
if len(kwargs) != 0:
raise ValueError("@serve.batch functions do not support kwargs")
if self is None: if self is None:
# For functions, inject the batch queue as an # For functions, inject the batch queue as an
@ -249,6 +276,8 @@ def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.0):
# For methods, inject the batch queue as an # For methods, inject the batch queue as an
# attribute of the object. # attribute of the object.
batch_queue_object = self batch_queue_object = self
# Trim the self argument from methods
flattened_args = flattened_args[2:]
# The first time the function runs, we lazily construct the batch # The first time the function runs, we lazily construct the batch
# queue and inject it under a custom attribute name. On subsequent # queue and inject it under a custom attribute name. On subsequent
@ -261,7 +290,7 @@ def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.0):
batch_queue = getattr(batch_queue_object, batch_queue_attr) batch_queue = getattr(batch_queue_object, batch_queue_attr)
future = asyncio.get_event_loop().create_future() future = asyncio.get_event_loop().create_future()
batch_queue.put((self, args[0], future)) batch_queue.put(SingleRequest(self, flattened_args, future))
# This will raise if the underlying call raised an exception. # This will raise if the underlying call raised an exception.
return await future return await future

View file

@ -259,6 +259,39 @@ async def test_batch_size_multiple_long_timeout(use_class):
t3.result() t3.result()
@pytest.mark.asyncio
@pytest.mark.parametrize("mode", ["args", "kwargs", "mixed", "out-of-order"])
@pytest.mark.parametrize("use_class", [True, False])
async def test_batch_args_kwargs(mode, use_class):
if use_class:
class MultipleArgs:
@serve.batch(max_batch_size=2, batch_wait_timeout_s=1000)
async def method(self, key1, key2):
return [(key1[i], key2[i]) for i in range(len(key1))]
instance = MultipleArgs()
func = instance.method
else:
@serve.batch(max_batch_size=2, batch_wait_timeout_s=1000)
async def func(key1, key2):
return [(key1[i], key2[i]) for i in range(len(key1))]
if mode == "args":
coros = [func("hi1", "hi2"), func("hi3", "hi4")]
elif mode == "kwargs":
coros = [func(key1="hi1", key2="hi2"), func(key1="hi3", key2="hi4")]
elif mode == "mixed":
coros = [func("hi1", key2="hi2"), func("hi3", key2="hi4")]
elif mode == "out-of-order":
coros = [func(key2="hi2", key1="hi1"), func(key2="hi4", key1="hi3")]
result = await asyncio.gather(*coros)
assert result == [("hi1", "hi2"), ("hi3", "hi4")]
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys