mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 21:08:50 +08:00
[Serve] Refactor RequestMetadata and Query objects (#10483)
This commit is contained in:
@@ -4,7 +4,6 @@ import inspect
|
||||
from collections.abc import Iterable
|
||||
from collections import defaultdict
|
||||
from itertools import groupby
|
||||
from operator import attrgetter
|
||||
from typing import Union, List, Any, Callable, Type
|
||||
import time
|
||||
|
||||
@@ -185,7 +184,7 @@ class RayServeWorker:
|
||||
asyncio.get_event_loop().create_task(self.main_loop())
|
||||
|
||||
def get_runner_method(self, request_item: Query) -> Callable:
|
||||
method_name = request_item.call_method
|
||||
method_name = request_item.metadata.call_method
|
||||
if not hasattr(self.callable, method_name):
|
||||
raise RayServeException("Backend doesn't have method {} "
|
||||
"which is specified in the request. "
|
||||
@@ -325,7 +324,9 @@ class RayServeWorker:
|
||||
all_evaluated_futures = [evaluated]
|
||||
chain_future(evaluated, query.async_future)
|
||||
else:
|
||||
get_call_method = attrgetter("call_method")
|
||||
get_call_method = (
|
||||
lambda query: query.metadata.call_method # noqa: E731
|
||||
)
|
||||
sorted_batch = sorted(batch, key=get_call_method)
|
||||
for _, group in groupby(sorted_batch, key=get_call_method):
|
||||
group = list(group)
|
||||
|
||||
@@ -77,10 +77,10 @@ class RandomEndpointPolicy(EndpointPolicy):
|
||||
assigned_backends = set()
|
||||
while len(endpoint_queue) > 0:
|
||||
query = endpoint_queue.pop()
|
||||
if query.shard_key is None:
|
||||
if query.metadata.shard_key is None:
|
||||
rstate = np.random
|
||||
else:
|
||||
sha256_seed = sha256(query.shard_key.encode("utf-8"))
|
||||
sha256_seed = sha256(query.metadata.shard_key.encode("utf-8"))
|
||||
seed = np.frombuffer(sha256_seed.digest(), dtype=np.uint32)
|
||||
# Note(simon): This constructor takes 100+us, maybe cache this?
|
||||
rstate = np.random.RandomState(seed)
|
||||
@@ -93,7 +93,7 @@ class RandomEndpointPolicy(EndpointPolicy):
|
||||
if len(shadow_backends) > 0:
|
||||
shadow_query = copy.copy(query)
|
||||
shadow_query.async_future = None
|
||||
shadow_query.is_shadow_query = True
|
||||
shadow_query.metadata.is_shadow_query = True
|
||||
for shadow_backend in shadow_backends:
|
||||
assigned_backends.add(shadow_backend)
|
||||
backend_queues[shadow_backend].appendleft(shadow_query)
|
||||
|
||||
@@ -2,28 +2,15 @@
|
||||
Example service that prints out http context.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
from pygments import formatters, highlight, lexers
|
||||
|
||||
import requests
|
||||
|
||||
from ray import serve
|
||||
|
||||
|
||||
def pformat_color_json(d):
|
||||
"""Use pygments to pretty format and colorize dictionary"""
|
||||
formatted_json = json.dumps(d, sort_keys=True, indent=4)
|
||||
|
||||
colorful_json = highlight(formatted_json, lexers.JsonLexer(),
|
||||
formatters.TerminalFormatter())
|
||||
|
||||
return colorful_json
|
||||
|
||||
|
||||
def echo(flask_request):
|
||||
return "hello " + flask_request.args.get("name", "serve!")
|
||||
return ["hello " + flask_request.args.get("name", "serve!")]
|
||||
|
||||
|
||||
serve.init()
|
||||
@@ -33,7 +20,7 @@ serve.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
|
||||
|
||||
while True:
|
||||
resp = requests.get("http://127.0.0.1:8000/echo").json()
|
||||
print(pformat_color_json(resp))
|
||||
print(resp)
|
||||
|
||||
print("...Sleeping for 2 seconds...")
|
||||
time.sleep(2)
|
||||
|
||||
+21
-34
@@ -1,7 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
|
||||
|
||||
@@ -16,7 +17,6 @@ class RayServeHandle:
|
||||
>>> handle
|
||||
RayServeHandle(
|
||||
Endpoint="my_endpoint",
|
||||
URL="...",
|
||||
Traffic=...
|
||||
)
|
||||
>>> handle.remote(my_request_content)
|
||||
@@ -31,61 +31,48 @@ class RayServeHandle:
|
||||
self,
|
||||
router_handle,
|
||||
endpoint_name,
|
||||
http_method=None,
|
||||
method_name=None,
|
||||
shard_key=None,
|
||||
):
|
||||
self.router_handle = router_handle
|
||||
self.endpoint_name = endpoint_name
|
||||
self.http_method = http_method
|
||||
self.method_name = method_name
|
||||
self.shard_key = shard_key
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
if len(args) != 0:
|
||||
raise RayServeException(
|
||||
if len(args) > 0:
|
||||
raise ValueError(
|
||||
"handle.remote must be invoked with keyword arguments.")
|
||||
|
||||
method_name = self.method_name
|
||||
if method_name is None:
|
||||
method_name = "__call__"
|
||||
|
||||
# create RequestMetadata instance
|
||||
request_in_object = RequestMetadata(
|
||||
request_metadata = RequestMetadata(
|
||||
self.endpoint_name,
|
||||
TaskContext.Python,
|
||||
call_method=method_name,
|
||||
http_method=self.http_method or "GET",
|
||||
call_method=self.method_name or "__call__",
|
||||
shard_key=self.shard_key,
|
||||
)
|
||||
return self.router_handle.enqueue_request.remote(
|
||||
request_in_object, **kwargs)
|
||||
|
||||
def options(self, method_name=None, shard_key=None):
|
||||
|
||||
# Don't override existing method
|
||||
if method_name is None and self.method_name is not None:
|
||||
method_name = self.method_name
|
||||
|
||||
if shard_key is None and self.shard_key is not None:
|
||||
shard_key = self.shard_key
|
||||
request_metadata, **kwargs)
|
||||
|
||||
def options(self,
|
||||
method_name: Optional[str] = None,
|
||||
http_method: Optional[str] = None,
|
||||
shard_key: Optional[str] = None):
|
||||
return RayServeHandle(
|
||||
self.router_handle,
|
||||
self.endpoint_name,
|
||||
method_name=method_name,
|
||||
shard_key=shard_key,
|
||||
# Don't override existing method
|
||||
http_method=self.http_method or http_method,
|
||||
method_name=self.method_name or method_name,
|
||||
shard_key=self.shard_key or shard_key,
|
||||
)
|
||||
|
||||
def get_traffic_policy(self):
|
||||
def _get_traffic_policy(self):
|
||||
controller = serve.api._get_controller()
|
||||
return ray.get(
|
||||
controller.get_traffic_policy.remote(self.endpoint_name))
|
||||
|
||||
def __repr__(self):
|
||||
return """
|
||||
RayServeHandle(
|
||||
Endpoint="{endpoint_name}",
|
||||
Traffic={traffic_policy}
|
||||
)
|
||||
""".format(
|
||||
endpoint_name=self.endpoint_name,
|
||||
traffic_policy=self.get_traffic_policy(),
|
||||
)
|
||||
return (f"RayServeHandle(Endpoint='{self.endpoint_name}', "
|
||||
f"Traffic={self._get_traffic_policy()})")
|
||||
|
||||
@@ -111,6 +111,7 @@ class HTTPProxy:
|
||||
request_metadata = RequestMetadata(
|
||||
endpoint_name,
|
||||
TaskContext.Web,
|
||||
http_method=scope["method"].upper(),
|
||||
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"),
|
||||
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None),
|
||||
)
|
||||
|
||||
@@ -1,30 +1,15 @@
|
||||
import ray.cloudpickle as pickle
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from ray.serve.context import TaskContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestMetadata:
|
||||
"""
|
||||
Request arguments required for enqueuing a request to the endpoint queue.
|
||||
endpoint: str
|
||||
request_context: TaskContext
|
||||
|
||||
Args:
|
||||
endpoint(str): A registered endpoint.
|
||||
request_context(TaskContext): Context of a request.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
endpoint,
|
||||
request_context,
|
||||
call_method="__call__",
|
||||
shard_key=None):
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.request_context = request_context
|
||||
self.call_method = call_method
|
||||
self.shard_key = shard_key
|
||||
|
||||
def ray_serialize(self):
|
||||
return pickle.dumps(self.__dict__)
|
||||
|
||||
@staticmethod
|
||||
def ray_deserialize(value):
|
||||
kwargs = pickle.loads(value)
|
||||
return RequestMetadata(**kwargs)
|
||||
call_method: str = "__call__"
|
||||
shard_key: Optional[str] = None
|
||||
http_method: str = "GET"
|
||||
is_shadow_query: bool = False
|
||||
|
||||
+13
-44
@@ -2,40 +2,31 @@ import asyncio
|
||||
import copy
|
||||
from collections import defaultdict, deque
|
||||
import time
|
||||
from typing import DefaultDict, List
|
||||
from typing import DefaultDict, List, Dict, Any, Optional
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ray.exceptions import RayTaskError
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.experimental import metrics
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.endpoint_policy import RandomEndpointPolicy
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.utils import logger, chain_future
|
||||
|
||||
REPORT_QUEUE_LENGTH_PERIOD_S = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Query:
|
||||
def __init__(
|
||||
self,
|
||||
request_args,
|
||||
request_kwargs,
|
||||
request_context,
|
||||
call_method="__call__",
|
||||
shard_key=None,
|
||||
async_future=None,
|
||||
is_shadow_query=False,
|
||||
):
|
||||
self.request_args = request_args
|
||||
self.request_kwargs = request_kwargs
|
||||
self.request_context = request_context
|
||||
args: List[Any]
|
||||
kwargs: Dict[Any, Any]
|
||||
context: TaskContext
|
||||
|
||||
self.async_future = async_future
|
||||
|
||||
self.call_method = call_method
|
||||
self.shard_key = shard_key
|
||||
self.is_shadow_query = is_shadow_query
|
||||
metadata: RequestMetadata
|
||||
async_future: Optional[asyncio.Future] = None
|
||||
|
||||
def __reduce__(self):
|
||||
return type(self).ray_deserialize, (self.ray_serialize(), )
|
||||
@@ -56,27 +47,6 @@ class Query:
|
||||
return Query(**kwargs)
|
||||
|
||||
|
||||
def _make_future_unwrapper(client_futures: List[asyncio.Future],
|
||||
host_future: asyncio.Future):
|
||||
"""Distribute the result of host_future to each of client_future"""
|
||||
for client_future in client_futures:
|
||||
# Keep a reference to host future so the host future won't get
|
||||
# garbage collected.
|
||||
client_future.host_ref = host_future
|
||||
|
||||
def unwrap_future(_):
|
||||
result = host_future.result()
|
||||
|
||||
if isinstance(result, list):
|
||||
for client_future, result_item in zip(client_futures, result):
|
||||
client_future.set_result(result_item)
|
||||
else: # Result is an exception.
|
||||
for client_future in client_futures:
|
||||
client_future.set_result(result)
|
||||
|
||||
return unwrap_future
|
||||
|
||||
|
||||
class Router:
|
||||
"""A router that routes request to available workers."""
|
||||
|
||||
@@ -175,8 +145,7 @@ class Router:
|
||||
request_args,
|
||||
request_kwargs,
|
||||
request_context,
|
||||
call_method=request_meta.call_method,
|
||||
shard_key=request_meta.shard_key,
|
||||
metadata=request_meta,
|
||||
async_future=asyncio.get_event_loop().create_future())
|
||||
async with self.flush_lock:
|
||||
self.endpoint_queues[endpoint].appendleft(query)
|
||||
@@ -301,7 +270,7 @@ class Router:
|
||||
worker = self.replicas[backend_replica_tag]
|
||||
try:
|
||||
object_ref = worker.handle_request.remote(req.ray_serialize())
|
||||
if req.is_shadow_query:
|
||||
if req.metadata.is_shadow_query:
|
||||
# No need to actually get the result, but we do need to wait
|
||||
# until the call completes to mark the worker idle.
|
||||
await asyncio.wait([object_ref])
|
||||
@@ -351,7 +320,7 @@ class Router:
|
||||
self._do_query(backend, backend_replica_tag, request))
|
||||
|
||||
# For shadow queries, just ignore the result.
|
||||
if not request.is_shadow_query:
|
||||
if not request.metadata.is_shadow_query:
|
||||
chain_future(future, request.async_future)
|
||||
|
||||
worker_queue.appendleft(backend_replica_tag)
|
||||
|
||||
@@ -61,8 +61,8 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
|
||||
# Make sure it's the right request
|
||||
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
||||
assert got_work.request_args[0] == 1
|
||||
assert got_work.request_kwargs == {}
|
||||
assert got_work.args[0] == 1
|
||||
assert got_work.kwargs == {}
|
||||
|
||||
|
||||
async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
@@ -74,14 +74,14 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
task_runner_mock_actor)
|
||||
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
||||
assert got_work.request_args[0] == 1
|
||||
assert got_work.args[0] == 1
|
||||
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter-2": 1}))
|
||||
await q.add_new_worker.remote("backend-alter-2", "replica-1",
|
||||
task_runner_mock_actor)
|
||||
await q.enqueue_request.remote(RequestMetadata("svc", None), 2)
|
||||
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
||||
assert got_work.request_args[0] == 2
|
||||
assert got_work.args[0] == 2
|
||||
|
||||
|
||||
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
||||
@@ -106,7 +106,7 @@ async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
||||
await runner.get_recent_call.remote()
|
||||
for runner in (runner_1, runner_2)
|
||||
]
|
||||
assert [g.request_args[0] for g in got_work] == [1, 1]
|
||||
assert [g.args[0] for g in got_work] == [1, 1]
|
||||
|
||||
|
||||
async def test_queue_remove_replicas(serve_instance):
|
||||
@@ -146,7 +146,7 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
|
||||
for i, runner in enumerate(runners):
|
||||
calls = await runner.get_all_calls.remote()
|
||||
for call in calls:
|
||||
runner_shard_keys[i].add(call.request_args[0])
|
||||
runner_shard_keys[i].add(call.args[0])
|
||||
await runner.clear_calls.remote()
|
||||
|
||||
# Send queries with the same shard keys a second time.
|
||||
@@ -158,7 +158,7 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
|
||||
for i, runner in enumerate(runners):
|
||||
calls = await runner.get_all_calls.remote()
|
||||
for call in calls:
|
||||
assert call.request_args[0] in runner_shard_keys[i]
|
||||
assert call.args[0] in runner_shard_keys[i]
|
||||
|
||||
|
||||
async def test_router_use_max_concurrency(serve_instance):
|
||||
|
||||
@@ -27,9 +27,9 @@ ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
|
||||
|
||||
|
||||
def parse_request_item(request_item):
|
||||
if request_item.request_context == TaskContext.Web:
|
||||
if request_item.metadata.request_context == TaskContext.Web:
|
||||
is_web_context = True
|
||||
asgi_scope, body_bytes = request_item.request_args
|
||||
asgi_scope, body_bytes = request_item.args
|
||||
|
||||
flask_request = build_flask_request(asgi_scope, io.BytesIO(body_bytes))
|
||||
args = (flask_request, )
|
||||
@@ -37,7 +37,7 @@ def parse_request_item(request_item):
|
||||
else:
|
||||
is_web_context = False
|
||||
args = (FakeFlaskRequest(), )
|
||||
kwargs = request_item.request_kwargs
|
||||
kwargs = request_item.kwargs
|
||||
|
||||
return args, kwargs, is_web_context
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ scipy==1.4.1
|
||||
tabulate
|
||||
tensorboardX
|
||||
uvicorn
|
||||
dataclasses
|
||||
|
||||
# Requirements for running tests
|
||||
blist; platform_system != "Windows"
|
||||
|
||||
+1
-1
@@ -110,7 +110,7 @@ if os.getenv("RAY_USE_NEW_GCS") == "on":
|
||||
# in this directory
|
||||
extras = {
|
||||
"debug": [],
|
||||
"serve": ["uvicorn", "flask", "requests"],
|
||||
"serve": ["uvicorn", "flask", "requests", "dataclasses"],
|
||||
"tune": ["tabulate", "tensorboardX", "pandas"]
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user