[Serve] Improve error message when result is not a list (#8378)

This commit is contained in:
Simon Mo 2020-05-10 17:18:06 -07:00 committed by GitHub
parent 3a25f5f5b4
commit 501b936114
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 6 deletions

View file

@ -1,5 +1,6 @@
import traceback import traceback
import inspect import inspect
from collections.abc import Iterable
import ray import ray
from ray import serve from ray import serve
@ -195,8 +196,18 @@ class RayServeWorker:
self.request_counter.add(batch_size) self.request_counter.add(batch_size)
result_list = await call_method(*arg_list, **kwargs_list) result_list = await call_method(*arg_list, **kwargs_list)
if (not isinstance(result_list, if not isinstance(result_list, Iterable) or isinstance(
list)) or (len(result_list) != batch_size): result_list, (dict, set)):
error_message = ("RayServe expects an ordered iterable object "
"but the worker returned a {}".format(
type(result_list)))
raise RayServeException(error_message)
# Normalize the result into a list type. This operation is fast
# in Python because it doesn't copy anything.
result_list = list(result_list)
if (len(result_list) != batch_size):
error_message = ("Worker doesn't preserve batch size. The " error_message = ("Worker doesn't preserve batch size. The "
"input has length {} but the returned list " "input has length {} but the returned list "
"has length {}. Please return a list of " "has length {}. Please return a list of "

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
import pytest import pytest
import numpy as np
import ray import ray
from ray import serve from ray import serve
@ -9,6 +10,7 @@ from ray.serve.policy import RoundRobinPolicyQueueActor
from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error
from ray.serve.request_params import RequestMetadata from ray.serve.request_params import RequestMetadata
from ray.serve.config import BackendConfig from ray.serve.config import BackendConfig
from ray.serve.exceptions import RayServeException
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
@ -151,6 +153,15 @@ async def test_task_runner_custom_method_batch(serve_instance):
def b(self, _): def b(self, _):
return ["b-{}".format(i) for i in range(serve.context.batch_size)] return ["b-{}".format(i) for i in range(serve.context.batch_size)]
def error_different_size(self, _):
return [""] * (serve.context.batch_size * 2)
def error_non_iterable(self, _):
return 42
def return_np_array(self, _):
return np.array([1] * serve.context.batch_size).astype(np.int32)
CONSUMER_NAME = "runner" CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer" PRODUCER_NAME = "producer"
@ -163,10 +174,12 @@ async def test_task_runner_custom_method_batch(serve_instance):
"max_batch_size": 10 "max_batch_size": 10
}, accepts_batches=True)) }, accepts_batches=True))
a_query_param = RequestMetadata( def make_request_param(call_method):
PRODUCER_NAME, context.TaskContext.Python, call_method="a") return RequestMetadata(
b_query_param = RequestMetadata( PRODUCER_NAME, context.TaskContext.Python, call_method=call_method)
PRODUCER_NAME, context.TaskContext.Python, call_method="b")
a_query_param = make_request_param("a")
b_query_param = make_request_param("b")
futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)] futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)] futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]
@ -175,3 +188,15 @@ async def test_task_runner_custom_method_batch(serve_instance):
gathered = await asyncio.gather(*futures) gathered = await asyncio.gather(*futures)
assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"} assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
with pytest.raises(RayServeException, match="doesn't preserve batch size"):
different_size = make_request_param("error_different_size")
await q.enqueue_request.remote(different_size)
with pytest.raises(RayServeException, match="iterable"):
non_iterable = make_request_param("error_non_iterable")
await q.enqueue_request.remote(non_iterable)
np_array = make_request_param("return_np_array")
result_np_value = await q.enqueue_request.remote(np_array)
assert isinstance(result_np_value, np.int32)

View file

@ -1,5 +1,7 @@
import json import json
import numpy as np
from ray.serve.utils import ServeEncoder from ray.serve.utils import ServeEncoder
@ -7,3 +9,14 @@ def test_bytes_encoder():
data_before = {"inp": {"nest": b"bytes"}} data_before = {"inp": {"nest": b"bytes"}}
data_after = {"inp": {"nest": "bytes"}} data_after = {"inp": {"nest": "bytes"}}
assert json.loads(json.dumps(data_before, cls=ServeEncoder)) == data_after assert json.loads(json.dumps(data_before, cls=ServeEncoder)) == data_after
def test_numpy_encoding():
data = [1, 2]
floats = np.array(data).astype(np.float32)
ints = floats.astype(np.int32)
uints = floats.astype(np.uint32)
assert json.loads(json.dumps(floats, cls=ServeEncoder)) == data
assert json.loads(json.dumps(ints, cls=ServeEncoder)) == data
assert json.loads(json.dumps(uints, cls=ServeEncoder)) == data

View file

@ -69,6 +69,10 @@ class ServeEncoder(json.JSONEncoder):
if isinstance(o, Exception): if isinstance(o, Exception):
return str(o) return str(o)
if isinstance(o, np.ndarray): if isinstance(o, np.ndarray):
if o.dtype.kind == "f": # floats
o = o.astype(float)
if o.dtype.kind in {"i", "u"}: # signed and unsigned integers.
o = o.astype(int)
return o.tolist() return o.tolist()
return super().default(o) return super().default(o)