diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index 4cbc537e3..0d115eddb 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -1,5 +1,6 @@ import traceback import inspect +from collections.abc import Iterable import ray from ray import serve @@ -195,8 +196,18 @@ class RayServeWorker: self.request_counter.add(batch_size) result_list = await call_method(*arg_list, **kwargs_list) - if (not isinstance(result_list, - list)) or (len(result_list) != batch_size): + if not isinstance(result_list, Iterable) or isinstance( + result_list, (dict, set)): + error_message = ("RayServe expects an ordered iterable object " + "but the worker returned a {}".format( + type(result_list))) + raise RayServeException(error_message) + + # Normalize the result into a list type. This operation is fast + # in Python because it doesn't copy anything. + result_list = list(result_list) + + if (len(result_list) != batch_size): error_message = ("Worker doesn't preserve batch size. The " "input has length {} but the returned list " "has length {}. Please return a list of " diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index 53f515939..48a9128e5 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -1,6 +1,7 @@ import asyncio import pytest +import numpy as np import ray from ray import serve @@ -9,6 +10,7 @@ from ray.serve.policy import RoundRobinPolicyQueueActor from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error from ray.serve.request_params import RequestMetadata from ray.serve.config import BackendConfig +from ray.serve.exceptions import RayServeException pytestmark = pytest.mark.asyncio @@ -151,6 +153,15 @@ async def test_task_runner_custom_method_batch(serve_instance): def b(self, _): return ["b-{}".format(i) for i in range(serve.context.batch_size)] + def error_different_size(self, _): + return [""] * (serve.context.batch_size * 2) + + def error_non_iterable(self, _): + return 42 + + def return_np_array(self, _): + return np.array([1] * serve.context.batch_size).astype(np.int32) + CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" @@ -163,10 +174,12 @@ async def test_task_runner_custom_method_batch(serve_instance): "max_batch_size": 10 }, accepts_batches=True)) - a_query_param = RequestMetadata( - PRODUCER_NAME, context.TaskContext.Python, call_method="a") - b_query_param = RequestMetadata( - PRODUCER_NAME, context.TaskContext.Python, call_method="b") + def make_request_param(call_method): + return RequestMetadata( + PRODUCER_NAME, context.TaskContext.Python, call_method=call_method) + + a_query_param = make_request_param("a") + b_query_param = make_request_param("b") futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)] futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)] @@ -175,3 +188,15 @@ async def test_task_runner_custom_method_batch(serve_instance): gathered = await asyncio.gather(*futures) assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"} + + with pytest.raises(RayServeException, match="doesn't preserve batch size"): + different_size = make_request_param("error_different_size") + await q.enqueue_request.remote(different_size) + + with pytest.raises(RayServeException, match="iterable"): + non_iterable = make_request_param("error_non_iterable") + await q.enqueue_request.remote(non_iterable) + + np_array = make_request_param("return_np_array") + result_np_value = await q.enqueue_request.remote(np_array) + assert isinstance(result_np_value, np.int32) diff --git a/python/ray/serve/tests/test_util.py b/python/ray/serve/tests/test_util.py index 62d63565c..b1c2254c6 100644 --- a/python/ray/serve/tests/test_util.py +++ b/python/ray/serve/tests/test_util.py @@ -1,5 +1,7 @@ import json +import numpy as np + from ray.serve.utils import ServeEncoder @@ -7,3 +9,14 @@ def test_bytes_encoder(): data_before = {"inp": {"nest": b"bytes"}} data_after = {"inp": {"nest": "bytes"}} assert json.loads(json.dumps(data_before, cls=ServeEncoder)) == data_after + + +def test_numpy_encoding(): + data = [1, 2] + floats = np.array(data).astype(np.float32) + ints = floats.astype(np.int32) + uints = floats.astype(np.uint32) + + assert json.loads(json.dumps(floats, cls=ServeEncoder)) == data + assert json.loads(json.dumps(ints, cls=ServeEncoder)) == data + assert json.loads(json.dumps(uints, cls=ServeEncoder)) == data diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index 511676a24..adf7d815e 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -69,6 +69,10 @@ class ServeEncoder(json.JSONEncoder): if isinstance(o, Exception): return str(o) if isinstance(o, np.ndarray): + if o.dtype.kind == "f": # floats + o = o.astype(float) + if o.dtype.kind in {"i", "u"}: # signed and unsigned integers. + o = o.astype(int) return o.tolist() return super().default(o)