[Serve] Serialize Query object directly (#9490)

This commit is contained in:
Simon Mo
2020-07-20 10:10:07 -07:00
committed by GitHub
parent bc842a7888
commit d0398bf7e1
4 changed files with 20 additions and 18 deletions
-9
View File
@@ -8,8 +8,6 @@ from ray.serve.handle import RayServeHandle
from ray.serve.utils import (block_until_http_ready, format_actor_name)
from ray.serve.exceptions import RayServeException
from ray.serve.config import BackendConfig, ReplicaConfig
from ray.serve.router import Query
from ray.serve.request_params import RequestMetadata
from ray.serve.metric import InMemoryExporter
master_actor = None
@@ -96,13 +94,6 @@ def init(name=None,
except ValueError:
pass
# Register serialization context once
ray.register_custom_serializer(Query, Query.ray_serialize,
Query.ray_deserialize)
ray.register_custom_serializer(RequestMetadata,
RequestMetadata.ray_serialize,
RequestMetadata.ray_deserialize)
# TODO(edoakes): for now, always start the HTTP proxy on the node that
# serve.init() was run on. We should consider making this configurable
# in the future.
+4 -2
View File
@@ -5,6 +5,7 @@ from collections.abc import Iterable
from collections import defaultdict
from itertools import groupby
from operator import attrgetter
from typing import Union
import time
import ray
@@ -343,8 +344,9 @@ class RayServeWorker:
self.batch_queue.set_config(self.config.max_batch_size or 1,
self.config.batch_wait_timeout)
async def handle_request(self, request: Query):
assert not isinstance(request, list)
async def handle_request(self, request: Union[Query, bytes]):
if isinstance(request, bytes):
request = Query.ray_deserialize(request)
logger.debug("Worker {} got request {}".format(self.name, request))
request.async_future = asyncio.get_event_loop().create_future()
self.batch_queue.put(request)
+3 -2
View File
@@ -316,13 +316,14 @@ class Router:
start = time.time()
worker = self.replicas[backend_replica_tag]
try:
object_ref = worker.handle_request.remote(req.ray_serialize())
if req.is_shadow_query:
# No need to actually get the result, but we do need to wait
# until the call completes to mark the worker idle.
asyncio.wait([worker.handle_request.remote(req)])
await asyncio.wait([object_ref])
result = ""
else:
result = await worker.handle_request.remote(req)
result = await object_ref
except RayTaskError as error:
self.num_error_backend_request.labels(backend=backend).add()
result = error
+13 -5
View File
@@ -5,7 +5,7 @@ import pytest
import ray
from ray.serve.master import TrafficPolicy
from ray.serve.router import Router
from ray.serve.router import Router, Query
from ray.serve.request_params import RequestMetadata
from ray.serve.utils import get_random_letters
from ray.test_utils import SignalActor
@@ -22,6 +22,8 @@ def mock_task_runner():
self.queries = []
async def handle_request(self, request):
if isinstance(request, bytes):
request = Query.ray_deserialize(request)
self.query = request
self.queries.append(request)
return "DONE"
@@ -186,6 +188,12 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
async def test_router_use_max_concurrency(serve_instance):
# The VisibleRouter::get_queues method needs to pickle queries
# so we register serializer here. In regular code path, query
# serialization is done by Serve manually for performance.
ray.register_custom_serializer(Query, Query.ray_serialize,
Query.ray_deserialize)
signal = SignalActor.remote()
@ray.remote
@@ -204,11 +212,11 @@ async def test_router_use_max_concurrency(serve_instance):
worker = MockWorker.remote()
q = ray.remote(VisibleRouter).remote()
await q.setup.remote()
BACKEND_NAME = "max-concurrent-test"
backend_name = "max-concurrent-test"
config = BackendConfig({"max_concurrent_queries": 1})
await q.set_traffic.remote("svc", TrafficPolicy({BACKEND_NAME: 1.0}))
await q.add_new_worker.remote(BACKEND_NAME, "replica-tag", worker)
await q.set_backend_config.remote(BACKEND_NAME, config)
await q.set_traffic.remote("svc", TrafficPolicy({backend_name: 1.0}))
await q.add_new_worker.remote(backend_name, "replica-tag", worker)
await q.set_backend_config.remote(backend_name, config)
# We send over two queries
first_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)