[Serve] Refactor RequestMetadata and Query objects (#10483)

This commit is contained in:
Simon Mo
2020-09-01 18:15:31 -07:00
committed by GitHub
parent 3b10b67a15
commit 65f17f2e14
11 changed files with 67 additions and 136 deletions
+4 -3
View File
@@ -4,7 +4,6 @@ import inspect
from collections.abc import Iterable
from collections import defaultdict
from itertools import groupby
from operator import attrgetter
from typing import Union, List, Any, Callable, Type
import time
@@ -185,7 +184,7 @@ class RayServeWorker:
asyncio.get_event_loop().create_task(self.main_loop())
def get_runner_method(self, request_item: Query) -> Callable:
method_name = request_item.call_method
method_name = request_item.metadata.call_method
if not hasattr(self.callable, method_name):
raise RayServeException("Backend doesn't have method {} "
"which is specified in the request. "
@@ -325,7 +324,9 @@ class RayServeWorker:
all_evaluated_futures = [evaluated]
chain_future(evaluated, query.async_future)
else:
get_call_method = attrgetter("call_method")
get_call_method = (
lambda query: query.metadata.call_method # noqa: E731
)
sorted_batch = sorted(batch, key=get_call_method)
for _, group in groupby(sorted_batch, key=get_call_method):
group = list(group)
+3 -3
View File
@@ -77,10 +77,10 @@ class RandomEndpointPolicy(EndpointPolicy):
assigned_backends = set()
while len(endpoint_queue) > 0:
query = endpoint_queue.pop()
if query.shard_key is None:
if query.metadata.shard_key is None:
rstate = np.random
else:
sha256_seed = sha256(query.shard_key.encode("utf-8"))
sha256_seed = sha256(query.metadata.shard_key.encode("utf-8"))
seed = np.frombuffer(sha256_seed.digest(), dtype=np.uint32)
# Note(simon): This constructor takes 100+us, maybe cache this?
rstate = np.random.RandomState(seed)
@@ -93,7 +93,7 @@ class RandomEndpointPolicy(EndpointPolicy):
if len(shadow_backends) > 0:
shadow_query = copy.copy(query)
shadow_query.async_future = None
shadow_query.is_shadow_query = True
shadow_query.metadata.is_shadow_query = True
for shadow_backend in shadow_backends:
assigned_backends.add(shadow_backend)
backend_queues[shadow_backend].appendleft(shadow_query)
+2 -15
View File
@@ -2,28 +2,15 @@
Example service that prints out http context.
"""
import json
import time
from pygments import formatters, highlight, lexers
import requests
from ray import serve
def pformat_color_json(d):
"""Use pygments to pretty format and colorize dictionary"""
formatted_json = json.dumps(d, sort_keys=True, indent=4)
colorful_json = highlight(formatted_json, lexers.JsonLexer(),
formatters.TerminalFormatter())
return colorful_json
def echo(flask_request):
return "hello " + flask_request.args.get("name", "serve!")
return ["hello " + flask_request.args.get("name", "serve!")]
serve.init()
@@ -33,7 +20,7 @@ serve.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
while True:
resp = requests.get("http://127.0.0.1:8000/echo").json()
print(pformat_color_json(resp))
print(resp)
print("...Sleeping for 2 seconds...")
time.sleep(2)
+21 -34
View File
@@ -1,7 +1,8 @@
from typing import Optional
import ray
from ray import serve
from ray.serve.context import TaskContext
from ray.serve.exceptions import RayServeException
from ray.serve.request_params import RequestMetadata
@@ -16,7 +17,6 @@ class RayServeHandle:
>>> handle
RayServeHandle(
Endpoint="my_endpoint",
URL="...",
Traffic=...
)
>>> handle.remote(my_request_content)
@@ -31,61 +31,48 @@ class RayServeHandle:
self,
router_handle,
endpoint_name,
http_method=None,
method_name=None,
shard_key=None,
):
self.router_handle = router_handle
self.endpoint_name = endpoint_name
self.http_method = http_method
self.method_name = method_name
self.shard_key = shard_key
def remote(self, *args, **kwargs):
if len(args) != 0:
raise RayServeException(
if len(args) > 0:
raise ValueError(
"handle.remote must be invoked with keyword arguments.")
method_name = self.method_name
if method_name is None:
method_name = "__call__"
# create RequestMetadata instance
request_in_object = RequestMetadata(
request_metadata = RequestMetadata(
self.endpoint_name,
TaskContext.Python,
call_method=method_name,
http_method=self.http_method or "GET",
call_method=self.method_name or "__call__",
shard_key=self.shard_key,
)
return self.router_handle.enqueue_request.remote(
request_in_object, **kwargs)
def options(self, method_name=None, shard_key=None):
# Don't override existing method
if method_name is None and self.method_name is not None:
method_name = self.method_name
if shard_key is None and self.shard_key is not None:
shard_key = self.shard_key
request_metadata, **kwargs)
def options(self,
method_name: Optional[str] = None,
http_method: Optional[str] = None,
shard_key: Optional[str] = None):
return RayServeHandle(
self.router_handle,
self.endpoint_name,
method_name=method_name,
shard_key=shard_key,
# Don't override existing method
http_method=self.http_method or http_method,
method_name=self.method_name or method_name,
shard_key=self.shard_key or shard_key,
)
def get_traffic_policy(self):
def _get_traffic_policy(self):
controller = serve.api._get_controller()
return ray.get(
controller.get_traffic_policy.remote(self.endpoint_name))
def __repr__(self):
return """
RayServeHandle(
Endpoint="{endpoint_name}",
Traffic={traffic_policy}
)
""".format(
endpoint_name=self.endpoint_name,
traffic_policy=self.get_traffic_policy(),
)
return (f"RayServeHandle(Endpoint='{self.endpoint_name}', "
f"Traffic={self._get_traffic_policy()})")
+1
View File
@@ -111,6 +111,7 @@ class HTTPProxy:
request_metadata = RequestMetadata(
endpoint_name,
TaskContext.Web,
http_method=scope["method"].upper(),
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"),
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None),
)
+11 -26
View File
@@ -1,30 +1,15 @@
import ray.cloudpickle as pickle
from dataclasses import dataclass
from typing import Optional
from ray.serve.context import TaskContext
@dataclass
class RequestMetadata:
"""
Request arguments required for enqueuing a request to the endpoint queue.
endpoint: str
request_context: TaskContext
Args:
endpoint(str): A registered endpoint.
request_context(TaskContext): Context of a request.
"""
def __init__(self,
endpoint,
request_context,
call_method="__call__",
shard_key=None):
self.endpoint = endpoint
self.request_context = request_context
self.call_method = call_method
self.shard_key = shard_key
def ray_serialize(self):
return pickle.dumps(self.__dict__)
@staticmethod
def ray_deserialize(value):
kwargs = pickle.loads(value)
return RequestMetadata(**kwargs)
call_method: str = "__call__"
shard_key: Optional[str] = None
http_method: str = "GET"
is_shadow_query: bool = False
+13 -44
View File
@@ -2,40 +2,31 @@ import asyncio
import copy
from collections import defaultdict, deque
import time
from typing import DefaultDict, List
from typing import DefaultDict, List, Dict, Any, Optional
import pickle
from dataclasses import dataclass
from ray.exceptions import RayTaskError
import ray
from ray import serve
from ray.experimental import metrics
from ray.serve.context import TaskContext
from ray.serve.endpoint_policy import RandomEndpointPolicy
from ray.serve.request_params import RequestMetadata
from ray.serve.utils import logger, chain_future
REPORT_QUEUE_LENGTH_PERIOD_S = 1.0
@dataclass
class Query:
def __init__(
self,
request_args,
request_kwargs,
request_context,
call_method="__call__",
shard_key=None,
async_future=None,
is_shadow_query=False,
):
self.request_args = request_args
self.request_kwargs = request_kwargs
self.request_context = request_context
args: List[Any]
kwargs: Dict[Any, Any]
context: TaskContext
self.async_future = async_future
self.call_method = call_method
self.shard_key = shard_key
self.is_shadow_query = is_shadow_query
metadata: RequestMetadata
async_future: Optional[asyncio.Future] = None
def __reduce__(self):
return type(self).ray_deserialize, (self.ray_serialize(), )
@@ -56,27 +47,6 @@ class Query:
return Query(**kwargs)
def _make_future_unwrapper(client_futures: List[asyncio.Future],
host_future: asyncio.Future):
"""Distribute the result of host_future to each of client_future"""
for client_future in client_futures:
# Keep a reference to host future so the host future won't get
# garbage collected.
client_future.host_ref = host_future
def unwrap_future(_):
result = host_future.result()
if isinstance(result, list):
for client_future, result_item in zip(client_futures, result):
client_future.set_result(result_item)
else: # Result is an exception.
for client_future in client_futures:
client_future.set_result(result)
return unwrap_future
class Router:
"""A router that routes request to available workers."""
@@ -175,8 +145,7 @@ class Router:
request_args,
request_kwargs,
request_context,
call_method=request_meta.call_method,
shard_key=request_meta.shard_key,
metadata=request_meta,
async_future=asyncio.get_event_loop().create_future())
async with self.flush_lock:
self.endpoint_queues[endpoint].appendleft(query)
@@ -301,7 +270,7 @@ class Router:
worker = self.replicas[backend_replica_tag]
try:
object_ref = worker.handle_request.remote(req.ray_serialize())
if req.is_shadow_query:
if req.metadata.is_shadow_query:
# No need to actually get the result, but we do need to wait
# until the call completes to mark the worker idle.
await asyncio.wait([object_ref])
@@ -351,7 +320,7 @@ class Router:
self._do_query(backend, backend_replica_tag, request))
# For shadow queries, just ignore the result.
if not request.is_shadow_query:
if not request.metadata.is_shadow_query:
chain_future(future, request.async_future)
worker_queue.appendleft(backend_replica_tag)
+7 -7
View File
@@ -61,8 +61,8 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
# Make sure it's the right request
got_work = await task_runner_mock_actor.get_recent_call.remote()
assert got_work.request_args[0] == 1
assert got_work.request_kwargs == {}
assert got_work.args[0] == 1
assert got_work.kwargs == {}
async def test_alter_backend(serve_instance, task_runner_mock_actor):
@@ -74,14 +74,14 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
task_runner_mock_actor)
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
got_work = await task_runner_mock_actor.get_recent_call.remote()
assert got_work.request_args[0] == 1
assert got_work.args[0] == 1
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter-2": 1}))
await q.add_new_worker.remote("backend-alter-2", "replica-1",
task_runner_mock_actor)
await q.enqueue_request.remote(RequestMetadata("svc", None), 2)
got_work = await task_runner_mock_actor.get_recent_call.remote()
assert got_work.request_args[0] == 2
assert got_work.args[0] == 2
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
@@ -106,7 +106,7 @@ async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
await runner.get_recent_call.remote()
for runner in (runner_1, runner_2)
]
assert [g.request_args[0] for g in got_work] == [1, 1]
assert [g.args[0] for g in got_work] == [1, 1]
async def test_queue_remove_replicas(serve_instance):
@@ -146,7 +146,7 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
for i, runner in enumerate(runners):
calls = await runner.get_all_calls.remote()
for call in calls:
runner_shard_keys[i].add(call.request_args[0])
runner_shard_keys[i].add(call.args[0])
await runner.clear_calls.remote()
# Send queries with the same shard keys a second time.
@@ -158,7 +158,7 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
for i, runner in enumerate(runners):
calls = await runner.get_all_calls.remote()
for call in calls:
assert call.request_args[0] in runner_shard_keys[i]
assert call.args[0] in runner_shard_keys[i]
async def test_router_use_max_concurrency(serve_instance):
+3 -3
View File
@@ -27,9 +27,9 @@ ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
def parse_request_item(request_item):
if request_item.request_context == TaskContext.Web:
if request_item.metadata.request_context == TaskContext.Web:
is_web_context = True
asgi_scope, body_bytes = request_item.request_args
asgi_scope, body_bytes = request_item.args
flask_request = build_flask_request(asgi_scope, io.BytesIO(body_bytes))
args = (flask_request, )
@@ -37,7 +37,7 @@ def parse_request_item(request_item):
else:
is_web_context = False
args = (FakeFlaskRequest(), )
kwargs = request_item.request_kwargs
kwargs = request_item.kwargs
return args, kwargs, is_web_context
+1
View File
@@ -37,6 +37,7 @@ scipy==1.4.1
tabulate
tensorboardX
uvicorn
dataclasses
# Requirements for running tests
blist; platform_system != "Windows"
+1 -1
View File
@@ -110,7 +110,7 @@ if os.getenv("RAY_USE_NEW_GCS") == "on":
# in this directory
extras = {
"debug": [],
"serve": ["uvicorn", "flask", "requests"],
"serve": ["uvicorn", "flask", "requests", "dataclasses"],
"tune": ["tabulate", "tensorboardX", "pandas"]
}