From 55b6c19d98bab621f5e96fda5271b6e27f953f01 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Fri, 4 Sep 2020 15:50:56 -0700 Subject: [PATCH] [Serve] Implement ServeHandle refactoring (#10527) --- doc/source/serve/advanced.rst | 63 +++++ doc/source/serve/package-ref.rst | 6 + doc/source/serve/tutorials/batch.rst | 32 +-- python/ray/serve/backend_worker.py | 88 ++----- python/ray/serve/context.py | 27 -- .../examples/doc/snippet_model_composition.py | 10 +- .../ray/serve/examples/doc/tutorial_batch.py | 15 +- python/ray/serve/examples/echo_full.py | 7 +- python/ray/serve/handle.py | 42 +-- python/ray/serve/http_proxy.py | 3 +- python/ray/serve/request_params.py | 15 -- python/ray/serve/router.py | 21 +- python/ray/serve/tests/conftest.py | 2 + python/ray/serve/tests/test_api.py | 48 ++-- python/ray/serve/tests/test_backend_worker.py | 240 +++++++----------- python/ray/serve/tests/test_config.py | 4 +- python/ray/serve/tests/test_failure.py | 14 +- python/ray/serve/tests/test_handle.py | 54 +++- python/ray/serve/tests/test_regression.py | 4 +- python/ray/serve/tests/test_router.py | 3 +- python/ray/serve/utils.py | 82 ++++-- 21 files changed, 396 insertions(+), 384 deletions(-) delete mode 100644 python/ray/serve/request_params.py diff --git a/doc/source/serve/advanced.rst b/doc/source/serve/advanced.rst index 63a1bd37a..74762055c 100644 --- a/doc/source/serve/advanced.rst +++ b/doc/source/serve/advanced.rst @@ -348,3 +348,66 @@ You can follow the same pattern for other Starlette middlewares. Middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) ]) + + +.. _serve-handle-explainer: + +How do ``ServeHandle`` and ``ServeRequest`` work? +--------------------------------------------------- + +Ray Serve enables you to query models both from HTTP and Python. This feature +enables seamless :ref:`model composition`. You can +get a ``ServeHandle`` corresponding to an ``endpoint``, similar how you can +reach an endpoint through HTTP via a specific route. When you issue a request +to an endpoint through ``ServeHandle``, the request goes through the same code +path as an HTTP request would: choosing backends through :ref:`traffic +policies `, finding the next available replica, and +batching requests together. + +When the request arrives in the model, you can access the data similarly to how +you would with HTTP request. Here are some examples how ServeRequest mirrors Flask.Request: + +.. list-table:: + :header-rows: 1 + + * - HTTP + - ServeHandle + - | Request + | (Flask.Request and ServeRequest) + * - ``requests.get(..., headers={...})`` + - ``handle.options(http_headers={...})`` + - ``request.headers`` + * - ``requests.post(...)`` + - ``handle.options(http_method="POST")`` + - ``requests.method`` + * - ``request.get(..., json={...})`` + - ``handle.remote({...})`` + - ``request.json`` + * - ``request.get(..., form={...})`` + - ``handle.remote({...})`` + - ``request.form`` + * - ``request.get(..., params={"a":"b"})`` + - ``handle.remote(a="b")`` + - ``request.args`` + * - ``request.get(..., data="long string")`` + - ``request.remote("long string")`` + - ``request.data`` + * - ``N/A`` + - ``request.remote(python_object)`` + - ``request.data`` + +.. note:: + + You might have noticed that the last row of the table shows that ServeRequest supports + Python object pass through the handle. This is not possible in HTTP. If you + need to distinguish if the origin of the request is from Python or HTTP, you can do an ``isinstance`` + check: + + .. code-block:: python + + import flask + + if isinstance(request, flask.Request): + print("Request coming from web!") + elif isinstance(request, ServeRequest): + print("Request coming from Python!") diff --git a/doc/source/serve/package-ref.rst b/doc/source/serve/package-ref.rst index 03c89da27..7041f3b40 100644 --- a/doc/source/serve/package-ref.rst +++ b/doc/source/serve/package-ref.rst @@ -20,6 +20,12 @@ Handle API .. autoclass:: ray.serve.handle.RayServeHandle :members: remote, options +When calling from Python, the backend implementation will receive ``ServeRequest`` +objects instead of Flask requests. + +.. autoclass:: ray.serve.utils.ServeRequest + :members: + Batching Requests ----------------- .. autofunction:: ray.serve.accept_batch diff --git a/doc/source/serve/tutorials/batch.rst b/doc/source/serve/tutorials/batch.rst index 152d6d192..d8fc13a3d 100644 --- a/doc/source/serve/tutorials/batch.rst +++ b/doc/source/serve/tutorials/batch.rst @@ -30,28 +30,24 @@ You can use the ``@serve.accept_batch`` decorator to annotate a function or a cl This annotation is needed because batched backends have different APIs compared to single request backends. In a batched backend, the inputs are a list of values. -For single query backend, the input types are single flask request or Python -argument: +For single query backend, the input type is a single Flask request or +:mod:`ServeRequest `: .. code-block:: python def single_request( - flask_request: Flask.Request, - *, - python_arg: int = 0 + request: Union[Flask.Request, ServeRequest], ): pass -For batched backend, the inputs types are converted to list of their original +For batched backends, the input types are converted to list of their original types: .. code-block:: python @serve.accept_batch def batched_request( - flask_request: List[Flask.Request], - *, - python_arg: List[int] + request: List[Union[Flask.Request, ServeRequest]], ): pass @@ -70,6 +66,8 @@ configuration option limits the maximum possible batch size send to the backend. Ray Serve performs *opportunistic batching*. When a worker is free to evaluate the next batch, Ray Serve will look at the pending queries and take ``max(number_of_pending_queries, max_batch_size)`` queries to form a batch. + You can provide :mod:`batch_wait_timeout ` to override + this behavior to wait for a full batch to arrive before executing (under a timeout). .. literalinclude:: ../../../../python/ray/serve/examples/doc/tutorial_batch.py :start-after: __doc_deploy_begin__ @@ -85,17 +83,9 @@ Ray Serve was able to evaluate them in batches. :end-before: __doc_query_end__ What if you want to evaluate a whole batch in Python? Ray Serve allows you to send -queries via the Python API. You can use the boolean value ``serve.context.web`` to -distinguish the origin of the queries. A batch of queries can either come from -the web server or the Python API. Ray Serve will guarantee there won't be queries -with mixed origins. - -When the batch of requests comes from the web API, Ray Serve will fill the first -argument ``flask_requests`` with a list of ``Flask.Request`` objects and set -``serve.context.web = True``. When the batch of requests comes from the Python API, -Ray Serve will fill ``flask_requests`` arguments with placeholders, and directly inject -Python objects into the keyword arguments. In this case, the ``numbers`` argument -will be a list of Python integers. +queries via the Python API. A batch of queries can either come from the web server +or the Python API. Requests coming from the Python API will have the similar API +as Flask.Request. See more on the API :ref:`here`. .. literalinclude:: ../../../../python/ray/serve/examples/doc/tutorial_batch.py :start-after: __doc_define_servable_v1_begin__ @@ -110,7 +100,7 @@ Let's deploy the new version to the same endpoint. Don't forget to set To query the backend via Python API, we can use ``serve.get_handle`` to receive a handle to the corresponding "endpoint". To enqueue a query, you can call -``handle.remote(argument_name=argument_value)``. This call returns immediately +``handle.remote(data, argument_name=argument_value)``. This call returns immediately with a :ref:`Ray ObjectRef`. You can call `ray.get` to retrieve the result. diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index 75db22bcd..c8e188f15 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -2,7 +2,6 @@ import asyncio import traceback import inspect from collections.abc import Iterable -from collections import defaultdict from itertools import groupby from typing import Union, List, Any, Callable, Type import time @@ -10,8 +9,6 @@ import time import ray from ray.async_compat import sync_to_async -from ray.serve import context as serve_context -from ray.serve.context import FakeFlaskRequest from ray.serve.utils import (parse_request_item, _get_logger, chain_future, unpack_future) from ray.serve.exceptions import RayServeException @@ -212,43 +209,18 @@ class RayServeWorker: return self.callable return getattr(self.callable, method_name) - def has_positional_args(self, f: Callable) -> bool: - # NOTE: - # In the case of simple functions, not actors, the f will be - # function.__call__, but we need to inspect the function itself. - if self.is_function: - f = self.callable - - signature = inspect.signature(f) - for param in signature.parameters.values(): - if (param.kind == param.POSITIONAL_OR_KEYWORD - and param.default is param.empty): - return True - return False - - def _reset_context(self) -> None: - # NOTE(simon): context management won't work in async mode because - # many concurrent queries might be running at the same time. - serve_context.web = None - serve_context.batch_size = None - async def invoke_single(self, request_item: Query) -> Any: - args, kwargs, is_web_context = parse_request_item(request_item) - serve_context.web = is_web_context - - method_to_call = self.get_runner_method(request_item) - args = args if self.has_positional_args(method_to_call) else [] - method_to_call = ensure_async(method_to_call) + method_to_call = ensure_async(self.get_runner_method(request_item)) + arg = parse_request_item(request_item) start = time.time() try: - result = await method_to_call(*args, **kwargs) + result = await method_to_call(arg) self.request_counter.record(1, {"backend": self.backend_tag}) except Exception as e: result = wrap_to_ray_error(e) self.error_counter.record(1, {"backend": self.backend_tag}) - finally: - self._reset_context() + self.processing_latency_tracker.record( (time.time() - start) * 1000, { "backend": self.backend_tag, @@ -259,56 +231,28 @@ class RayServeWorker: return result async def invoke_batch(self, request_item_list: List[Query]) -> List[Any]: - arg_list = [] - kwargs_list = defaultdict(list) - context_flags = set() - batch_size = len(request_item_list) + args = [] call_methods = set() + batch_size = len(request_item_list) + # Construct the batch of requests for item in request_item_list: - args, kwargs, is_web_context = parse_request_item(item) - context_flags.add(is_web_context) - - call_method = self.get_runner_method(item) - call_methods.add(call_method) - - if is_web_context: - # Python context only have kwargs - flask_request = args[0] - arg_list.append(flask_request) - else: - # Web context only have one positional argument - for k, v in kwargs.items(): - kwargs_list[k].append(v) - - # Set the flask request as a list to conform - # with batching semantics: when in batching - # mode, each argument is turned into list. - if self.has_positional_args(call_method): - arg_list.append(FakeFlaskRequest()) + args.append(parse_request_item(item)) + call_methods.add(self.get_runner_method(item)) timing_start = time.time() try: - # Check mixing of query context (unified context needed). - if len(context_flags) != 1: - raise RayServeException( - "Batched queries contain mixed context. Please only send " - "the same type of requests in batching mode.") - serve_context.web = context_flags.pop() - if len(call_methods) != 1: raise RayServeException( - "Queries contain mixed calling methods. Please only send " - "the same type of requests in batching mode.") - call_method = ensure_async(call_methods.pop()) - - serve_context.batch_size = batch_size - # Flask requests are passed to __call__ as a list - arg_list = [arg_list] + f"Queries contain mixed calling methods: {call_methods}. " + "Please only send the same type of requests in batching " + "mode.") self.request_counter.record(batch_size, {"backend": self.backend_tag}) - result_list = await call_method(*arg_list, **kwargs_list) + + call_method = ensure_async(call_methods.pop()) + result_list = await call_method(args) if not isinstance(result_list, Iterable) or isinstance( result_list, (dict, set)): @@ -328,11 +272,9 @@ class RayServeWorker: "results with length equal to the batch size" ".".format(batch_size, len(result_list))) raise RayServeException(error_message) - self._reset_context() except Exception as e: wrapped_exception = wrap_to_ray_error(e) self.error_counter.record(1, {"backend": self.backend_tag}) - self._reset_context() result_list = [wrapped_exception for _ in range(batch_size)] self.processing_latency_tracker.record( diff --git a/python/ray/serve/context.py b/python/ray/serve/context.py index 1069bbe21..5d21c40e7 100644 --- a/python/ray/serve/context.py +++ b/python/ray/serve/context.py @@ -1,34 +1,7 @@ from enum import IntEnum -from ray.serve.exceptions import RayServeException - class TaskContext(IntEnum): """TaskContext constants for queue.enqueue method""" Web = 1 Python = 2 - - -# Global variable will be modified in worker -# web == True: currrently processing a request from web server -# web == False: currently processing a request from python -web = False - -# batching information in serve context -# batch_size == None : the backend doesn't support batching -# batch_size(int) : the number of elements of input list -batch_size = None - -_not_in_web_context_error = """ -Accessing the request object outside of the web context. Please use -"serve.context.web" to determine when the function is called within -a web context. -""" - - -class FakeFlaskRequest: - def __getattribute__(self, name): - raise RayServeException(_not_in_web_context_error) - - def __setattr__(self, name, value): - raise RayServeException(_not_in_web_context_error) diff --git a/python/ray/serve/examples/doc/snippet_model_composition.py b/python/ray/serve/examples/doc/snippet_model_composition.py index 739228584..25705a7f8 100644 --- a/python/ray/serve/examples/doc/snippet_model_composition.py +++ b/python/ray/serve/examples/doc/snippet_model_composition.py @@ -15,14 +15,14 @@ client = serve.start() # Let's define two models that just print out the data they received. -def model_one(_unused_flask_request, data=None): - print("Model 1 called with data ", data) +def model_one(request): + print("Model 1 called with data ", request.args.get("data")) return random() -def model_two(_unused_flask_request, data=None): - print("Model 2 called with data ", data) - return data +def model_two(request): + print("Model 2 called with data ", request.args.get("data")) + return request.args.get("data") class ComposedModel: diff --git a/python/ray/serve/examples/doc/tutorial_batch.py b/python/ray/serve/examples/doc/tutorial_batch.py index dfe24931f..271e69afa 100644 --- a/python/ray/serve/examples/doc/tutorial_batch.py +++ b/python/ray/serve/examples/doc/tutorial_batch.py @@ -57,17 +57,8 @@ print("Result returned:", results) # __doc_define_servable_v1_begin__ @serve.accept_batch -def batch_adder_v1(flask_requests: List, *, numbers: List = []): - # Depending on request context, we process the input data differently. - print("Current context is", "web" if serve.context.web else "python") - if serve.context.web: - # If the requests come from web request, we parse the flask request - # to numbers - numbers = [int(request.args["number"]) for request in flask_requests] - else: - # Otherwise, we are processing requests invoked directly from Python. - numbers = numbers - +def batch_adder_v1(requests: List): + numbers = [int(request.args["number"]) for request in requests] input_array = np.array(numbers) print("Our input array has shape:", input_array.shape) # Sleep for 200ms, this could be performing CPU intensive computation @@ -97,7 +88,7 @@ input_batch = list(range(9)) print("Input batch is", input_batch) # Input batch is [0, 1, 2, 3, 4, 5, 6, 7, 8] -result_batch = ray.get([handle.remote(numbers=i) for i in input_batch]) +result_batch = ray.get([handle.remote(number=i) for i in input_batch]) # Output # (pid=...) Current context is python # (pid=...) Our input array has shape: (1,) diff --git a/python/ray/serve/examples/echo_full.py b/python/ray/serve/examples/echo_full.py index 391015fad..9639f1a25 100644 --- a/python/ray/serve/examples/echo_full.py +++ b/python/ray/serve/examples/echo_full.py @@ -12,9 +12,8 @@ client = serve.start() # a backend can be a function or class. # it can be made to be invoked from web as well as python. -def echo_v1(flask_request, response="hello from python!"): - if serve.context.web: - response = flask_request.url +def echo_v1(flask_request): + response = flask_request.args.get("response", "web") return response @@ -46,7 +45,7 @@ client.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5}) # Observe requests are now split between two backends. for _ in range(10): print(requests.get("http://127.0.0.1:8000/echo").text) - time.sleep(0.5) + time.sleep(0.2) # You can also change number of replicas for each backend independently. client.update_backend_config("echo:v1", {"num_replicas": 2}) diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index de0cacaef..7943d7faa 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -1,7 +1,7 @@ -from typing import Optional +from typing import Optional, Dict, Any, Union from ray.serve.context import TaskContext -from ray.serve.request_params import RequestMetadata +from ray.serve.router import RequestMetadata class RayServeHandle: @@ -29,42 +29,53 @@ class RayServeHandle: self, router_handle, endpoint_name, - http_method=None, + *, method_name=None, shard_key=None, + http_method=None, + http_headers=None, ): self.router_handle = router_handle self.endpoint_name = endpoint_name - self.http_method = http_method + self.method_name = method_name self.shard_key = shard_key + self.http_method = http_method + self.http_headers = http_headers - def remote(self, *args, **kwargs): - """Invoke a request on the endpoint. + def remote(self, request_data: Optional[Union[Dict, Any]] = None, + **kwargs): + """Issue an asynchrounous request to the endpoint. - Returns a Ray ObjectRef whose result can be waited for or retrieved - using `ray.wait` or `ray.get`, respectively. + Returns a Ray ObjectRef whose results can be waited for or retrieved + using ray.wait or ray.get, respectively. Returns: ray.ObjectRef + Input: + request_data(dict, Any): If it's a dictionary, the data will be + available in ``request.json()`` or ``request.form()``. Otherwise, + it will be available in ``request.data``. + ``**kwargs``: All keyword arguments will be available in + ``request.args``. """ - if len(args) > 0: - raise ValueError( - "handle.remote must be invoked with keyword arguments.") request_metadata = RequestMetadata( self.endpoint_name, TaskContext.Python, - http_method=self.http_method or "GET", call_method=self.method_name or "__call__", shard_key=self.shard_key, + http_method=self.http_method or "GET", + http_headers=self.http_headers or dict(), ) return self.router_handle.enqueue_request.remote( - request_metadata, **kwargs) + request_metadata, request_data, **kwargs) def options(self, method_name: Optional[str] = None, + *, + shard_key: Optional[str] = None, http_method: Optional[str] = None, - shard_key: Optional[str] = None): + http_headers: Optional[Dict[str, str]] = None): """Set options for this handle. Args: @@ -77,9 +88,10 @@ class RayServeHandle: self.router_handle, self.endpoint_name, # Don't override existing method - http_method=self.http_method or http_method, method_name=self.method_name or method_name, shard_key=self.shard_key or shard_key, + http_method=self.http_method or http_method, + http_headers=self.http_headers or http_headers, ) def __repr__(self): diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index b275df65d..2d0e2265f 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -8,9 +8,8 @@ import ray from ray.exceptions import RayTaskError from ray.serve.context import TaskContext from ray.experimental import metrics -from ray.serve.request_params import RequestMetadata from ray.serve.http_util import Response -from ray.serve.router import Router +from ray.serve.router import Router, RequestMetadata # The maximum number of times to retry a request due to actor failure. # TODO(edoakes): this should probably be configurable. diff --git a/python/ray/serve/request_params.py b/python/ray/serve/request_params.py deleted file mode 100644 index bd8265e47..000000000 --- a/python/ray/serve/request_params.py +++ /dev/null @@ -1,15 +0,0 @@ -from dataclasses import dataclass -from typing import Optional - -from ray.serve.context import TaskContext - - -@dataclass -class RequestMetadata: - endpoint: str - request_context: TaskContext - - call_method: str = "__call__" - shard_key: Optional[str] = None - http_method: str = "GET" - is_shadow_query: bool = False diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index ba20e1db5..aac4b8eac 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -4,7 +4,7 @@ from collections import defaultdict, deque import time from typing import DefaultDict, List, Dict, Any, Optional import pickle -from dataclasses import dataclass +from dataclasses import dataclass, field from ray.exceptions import RayTaskError @@ -12,12 +12,29 @@ import ray from ray.experimental import metrics from ray.serve.context import TaskContext from ray.serve.endpoint_policy import RandomEndpointPolicy -from ray.serve.request_params import RequestMetadata from ray.serve.utils import logger, chain_future REPORT_QUEUE_LENGTH_PERIOD_S = 1.0 +@dataclass +class RequestMetadata: + endpoint: str + request_context: TaskContext + + call_method: str = "__call__" + shard_key: Optional[str] = None + + http_method: str = "GET" + http_headers: Dict[str, str] = field(default_factory=dict) + + is_shadow_query: bool = False + + def __post_init__(self): + self.http_headers.setdefault("X-Serve-Call-Method", self.call_method) + self.http_headers.setdefault("X-Serve-Shard-Key", self.shard_key) + + @dataclass class Query: args: List[Any] diff --git a/python/ray/serve/tests/conftest.py b/python/ray/serve/tests/conftest.py index fc1d53ae7..be3e39019 100644 --- a/python/ray/serve/tests/conftest.py +++ b/python/ray/serve/tests/conftest.py @@ -11,6 +11,8 @@ if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False): @pytest.fixture(scope="session") def _shared_serve_instance(): + # Uncomment the line below to turn on debug log for tests. + # os.environ["SERVE_LOG_DEBUG"] = "1" ray.init( num_cpus=36, _metrics_export_port=9999, diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 3759f32aa..26a88102b 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -242,9 +242,9 @@ def test_batching(serve_instance): self.count = 0 @serve.accept_batch - def __call__(self, flask_request, temp=None): + def __call__(self, requests): self.count += 1 - batch_size = serve.context.batch_size + batch_size = len(requests) return [self.count] * batch_size # set the max batch size @@ -281,10 +281,9 @@ def test_batching_legacy(serve_instance): self.count = 0 @serve.accept_batch - def __call__(self, flask_request, temp=None): + def __call__(self, request): self.count += 1 - batch_size = serve.context.batch_size - return [self.count] * batch_size + return [self.count] * len(request) # set the max batch size client.create_backend( @@ -305,7 +304,7 @@ def test_batching_legacy(serve_instance): future_list = [] handle = client.get_handle("counter1") for _ in range(20): - f = handle.remote(temp=1) + f = handle.remote() future_list.append(f) counter_result = ray.get(future_list) @@ -323,9 +322,8 @@ def test_batching_exception(serve_instance): self.count = 0 @serve.accept_batch - def __call__(self, flask_request, temp=None): - batch_size = serve.context.batch_size - return batch_size + def __call__(self, requests): + return len(requests) # set the max batch size client.create_backend( @@ -369,9 +367,8 @@ def test_updating_config(serve_instance): self.count = 0 @serve.accept_batch - def __call__(self, flask_request, temp=None): - batch_size = serve.context.batch_size - return [1] * batch_size + def __call__(self, request): + return [1] * len(request) client.create_backend( "bsimple:v1", @@ -405,9 +402,8 @@ def test_updating_config_legacy(serve_instance): self.count = 0 @serve.accept_batch - def __call__(self, flask_request, temp=None): - batch_size = serve.context.batch_size - return [1] * batch_size + def __call__(self, request): + return [1] * len(request) client.create_backend( "bsimple:v1", @@ -439,7 +435,7 @@ def test_updating_config_legacy(serve_instance): def test_delete_backend(serve_instance): client = serve_instance - def function(): + def function(_): return "hello" client.create_backend("delete:v1", function) @@ -466,7 +462,7 @@ def test_delete_backend(serve_instance): with pytest.raises(ValueError): client.set_traffic("delete_backend", {"delete:v1": 1.0}) - def function2(): + def function2(_): return "olleh" # Check that we can now reuse the previously delete backend's tag. @@ -480,7 +476,7 @@ def test_delete_backend(serve_instance): def test_delete_endpoint(serve_instance, route): client = serve_instance - def function(): + def function(_): return "hello" backend_name = "delete-endpoint:v1" @@ -521,7 +517,7 @@ def test_shard_key(serve_instance, route): traffic_dict = {} for i in range(num_backends): - def function(): + def function(_): return i backend_name = "backend-split-" + str(i) @@ -560,7 +556,7 @@ def test_multiple_instances(): client1 = serve.start(http_port=8001) - def function(): + def function(_): return "hello1" client1.create_backend(backend, function) @@ -572,7 +568,7 @@ def test_multiple_instances(): # the same names and check that they don't collide. client2 = serve.start(http_port=8002) - def function(): + def function(_): return "hello2" client2.create_backend(backend, function) @@ -889,19 +885,19 @@ def test_shadow_traffic(serve_instance): counter = RequestCounter.remote() - def f(): + def f(_): ray.get(counter.record.remote("backend1")) return "hello" - def f_shadow_1(): + def f_shadow_1(_): ray.get(counter.record.remote("backend2")) return "oops" - def f_shadow_2(): + def f_shadow_2(_): ray.get(counter.record.remote("backend3")) return "oops" - def f_shadow_3(): + def f_shadow_3(_): ray.get(counter.record.remote("backend4")) return "oops" @@ -950,7 +946,7 @@ def test_connect(serve_instance): # Check that you can call serve.connect() from within a backend for both # detached and non-detached instances. - def connect_in_backend(): + def connect_in_backend(_): client = serve.connect() client.create_backend("backend-ception", connect_in_backend) return client._controller_name diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index 9ca336d8f..b0bd6763a 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -8,8 +8,7 @@ from ray import serve import ray.serve.context as context from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error from ray.serve.controller import TrafficPolicy -from ray.serve.request_params import RequestMetadata -from ray.serve.router import Router +from ray.serve.router import Router, RequestMetadata from ray.serve.config import BackendConfig, BackendMetadata from ray.serve.exceptions import RayServeException @@ -45,85 +44,66 @@ def setup_worker(name, return worker +async def add_servable_to_router(servable, router, **kwargs): + worker = setup_worker("backend", servable, **kwargs) + await router.add_new_worker.remote("backend", "replica", worker) + await router.set_traffic.remote("endpoint", TrafficPolicy({ + "backend": 1.0 + })) + + if "backend_config" in kwargs: + await router.set_backend_config.remote("backend", + kwargs["backend_config"]) + return worker + + +def make_request_param(call_method="__call__"): + return RequestMetadata( + "endpoint", context.TaskContext.Python, call_method=call_method) + + +@pytest.fixture +def router(serve_instance): + q = ray.remote(Router).remote() + ray.get(q.setup.remote("", serve_instance._controller_name)) + yield q + ray.kill(q) + + async def test_runner_wraps_error(): wrapped = wrap_to_ray_error(Exception()) assert isinstance(wrapped, ray.exceptions.RayTaskError) -async def test_runner_actor(serve_instance): - q = ray.remote(Router).remote() - await q.setup.remote("", serve_instance._controller_name) +async def test_servable_function(serve_instance, router): + def echo(request): + return request.args["i"] - def echo(flask_request, i=None): - return i - - CONSUMER_NAME = "runner" - PRODUCER_NAME = "prod" - - worker = setup_worker(CONSUMER_NAME, echo) - await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) - - q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) + _ = await add_servable_to_router(echo, router) for query in [333, 444, 555]: - query_param = RequestMetadata(PRODUCER_NAME, - context.TaskContext.Python) - result = await q.enqueue_request.remote(query_param, i=query) + query_param = make_request_param() + result = await router.enqueue_request.remote(query_param, i=query) assert result == query -async def test_ray_serve_mixin(serve_instance): - q = ray.remote(Router).remote() - await q.setup.remote("", serve_instance._controller_name) - - CONSUMER_NAME = "runner-cls" - PRODUCER_NAME = "prod-cls" - +async def test_servable_class(serve_instance, router): class MyAdder: def __init__(self, inc): self.increment = inc - def __call__(self, flask_request, i=None): - return i + self.increment + def __call__(self, request): + return request.args["i"] + self.increment - worker = setup_worker(CONSUMER_NAME, MyAdder, init_args=(3, )) - await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) - - q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) + _ = await add_servable_to_router(MyAdder, router, init_args=(3, )) for query in [333, 444, 555]: - query_param = RequestMetadata(PRODUCER_NAME, - context.TaskContext.Python) - result = await q.enqueue_request.remote(query_param, i=query) + query_param = make_request_param() + result = await router.enqueue_request.remote(query_param, i=query) assert result == query + 3 -async def test_task_runner_check_context(serve_instance): - q = ray.remote(Router).remote() - await q.setup.remote("", serve_instance._controller_name) - - def echo(flask_request, i=None): - # Accessing the flask_request without web context should throw. - return flask_request.args["i"] - - CONSUMER_NAME = "runner" - PRODUCER_NAME = "producer" - - worker = setup_worker(CONSUMER_NAME, echo) - await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) - - q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) - query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) - result_oid = q.enqueue_request.remote(query_param, i=42) - - with pytest.raises(ray.exceptions.RayTaskError): - await result_oid - - -async def test_task_runner_custom_method_single(serve_instance): - q = ray.remote(Router).remote() - await q.setup.remote("", serve_instance._controller_name) - +async def test_task_runner_custom_method_single(serve_instance, router): class NonBatcher: def a(self, _): return "a" @@ -131,129 +111,97 @@ async def test_task_runner_custom_method_single(serve_instance): def b(self, _): return "b" - CONSUMER_NAME = "runner" - PRODUCER_NAME = "producer" + _ = await add_servable_to_router(NonBatcher, router) - worker = setup_worker(CONSUMER_NAME, NonBatcher) - await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) - - q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) - - query_param = RequestMetadata( - PRODUCER_NAME, context.TaskContext.Python, call_method="a") - a_result = await q.enqueue_request.remote(query_param) + query_param = make_request_param("a") + a_result = await router.enqueue_request.remote(query_param) assert a_result == "a" - query_param = RequestMetadata( - PRODUCER_NAME, context.TaskContext.Python, call_method="b") - b_result = await q.enqueue_request.remote(query_param) + query_param = make_request_param("b") + b_result = await router.enqueue_request.remote(query_param) assert b_result == "b" - query_param = RequestMetadata( - PRODUCER_NAME, context.TaskContext.Python, call_method="non_exist") + query_param = make_request_param("non_exist") with pytest.raises(ray.exceptions.RayTaskError): - await q.enqueue_request.remote(query_param) + await router.enqueue_request.remote(query_param) -async def test_task_runner_custom_method_batch(serve_instance): - q = ray.remote(Router).remote() - await q.setup.remote("", serve_instance._controller_name) - +async def test_task_runner_custom_method_batch(serve_instance, router): @serve.accept_batch class Batcher: - def a(self, _): - return ["a-{}".format(i) for i in range(serve.context.batch_size)] + def a(self, requests): + return ["a-{}".format(i) for i in range(len(requests))] - def b(self, _): - return ["b-{}".format(i) for i in range(serve.context.batch_size)] - - def error_different_size(self, _): - return [""] * (serve.context.batch_size * 2) - - def error_non_iterable(self, _): - return 42 - - def return_np_array(self, _): - return np.array([1] * serve.context.batch_size).astype(np.int32) - - CONSUMER_NAME = "runner" - PRODUCER_NAME = "producer" + def b(self, requests): + return ["b-{}".format(i) for i in range(len(requests))] backend_config = BackendConfig( max_batch_size=4, - batch_wait_timeout=2, + batch_wait_timeout=10, internal_metadata=BackendMetadata(accepts_batches=True)) - worker = setup_worker( - CONSUMER_NAME, Batcher, backend_config=backend_config) - - await q.set_traffic.remote(PRODUCER_NAME, - TrafficPolicy({ - CONSUMER_NAME: 1.0 - })) - await q.set_backend_config.remote(CONSUMER_NAME, backend_config) - - def make_request_param(call_method): - return RequestMetadata( - PRODUCER_NAME, context.TaskContext.Python, call_method=call_method) + _ = await add_servable_to_router( + Batcher, router, backend_config=backend_config) a_query_param = make_request_param("a") b_query_param = make_request_param("b") - futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)] - futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)] - - await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) + futures = [router.enqueue_request.remote(a_query_param) for _ in range(2)] + futures += [router.enqueue_request.remote(b_query_param) for _ in range(2)] gathered = await asyncio.gather(*futures) assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"} + +async def test_servable_batch_error(serve_instance, router): + @serve.accept_batch + class ErrorBatcher: + def error_different_size(self, requests): + return [""] * (len(requests) + 10) + + def error_non_iterable(self, _): + return 42 + + def return_np_array(self, requests): + return np.array([1] * len(requests)).astype(np.int32) + + backend_config = BackendConfig( + max_batch_size=4, + internal_metadata=BackendMetadata(accepts_batches=True)) + _ = await add_servable_to_router( + ErrorBatcher, router, backend_config=backend_config) + with pytest.raises(RayServeException, match="doesn't preserve batch size"): different_size = make_request_param("error_different_size") - await q.enqueue_request.remote(different_size) + await router.enqueue_request.remote(different_size) with pytest.raises(RayServeException, match="iterable"): non_iterable = make_request_param("error_non_iterable") - await q.enqueue_request.remote(non_iterable) + await router.enqueue_request.remote(non_iterable) np_array = make_request_param("return_np_array") - result_np_value = await q.enqueue_request.remote(np_array) + result_np_value = await router.enqueue_request.remote(np_array) assert isinstance(result_np_value, np.int32) -async def test_task_runner_perform_batch(serve_instance): - q = ray.remote(Router).remote() - await q.setup.remote("", serve_instance._controller_name) - - def batcher(*args, **kwargs): - return [serve.context.batch_size] * serve.context.batch_size - - CONSUMER_NAME = "runner" - PRODUCER_NAME = "producer" +async def test_task_runner_perform_batch(serve_instance, router): + def batcher(requests): + batch_size = len(requests) + return [batch_size] * batch_size config = BackendConfig( max_batch_size=2, batch_wait_timeout=10, internal_metadata=BackendMetadata(accepts_batches=True)) - worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config) - await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) - await q.set_backend_config.remote(CONSUMER_NAME, config) - await q.set_traffic.remote(PRODUCER_NAME, - TrafficPolicy({ - CONSUMER_NAME: 1.0 - })) - - query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) + _ = await add_servable_to_router(batcher, router, backend_config=config) + query_param = make_request_param() my_batch_sizes = await asyncio.gather( - *[q.enqueue_request.remote(query_param) for _ in range(3)]) + *[router.enqueue_request.remote(query_param) for _ in range(3)]) assert my_batch_sizes == [2, 2, 1] -async def test_task_runner_perform_async(serve_instance): - q = ray.remote(Router).remote() - await q.setup.remote("", serve_instance._controller_name) - +async def test_task_runner_perform_async(serve_instance, router): @ray.remote class Barrier: def __init__(self, release_on): @@ -274,22 +222,18 @@ async def test_task_runner_perform_async(serve_instance): await barrier.wait.remote() return "done!" - CONSUMER_NAME = "runner" - PRODUCER_NAME = "producer" - config = BackendConfig( max_concurrent_queries=10, internal_metadata=BackendMetadata(is_blocking=False)) - worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config) - await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) - await q.set_backend_config.remote(CONSUMER_NAME, config) - q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) + _ = await add_servable_to_router( + wait_and_go, router, backend_config=config) - query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) + query_param = make_request_param() done, not_done = await asyncio.wait( - [q.enqueue_request.remote(query_param) for _ in range(10)], timeout=10) + [router.enqueue_request.remote(query_param) for _ in range(10)], + timeout=10) assert len(done) == 10 for item in done: await item == "done!" diff --git a/python/ray/serve/tests/test_config.py b/python/ray/serve/tests/test_config.py index 00dc6fa2f..3b07148c9 100644 --- a/python/ray/serve/tests/test_config.py +++ b/python/ray/serve/tests/test_config.py @@ -81,11 +81,11 @@ def test_replica_config_validation(): def __call__(self): pass - def function(): + def function(_): pass @serve.accept_batch - def batch_function(): + def batch_function(_): pass ReplicaConfig(Class) diff --git a/python/ray/serve/tests/test_failure.py b/python/ray/serve/tests/test_failure.py index 79fd0ea7e..92b89c7d5 100644 --- a/python/ray/serve/tests/test_failure.py +++ b/python/ray/serve/tests/test_failure.py @@ -23,7 +23,7 @@ def request_with_retries(endpoint, timeout=30): def test_controller_failure(serve_instance): client = serve_instance - def function(): + def function(_): return "hello1" client.create_backend("controller_failure:v1", function) @@ -45,7 +45,7 @@ def test_controller_failure(serve_instance): response = request_with_retries("/controller_failure", timeout=30) assert response.text == "hello1" - def function(): + def function(_): return "hello2" ray.kill(client._controller, no_restart=False) @@ -57,7 +57,7 @@ def test_controller_failure(serve_instance): response = request_with_retries("/controller_failure", timeout=30) assert response.text == "hello2" - def function(): + def function(_): return "hello3" ray.kill(client._controller, no_restart=False) @@ -85,7 +85,7 @@ def _kill_routers(client): def test_http_proxy_failure(serve_instance): client = serve_instance - def function(): + def function(_): return "hello1" client.create_backend("proxy_failure:v1", function) @@ -100,7 +100,7 @@ def test_http_proxy_failure(serve_instance): _kill_routers(client) - def function(): + def function(_): return "hello2" client.create_backend("proxy_failure:v2", function) @@ -213,7 +213,7 @@ def test_worker_replica_failure(serve_instance): def test_create_backend_idempotent(serve_instance): client = serve_instance - def f(): + def f(_): return "hello" controller = client._controller @@ -236,7 +236,7 @@ def test_create_backend_idempotent(serve_instance): def test_create_endpoint_idempotent(serve_instance): client = serve_instance - def f(): + def f(_): return "hello" client.create_backend("my_backend", f) diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index d070b6a8e..c0be51a06 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -1,8 +1,8 @@ +import requests + import ray from ray import serve -import requests - def test_handle_in_endpoint(serve_instance): client = serve_instance @@ -16,7 +16,7 @@ def test_handle_in_endpoint(serve_instance): client = serve.connect() self.handle = client.get_handle("endpoint1") - def __call__(self): + def __call__(self, _): return ray.get(self.handle.remote()) client.create_backend("endpoint1:v0", Endpoint1) @@ -36,6 +36,54 @@ def test_handle_in_endpoint(serve_instance): assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello" +def test_handle_http_args(serve_instance): + client = serve_instance + + class Endpoint: + def __call__(self, request): + return { + "args": dict(request.args), + "headers": dict(request.headers), + "method": request.method, + "json": request.json + } + + client.create_backend("backend", Endpoint) + client.create_endpoint( + "endpoint", backend="backend", route="/endpoint", methods=["POST"]) + + ground_truth = { + "args": { + "arg1": "1", + "arg2": "2" + }, + "headers": { + "X-Custom-Header": "value" + }, + "method": "POST", + "json": { + "json_key": "json_val" + } + } + + resp_web = requests.post( + "http://127.0.0.1:8000/endpoint?arg1=1&arg2=2", + headers=ground_truth["headers"], + json=ground_truth["json"]).json() + + handle = client.get_handle("endpoint") + resp_handle = ray.get( + handle.options( + http_method=ground_truth["method"], + http_headers=ground_truth["headers"]).remote( + ground_truth["json"], **ground_truth["args"])) + + for resp in [resp_web, resp_handle]: + for field in ["args", "method", "json"]: + assert resp[field] == ground_truth[field] + resp["headers"]["X-Custom-Header"] == "value" + + if __name__ == "__main__": import sys import pytest diff --git a/python/ray/serve/tests/test_regression.py b/python/ray/serve/tests/test_regression.py index 3b4be4ee6..010478211 100644 --- a/python/ray/serve/tests/test_regression.py +++ b/python/ray/serve/tests/test_regression.py @@ -11,8 +11,8 @@ def test_np_in_composed_model(serve_instance): # AttributeError: 'bytes' object has no attribute 'readonly' # in cloudpickle _from_numpy_buffer - def sum_model(_request, data=None): - return np.sum(data) + def sum_model(request): + return np.sum(request.args["data"]) class ComposedModel: def __init__(self): diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index cc896425e..e05c38f91 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -4,8 +4,7 @@ import pytest import ray from ray.serve.controller import TrafficPolicy -from ray.serve.router import Router, Query -from ray.serve.request_params import RequestMetadata +from ray.serve.router import Router, Query, RequestMetadata from ray.serve.utils import get_random_letters from ray.test_utils import SignalActor from ray.serve.config import BackendConfig diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index 7586fae96..1699d41de 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -9,37 +9,83 @@ import time from typing import List import io import os +from ray.serve.exceptions import RayServeException import requests +import numpy as np +import pydantic +from werkzeug.datastructures import ImmutableMultiDict import ray from ray.serve.constants import HTTP_PROXY_TIMEOUT -from ray.serve.context import FakeFlaskRequest, TaskContext +from ray.serve.context import TaskContext from ray.serve.http_util import build_flask_request -import numpy as np - -try: - import pydantic -except ImportError: - pydantic = None ACTOR_FAILURE_RETRY_TIMEOUT_S = 60 +class ServeRequest: + """The request object used in Python context. + + ServeRequest is built to have similar API as Flask.Request. You only need + to write your model serving code once; it can be queried by both HTTP and + Python. + """ + + def __init__(self, data, kwargs, headers, method): + self._data = data + self._kwargs = kwargs + self._headers = headers + self._method = method + + @property + def headers(self): + """The HTTP headers from ``handle.option(http_headers=...)``.""" + return self._headers + + @property + def method(self): + """The HTTP method data from ``handle.option(http_method=...)``.""" + return self._method + + @property + def args(self): + """The keyword arguments from ``handle.remote(**kwargs)``.""" + return ImmutableMultiDict(self._kwargs) + + @property + def json(self): + """The request dictionary, from ``handle.remote(dict)``.""" + if not isinstance(self._data, dict): + raise RayServeException("Request data is not a dictionary. " + f"It is {type(self._data)}.") + return self._data + + @property + def form(self): + """The request dictionary, from ``handle.remote(dict)``.""" + if not isinstance(self._data, dict): + raise RayServeException("Request data is not a dictionary. " + f"It is {type(self._data)}.") + return self._data + + @property + def data(self): + """The request data from ``handle.remote(obj)``.""" + return self._data + + def parse_request_item(request_item): if request_item.metadata.request_context == TaskContext.Web: - is_web_context = True asgi_scope, body_bytes = request_item.args - - flask_request = build_flask_request(asgi_scope, io.BytesIO(body_bytes)) - args = (flask_request, ) - kwargs = {} + return build_flask_request(asgi_scope, io.BytesIO(body_bytes)) else: - is_web_context = False - args = (FakeFlaskRequest(), ) - kwargs = request_item.kwargs - - return args, kwargs, is_web_context + return ServeRequest( + request_item.args[0] if len(request_item.args) == 1 else None, + request_item.kwargs, + headers=request_item.metadata.http_headers, + method=request_item.metadata.http_method, + ) def _get_logger(): @@ -66,7 +112,7 @@ class ServeEncoder(json.JSONEncoder): def default(self, o): # pylint: disable=E0202 if isinstance(o, bytes): return o.decode("utf-8") - if pydantic is not None and isinstance(o, pydantic.BaseModel): + if isinstance(o, pydantic.BaseModel): return o.dict() if isinstance(o, Exception): return str(o)