From d0398bf7e183f251730d2312a31b22af4487257a Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 20 Jul 2020 10:10:07 -0700 Subject: [PATCH] [Serve] Serialize Query object directly (#9490) --- python/ray/serve/api.py | 9 --------- python/ray/serve/backend_worker.py | 6 ++++-- python/ray/serve/router.py | 5 +++-- python/ray/serve/tests/test_router.py | 18 +++++++++++++----- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 8c658de9d..a3574b095 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -8,8 +8,6 @@ from ray.serve.handle import RayServeHandle from ray.serve.utils import (block_until_http_ready, format_actor_name) from ray.serve.exceptions import RayServeException from ray.serve.config import BackendConfig, ReplicaConfig -from ray.serve.router import Query -from ray.serve.request_params import RequestMetadata from ray.serve.metric import InMemoryExporter master_actor = None @@ -96,13 +94,6 @@ def init(name=None, except ValueError: pass - # Register serialization context once - ray.register_custom_serializer(Query, Query.ray_serialize, - Query.ray_deserialize) - ray.register_custom_serializer(RequestMetadata, - RequestMetadata.ray_serialize, - RequestMetadata.ray_deserialize) - # TODO(edoakes): for now, always start the HTTP proxy on the node that # serve.init() was run on. We should consider making this configurable # in the future. diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index 486b66b60..2c82d89f7 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -5,6 +5,7 @@ from collections.abc import Iterable from collections import defaultdict from itertools import groupby from operator import attrgetter +from typing import Union import time import ray @@ -343,8 +344,9 @@ class RayServeWorker: self.batch_queue.set_config(self.config.max_batch_size or 1, self.config.batch_wait_timeout) - async def handle_request(self, request: Query): - assert not isinstance(request, list) + async def handle_request(self, request: Union[Query, bytes]): + if isinstance(request, bytes): + request = Query.ray_deserialize(request) logger.debug("Worker {} got request {}".format(self.name, request)) request.async_future = asyncio.get_event_loop().create_future() self.batch_queue.put(request) diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index 03a15d641..00a076efd 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -316,13 +316,14 @@ class Router: start = time.time() worker = self.replicas[backend_replica_tag] try: + object_ref = worker.handle_request.remote(req.ray_serialize()) if req.is_shadow_query: # No need to actually get the result, but we do need to wait # until the call completes to mark the worker idle. - asyncio.wait([worker.handle_request.remote(req)]) + await asyncio.wait([object_ref]) result = "" else: - result = await worker.handle_request.remote(req) + result = await object_ref except RayTaskError as error: self.num_error_backend_request.labels(backend=backend).add() result = error diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index 1954d46de..ddf5cd537 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -5,7 +5,7 @@ import pytest import ray from ray.serve.master import TrafficPolicy -from ray.serve.router import Router +from ray.serve.router import Router, Query from ray.serve.request_params import RequestMetadata from ray.serve.utils import get_random_letters from ray.test_utils import SignalActor @@ -22,6 +22,8 @@ def mock_task_runner(): self.queries = [] async def handle_request(self, request): + if isinstance(request, bytes): + request = Query.ray_deserialize(request) self.query = request self.queries.append(request) return "DONE" @@ -186,6 +188,12 @@ async def test_shard_key(serve_instance, task_runner_mock_actor): async def test_router_use_max_concurrency(serve_instance): + # The VisibleRouter::get_queues method needs to pickle queries + # so we register serializer here. In regular code path, query + # serialization is done by Serve manually for performance. + ray.register_custom_serializer(Query, Query.ray_serialize, + Query.ray_deserialize) + signal = SignalActor.remote() @ray.remote @@ -204,11 +212,11 @@ async def test_router_use_max_concurrency(serve_instance): worker = MockWorker.remote() q = ray.remote(VisibleRouter).remote() await q.setup.remote() - BACKEND_NAME = "max-concurrent-test" + backend_name = "max-concurrent-test" config = BackendConfig({"max_concurrent_queries": 1}) - await q.set_traffic.remote("svc", TrafficPolicy({BACKEND_NAME: 1.0})) - await q.add_new_worker.remote(BACKEND_NAME, "replica-tag", worker) - await q.set_backend_config.remote(BACKEND_NAME, config) + await q.set_traffic.remote("svc", TrafficPolicy({backend_name: 1.0})) + await q.add_new_worker.remote(backend_name, "replica-tag", worker) + await q.set_backend_config.remote(backend_name, config) # We send over two queries first_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)