mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Serve] Improve error message when result is not a list (#8378)
This commit is contained in:
parent
3a25f5f5b4
commit
501b936114
4 changed files with 59 additions and 6 deletions
|
@ -1,5 +1,6 @@
|
|||
import traceback
|
||||
import inspect
|
||||
from collections.abc import Iterable
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
|
@ -195,8 +196,18 @@ class RayServeWorker:
|
|||
self.request_counter.add(batch_size)
|
||||
result_list = await call_method(*arg_list, **kwargs_list)
|
||||
|
||||
if (not isinstance(result_list,
|
||||
list)) or (len(result_list) != batch_size):
|
||||
if not isinstance(result_list, Iterable) or isinstance(
|
||||
result_list, (dict, set)):
|
||||
error_message = ("RayServe expects an ordered iterable object "
|
||||
"but the worker returned a {}".format(
|
||||
type(result_list)))
|
||||
raise RayServeException(error_message)
|
||||
|
||||
# Normalize the result into a list type. This operation is fast
|
||||
# in Python because it doesn't copy anything.
|
||||
result_list = list(result_list)
|
||||
|
||||
if (len(result_list) != batch_size):
|
||||
error_message = ("Worker doesn't preserve batch size. The "
|
||||
"input has length {} but the returned list "
|
||||
"has length {}. Please return a list of "
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
|
@ -9,6 +10,7 @@ from ray.serve.policy import RoundRobinPolicyQueueActor
|
|||
from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.config import BackendConfig
|
||||
from ray.serve.exceptions import RayServeException
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
@ -151,6 +153,15 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
|||
def b(self, _):
|
||||
return ["b-{}".format(i) for i in range(serve.context.batch_size)]
|
||||
|
||||
def error_different_size(self, _):
|
||||
return [""] * (serve.context.batch_size * 2)
|
||||
|
||||
def error_non_iterable(self, _):
|
||||
return 42
|
||||
|
||||
def return_np_array(self, _):
|
||||
return np.array([1] * serve.context.batch_size).astype(np.int32)
|
||||
|
||||
CONSUMER_NAME = "runner"
|
||||
PRODUCER_NAME = "producer"
|
||||
|
||||
|
@ -163,10 +174,12 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
|||
"max_batch_size": 10
|
||||
}, accepts_batches=True))
|
||||
|
||||
a_query_param = RequestMetadata(
|
||||
PRODUCER_NAME, context.TaskContext.Python, call_method="a")
|
||||
b_query_param = RequestMetadata(
|
||||
PRODUCER_NAME, context.TaskContext.Python, call_method="b")
|
||||
def make_request_param(call_method):
|
||||
return RequestMetadata(
|
||||
PRODUCER_NAME, context.TaskContext.Python, call_method=call_method)
|
||||
|
||||
a_query_param = make_request_param("a")
|
||||
b_query_param = make_request_param("b")
|
||||
|
||||
futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
|
||||
futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]
|
||||
|
@ -175,3 +188,15 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
|||
|
||||
gathered = await asyncio.gather(*futures)
|
||||
assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
|
||||
|
||||
with pytest.raises(RayServeException, match="doesn't preserve batch size"):
|
||||
different_size = make_request_param("error_different_size")
|
||||
await q.enqueue_request.remote(different_size)
|
||||
|
||||
with pytest.raises(RayServeException, match="iterable"):
|
||||
non_iterable = make_request_param("error_non_iterable")
|
||||
await q.enqueue_request.remote(non_iterable)
|
||||
|
||||
np_array = make_request_param("return_np_array")
|
||||
result_np_value = await q.enqueue_request.remote(np_array)
|
||||
assert isinstance(result_np_value, np.int32)
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.serve.utils import ServeEncoder
|
||||
|
||||
|
||||
|
@ -7,3 +9,14 @@ def test_bytes_encoder():
|
|||
data_before = {"inp": {"nest": b"bytes"}}
|
||||
data_after = {"inp": {"nest": "bytes"}}
|
||||
assert json.loads(json.dumps(data_before, cls=ServeEncoder)) == data_after
|
||||
|
||||
|
||||
def test_numpy_encoding():
|
||||
data = [1, 2]
|
||||
floats = np.array(data).astype(np.float32)
|
||||
ints = floats.astype(np.int32)
|
||||
uints = floats.astype(np.uint32)
|
||||
|
||||
assert json.loads(json.dumps(floats, cls=ServeEncoder)) == data
|
||||
assert json.loads(json.dumps(ints, cls=ServeEncoder)) == data
|
||||
assert json.loads(json.dumps(uints, cls=ServeEncoder)) == data
|
||||
|
|
|
@ -69,6 +69,10 @@ class ServeEncoder(json.JSONEncoder):
|
|||
if isinstance(o, Exception):
|
||||
return str(o)
|
||||
if isinstance(o, np.ndarray):
|
||||
if o.dtype.kind == "f": # floats
|
||||
o = o.astype(float)
|
||||
if o.dtype.kind in {"i", "u"}: # signed and unsigned integers.
|
||||
o = o.astype(int)
|
||||
return o.tolist()
|
||||
return super().default(o)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue