From 65f17f2e148ffe7c89376ed2e8f483961ddcece4 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 1 Sep 2020 18:15:31 -0700 Subject: [PATCH] [Serve] Refactor RequestMetadata and Query objects (#10483) --- python/ray/serve/backend_worker.py | 7 ++-- python/ray/serve/endpoint_policy.py | 6 +-- python/ray/serve/examples/echo.py | 17 +------- python/ray/serve/handle.py | 55 ++++++++++---------------- python/ray/serve/http_proxy.py | 1 + python/ray/serve/request_params.py | 37 ++++++----------- python/ray/serve/router.py | 57 ++++++--------------------- python/ray/serve/tests/test_router.py | 14 +++---- python/ray/serve/utils.py | 6 +-- python/requirements.txt | 1 + python/setup.py | 2 +- 11 files changed, 67 insertions(+), 136 deletions(-) diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index 11cd5bdc7..e5e9bba2e 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -4,7 +4,6 @@ import inspect from collections.abc import Iterable from collections import defaultdict from itertools import groupby -from operator import attrgetter from typing import Union, List, Any, Callable, Type import time @@ -185,7 +184,7 @@ class RayServeWorker: asyncio.get_event_loop().create_task(self.main_loop()) def get_runner_method(self, request_item: Query) -> Callable: - method_name = request_item.call_method + method_name = request_item.metadata.call_method if not hasattr(self.callable, method_name): raise RayServeException("Backend doesn't have method {} " "which is specified in the request. " @@ -325,7 +324,9 @@ class RayServeWorker: all_evaluated_futures = [evaluated] chain_future(evaluated, query.async_future) else: - get_call_method = attrgetter("call_method") + get_call_method = ( + lambda query: query.metadata.call_method # noqa: E731 + ) sorted_batch = sorted(batch, key=get_call_method) for _, group in groupby(sorted_batch, key=get_call_method): group = list(group) diff --git a/python/ray/serve/endpoint_policy.py b/python/ray/serve/endpoint_policy.py index c345ba30c..86728e68f 100644 --- a/python/ray/serve/endpoint_policy.py +++ b/python/ray/serve/endpoint_policy.py @@ -77,10 +77,10 @@ class RandomEndpointPolicy(EndpointPolicy): assigned_backends = set() while len(endpoint_queue) > 0: query = endpoint_queue.pop() - if query.shard_key is None: + if query.metadata.shard_key is None: rstate = np.random else: - sha256_seed = sha256(query.shard_key.encode("utf-8")) + sha256_seed = sha256(query.metadata.shard_key.encode("utf-8")) seed = np.frombuffer(sha256_seed.digest(), dtype=np.uint32) # Note(simon): This constructor takes 100+us, maybe cache this? rstate = np.random.RandomState(seed) @@ -93,7 +93,7 @@ class RandomEndpointPolicy(EndpointPolicy): if len(shadow_backends) > 0: shadow_query = copy.copy(query) shadow_query.async_future = None - shadow_query.is_shadow_query = True + shadow_query.metadata.is_shadow_query = True for shadow_backend in shadow_backends: assigned_backends.add(shadow_backend) backend_queues[shadow_backend].appendleft(shadow_query) diff --git a/python/ray/serve/examples/echo.py b/python/ray/serve/examples/echo.py index eb2f37358..40fcdc08e 100644 --- a/python/ray/serve/examples/echo.py +++ b/python/ray/serve/examples/echo.py @@ -2,28 +2,15 @@ Example service that prints out http context. """ -import json import time -from pygments import formatters, highlight, lexers - import requests from ray import serve -def pformat_color_json(d): - """Use pygments to pretty format and colorize dictionary""" - formatted_json = json.dumps(d, sort_keys=True, indent=4) - - colorful_json = highlight(formatted_json, lexers.JsonLexer(), - formatters.TerminalFormatter()) - - return colorful_json - - def echo(flask_request): - return "hello " + flask_request.args.get("name", "serve!") + return ["hello " + flask_request.args.get("name", "serve!")] serve.init() @@ -33,7 +20,7 @@ serve.create_endpoint("my_endpoint", backend="echo:v1", route="/echo") while True: resp = requests.get("http://127.0.0.1:8000/echo").json() - print(pformat_color_json(resp)) + print(resp) print("...Sleeping for 2 seconds...") time.sleep(2) diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 615f344dd..a665fa6d4 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -1,7 +1,8 @@ +from typing import Optional + import ray from ray import serve from ray.serve.context import TaskContext -from ray.serve.exceptions import RayServeException from ray.serve.request_params import RequestMetadata @@ -16,7 +17,6 @@ class RayServeHandle: >>> handle RayServeHandle( Endpoint="my_endpoint", - URL="...", Traffic=... ) >>> handle.remote(my_request_content) @@ -31,61 +31,48 @@ class RayServeHandle: self, router_handle, endpoint_name, + http_method=None, method_name=None, shard_key=None, ): self.router_handle = router_handle self.endpoint_name = endpoint_name + self.http_method = http_method self.method_name = method_name self.shard_key = shard_key def remote(self, *args, **kwargs): - if len(args) != 0: - raise RayServeException( + if len(args) > 0: + raise ValueError( "handle.remote must be invoked with keyword arguments.") - - method_name = self.method_name - if method_name is None: - method_name = "__call__" - - # create RequestMetadata instance - request_in_object = RequestMetadata( + request_metadata = RequestMetadata( self.endpoint_name, TaskContext.Python, - call_method=method_name, + http_method=self.http_method or "GET", + call_method=self.method_name or "__call__", shard_key=self.shard_key, ) return self.router_handle.enqueue_request.remote( - request_in_object, **kwargs) - - def options(self, method_name=None, shard_key=None): - - # Don't override existing method - if method_name is None and self.method_name is not None: - method_name = self.method_name - - if shard_key is None and self.shard_key is not None: - shard_key = self.shard_key + request_metadata, **kwargs) + def options(self, + method_name: Optional[str] = None, + http_method: Optional[str] = None, + shard_key: Optional[str] = None): return RayServeHandle( self.router_handle, self.endpoint_name, - method_name=method_name, - shard_key=shard_key, + # Don't override existing method + http_method=self.http_method or http_method, + method_name=self.method_name or method_name, + shard_key=self.shard_key or shard_key, ) - def get_traffic_policy(self): + def _get_traffic_policy(self): controller = serve.api._get_controller() return ray.get( controller.get_traffic_policy.remote(self.endpoint_name)) def __repr__(self): - return """ -RayServeHandle( - Endpoint="{endpoint_name}", - Traffic={traffic_policy} -) -""".format( - endpoint_name=self.endpoint_name, - traffic_policy=self.get_traffic_policy(), - ) + return (f"RayServeHandle(Endpoint='{self.endpoint_name}', " + f"Traffic={self._get_traffic_policy()})") diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index 243b79a26..b1a7672ad 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -111,6 +111,7 @@ class HTTPProxy: request_metadata = RequestMetadata( endpoint_name, TaskContext.Web, + http_method=scope["method"].upper(), call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"), shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None), ) diff --git a/python/ray/serve/request_params.py b/python/ray/serve/request_params.py index 76da18e12..bd8265e47 100644 --- a/python/ray/serve/request_params.py +++ b/python/ray/serve/request_params.py @@ -1,30 +1,15 @@ -import ray.cloudpickle as pickle +from dataclasses import dataclass +from typing import Optional + +from ray.serve.context import TaskContext +@dataclass class RequestMetadata: - """ - Request arguments required for enqueuing a request to the endpoint queue. + endpoint: str + request_context: TaskContext - Args: - endpoint(str): A registered endpoint. - request_context(TaskContext): Context of a request. - """ - - def __init__(self, - endpoint, - request_context, - call_method="__call__", - shard_key=None): - - self.endpoint = endpoint - self.request_context = request_context - self.call_method = call_method - self.shard_key = shard_key - - def ray_serialize(self): - return pickle.dumps(self.__dict__) - - @staticmethod - def ray_deserialize(value): - kwargs = pickle.loads(value) - return RequestMetadata(**kwargs) + call_method: str = "__call__" + shard_key: Optional[str] = None + http_method: str = "GET" + is_shadow_query: bool = False diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index 825182d7d..b058044c8 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -2,40 +2,31 @@ import asyncio import copy from collections import defaultdict, deque import time -from typing import DefaultDict, List +from typing import DefaultDict, List, Dict, Any, Optional import pickle +from dataclasses import dataclass from ray.exceptions import RayTaskError import ray from ray import serve from ray.experimental import metrics +from ray.serve.context import TaskContext from ray.serve.endpoint_policy import RandomEndpointPolicy +from ray.serve.request_params import RequestMetadata from ray.serve.utils import logger, chain_future REPORT_QUEUE_LENGTH_PERIOD_S = 1.0 +@dataclass class Query: - def __init__( - self, - request_args, - request_kwargs, - request_context, - call_method="__call__", - shard_key=None, - async_future=None, - is_shadow_query=False, - ): - self.request_args = request_args - self.request_kwargs = request_kwargs - self.request_context = request_context + args: List[Any] + kwargs: Dict[Any, Any] + context: TaskContext - self.async_future = async_future - - self.call_method = call_method - self.shard_key = shard_key - self.is_shadow_query = is_shadow_query + metadata: RequestMetadata + async_future: Optional[asyncio.Future] = None def __reduce__(self): return type(self).ray_deserialize, (self.ray_serialize(), ) @@ -56,27 +47,6 @@ class Query: return Query(**kwargs) -def _make_future_unwrapper(client_futures: List[asyncio.Future], - host_future: asyncio.Future): - """Distribute the result of host_future to each of client_future""" - for client_future in client_futures: - # Keep a reference to host future so the host future won't get - # garbage collected. - client_future.host_ref = host_future - - def unwrap_future(_): - result = host_future.result() - - if isinstance(result, list): - for client_future, result_item in zip(client_futures, result): - client_future.set_result(result_item) - else: # Result is an exception. - for client_future in client_futures: - client_future.set_result(result) - - return unwrap_future - - class Router: """A router that routes request to available workers.""" @@ -175,8 +145,7 @@ class Router: request_args, request_kwargs, request_context, - call_method=request_meta.call_method, - shard_key=request_meta.shard_key, + metadata=request_meta, async_future=asyncio.get_event_loop().create_future()) async with self.flush_lock: self.endpoint_queues[endpoint].appendleft(query) @@ -301,7 +270,7 @@ class Router: worker = self.replicas[backend_replica_tag] try: object_ref = worker.handle_request.remote(req.ray_serialize()) - if req.is_shadow_query: + if req.metadata.is_shadow_query: # No need to actually get the result, but we do need to wait # until the call completes to mark the worker idle. await asyncio.wait([object_ref]) @@ -351,7 +320,7 @@ class Router: self._do_query(backend, backend_replica_tag, request)) # For shadow queries, just ignore the result. - if not request.is_shadow_query: + if not request.metadata.is_shadow_query: chain_future(future, request.async_future) worker_queue.appendleft(backend_replica_tag) diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index a7e06a1fd..67cdc6749 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -61,8 +61,8 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor): # Make sure it's the right request got_work = await task_runner_mock_actor.get_recent_call.remote() - assert got_work.request_args[0] == 1 - assert got_work.request_kwargs == {} + assert got_work.args[0] == 1 + assert got_work.kwargs == {} async def test_alter_backend(serve_instance, task_runner_mock_actor): @@ -74,14 +74,14 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor): task_runner_mock_actor) await q.enqueue_request.remote(RequestMetadata("svc", None), 1) got_work = await task_runner_mock_actor.get_recent_call.remote() - assert got_work.request_args[0] == 1 + assert got_work.args[0] == 1 await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter-2": 1})) await q.add_new_worker.remote("backend-alter-2", "replica-1", task_runner_mock_actor) await q.enqueue_request.remote(RequestMetadata("svc", None), 2) got_work = await task_runner_mock_actor.get_recent_call.remote() - assert got_work.request_args[0] == 2 + assert got_work.args[0] == 2 async def test_split_traffic_random(serve_instance, task_runner_mock_actor): @@ -106,7 +106,7 @@ async def test_split_traffic_random(serve_instance, task_runner_mock_actor): await runner.get_recent_call.remote() for runner in (runner_1, runner_2) ] - assert [g.request_args[0] for g in got_work] == [1, 1] + assert [g.args[0] for g in got_work] == [1, 1] async def test_queue_remove_replicas(serve_instance): @@ -146,7 +146,7 @@ async def test_shard_key(serve_instance, task_runner_mock_actor): for i, runner in enumerate(runners): calls = await runner.get_all_calls.remote() for call in calls: - runner_shard_keys[i].add(call.request_args[0]) + runner_shard_keys[i].add(call.args[0]) await runner.clear_calls.remote() # Send queries with the same shard keys a second time. @@ -158,7 +158,7 @@ async def test_shard_key(serve_instance, task_runner_mock_actor): for i, runner in enumerate(runners): calls = await runner.get_all_calls.remote() for call in calls: - assert call.request_args[0] in runner_shard_keys[i] + assert call.args[0] in runner_shard_keys[i] async def test_router_use_max_concurrency(serve_instance): diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index 4cf46e2a7..a563ac87a 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -27,9 +27,9 @@ ACTOR_FAILURE_RETRY_TIMEOUT_S = 60 def parse_request_item(request_item): - if request_item.request_context == TaskContext.Web: + if request_item.metadata.request_context == TaskContext.Web: is_web_context = True - asgi_scope, body_bytes = request_item.request_args + asgi_scope, body_bytes = request_item.args flask_request = build_flask_request(asgi_scope, io.BytesIO(body_bytes)) args = (flask_request, ) @@ -37,7 +37,7 @@ def parse_request_item(request_item): else: is_web_context = False args = (FakeFlaskRequest(), ) - kwargs = request_item.request_kwargs + kwargs = request_item.kwargs return args, kwargs, is_web_context diff --git a/python/requirements.txt b/python/requirements.txt index 86589b3bc..799fd78ed 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -37,6 +37,7 @@ scipy==1.4.1 tabulate tensorboardX uvicorn +dataclasses # Requirements for running tests blist; platform_system != "Windows" diff --git a/python/setup.py b/python/setup.py index bbe2a0be4..ced73786a 100644 --- a/python/setup.py +++ b/python/setup.py @@ -110,7 +110,7 @@ if os.getenv("RAY_USE_NEW_GCS") == "on": # in this directory extras = { "debug": [], - "serve": ["uvicorn", "flask", "requests"], + "serve": ["uvicorn", "flask", "requests", "dataclasses"], "tune": ["tabulate", "tensorboardX", "pandas"] }