diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 87ebaa076..b8456dfa3 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -9,15 +9,15 @@ import numpy as np import ray from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, SERVE_NURSERY_NAME) -from ray.serve.global_state import (GlobalState, start_initial_state) +from ray.serve.global_state import GlobalState, start_initial_state from ray.serve.kv_store_service import SQLiteKVStore from ray.serve.task_runner import RayServeMixin, TaskRunnerActor -from ray.serve.utils import (block_until_http_ready, get_random_letters, - expand) +from ray.serve.utils import block_until_http_ready, get_random_letters, expand from ray.serve.exceptions import RayServeException, batch_annotation_not_found from ray.serve.backend_config import BackendConfig from ray.serve.policy import RoutePolicy from ray.serve.queues import Query + global_state = None @@ -61,19 +61,21 @@ def accept_batch(f): return f -def init(kv_store_connector=None, - kv_store_path=None, - blocking=False, - start_server=True, - http_host=DEFAULT_HTTP_HOST, - http_port=DEFAULT_HTTP_PORT, - ray_init_kwargs={ - "object_store_memory": int(1e8), - "num_cpus": max(cpu_count(), 8) - }, - gc_window_seconds=3600, - queueing_policy=RoutePolicy.Random, - policy_kwargs={}): +def init( + kv_store_connector=None, + kv_store_path=None, + blocking=False, + start_server=True, + http_host=DEFAULT_HTTP_HOST, + http_port=DEFAULT_HTTP_PORT, + ray_init_kwargs={ + "object_store_memory": int(1e8), + "num_cpus": max(cpu_count(), 8) + }, + gc_window_seconds=3600, + queueing_policy=RoutePolicy.Random, + policy_kwargs={}, +): """Initialize a serve cluster. If serve cluster has already initialized, this function will just return. @@ -127,7 +129,7 @@ def init(kv_store_connector=None, _, kv_store_path = mkstemp() # Serve has not been initialized, perform init sequence - # Todo, move the db to session_dir + # TODO move the db to session_dir # ray.worker._global_node.address_info["session_dir"] def kv_store_connector(namespace): return SQLiteKVStore(namespace, db_path=kv_store_path) @@ -143,11 +145,12 @@ def init(kv_store_connector=None, gc_window_seconds=gc_window_seconds) if start_server and blocking: - block_until_http_ready("http://{}:{}".format(http_host, http_port)) + block_until_http_ready("http://{}:{}/-/routes".format( + http_host, http_port)) @_ensure_connected -def create_endpoint(endpoint_name, route=None, blocking=True): +def create_endpoint(endpoint_name, route=None, methods=["GET"]): """Create a service endpoint given route_expression. Args: @@ -158,7 +161,9 @@ def create_endpoint(endpoint_name, route=None, blocking=True): blocking (bool): If true, the function will wait for service to be registered before returning """ - global_state.route_table.register_service(route, endpoint_name) + methods = [m.upper() for m in methods] + global_state.route_table.register_service( + route, endpoint_name, methods=methods) @_ensure_connected @@ -169,8 +174,8 @@ def set_backend_config(backend_tag, backend_config): backend_tag(str): A registered backend. backend_config(BackendConfig) : Desired backend configuration. """ - assert backend_tag in global_state.backend_table.list_backends(), ( - "Backend {} is not registered.".format(backend_tag)) + assert (backend_tag in global_state.backend_table.list_backends() + ), "Backend {} is not registered.".format(backend_tag) assert isinstance(backend_config, BackendConfig), ("backend_config must be" " of instance BackendConfig") @@ -211,8 +216,8 @@ def get_backend_config(backend_tag): Args: backend_tag(str): A registered backend. """ - assert backend_tag in global_state.backend_table.list_backends(), ( - "Backend {} is not registered.".format(backend_tag)) + assert (backend_tag in global_state.backend_table.list_backends() + ), "Backend {} is not registered.".format(backend_tag) backend_config_dict = global_state.backend_table.get_info(backend_tag) return BackendConfig(**backend_config_dict) @@ -249,8 +254,7 @@ def create_backend(func_or_class, " of instance BackendConfig") # Make sure the batch size is correct - should_accept_batch = (True if backend_config.max_batch_size is not None - else False) + should_accept_batch = backend_config.max_batch_size is not None if should_accept_batch and not _backend_accept_batch(func_or_class): raise batch_annotation_not_found if _backend_accept_batch(func_or_class): @@ -267,7 +271,10 @@ def create_backend(func_or_class, # on the left to make sure its methods are not overriden. @ray.remote class CustomActor(RayServeMixin, func_or_class): - pass + @wraps(func_or_class.__init__) + def __init__(self, *args, **kwargs): + init() # serve init + super().__init__(*args, **kwargs) arg_list = actor_init_args # ignore lint on lambda expression @@ -296,8 +303,8 @@ def create_backend(func_or_class, def _start_replica(backend_tag): - assert backend_tag in global_state.backend_table.list_backends(), ( - "Backend {} is not registered.".format(backend_tag)) + assert (backend_tag in global_state.backend_table.list_backends() + ), "Backend {} is not registered.".format(backend_tag) replica_tag = "{}#{}".format(backend_tag, get_random_letters(length=6)) @@ -327,11 +334,12 @@ def _start_replica(backend_tag): def _remove_replica(backend_tag): - assert backend_tag in global_state.backend_table.list_backends(), ( - "Backend {} is not registered.".format(backend_tag)) - assert len(global_state.backend_table.list_replicas(backend_tag)) > 0, ( - "Backend {} does not have enough replicas to be removed.".format( - backend_tag)) + assert (backend_tag in global_state.backend_table.list_backends() + ), "Backend {} is not registered.".format(backend_tag) + assert ( + len(global_state.backend_table.list_replicas(backend_tag)) > + 0), "Backend {} does not have enough replicas to be removed.".format( + backend_tag) replica_tag = global_state.backend_table.remove_replica(backend_tag) [replica_handle] = ray.get( @@ -347,8 +355,9 @@ def _remove_replica(backend_tag): # Remove the replica from router. # This will also destory the actor handle. - ray.get(global_state.init_or_get_router() - .remove_and_destory_replica.remote(backend_tag, replica_handle)) + ray.get( + global_state.init_or_get_router().remove_and_destory_replica.remote( + backend_tag, replica_handle)) @_ensure_connected @@ -359,8 +368,8 @@ def _scale(backend_tag, num_replicas): backend_tag (str): A registered backend. num_replicas (int): Desired number of replicas """ - assert backend_tag in global_state.backend_table.list_backends(), ( - "Backend {} is not registered.".format(backend_tag)) + assert (backend_tag in global_state.backend_table.list_backends() + ), "Backend {} is not registered.".format(backend_tag) assert num_replicas >= 0, ("Number of replicas must be" " greater than or equal to 0.") @@ -430,7 +439,10 @@ def split(endpoint_name, traffic_policy_dictionary): @_ensure_connected -def get_handle(endpoint_name, relative_slo_ms=None, absolute_slo_ms=None): +def get_handle(endpoint_name, + relative_slo_ms=None, + absolute_slo_ms=None, + missing_ok=False): """Retrieve RayServeHandle for service endpoint to invoke it from Python. Args: @@ -439,18 +451,26 @@ def get_handle(endpoint_name, relative_slo_ms=None, absolute_slo_ms=None): queries fired using this handle. (Default: None) absolute_slo_ms(float): Specify absolute deadline in milliseconds for queries fired using this handle. (Default: None) + missing_ok (bool): If true, skip the check for the endpoint existence. + It can be useful when the endpoint has not been registered. Returns: RayServeHandle """ - assert endpoint_name in expand( - global_state.route_table.list_service(include_headless=True).values()) + if not missing_ok: + assert endpoint_name in expand( + global_state.route_table.list_service( + include_headless=True).values()) # Delay import due to it's dependency on global_state from ray.serve.handle import RayServeHandle - return RayServeHandle(global_state.init_or_get_router(), endpoint_name, - relative_slo_ms, absolute_slo_ms) + return RayServeHandle( + global_state.init_or_get_router(), + endpoint_name, + relative_slo_ms, + absolute_slo_ms, + ) @_ensure_connected diff --git a/python/ray/serve/examples/echo.py b/python/ray/serve/examples/echo.py index 578a630fb..9a5464ea9 100644 --- a/python/ray/serve/examples/echo.py +++ b/python/ray/serve/examples/echo.py @@ -16,7 +16,7 @@ def echo(flask_request): serve.init(blocking=True) -serve.create_endpoint("my_endpoint", "/echo", blocking=True) +serve.create_endpoint("my_endpoint", "/echo") serve.create_backend(echo, "echo:v1") serve.link("my_endpoint", "echo:v1") diff --git a/python/ray/serve/examples/echo_actor.py b/python/ray/serve/examples/echo_actor.py index dfc165513..28f782708 100644 --- a/python/ray/serve/examples/echo_actor.py +++ b/python/ray/serve/examples/echo_actor.py @@ -25,7 +25,7 @@ class MagicCounter: serve.init(blocking=True) -serve.create_endpoint("magic_counter", "/counter", blocking=True) +serve.create_endpoint("magic_counter", "/counter") serve.create_backend(MagicCounter, "counter:v1", 42) # increment=42 serve.link("magic_counter", "counter:v1") diff --git a/python/ray/serve/examples/echo_actor_batch.py b/python/ray/serve/examples/echo_actor_batch.py index 5420e60ad..da1c13fd0 100644 --- a/python/ray/serve/examples/echo_actor_batch.py +++ b/python/ray/serve/examples/echo_actor_batch.py @@ -37,7 +37,7 @@ class MagicCounter: serve.init(blocking=True) -serve.create_endpoint("magic_counter", "/counter", blocking=True) +serve.create_endpoint("magic_counter", "/counter") b_config = BackendConfig(max_batch_size=5) serve.create_backend( MagicCounter, "counter:v1", 42, backend_config=b_config) # increment=42 diff --git a/python/ray/serve/examples/echo_batching.py b/python/ray/serve/examples/echo_batching.py index 6bea509de..79568b443 100644 --- a/python/ray/serve/examples/echo_batching.py +++ b/python/ray/serve/examples/echo_batching.py @@ -28,7 +28,7 @@ class MagicCounter: serve.init(blocking=True) -serve.create_endpoint("magic_counter", "/counter", blocking=True) +serve.create_endpoint("magic_counter", "/counter") # specify max_batch_size in BackendConfig b_config = BackendConfig(max_batch_size=5) serve.create_backend( diff --git a/python/ray/serve/examples/echo_error.py b/python/ray/serve/examples/echo_error.py index b286dc31d..0cb25ee10 100644 --- a/python/ray/serve/examples/echo_error.py +++ b/python/ray/serve/examples/echo_error.py @@ -28,7 +28,7 @@ def echo(_): serve.init(blocking=True) -serve.create_endpoint("my_endpoint", "/echo", blocking=True) +serve.create_endpoint("my_endpoint", "/echo") serve.create_backend(echo, "echo:v1") serve.link("my_endpoint", "echo:v1") diff --git a/python/ray/serve/examples/echo_fixed_packing.py b/python/ray/serve/examples/echo_fixed_packing.py index fa0258a4b..1594771b6 100644 --- a/python/ray/serve/examples/echo_fixed_packing.py +++ b/python/ray/serve/examples/echo_fixed_packing.py @@ -29,7 +29,7 @@ serve.init( policy_kwargs={"packing_num": 5}) # create a service -serve.create_endpoint("my_endpoint", "/echo", blocking=True) +serve.create_endpoint("my_endpoint", "/echo") # create first backend serve.create_backend(echo_v1, "echo:v1") diff --git a/python/ray/serve/examples/echo_full.py b/python/ray/serve/examples/echo_full.py index 23b9c64ae..395d6fab0 100644 --- a/python/ray/serve/examples/echo_full.py +++ b/python/ray/serve/examples/echo_full.py @@ -33,7 +33,7 @@ backend_config_v1 = serve.get_backend_config("echo:v1") # goes to my_endpoint will now goes to echo:v1 backend. serve.link("my_endpoint", "echo:v1") -print(requests.get("http://127.0.0.1:8000/echo", timeout=0.5).json()) +print(requests.get("http://127.0.0.1:8000/echo", timeout=0.5).text) # The service will be reachable from http print(ray.get(serve.get_handle("my_endpoint").remote(response="hello"))) @@ -55,7 +55,7 @@ serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5}) # Observe requests are now split between two backends. for _ in range(10): - print(requests.get("http://127.0.0.1:8000/echo").json()) + print(requests.get("http://127.0.0.1:8000/echo").text) time.sleep(0.5) # You can also change number of replicas diff --git a/python/ray/serve/examples/echo_round_robin.py b/python/ray/serve/examples/echo_round_robin.py index d17924df5..05e674d8b 100644 --- a/python/ray/serve/examples/echo_round_robin.py +++ b/python/ray/serve/examples/echo_round_robin.py @@ -22,7 +22,7 @@ def echo_v2(_): serve.init(blocking=True, queueing_policy=serve.RoutePolicy.RoundRobin) # create a service -serve.create_endpoint("my_endpoint", "/echo", blocking=True) +serve.create_endpoint("my_endpoint", "/echo") # create first backend serve.create_backend(echo_v1, "echo:v1") diff --git a/python/ray/serve/examples/echo_split.py b/python/ray/serve/examples/echo_split.py index a9a7dabe9..dfb406065 100644 --- a/python/ray/serve/examples/echo_split.py +++ b/python/ray/serve/examples/echo_split.py @@ -20,7 +20,7 @@ def echo_v2(_): serve.init(blocking=True) -serve.create_endpoint("my_endpoint", "/echo", blocking=True) +serve.create_endpoint("my_endpoint", "/echo") serve.create_backend(echo_v1, "echo:v1") serve.link("my_endpoint", "echo:v1") diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 937fbf002..173f7345c 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -27,19 +27,23 @@ class RayServeHandle: # raises RayTaskError Exception """ - def __init__(self, - router_handle, - endpoint_name, - relative_slo_ms=None, - absolute_slo_ms=None): + def __init__( + self, + router_handle, + endpoint_name, + relative_slo_ms=None, + absolute_slo_ms=None, + method_name=None, + ): self.router_handle = router_handle self.endpoint_name = endpoint_name - assert (relative_slo_ms is None - or absolute_slo_ms is None), ("Can't specify both " - "relative and absolute " - "slo's together!") + assert relative_slo_ms is None or absolute_slo_ms is None, ( + "Can't specify both " + "relative and absolute " + "slo's together!") self.relative_slo_ms = self._check_slo_ms(relative_slo_ms) self.absolute_slo_ms = self._check_slo_ms(absolute_slo_ms) + self.method_name = method_name def _check_slo_ms(self, slo_value): if slo_value is not None: @@ -59,23 +63,44 @@ class RayServeHandle: raise RayServeException( "handle.remote must be invoked with keyword arguments.") + method_name = self.method_name + if method_name is None: + method_name = "__call__" + # create RequestMetadata instance request_in_object = RequestMetadata( - self.endpoint_name, TaskContext.Python, self.relative_slo_ms, - self.absolute_slo_ms) + self.endpoint_name, + TaskContext.Python, + self.relative_slo_ms, + self.absolute_slo_ms, + call_method=method_name, + ) return self.router_handle.enqueue_request.remote( request_in_object, **kwargs) - def options(self, relative_slo_ms=None, absolute_slo_ms=None): + def options(self, + method_name=None, + relative_slo_ms=None, + absolute_slo_ms=None): # If both the slo's are None then then we use a high default # value so other queries can be prioritize and put in front of these # queries. - assert (relative_slo_ms is None - or absolute_slo_ms is None), ("Can't specify both " - "relative and absolute " - "slo's together!") - return RayServeHandle(self.router_handle, self.endpoint_name, - relative_slo_ms, absolute_slo_ms) + assert not all(absolute_slo_ms, + relative_slo_ms), ("Can't specify both " + "relative and absolute " + "slo's together!") + + # Don't override existing method + if method_name is None and self.method_name is not None: + method_name = self.method_name + + return RayServeHandle( + self.router_handle, + self.endpoint_name, + relative_slo_ms, + absolute_slo_ms, + method_name=method_name, + ) def get_traffic_policy(self): policy_table = serve.api._get_global_state().policy_table @@ -94,9 +119,9 @@ class RayServeHandle: backends = set(traffic_policy.keys()) return backends.pop() else: - assert backend_tag in traffic_policy, ( - "Backend {} not found in avaiable backends: {}.".format( - backend_tag, list(traffic_policy.keys()))) + assert (backend_tag in traffic_policy + ), "Backend {} not found in avaiable backends: {}.".format( + backend_tag, list(traffic_policy.keys())) return backend_tag def scale(self, new_num_replicas, backend_tag=None): @@ -118,9 +143,11 @@ RayServeHandle( URL="{http_endpoint}/{endpoint_name}", Traffic={traffic_policy} ) -""".format(endpoint_name=self.endpoint_name, - http_endpoint=self.get_http_endpoint(), - traffic_policy=self.get_traffic_policy()) +""".format( + endpoint_name=self.endpoint_name, + http_endpoint=self.get_http_endpoint(), + traffic_policy=self.get_traffic_policy(), + ) # TODO(simon): a convenience function that dumps equivalent requests # code for a given call. diff --git a/python/ray/serve/http_util.py b/python/ray/serve/http_util.py index 16c877152..177a89f4b 100644 --- a/python/ray/serve/http_util.py +++ b/python/ray/serve/http_util.py @@ -1,4 +1,5 @@ import io +import json import flask @@ -67,3 +68,58 @@ def build_wsgi_environ(scope, body): environ[corrected_name] = value return environ + + +class Response: + """ASGI compliant response class. + + It is expected to be called in async context and pass along + `scope, receive, send` as in ASGI spec. + + >>> await Response({"k": "v"}).send(scope, receive, send) + """ + + def __init__(self, content=None, status_code=200): + """Construct a HTTP Response based on input type. + + Args: + content (optional): Any JSON serializable object. + status_code (int, optional): Default status code is 200. + """ + self.status_code = status_code + self.raw_headers = [] + + if content is None: + self.body = b"" + self.set_content_type("text") + elif isinstance(content, bytes): + self.body = content + self.set_content_type("text") + elif isinstance(content, str): + self.body = content.encode("utf-8") + self.set_content_type("text-utf8") + else: + # Delayed import since utils depends on http_util + from ray.serve.utils import ServeEncoder + self.body = json.dumps( + content, cls=ServeEncoder, indent=2).encode() + self.set_content_type("json") + + def set_content_type(self, content_type): + if content_type == "text": + self.raw_headers.append([b"content-type", b"text/plain"]) + elif content_type == "text-utf8": + self.raw_headers.append( + [b"content-type", b"text/plain; charset=utf-8"]) + elif content_type == "json": + self.raw_headers.append([b"content-type", b"application/json"]) + else: + raise ValueError("Invalid content type {}".foramt(content_type)) + + async def send(self, scope, receive, send): + await send({ + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + }) + await send({"type": "http.response.body", "body": self.body}) diff --git a/python/ray/serve/kv_store_service.py b/python/ray/serve/kv_store_service.py index f9287a92a..b90a8ab09 100644 --- a/python/ray/serve/kv_store_service.py +++ b/python/ray/serve/kv_store_service.py @@ -1,7 +1,7 @@ import json import sqlite3 from abc import ABC -from typing import Union +from typing import Union, List from ray import cloudpickle as pickle import ray.experimental.internal_kv as ray_kv @@ -173,9 +173,11 @@ class SQLiteKVStore(NamespacedKVStore): class RoutingTable: def __init__(self, kv_connector): self.routing_table = kv_connector("routing_table") + self.methods_table = kv_connector("methods_table") self.request_count = 0 - def register_service(self, route: Union[str, None], service: str): + def register_service(self, route: Union[str, None], service: str, + methods: List[str]): """Create an entry in the routing table Args: @@ -183,8 +185,9 @@ class RoutingTable: service: service name. This is the name http actor will push the request to. """ - logger.debug("[KV] Registering route {} to service {}.".format( - route, service)) + logger.debug( + "[KV] Registering route {} to service {} with methods {}.".format( + route, service, methods)) # put no route services in default key if route is None: @@ -194,14 +197,23 @@ class RoutingTable: self.routing_table.put(NO_ROUTE_KEY, json.dumps(no_http_services)) else: self.routing_table.put(route, service) + self.methods_table.put(route, json.dumps(methods)) - def list_service(self, include_headless=False): + def list_service(self, include_headless=False, include_methods=False): """Returns the routing table. Args: include_headless: If True, returns a no route services (headless) services with normal services. (Default: False) + include_methods: If True, returns a mapping include the methods + list for each route. """ table = self.routing_table.as_dict() + if include_methods: + methods_table = self.methods_table.as_dict() + for route, methods in methods_table.items(): + if route in table: + table[route] = (table[route], json.loads(methods)) + if include_headless: table[NO_ROUTE_KEY] = json.loads(table.get(NO_ROUTE_KEY, "[]")) else: diff --git a/python/ray/serve/server.py b/python/ray/serve/server.py index 514530bd8..c19067788 100644 --- a/python/ray/serve/server.py +++ b/python/ray/serve/server.py @@ -1,5 +1,4 @@ import asyncio -import json import uvicorn @@ -7,48 +6,12 @@ import ray from ray.experimental.async_api import _async_init from ray.serve.constants import HTTP_ROUTER_CHECKER_INTERVAL_S from ray.serve.context import TaskContext -from ray.serve.utils import BytesEncoder from ray.serve.request_params import RequestMetadata +from ray.serve.http_util import Response from urllib.parse import parse_qs -class JSONResponse: - """ASGI compliant response class. - - It is expected to be called in async context and pass along - `scope, receive, send` as in ASGI spec. - - >>> await JSONResponse({"k": "v"})(scope, receive, send) - """ - - def __init__(self, content=None, status_code=200): - """Construct a JSON HTTP Response. - - Args: - content (optional): Any JSON serializable object. - status_code (int, optional): Default status code is 200. - """ - self.body = self.render(content) - self.status_code = status_code - self.raw_headers = [[b"content-type", b"application/json"]] - - def render(self, content): - if content is None: - return b"" - if isinstance(content, bytes): - return content - return json.dumps(content, cls=BytesEncoder, indent=2).encode() - - async def __call__(self, scope, receive, send): - await send({ - "type": "http.response.start", - "status": self.status_code, - "headers": self.raw_headers, - }) - await send({"type": "http.response.body", "body": self.body}) - - class HTTPProxy: """ This class should be instantiated and ran by ASGI server. @@ -63,6 +26,7 @@ class HTTPProxy: # Delay import due to GlobalState depends on HTTP actor from ray.serve.global_state import GlobalState + self.serve_global_state = GlobalState() self.route_table_cache = dict() @@ -75,7 +39,8 @@ class HTTPProxy: return self.route_table_cache = ( - self.serve_global_state.route_table.list_service()) + self.serve_global_state.route_table.list_service( + include_methods=True, include_headless=False)) await asyncio.sleep(interval) @@ -105,19 +70,38 @@ class HTTPProxy: return b"".join(body_buffer) - def _check_slo_ms(self, request_slo_ms): - if request_slo_ms is not None: - if len(request_slo_ms) != 1: - raise ValueError( - "Multiple SLO specified, please specific only one.") - request_slo_ms = request_slo_ms[0] - request_slo_ms = float(request_slo_ms) - if request_slo_ms < 0: - raise ValueError( - "Request SLO must be positive, it is {}".format( - request_slo_ms)) - return request_slo_ms - return None + def _parse_latency_slo(self, scope): + query_string = scope["query_string"].decode("ascii") + query_kwargs = parse_qs(query_string) + + relative_slo_ms = query_kwargs.pop("relative_slo_ms", None) + absolute_slo_ms = query_kwargs.pop("absolute_slo_ms", None) + relative_slo_ms = self._validate_slo_ms(relative_slo_ms) + absolute_slo_ms = self._validate_slo_ms(absolute_slo_ms) + if relative_slo_ms is not None and absolute_slo_ms is not None: + raise ValueError("Both relative and absolute slo's" + "cannot be specified.") + return relative_slo_ms, absolute_slo_ms + + def _validate_slo_ms(self, request_slo_ms): + if request_slo_ms is None: + return None + if len(request_slo_ms) != 1: + raise ValueError( + "Multiple SLO specified, please specific only one.") + request_slo_ms = request_slo_ms[0] + request_slo_ms = float(request_slo_ms) + if request_slo_ms < 0: + raise ValueError("Request SLO must be positive, it is {}".format( + request_slo_ms)) + return request_slo_ms + + def _make_error_sender(self, scope, receive, send): + async def sender(error_message, status_code): + response = Response(error_message, status_code=status_code) + await response.send(scope, receive, send) + + return sender async def __call__(self, scope, receive, send): # NOTE: This implements ASGI protocol specified in @@ -127,39 +111,39 @@ class HTTPProxy: await self.handle_lifespan_message(scope, receive, send) return + error_sender = self._make_error_sender(scope, receive, send) + assert scope["type"] == "http" current_path = scope["path"] - if current_path == "/": - await JSONResponse(self.route_table_cache)(scope, receive, send) + if current_path == "/-/routes": + await Response(self.route_table_cache).send(scope, receive, send) return # TODO(simon): Use werkzeug route mapper to support variable path if current_path not in self.route_table_cache: - error_message = ("Path {} not found. " - "Please ping http://.../ for routing table" - ).format(current_path) - await JSONResponse( - { - "error": error_message - }, status_code=404)(scope, receive, send) + error_message = ( + "Path {} not found. " + "Please ping http://.../-/routes for routing table" + ).format(current_path) + await error_sender(error_message, 404) + return + + endpoint_name, methods_allowed = self.route_table_cache[current_path] + + if scope["method"] not in methods_allowed: + error_message = ("Methods {} not allowed. " + "Avaiable HTTP methods are {}.").format( + scope["method"], methods_allowed) + await error_sender(error_message, 405) return - endpoint_name = self.route_table_cache[current_path] http_body_bytes = await self.receive_http_body(scope, receive, send) # get slo_ms before enqueuing the query - query_string = scope["query_string"].decode("ascii") - query_kwargs = parse_qs(query_string) - relative_slo_ms = query_kwargs.pop("relative_slo_ms", None) - absolute_slo_ms = query_kwargs.pop("absolute_slo_ms", None) try: - relative_slo_ms = self._check_slo_ms(relative_slo_ms) - absolute_slo_ms = self._check_slo_ms(absolute_slo_ms) - if relative_slo_ms is not None and absolute_slo_ms is not None: - raise ValueError("Both relative and absolute slo's" - "cannot be specified.") + relative_slo_ms, absolute_slo_ms = self._parse_latency_slo(scope) except ValueError as e: - await JSONResponse({"error": str(e)})(scope, receive, send) + await error_sender(str(e), 400) return # create objects necessary for enqueue @@ -167,23 +151,21 @@ class HTTPProxy: # https://github.com/ray-project/ray/issues/6944 # TODO(alind): remove list enclosing after issue is fixed args = (scope, [http_body_bytes]) + headers = {k.decode(): v.decode() for k, v in scope["headers"]} request_in_object = RequestMetadata( endpoint_name, TaskContext.Web, relative_slo_ms=relative_slo_ms, - absolute_slo_ms=absolute_slo_ms) + absolute_slo_ms=absolute_slo_ms, + call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__")) - actual_result = await (self.serve_global_state.init_or_get_router() - .enqueue_request.remote(request_in_object, - *args)) - result = actual_result - - if isinstance(result, ray.exceptions.RayTaskError): - await JSONResponse({ - "error": "internal error, please use python API to debug" - })(scope, receive, send) - else: - await JSONResponse({"result": result})(scope, receive, send) + try: + result = await (self.serve_global_state.init_or_get_router() + .enqueue_request.remote(request_in_object, *args)) + await Response(result).send(scope, receive, send) + except Exception as e: + error_message = "Internal Error. Traceback: {}.".format(e) + await error_sender(error_message, 500) @ray.remote diff --git a/python/ray/serve/task_runner.py b/python/ray/serve/task_runner.py index 65232e9f9..c5ebdd984 100644 --- a/python/ray/serve/task_runner.py +++ b/python/ray/serve/task_runner.py @@ -21,6 +21,9 @@ class TaskRunner: def __init__(self, func_to_run): self.func = func_to_run + # This parameter let argument inspection work with inner function. + self.__wrapped__ = func_to_run + def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) @@ -111,17 +114,33 @@ class RayServeMixin: "The avaiable methods are {}".format( method_name, dir(self))) - return getattr(self, method_name) + if method_name != "__call__": + return getattr(self, method_name) + else: + # For simple callables, we should just return the object so + # signature recoding will continue to funciton. + return self + + def _ray_serve_count_num_positional(self, f): + signature = inspect.signature(f) + counter = 0 + for param in signature.parameters.values(): + if (param.kind == param.POSITIONAL_OR_KEYWORD + and param.default is param.empty): + counter += 1 + return counter async def invoke_single(self, request_item): args, kwargs, is_web_context = parse_request_item(request_item) serve_context.web = is_web_context start_timestamp = time.time() + method_to_call = self._ray_serve_get_runner_method(request_item) + args = args if self._ray_serve_count_num_positional( + method_to_call) else [] + method_to_call = ensure_async(method_to_call) try: - result = await ensure_async( - self._ray_serve_get_runner_method(request_item))(*args, - **kwargs) + result = await method_to_call(*args, **kwargs) except Exception as e: result = wrap_to_ray_error(e) self._serve_metric_error_counter += 1 @@ -154,7 +173,8 @@ class RayServeMixin: args, kwargs, is_web_context = parse_request_item(item) context_flags.add(is_web_context) - call_methods.add(self._ray_serve_get_runner_method(item)) + call_method = self._ray_serve_get_runner_method(item) + call_methods.add(call_method) if is_web_context: # Python context only have kwargs @@ -168,7 +188,8 @@ class RayServeMixin: # Set the flask request as a list to conform # with batching semantics: when in batching # mode, each argument it turned into list. - arg_list.append(FakeFlaskRequest()) + if self._ray_serve_count_num_positional(call_method): + arg_list.append(FakeFlaskRequest()) try: # check mixing of query context diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 574ac29c4..e63d806b2 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -10,6 +10,41 @@ from ray.serve.exceptions import RayServeException from ray.serve.handle import RayServeHandle +def test_e2e(serve_instance): + serve.init() # so we have access to global state + serve.create_endpoint("endpoint", "/api", methods=["GET", "POST"]) + result = serve.api._get_global_state().route_table.list_service() + assert result["/api"] == "endpoint" + + retry_count = 5 + timeout_sleep = 0.5 + while True: + try: + resp = requests.get( + "http://127.0.0.1:8000/-/routes", timeout=0.5).json() + assert resp == {"/api": ["endpoint", ["GET", "POST"]]} + break + except Exception as e: + time.sleep(timeout_sleep) + timeout_sleep *= 2 + retry_count -= 1 + if retry_count == 0: + assert False, ("Route table hasn't been updated after 3 tries." + "The latest error was {}").format(e) + + def function(flask_request): + return {"method": flask_request.method} + + serve.create_backend(function, "echo:v1") + serve.link("endpoint", "echo:v1") + + resp = requests.get("http://127.0.0.1:8000/api").json()["method"] + assert resp == "GET" + + resp = requests.post("http://127.0.0.1:8000/api").json()["method"] + assert resp == "POST" + + def test_route_decorator(serve_instance): @serve.route("/hello_world") def hello_world(_): @@ -25,38 +60,8 @@ def test_route_decorator(serve_instance): hello_world.set_max_batch_size(2) -def test_e2e(serve_instance): - serve.init() # so we have access to global state - serve.create_endpoint("endpoint", "/api", blocking=True) - result = serve.api._get_global_state().route_table.list_service() - assert result["/api"] == "endpoint" - - retry_count = 5 - timeout_sleep = 0.5 - while True: - try: - resp = requests.get("http://127.0.0.1:8000/", timeout=0.5).json() - assert resp == result - break - except Exception: - time.sleep(timeout_sleep) - timeout_sleep *= 2 - retry_count -= 1 - if retry_count == 0: - assert False, "Route table hasn't been updated after 3 tries." - - def function(flask_request): - return "OK" - - serve.create_backend(function, "echo:v1") - serve.link("endpoint", "echo:v1") - - resp = requests.get("http://127.0.0.1:8000/api").json()["result"] - assert resp == "OK" - - def test_no_route(serve_instance): - serve.create_endpoint("noroute-endpoint", blocking=True) + serve.create_endpoint("noroute-endpoint") global_state = serve.api._get_global_state() result = global_state.route_table.list_service(include_headless=True) @@ -87,7 +92,8 @@ def test_scaling_replicas(serve_instance): serve.create_endpoint("counter", "/increment") # Keep checking the routing table until /increment is populated - while "/increment" not in requests.get("http://127.0.0.1:8000/").json(): + while "/increment" not in requests.get( + "http://127.0.0.1:8000/-/routes").json(): time.sleep(0.2) b_config = BackendConfig(num_replicas=2) @@ -96,7 +102,7 @@ def test_scaling_replicas(serve_instance): counter_result = [] for _ in range(10): - resp = requests.get("http://127.0.0.1:8000/increment").json()["result"] + resp = requests.get("http://127.0.0.1:8000/increment").json() counter_result.append(resp) # If the load is shared among two replicas. The max result cannot be 10. @@ -108,7 +114,7 @@ def test_scaling_replicas(serve_instance): counter_result = [] for _ in range(10): - resp = requests.get("http://127.0.0.1:8000/increment").json()["result"] + resp = requests.get("http://127.0.0.1:8000/increment").json() counter_result.append(resp) # Give some time for a replica to spin down. But majority of the request # should be served by the only remaining replica. @@ -129,7 +135,8 @@ def test_batching(serve_instance): serve.create_endpoint("counter1", "/increment") # Keep checking the routing table until /increment is populated - while "/increment" not in requests.get("http://127.0.0.1:8000/").json(): + while "/increment" not in requests.get( + "http://127.0.0.1:8000/-/routes").json(): time.sleep(0.2) # set the max batch size diff --git a/python/ray/serve/tests/test_util.py b/python/ray/serve/tests/test_util.py index 4cc944643..62d63565c 100644 --- a/python/ray/serve/tests/test_util.py +++ b/python/ray/serve/tests/test_util.py @@ -1,9 +1,9 @@ import json -from ray.serve.utils import BytesEncoder +from ray.serve.utils import ServeEncoder def test_bytes_encoder(): data_before = {"inp": {"nest": b"bytes"}} data_after = {"inp": {"nest": "bytes"}} - assert json.loads(json.dumps(data_before, cls=BytesEncoder)) == data_after + assert json.loads(json.dumps(data_before, cls=ServeEncoder)) == data_after diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index eb78eab6d..6baa907fb 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -11,6 +11,12 @@ from pygments import formatters, highlight, lexers from ray.serve.context import FakeFlaskRequest, TaskContext from ray.serve.http_util import build_flask_request import itertools +import numpy as np + +try: + import pydantic +except ImportError: + pydantic = None def expand(l): @@ -60,19 +66,23 @@ def _get_logger(): logger = _get_logger() -class BytesEncoder(json.JSONEncoder): - """Allow bytes to be part of the JSON document. - - BytesEncoder will walk the JSON tree and decode bytes with utf-8 codec. - - Example: - >>> json.dumps({b'a': b'c'}, cls=BytesEncoder) - '{"a":"c"}' +class ServeEncoder(json.JSONEncoder): + """Ray.Serve's utility JSON encoder. Adds support for: + - bytes + - Pydantic types + - Exceptions + - numpy.ndarray """ def default(self, o): # pylint: disable=E0202 if isinstance(o, bytes): return o.decode("utf-8") + if pydantic is not None and isinstance(o, pydantic.BaseModel): + return o.dict() + if isinstance(o, Exception): + return str(o) + if isinstance(o, np.ndarray): + return o.tolist() return super().default(o) diff --git a/python/ray/worker.py b/python/ray/worker.py index 28623b825..d28fb7f94 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1475,10 +1475,10 @@ def get(object_ids, timeout=None): "core_worker") and worker.core_worker.current_actor_is_asyncio(): global blocking_get_inside_async_warned if not blocking_get_inside_async_warned: - logger.warning("Using blocking ray.get inside async actor. " - "This blocks the event loop. Please use `await` " - "on object id with asyncio.gather if you want to " - "yield execution to the event loop instead.") + logger.debug("Using blocking ray.get inside async actor. " + "This blocks the event loop. Please use `await` " + "on object id with asyncio.gather if you want to " + "yield execution to the event loop instead.") blocking_get_inside_async_warned = True with profiling.profile("ray.get"): @@ -1586,9 +1586,9 @@ def wait(object_ids, num_returns=1, timeout=None): ) and timeout != 0: global blocking_wait_inside_async_warned if not blocking_wait_inside_async_warned: - logger.warning("Using blocking ray.wait inside async method. " - "This blocks the event loop. Please use `await` " - "on object id with asyncio.wait. ") + logger.debug("Using blocking ray.wait inside async method. " + "This blocks the event loop. Please use `await` " + "on object id with asyncio.wait. ") blocking_wait_inside_async_warned = True if isinstance(object_ids, ObjectID):