mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 05:52:54 +08:00
[Serve] Serialize Query object directly (#9490)
This commit is contained in:
@@ -8,8 +8,6 @@ from ray.serve.handle import RayServeHandle
|
||||
from ray.serve.utils import (block_until_http_ready, format_actor_name)
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig
|
||||
from ray.serve.router import Query
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.metric import InMemoryExporter
|
||||
|
||||
master_actor = None
|
||||
@@ -96,13 +94,6 @@ def init(name=None,
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Register serialization context once
|
||||
ray.register_custom_serializer(Query, Query.ray_serialize,
|
||||
Query.ray_deserialize)
|
||||
ray.register_custom_serializer(RequestMetadata,
|
||||
RequestMetadata.ray_serialize,
|
||||
RequestMetadata.ray_deserialize)
|
||||
|
||||
# TODO(edoakes): for now, always start the HTTP proxy on the node that
|
||||
# serve.init() was run on. We should consider making this configurable
|
||||
# in the future.
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Iterable
|
||||
from collections import defaultdict
|
||||
from itertools import groupby
|
||||
from operator import attrgetter
|
||||
from typing import Union
|
||||
import time
|
||||
|
||||
import ray
|
||||
@@ -343,8 +344,9 @@ class RayServeWorker:
|
||||
self.batch_queue.set_config(self.config.max_batch_size or 1,
|
||||
self.config.batch_wait_timeout)
|
||||
|
||||
async def handle_request(self, request: Query):
|
||||
assert not isinstance(request, list)
|
||||
async def handle_request(self, request: Union[Query, bytes]):
|
||||
if isinstance(request, bytes):
|
||||
request = Query.ray_deserialize(request)
|
||||
logger.debug("Worker {} got request {}".format(self.name, request))
|
||||
request.async_future = asyncio.get_event_loop().create_future()
|
||||
self.batch_queue.put(request)
|
||||
|
||||
@@ -316,13 +316,14 @@ class Router:
|
||||
start = time.time()
|
||||
worker = self.replicas[backend_replica_tag]
|
||||
try:
|
||||
object_ref = worker.handle_request.remote(req.ray_serialize())
|
||||
if req.is_shadow_query:
|
||||
# No need to actually get the result, but we do need to wait
|
||||
# until the call completes to mark the worker idle.
|
||||
asyncio.wait([worker.handle_request.remote(req)])
|
||||
await asyncio.wait([object_ref])
|
||||
result = ""
|
||||
else:
|
||||
result = await worker.handle_request.remote(req)
|
||||
result = await object_ref
|
||||
except RayTaskError as error:
|
||||
self.num_error_backend_request.labels(backend=backend).add()
|
||||
result = error
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
import ray
|
||||
|
||||
from ray.serve.master import TrafficPolicy
|
||||
from ray.serve.router import Router
|
||||
from ray.serve.router import Router, Query
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.utils import get_random_letters
|
||||
from ray.test_utils import SignalActor
|
||||
@@ -22,6 +22,8 @@ def mock_task_runner():
|
||||
self.queries = []
|
||||
|
||||
async def handle_request(self, request):
|
||||
if isinstance(request, bytes):
|
||||
request = Query.ray_deserialize(request)
|
||||
self.query = request
|
||||
self.queries.append(request)
|
||||
return "DONE"
|
||||
@@ -186,6 +188,12 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
|
||||
|
||||
|
||||
async def test_router_use_max_concurrency(serve_instance):
|
||||
# The VisibleRouter::get_queues method needs to pickle queries
|
||||
# so we register serializer here. In regular code path, query
|
||||
# serialization is done by Serve manually for performance.
|
||||
ray.register_custom_serializer(Query, Query.ray_serialize,
|
||||
Query.ray_deserialize)
|
||||
|
||||
signal = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
@@ -204,11 +212,11 @@ async def test_router_use_max_concurrency(serve_instance):
|
||||
worker = MockWorker.remote()
|
||||
q = ray.remote(VisibleRouter).remote()
|
||||
await q.setup.remote()
|
||||
BACKEND_NAME = "max-concurrent-test"
|
||||
backend_name = "max-concurrent-test"
|
||||
config = BackendConfig({"max_concurrent_queries": 1})
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({BACKEND_NAME: 1.0}))
|
||||
await q.add_new_worker.remote(BACKEND_NAME, "replica-tag", worker)
|
||||
await q.set_backend_config.remote(BACKEND_NAME, config)
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({backend_name: 1.0}))
|
||||
await q.add_new_worker.remote(backend_name, "replica-tag", worker)
|
||||
await q.set_backend_config.remote(backend_name, config)
|
||||
|
||||
# We send over two queries
|
||||
first_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||
|
||||
Reference in New Issue
Block a user