mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Serve] Improve error message when result is not a list (#8378)
This commit is contained in:
parent
3a25f5f5b4
commit
501b936114
4 changed files with 59 additions and 6 deletions
|
@ -1,5 +1,6 @@
|
||||||
import traceback
|
import traceback
|
||||||
import inspect
|
import inspect
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import serve
|
from ray import serve
|
||||||
|
@ -195,8 +196,18 @@ class RayServeWorker:
|
||||||
self.request_counter.add(batch_size)
|
self.request_counter.add(batch_size)
|
||||||
result_list = await call_method(*arg_list, **kwargs_list)
|
result_list = await call_method(*arg_list, **kwargs_list)
|
||||||
|
|
||||||
if (not isinstance(result_list,
|
if not isinstance(result_list, Iterable) or isinstance(
|
||||||
list)) or (len(result_list) != batch_size):
|
result_list, (dict, set)):
|
||||||
|
error_message = ("RayServe expects an ordered iterable object "
|
||||||
|
"but the worker returned a {}".format(
|
||||||
|
type(result_list)))
|
||||||
|
raise RayServeException(error_message)
|
||||||
|
|
||||||
|
# Normalize the result into a list type. This operation is fast
|
||||||
|
# in Python because it doesn't copy anything.
|
||||||
|
result_list = list(result_list)
|
||||||
|
|
||||||
|
if (len(result_list) != batch_size):
|
||||||
error_message = ("Worker doesn't preserve batch size. The "
|
error_message = ("Worker doesn't preserve batch size. The "
|
||||||
"input has length {} but the returned list "
|
"input has length {} but the returned list "
|
||||||
"has length {}. Please return a list of "
|
"has length {}. Please return a list of "
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import serve
|
from ray import serve
|
||||||
|
@ -9,6 +10,7 @@ from ray.serve.policy import RoundRobinPolicyQueueActor
|
||||||
from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error
|
from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error
|
||||||
from ray.serve.request_params import RequestMetadata
|
from ray.serve.request_params import RequestMetadata
|
||||||
from ray.serve.config import BackendConfig
|
from ray.serve.config import BackendConfig
|
||||||
|
from ray.serve.exceptions import RayServeException
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
@ -151,6 +153,15 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||||
def b(self, _):
|
def b(self, _):
|
||||||
return ["b-{}".format(i) for i in range(serve.context.batch_size)]
|
return ["b-{}".format(i) for i in range(serve.context.batch_size)]
|
||||||
|
|
||||||
|
def error_different_size(self, _):
|
||||||
|
return [""] * (serve.context.batch_size * 2)
|
||||||
|
|
||||||
|
def error_non_iterable(self, _):
|
||||||
|
return 42
|
||||||
|
|
||||||
|
def return_np_array(self, _):
|
||||||
|
return np.array([1] * serve.context.batch_size).astype(np.int32)
|
||||||
|
|
||||||
CONSUMER_NAME = "runner"
|
CONSUMER_NAME = "runner"
|
||||||
PRODUCER_NAME = "producer"
|
PRODUCER_NAME = "producer"
|
||||||
|
|
||||||
|
@ -163,10 +174,12 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||||
"max_batch_size": 10
|
"max_batch_size": 10
|
||||||
}, accepts_batches=True))
|
}, accepts_batches=True))
|
||||||
|
|
||||||
a_query_param = RequestMetadata(
|
def make_request_param(call_method):
|
||||||
PRODUCER_NAME, context.TaskContext.Python, call_method="a")
|
return RequestMetadata(
|
||||||
b_query_param = RequestMetadata(
|
PRODUCER_NAME, context.TaskContext.Python, call_method=call_method)
|
||||||
PRODUCER_NAME, context.TaskContext.Python, call_method="b")
|
|
||||||
|
a_query_param = make_request_param("a")
|
||||||
|
b_query_param = make_request_param("b")
|
||||||
|
|
||||||
futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
|
futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
|
||||||
futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]
|
futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]
|
||||||
|
@ -175,3 +188,15 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||||
|
|
||||||
gathered = await asyncio.gather(*futures)
|
gathered = await asyncio.gather(*futures)
|
||||||
assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
|
assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
|
||||||
|
|
||||||
|
with pytest.raises(RayServeException, match="doesn't preserve batch size"):
|
||||||
|
different_size = make_request_param("error_different_size")
|
||||||
|
await q.enqueue_request.remote(different_size)
|
||||||
|
|
||||||
|
with pytest.raises(RayServeException, match="iterable"):
|
||||||
|
non_iterable = make_request_param("error_non_iterable")
|
||||||
|
await q.enqueue_request.remote(non_iterable)
|
||||||
|
|
||||||
|
np_array = make_request_param("return_np_array")
|
||||||
|
result_np_value = await q.enqueue_request.remote(np_array)
|
||||||
|
assert isinstance(result_np_value, np.int32)
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ray.serve.utils import ServeEncoder
|
from ray.serve.utils import ServeEncoder
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,3 +9,14 @@ def test_bytes_encoder():
|
||||||
data_before = {"inp": {"nest": b"bytes"}}
|
data_before = {"inp": {"nest": b"bytes"}}
|
||||||
data_after = {"inp": {"nest": "bytes"}}
|
data_after = {"inp": {"nest": "bytes"}}
|
||||||
assert json.loads(json.dumps(data_before, cls=ServeEncoder)) == data_after
|
assert json.loads(json.dumps(data_before, cls=ServeEncoder)) == data_after
|
||||||
|
|
||||||
|
|
||||||
|
def test_numpy_encoding():
|
||||||
|
data = [1, 2]
|
||||||
|
floats = np.array(data).astype(np.float32)
|
||||||
|
ints = floats.astype(np.int32)
|
||||||
|
uints = floats.astype(np.uint32)
|
||||||
|
|
||||||
|
assert json.loads(json.dumps(floats, cls=ServeEncoder)) == data
|
||||||
|
assert json.loads(json.dumps(ints, cls=ServeEncoder)) == data
|
||||||
|
assert json.loads(json.dumps(uints, cls=ServeEncoder)) == data
|
||||||
|
|
|
@ -69,6 +69,10 @@ class ServeEncoder(json.JSONEncoder):
|
||||||
if isinstance(o, Exception):
|
if isinstance(o, Exception):
|
||||||
return str(o)
|
return str(o)
|
||||||
if isinstance(o, np.ndarray):
|
if isinstance(o, np.ndarray):
|
||||||
|
if o.dtype.kind == "f": # floats
|
||||||
|
o = o.astype(float)
|
||||||
|
if o.dtype.kind in {"i", "u"}: # signed and unsigned integers.
|
||||||
|
o = o.astype(int)
|
||||||
return o.tolist()
|
return o.tolist()
|
||||||
return super().default(o)
|
return super().default(o)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue