mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Serve] [1/3 Wrappers] Allow @serve.batch
to accept args and kwargs (#22913)
This commit is contained in:
parent
15601ed79b
commit
77ead01b65
2 changed files with 82 additions and 20 deletions
|
@ -2,11 +2,46 @@ import asyncio
|
|||
from functools import wraps
|
||||
from inspect import iscoroutinefunction
|
||||
import time
|
||||
from typing import Any, Callable, List, Optional, overload, Tuple, TypeVar
|
||||
from typing import Any, Callable, Dict, List, Optional, overload, Tuple, TypeVar
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
from ray._private.signature import extract_signature, flatten_args, recover_args
|
||||
from ray.serve.exceptions import RayServeException
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingleRequest:
|
||||
self_arg: Optional[Any]
|
||||
flattened_args: List[Any]
|
||||
future: asyncio.Future
|
||||
|
||||
|
||||
def _batch_args_kwargs(
|
||||
list_of_flattened_args: List[List[Any]],
|
||||
) -> Tuple[Tuple[Any], Dict[Any, Any]]:
|
||||
"""Batch a list of flatten args and returns regular args and kwargs"""
|
||||
# Ray's flatten arg format is a list with alternating key and values
|
||||
# e.g. args=(1, 2), kwargs={"key": "val"} got turned into
|
||||
# [None, 1, None, 2, "key", "val"]
|
||||
arg_lengths = {len(args) for args in list_of_flattened_args}
|
||||
assert (
|
||||
len(arg_lengths) == 1
|
||||
), "All batch requests should have the same number of parameters."
|
||||
arg_length = arg_lengths.pop()
|
||||
|
||||
batched_flattened_args = []
|
||||
for idx in range(arg_length):
|
||||
if idx % 2 == 0:
|
||||
batched_flattened_args.append(list_of_flattened_args[0][idx])
|
||||
else:
|
||||
batched_flattened_args.append(
|
||||
[item[idx] for item in list_of_flattened_args]
|
||||
)
|
||||
|
||||
return recover_args(batched_flattened_args)
|
||||
|
||||
|
||||
class _BatchQueue:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -30,7 +65,7 @@ class _BatchQueue:
|
|||
handle_batch_func(Optional[Callable]): callback to run in the
|
||||
background to handle batches if provided.
|
||||
"""
|
||||
self.queue = asyncio.Queue()
|
||||
self.queue: asyncio.Queue[SingleRequest] = asyncio.Queue()
|
||||
self.full_batch_event = asyncio.Event()
|
||||
self.max_batch_size = max_batch_size
|
||||
self.timeout_s = timeout_s
|
||||
|
@ -41,7 +76,7 @@ class _BatchQueue:
|
|||
self._handle_batches(handle_batch_func)
|
||||
)
|
||||
|
||||
def put(self, request: Tuple[Any, asyncio.Future]) -> None:
|
||||
def put(self, request: Tuple[SingleRequest, asyncio.Future]) -> None:
|
||||
self.queue.put_nowait(request)
|
||||
# Signal when the full batch is ready. The event will be reset
|
||||
# in wait_for_batch.
|
||||
|
@ -94,19 +129,19 @@ class _BatchQueue:
|
|||
|
||||
async def _handle_batches(self, func):
|
||||
while True:
|
||||
batch = await self.wait_for_batch()
|
||||
batch: List[SingleRequest] = await self.wait_for_batch()
|
||||
assert len(batch) > 0
|
||||
self_arg = batch[0][0]
|
||||
args = [item[1] for item in batch]
|
||||
futures = [item[2] for item in batch]
|
||||
self_arg = batch[0].self_arg
|
||||
args, kwargs = _batch_args_kwargs([item.flattened_args for item in batch])
|
||||
futures = [item.future for item in batch]
|
||||
|
||||
try:
|
||||
# Method call.
|
||||
if self_arg is not None:
|
||||
results = await func(self_arg, args)
|
||||
results = await func(self_arg, *args, **kwargs)
|
||||
# Normal function call.
|
||||
else:
|
||||
results = await func(args)
|
||||
results = await func(*args, **kwargs)
|
||||
|
||||
if len(results) != len(batch):
|
||||
raise RayServeException(
|
||||
|
@ -150,7 +185,7 @@ def extract_self_if_method_call(args: List[Any], func: Callable) -> Optional[obj
|
|||
if method:
|
||||
wrapped = getattr(method, "__wrapped__", False)
|
||||
if wrapped and wrapped == func:
|
||||
return args.pop(0)
|
||||
return args[0]
|
||||
|
||||
return None
|
||||
|
||||
|
@ -230,16 +265,8 @@ def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.0):
|
|||
def _batch_decorator(_func):
|
||||
@wraps(_func)
|
||||
async def batch_wrapper(*args, **kwargs):
|
||||
args = list(args)
|
||||
self = extract_self_if_method_call(args, _func)
|
||||
|
||||
if len(args) != 1:
|
||||
raise ValueError(
|
||||
"@serve.batch functions can only take a " "single argument as input"
|
||||
)
|
||||
|
||||
if len(kwargs) != 0:
|
||||
raise ValueError("@serve.batch functions do not support kwargs")
|
||||
flattened_args: List = flatten_args(extract_signature(_func), args, kwargs)
|
||||
|
||||
if self is None:
|
||||
# For functions, inject the batch queue as an
|
||||
|
@ -249,6 +276,8 @@ def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.0):
|
|||
# For methods, inject the batch queue as an
|
||||
# attribute of the object.
|
||||
batch_queue_object = self
|
||||
# Trim the self argument from methods
|
||||
flattened_args = flattened_args[2:]
|
||||
|
||||
# The first time the function runs, we lazily construct the batch
|
||||
# queue and inject it under a custom attribute name. On subsequent
|
||||
|
@ -261,7 +290,7 @@ def batch(_func=None, max_batch_size=10, batch_wait_timeout_s=0.0):
|
|||
batch_queue = getattr(batch_queue_object, batch_queue_attr)
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
batch_queue.put((self, args[0], future))
|
||||
batch_queue.put(SingleRequest(self, flattened_args, future))
|
||||
|
||||
# This will raise if the underlying call raised an exception.
|
||||
return await future
|
||||
|
|
|
@ -259,6 +259,39 @@ async def test_batch_size_multiple_long_timeout(use_class):
|
|||
t3.result()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("mode", ["args", "kwargs", "mixed", "out-of-order"])
|
||||
@pytest.mark.parametrize("use_class", [True, False])
|
||||
async def test_batch_args_kwargs(mode, use_class):
|
||||
if use_class:
|
||||
|
||||
class MultipleArgs:
|
||||
@serve.batch(max_batch_size=2, batch_wait_timeout_s=1000)
|
||||
async def method(self, key1, key2):
|
||||
return [(key1[i], key2[i]) for i in range(len(key1))]
|
||||
|
||||
instance = MultipleArgs()
|
||||
func = instance.method
|
||||
|
||||
else:
|
||||
|
||||
@serve.batch(max_batch_size=2, batch_wait_timeout_s=1000)
|
||||
async def func(key1, key2):
|
||||
return [(key1[i], key2[i]) for i in range(len(key1))]
|
||||
|
||||
if mode == "args":
|
||||
coros = [func("hi1", "hi2"), func("hi3", "hi4")]
|
||||
elif mode == "kwargs":
|
||||
coros = [func(key1="hi1", key2="hi2"), func(key1="hi3", key2="hi4")]
|
||||
elif mode == "mixed":
|
||||
coros = [func("hi1", key2="hi2"), func("hi3", key2="hi4")]
|
||||
elif mode == "out-of-order":
|
||||
coros = [func(key2="hi2", key1="hi1"), func(key2="hi4", key1="hi3")]
|
||||
|
||||
result = await asyncio.gather(*coros)
|
||||
assert result == [("hi1", "hi2"), ("hi3", "hi4")]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue