[Serve] Enhancement in HTTP Methods and Multi-route support (#7709)

This commit is contained in:
Simon Mo
2020-03-24 20:25:05 -07:00
committed by GitHub
parent a1cee6af7b
commit a519b4f2a9
19 changed files with 359 additions and 224 deletions
+63 -43
View File
@@ -9,15 +9,15 @@ import numpy as np
import ray
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
SERVE_NURSERY_NAME)
from ray.serve.global_state import (GlobalState, start_initial_state)
from ray.serve.global_state import GlobalState, start_initial_state
from ray.serve.kv_store_service import SQLiteKVStore
from ray.serve.task_runner import RayServeMixin, TaskRunnerActor
from ray.serve.utils import (block_until_http_ready, get_random_letters,
expand)
from ray.serve.utils import block_until_http_ready, get_random_letters, expand
from ray.serve.exceptions import RayServeException, batch_annotation_not_found
from ray.serve.backend_config import BackendConfig
from ray.serve.policy import RoutePolicy
from ray.serve.queues import Query
global_state = None
@@ -61,19 +61,21 @@ def accept_batch(f):
return f
def init(kv_store_connector=None,
kv_store_path=None,
blocking=False,
start_server=True,
http_host=DEFAULT_HTTP_HOST,
http_port=DEFAULT_HTTP_PORT,
ray_init_kwargs={
"object_store_memory": int(1e8),
"num_cpus": max(cpu_count(), 8)
},
gc_window_seconds=3600,
queueing_policy=RoutePolicy.Random,
policy_kwargs={}):
def init(
kv_store_connector=None,
kv_store_path=None,
blocking=False,
start_server=True,
http_host=DEFAULT_HTTP_HOST,
http_port=DEFAULT_HTTP_PORT,
ray_init_kwargs={
"object_store_memory": int(1e8),
"num_cpus": max(cpu_count(), 8)
},
gc_window_seconds=3600,
queueing_policy=RoutePolicy.Random,
policy_kwargs={},
):
"""Initialize a serve cluster.
If serve cluster has already initialized, this function will just return.
@@ -127,7 +129,7 @@ def init(kv_store_connector=None,
_, kv_store_path = mkstemp()
# Serve has not been initialized, perform init sequence
# Todo, move the db to session_dir
# TODO move the db to session_dir
# ray.worker._global_node.address_info["session_dir"]
def kv_store_connector(namespace):
return SQLiteKVStore(namespace, db_path=kv_store_path)
@@ -143,11 +145,12 @@ def init(kv_store_connector=None,
gc_window_seconds=gc_window_seconds)
if start_server and blocking:
block_until_http_ready("http://{}:{}".format(http_host, http_port))
block_until_http_ready("http://{}:{}/-/routes".format(
http_host, http_port))
@_ensure_connected
def create_endpoint(endpoint_name, route=None, blocking=True):
def create_endpoint(endpoint_name, route=None, methods=["GET"]):
"""Create a service endpoint given route_expression.
Args:
@@ -158,7 +161,9 @@ def create_endpoint(endpoint_name, route=None, blocking=True):
blocking (bool): If true, the function will wait for service to be
registered before returning
"""
global_state.route_table.register_service(route, endpoint_name)
methods = [m.upper() for m in methods]
global_state.route_table.register_service(
route, endpoint_name, methods=methods)
@_ensure_connected
@@ -169,8 +174,8 @@ def set_backend_config(backend_tag, backend_config):
backend_tag(str): A registered backend.
backend_config(BackendConfig) : Desired backend configuration.
"""
assert backend_tag in global_state.backend_table.list_backends(), (
"Backend {} is not registered.".format(backend_tag))
assert (backend_tag in global_state.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
assert isinstance(backend_config,
BackendConfig), ("backend_config must be"
" of instance BackendConfig")
@@ -211,8 +216,8 @@ def get_backend_config(backend_tag):
Args:
backend_tag(str): A registered backend.
"""
assert backend_tag in global_state.backend_table.list_backends(), (
"Backend {} is not registered.".format(backend_tag))
assert (backend_tag in global_state.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
backend_config_dict = global_state.backend_table.get_info(backend_tag)
return BackendConfig(**backend_config_dict)
@@ -249,8 +254,7 @@ def create_backend(func_or_class,
" of instance BackendConfig")
# Make sure the batch size is correct
should_accept_batch = (True if backend_config.max_batch_size is not None
else False)
should_accept_batch = backend_config.max_batch_size is not None
if should_accept_batch and not _backend_accept_batch(func_or_class):
raise batch_annotation_not_found
if _backend_accept_batch(func_or_class):
@@ -267,7 +271,10 @@ def create_backend(func_or_class,
# on the left to make sure its methods are not overriden.
@ray.remote
class CustomActor(RayServeMixin, func_or_class):
pass
@wraps(func_or_class.__init__)
def __init__(self, *args, **kwargs):
init() # serve init
super().__init__(*args, **kwargs)
arg_list = actor_init_args
# ignore lint on lambda expression
@@ -296,8 +303,8 @@ def create_backend(func_or_class,
def _start_replica(backend_tag):
assert backend_tag in global_state.backend_table.list_backends(), (
"Backend {} is not registered.".format(backend_tag))
assert (backend_tag in global_state.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
replica_tag = "{}#{}".format(backend_tag, get_random_letters(length=6))
@@ -327,11 +334,12 @@ def _start_replica(backend_tag):
def _remove_replica(backend_tag):
assert backend_tag in global_state.backend_table.list_backends(), (
"Backend {} is not registered.".format(backend_tag))
assert len(global_state.backend_table.list_replicas(backend_tag)) > 0, (
"Backend {} does not have enough replicas to be removed.".format(
backend_tag))
assert (backend_tag in global_state.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
assert (
len(global_state.backend_table.list_replicas(backend_tag)) >
0), "Backend {} does not have enough replicas to be removed.".format(
backend_tag)
replica_tag = global_state.backend_table.remove_replica(backend_tag)
[replica_handle] = ray.get(
@@ -347,8 +355,9 @@ def _remove_replica(backend_tag):
# Remove the replica from router.
# This will also destory the actor handle.
ray.get(global_state.init_or_get_router()
.remove_and_destory_replica.remote(backend_tag, replica_handle))
ray.get(
global_state.init_or_get_router().remove_and_destory_replica.remote(
backend_tag, replica_handle))
@_ensure_connected
@@ -359,8 +368,8 @@ def _scale(backend_tag, num_replicas):
backend_tag (str): A registered backend.
num_replicas (int): Desired number of replicas
"""
assert backend_tag in global_state.backend_table.list_backends(), (
"Backend {} is not registered.".format(backend_tag))
assert (backend_tag in global_state.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
assert num_replicas >= 0, ("Number of replicas must be"
" greater than or equal to 0.")
@@ -430,7 +439,10 @@ def split(endpoint_name, traffic_policy_dictionary):
@_ensure_connected
def get_handle(endpoint_name, relative_slo_ms=None, absolute_slo_ms=None):
def get_handle(endpoint_name,
relative_slo_ms=None,
absolute_slo_ms=None,
missing_ok=False):
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
Args:
@@ -439,18 +451,26 @@ def get_handle(endpoint_name, relative_slo_ms=None, absolute_slo_ms=None):
queries fired using this handle. (Default: None)
absolute_slo_ms(float): Specify absolute deadline in milliseconds for
queries fired using this handle. (Default: None)
missing_ok (bool): If true, skip the check for the endpoint existence.
It can be useful when the endpoint has not been registered.
Returns:
RayServeHandle
"""
assert endpoint_name in expand(
global_state.route_table.list_service(include_headless=True).values())
if not missing_ok:
assert endpoint_name in expand(
global_state.route_table.list_service(
include_headless=True).values())
# Delay import due to it's dependency on global_state
from ray.serve.handle import RayServeHandle
return RayServeHandle(global_state.init_or_get_router(), endpoint_name,
relative_slo_ms, absolute_slo_ms)
return RayServeHandle(
global_state.init_or_get_router(),
endpoint_name,
relative_slo_ms,
absolute_slo_ms,
)
@_ensure_connected
+1 -1
View File
@@ -16,7 +16,7 @@ def echo(flask_request):
serve.init(blocking=True)
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
serve.create_endpoint("my_endpoint", "/echo")
serve.create_backend(echo, "echo:v1")
serve.link("my_endpoint", "echo:v1")
+1 -1
View File
@@ -25,7 +25,7 @@ class MagicCounter:
serve.init(blocking=True)
serve.create_endpoint("magic_counter", "/counter", blocking=True)
serve.create_endpoint("magic_counter", "/counter")
serve.create_backend(MagicCounter, "counter:v1", 42) # increment=42
serve.link("magic_counter", "counter:v1")
@@ -37,7 +37,7 @@ class MagicCounter:
serve.init(blocking=True)
serve.create_endpoint("magic_counter", "/counter", blocking=True)
serve.create_endpoint("magic_counter", "/counter")
b_config = BackendConfig(max_batch_size=5)
serve.create_backend(
MagicCounter, "counter:v1", 42, backend_config=b_config) # increment=42
+1 -1
View File
@@ -28,7 +28,7 @@ class MagicCounter:
serve.init(blocking=True)
serve.create_endpoint("magic_counter", "/counter", blocking=True)
serve.create_endpoint("magic_counter", "/counter")
# specify max_batch_size in BackendConfig
b_config = BackendConfig(max_batch_size=5)
serve.create_backend(
+1 -1
View File
@@ -28,7 +28,7 @@ def echo(_):
serve.init(blocking=True)
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
serve.create_endpoint("my_endpoint", "/echo")
serve.create_backend(echo, "echo:v1")
serve.link("my_endpoint", "echo:v1")
@@ -29,7 +29,7 @@ serve.init(
policy_kwargs={"packing_num": 5})
# create a service
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
serve.create_endpoint("my_endpoint", "/echo")
# create first backend
serve.create_backend(echo_v1, "echo:v1")
+2 -2
View File
@@ -33,7 +33,7 @@ backend_config_v1 = serve.get_backend_config("echo:v1")
# goes to my_endpoint will now goes to echo:v1 backend.
serve.link("my_endpoint", "echo:v1")
print(requests.get("http://127.0.0.1:8000/echo", timeout=0.5).json())
print(requests.get("http://127.0.0.1:8000/echo", timeout=0.5).text)
# The service will be reachable from http
print(ray.get(serve.get_handle("my_endpoint").remote(response="hello")))
@@ -55,7 +55,7 @@ serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
# Observe requests are now split between two backends.
for _ in range(10):
print(requests.get("http://127.0.0.1:8000/echo").json())
print(requests.get("http://127.0.0.1:8000/echo").text)
time.sleep(0.5)
# You can also change number of replicas
@@ -22,7 +22,7 @@ def echo_v2(_):
serve.init(blocking=True, queueing_policy=serve.RoutePolicy.RoundRobin)
# create a service
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
serve.create_endpoint("my_endpoint", "/echo")
# create first backend
serve.create_backend(echo_v1, "echo:v1")
+1 -1
View File
@@ -20,7 +20,7 @@ def echo_v2(_):
serve.init(blocking=True)
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
serve.create_endpoint("my_endpoint", "/echo")
serve.create_backend(echo_v1, "echo:v1")
serve.link("my_endpoint", "echo:v1")
+51 -24
View File
@@ -27,19 +27,23 @@ class RayServeHandle:
# raises RayTaskError Exception
"""
def __init__(self,
router_handle,
endpoint_name,
relative_slo_ms=None,
absolute_slo_ms=None):
def __init__(
self,
router_handle,
endpoint_name,
relative_slo_ms=None,
absolute_slo_ms=None,
method_name=None,
):
self.router_handle = router_handle
self.endpoint_name = endpoint_name
assert (relative_slo_ms is None
or absolute_slo_ms is None), ("Can't specify both "
"relative and absolute "
"slo's together!")
assert relative_slo_ms is None or absolute_slo_ms is None, (
"Can't specify both "
"relative and absolute "
"slo's together!")
self.relative_slo_ms = self._check_slo_ms(relative_slo_ms)
self.absolute_slo_ms = self._check_slo_ms(absolute_slo_ms)
self.method_name = method_name
def _check_slo_ms(self, slo_value):
if slo_value is not None:
@@ -59,23 +63,44 @@ class RayServeHandle:
raise RayServeException(
"handle.remote must be invoked with keyword arguments.")
method_name = self.method_name
if method_name is None:
method_name = "__call__"
# create RequestMetadata instance
request_in_object = RequestMetadata(
self.endpoint_name, TaskContext.Python, self.relative_slo_ms,
self.absolute_slo_ms)
self.endpoint_name,
TaskContext.Python,
self.relative_slo_ms,
self.absolute_slo_ms,
call_method=method_name,
)
return self.router_handle.enqueue_request.remote(
request_in_object, **kwargs)
def options(self, relative_slo_ms=None, absolute_slo_ms=None):
def options(self,
method_name=None,
relative_slo_ms=None,
absolute_slo_ms=None):
# If both the slo's are None then then we use a high default
# value so other queries can be prioritize and put in front of these
# queries.
assert (relative_slo_ms is None
or absolute_slo_ms is None), ("Can't specify both "
"relative and absolute "
"slo's together!")
return RayServeHandle(self.router_handle, self.endpoint_name,
relative_slo_ms, absolute_slo_ms)
assert not all(absolute_slo_ms,
relative_slo_ms), ("Can't specify both "
"relative and absolute "
"slo's together!")
# Don't override existing method
if method_name is None and self.method_name is not None:
method_name = self.method_name
return RayServeHandle(
self.router_handle,
self.endpoint_name,
relative_slo_ms,
absolute_slo_ms,
method_name=method_name,
)
def get_traffic_policy(self):
policy_table = serve.api._get_global_state().policy_table
@@ -94,9 +119,9 @@ class RayServeHandle:
backends = set(traffic_policy.keys())
return backends.pop()
else:
assert backend_tag in traffic_policy, (
"Backend {} not found in avaiable backends: {}.".format(
backend_tag, list(traffic_policy.keys())))
assert (backend_tag in traffic_policy
), "Backend {} not found in avaiable backends: {}.".format(
backend_tag, list(traffic_policy.keys()))
return backend_tag
def scale(self, new_num_replicas, backend_tag=None):
@@ -118,9 +143,11 @@ RayServeHandle(
URL="{http_endpoint}/{endpoint_name}",
Traffic={traffic_policy}
)
""".format(endpoint_name=self.endpoint_name,
http_endpoint=self.get_http_endpoint(),
traffic_policy=self.get_traffic_policy())
""".format(
endpoint_name=self.endpoint_name,
http_endpoint=self.get_http_endpoint(),
traffic_policy=self.get_traffic_policy(),
)
# TODO(simon): a convenience function that dumps equivalent requests
# code for a given call.
+56
View File
@@ -1,4 +1,5 @@
import io
import json
import flask
@@ -67,3 +68,58 @@ def build_wsgi_environ(scope, body):
environ[corrected_name] = value
return environ
class Response:
"""ASGI compliant response class.
It is expected to be called in async context and pass along
`scope, receive, send` as in ASGI spec.
>>> await Response({"k": "v"}).send(scope, receive, send)
"""
def __init__(self, content=None, status_code=200):
"""Construct a HTTP Response based on input type.
Args:
content (optional): Any JSON serializable object.
status_code (int, optional): Default status code is 200.
"""
self.status_code = status_code
self.raw_headers = []
if content is None:
self.body = b""
self.set_content_type("text")
elif isinstance(content, bytes):
self.body = content
self.set_content_type("text")
elif isinstance(content, str):
self.body = content.encode("utf-8")
self.set_content_type("text-utf8")
else:
# Delayed import since utils depends on http_util
from ray.serve.utils import ServeEncoder
self.body = json.dumps(
content, cls=ServeEncoder, indent=2).encode()
self.set_content_type("json")
def set_content_type(self, content_type):
if content_type == "text":
self.raw_headers.append([b"content-type", b"text/plain"])
elif content_type == "text-utf8":
self.raw_headers.append(
[b"content-type", b"text/plain; charset=utf-8"])
elif content_type == "json":
self.raw_headers.append([b"content-type", b"application/json"])
else:
raise ValueError("Invalid content type {}".foramt(content_type))
async def send(self, scope, receive, send):
await send({
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
})
await send({"type": "http.response.body", "body": self.body})
+17 -5
View File
@@ -1,7 +1,7 @@
import json
import sqlite3
from abc import ABC
from typing import Union
from typing import Union, List
from ray import cloudpickle as pickle
import ray.experimental.internal_kv as ray_kv
@@ -173,9 +173,11 @@ class SQLiteKVStore(NamespacedKVStore):
class RoutingTable:
def __init__(self, kv_connector):
self.routing_table = kv_connector("routing_table")
self.methods_table = kv_connector("methods_table")
self.request_count = 0
def register_service(self, route: Union[str, None], service: str):
def register_service(self, route: Union[str, None], service: str,
methods: List[str]):
"""Create an entry in the routing table
Args:
@@ -183,8 +185,9 @@ class RoutingTable:
service: service name. This is the name http actor will push
the request to.
"""
logger.debug("[KV] Registering route {} to service {}.".format(
route, service))
logger.debug(
"[KV] Registering route {} to service {} with methods {}.".format(
route, service, methods))
# put no route services in default key
if route is None:
@@ -194,14 +197,23 @@ class RoutingTable:
self.routing_table.put(NO_ROUTE_KEY, json.dumps(no_http_services))
else:
self.routing_table.put(route, service)
self.methods_table.put(route, json.dumps(methods))
def list_service(self, include_headless=False):
def list_service(self, include_headless=False, include_methods=False):
"""Returns the routing table.
Args:
include_headless: If True, returns a no route services (headless)
services with normal services. (Default: False)
include_methods: If True, returns a mapping include the methods
list for each route.
"""
table = self.routing_table.as_dict()
if include_methods:
methods_table = self.methods_table.as_dict()
for route, methods in methods_table.items():
if route in table:
table[route] = (table[route], json.loads(methods))
if include_headless:
table[NO_ROUTE_KEY] = json.loads(table.get(NO_ROUTE_KEY, "[]"))
else:
+66 -84
View File
@@ -1,5 +1,4 @@
import asyncio
import json
import uvicorn
@@ -7,48 +6,12 @@ import ray
from ray.experimental.async_api import _async_init
from ray.serve.constants import HTTP_ROUTER_CHECKER_INTERVAL_S
from ray.serve.context import TaskContext
from ray.serve.utils import BytesEncoder
from ray.serve.request_params import RequestMetadata
from ray.serve.http_util import Response
from urllib.parse import parse_qs
class JSONResponse:
"""ASGI compliant response class.
It is expected to be called in async context and pass along
`scope, receive, send` as in ASGI spec.
>>> await JSONResponse({"k": "v"})(scope, receive, send)
"""
def __init__(self, content=None, status_code=200):
"""Construct a JSON HTTP Response.
Args:
content (optional): Any JSON serializable object.
status_code (int, optional): Default status code is 200.
"""
self.body = self.render(content)
self.status_code = status_code
self.raw_headers = [[b"content-type", b"application/json"]]
def render(self, content):
if content is None:
return b""
if isinstance(content, bytes):
return content
return json.dumps(content, cls=BytesEncoder, indent=2).encode()
async def __call__(self, scope, receive, send):
await send({
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
})
await send({"type": "http.response.body", "body": self.body})
class HTTPProxy:
"""
This class should be instantiated and ran by ASGI server.
@@ -63,6 +26,7 @@ class HTTPProxy:
# Delay import due to GlobalState depends on HTTP actor
from ray.serve.global_state import GlobalState
self.serve_global_state = GlobalState()
self.route_table_cache = dict()
@@ -75,7 +39,8 @@ class HTTPProxy:
return
self.route_table_cache = (
self.serve_global_state.route_table.list_service())
self.serve_global_state.route_table.list_service(
include_methods=True, include_headless=False))
await asyncio.sleep(interval)
@@ -105,19 +70,38 @@ class HTTPProxy:
return b"".join(body_buffer)
def _check_slo_ms(self, request_slo_ms):
if request_slo_ms is not None:
if len(request_slo_ms) != 1:
raise ValueError(
"Multiple SLO specified, please specific only one.")
request_slo_ms = request_slo_ms[0]
request_slo_ms = float(request_slo_ms)
if request_slo_ms < 0:
raise ValueError(
"Request SLO must be positive, it is {}".format(
request_slo_ms))
return request_slo_ms
return None
def _parse_latency_slo(self, scope):
query_string = scope["query_string"].decode("ascii")
query_kwargs = parse_qs(query_string)
relative_slo_ms = query_kwargs.pop("relative_slo_ms", None)
absolute_slo_ms = query_kwargs.pop("absolute_slo_ms", None)
relative_slo_ms = self._validate_slo_ms(relative_slo_ms)
absolute_slo_ms = self._validate_slo_ms(absolute_slo_ms)
if relative_slo_ms is not None and absolute_slo_ms is not None:
raise ValueError("Both relative and absolute slo's"
"cannot be specified.")
return relative_slo_ms, absolute_slo_ms
def _validate_slo_ms(self, request_slo_ms):
if request_slo_ms is None:
return None
if len(request_slo_ms) != 1:
raise ValueError(
"Multiple SLO specified, please specific only one.")
request_slo_ms = request_slo_ms[0]
request_slo_ms = float(request_slo_ms)
if request_slo_ms < 0:
raise ValueError("Request SLO must be positive, it is {}".format(
request_slo_ms))
return request_slo_ms
def _make_error_sender(self, scope, receive, send):
async def sender(error_message, status_code):
response = Response(error_message, status_code=status_code)
await response.send(scope, receive, send)
return sender
async def __call__(self, scope, receive, send):
# NOTE: This implements ASGI protocol specified in
@@ -127,39 +111,39 @@ class HTTPProxy:
await self.handle_lifespan_message(scope, receive, send)
return
error_sender = self._make_error_sender(scope, receive, send)
assert scope["type"] == "http"
current_path = scope["path"]
if current_path == "/":
await JSONResponse(self.route_table_cache)(scope, receive, send)
if current_path == "/-/routes":
await Response(self.route_table_cache).send(scope, receive, send)
return
# TODO(simon): Use werkzeug route mapper to support variable path
if current_path not in self.route_table_cache:
error_message = ("Path {} not found. "
"Please ping http://.../ for routing table"
).format(current_path)
await JSONResponse(
{
"error": error_message
}, status_code=404)(scope, receive, send)
error_message = (
"Path {} not found. "
"Please ping http://.../-/routes for routing table"
).format(current_path)
await error_sender(error_message, 404)
return
endpoint_name, methods_allowed = self.route_table_cache[current_path]
if scope["method"] not in methods_allowed:
error_message = ("Methods {} not allowed. "
"Avaiable HTTP methods are {}.").format(
scope["method"], methods_allowed)
await error_sender(error_message, 405)
return
endpoint_name = self.route_table_cache[current_path]
http_body_bytes = await self.receive_http_body(scope, receive, send)
# get slo_ms before enqueuing the query
query_string = scope["query_string"].decode("ascii")
query_kwargs = parse_qs(query_string)
relative_slo_ms = query_kwargs.pop("relative_slo_ms", None)
absolute_slo_ms = query_kwargs.pop("absolute_slo_ms", None)
try:
relative_slo_ms = self._check_slo_ms(relative_slo_ms)
absolute_slo_ms = self._check_slo_ms(absolute_slo_ms)
if relative_slo_ms is not None and absolute_slo_ms is not None:
raise ValueError("Both relative and absolute slo's"
"cannot be specified.")
relative_slo_ms, absolute_slo_ms = self._parse_latency_slo(scope)
except ValueError as e:
await JSONResponse({"error": str(e)})(scope, receive, send)
await error_sender(str(e), 400)
return
# create objects necessary for enqueue
@@ -167,23 +151,21 @@ class HTTPProxy:
# https://github.com/ray-project/ray/issues/6944
# TODO(alind): remove list enclosing after issue is fixed
args = (scope, [http_body_bytes])
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
request_in_object = RequestMetadata(
endpoint_name,
TaskContext.Web,
relative_slo_ms=relative_slo_ms,
absolute_slo_ms=absolute_slo_ms)
absolute_slo_ms=absolute_slo_ms,
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"))
actual_result = await (self.serve_global_state.init_or_get_router()
.enqueue_request.remote(request_in_object,
*args))
result = actual_result
if isinstance(result, ray.exceptions.RayTaskError):
await JSONResponse({
"error": "internal error, please use python API to debug"
})(scope, receive, send)
else:
await JSONResponse({"result": result})(scope, receive, send)
try:
result = await (self.serve_global_state.init_or_get_router()
.enqueue_request.remote(request_in_object, *args))
await Response(result).send(scope, receive, send)
except Exception as e:
error_message = "Internal Error. Traceback: {}.".format(e)
await error_sender(error_message, 500)
@ray.remote
+27 -6
View File
@@ -21,6 +21,9 @@ class TaskRunner:
def __init__(self, func_to_run):
self.func = func_to_run
# This parameter let argument inspection work with inner function.
self.__wrapped__ = func_to_run
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
@@ -111,17 +114,33 @@ class RayServeMixin:
"The avaiable methods are {}".format(
method_name, dir(self)))
return getattr(self, method_name)
if method_name != "__call__":
return getattr(self, method_name)
else:
# For simple callables, we should just return the object so
# signature recoding will continue to funciton.
return self
def _ray_serve_count_num_positional(self, f):
signature = inspect.signature(f)
counter = 0
for param in signature.parameters.values():
if (param.kind == param.POSITIONAL_OR_KEYWORD
and param.default is param.empty):
counter += 1
return counter
async def invoke_single(self, request_item):
args, kwargs, is_web_context = parse_request_item(request_item)
serve_context.web = is_web_context
start_timestamp = time.time()
method_to_call = self._ray_serve_get_runner_method(request_item)
args = args if self._ray_serve_count_num_positional(
method_to_call) else []
method_to_call = ensure_async(method_to_call)
try:
result = await ensure_async(
self._ray_serve_get_runner_method(request_item))(*args,
**kwargs)
result = await method_to_call(*args, **kwargs)
except Exception as e:
result = wrap_to_ray_error(e)
self._serve_metric_error_counter += 1
@@ -154,7 +173,8 @@ class RayServeMixin:
args, kwargs, is_web_context = parse_request_item(item)
context_flags.add(is_web_context)
call_methods.add(self._ray_serve_get_runner_method(item))
call_method = self._ray_serve_get_runner_method(item)
call_methods.add(call_method)
if is_web_context:
# Python context only have kwargs
@@ -168,7 +188,8 @@ class RayServeMixin:
# Set the flask request as a list to conform
# with batching semantics: when in batching
# mode, each argument it turned into list.
arg_list.append(FakeFlaskRequest())
if self._ray_serve_count_num_positional(call_method):
arg_list.append(FakeFlaskRequest())
try:
# check mixing of query context
+42 -35
View File
@@ -10,6 +10,41 @@ from ray.serve.exceptions import RayServeException
from ray.serve.handle import RayServeHandle
def test_e2e(serve_instance):
serve.init() # so we have access to global state
serve.create_endpoint("endpoint", "/api", methods=["GET", "POST"])
result = serve.api._get_global_state().route_table.list_service()
assert result["/api"] == "endpoint"
retry_count = 5
timeout_sleep = 0.5
while True:
try:
resp = requests.get(
"http://127.0.0.1:8000/-/routes", timeout=0.5).json()
assert resp == {"/api": ["endpoint", ["GET", "POST"]]}
break
except Exception as e:
time.sleep(timeout_sleep)
timeout_sleep *= 2
retry_count -= 1
if retry_count == 0:
assert False, ("Route table hasn't been updated after 3 tries."
"The latest error was {}").format(e)
def function(flask_request):
return {"method": flask_request.method}
serve.create_backend(function, "echo:v1")
serve.link("endpoint", "echo:v1")
resp = requests.get("http://127.0.0.1:8000/api").json()["method"]
assert resp == "GET"
resp = requests.post("http://127.0.0.1:8000/api").json()["method"]
assert resp == "POST"
def test_route_decorator(serve_instance):
@serve.route("/hello_world")
def hello_world(_):
@@ -25,38 +60,8 @@ def test_route_decorator(serve_instance):
hello_world.set_max_batch_size(2)
def test_e2e(serve_instance):
serve.init() # so we have access to global state
serve.create_endpoint("endpoint", "/api", blocking=True)
result = serve.api._get_global_state().route_table.list_service()
assert result["/api"] == "endpoint"
retry_count = 5
timeout_sleep = 0.5
while True:
try:
resp = requests.get("http://127.0.0.1:8000/", timeout=0.5).json()
assert resp == result
break
except Exception:
time.sleep(timeout_sleep)
timeout_sleep *= 2
retry_count -= 1
if retry_count == 0:
assert False, "Route table hasn't been updated after 3 tries."
def function(flask_request):
return "OK"
serve.create_backend(function, "echo:v1")
serve.link("endpoint", "echo:v1")
resp = requests.get("http://127.0.0.1:8000/api").json()["result"]
assert resp == "OK"
def test_no_route(serve_instance):
serve.create_endpoint("noroute-endpoint", blocking=True)
serve.create_endpoint("noroute-endpoint")
global_state = serve.api._get_global_state()
result = global_state.route_table.list_service(include_headless=True)
@@ -87,7 +92,8 @@ def test_scaling_replicas(serve_instance):
serve.create_endpoint("counter", "/increment")
# Keep checking the routing table until /increment is populated
while "/increment" not in requests.get("http://127.0.0.1:8000/").json():
while "/increment" not in requests.get(
"http://127.0.0.1:8000/-/routes").json():
time.sleep(0.2)
b_config = BackendConfig(num_replicas=2)
@@ -96,7 +102,7 @@ def test_scaling_replicas(serve_instance):
counter_result = []
for _ in range(10):
resp = requests.get("http://127.0.0.1:8000/increment").json()["result"]
resp = requests.get("http://127.0.0.1:8000/increment").json()
counter_result.append(resp)
# If the load is shared among two replicas. The max result cannot be 10.
@@ -108,7 +114,7 @@ def test_scaling_replicas(serve_instance):
counter_result = []
for _ in range(10):
resp = requests.get("http://127.0.0.1:8000/increment").json()["result"]
resp = requests.get("http://127.0.0.1:8000/increment").json()
counter_result.append(resp)
# Give some time for a replica to spin down. But majority of the request
# should be served by the only remaining replica.
@@ -129,7 +135,8 @@ def test_batching(serve_instance):
serve.create_endpoint("counter1", "/increment")
# Keep checking the routing table until /increment is populated
while "/increment" not in requests.get("http://127.0.0.1:8000/").json():
while "/increment" not in requests.get(
"http://127.0.0.1:8000/-/routes").json():
time.sleep(0.2)
# set the max batch size
+2 -2
View File
@@ -1,9 +1,9 @@
import json
from ray.serve.utils import BytesEncoder
from ray.serve.utils import ServeEncoder
def test_bytes_encoder():
data_before = {"inp": {"nest": b"bytes"}}
data_after = {"inp": {"nest": "bytes"}}
assert json.loads(json.dumps(data_before, cls=BytesEncoder)) == data_after
assert json.loads(json.dumps(data_before, cls=ServeEncoder)) == data_after
+18 -8
View File
@@ -11,6 +11,12 @@ from pygments import formatters, highlight, lexers
from ray.serve.context import FakeFlaskRequest, TaskContext
from ray.serve.http_util import build_flask_request
import itertools
import numpy as np
try:
import pydantic
except ImportError:
pydantic = None
def expand(l):
@@ -60,19 +66,23 @@ def _get_logger():
logger = _get_logger()
class BytesEncoder(json.JSONEncoder):
"""Allow bytes to be part of the JSON document.
BytesEncoder will walk the JSON tree and decode bytes with utf-8 codec.
Example:
>>> json.dumps({b'a': b'c'}, cls=BytesEncoder)
'{"a":"c"}'
class ServeEncoder(json.JSONEncoder):
"""Ray.Serve's utility JSON encoder. Adds support for:
- bytes
- Pydantic types
- Exceptions
- numpy.ndarray
"""
def default(self, o): # pylint: disable=E0202
if isinstance(o, bytes):
return o.decode("utf-8")
if pydantic is not None and isinstance(o, pydantic.BaseModel):
return o.dict()
if isinstance(o, Exception):
return str(o)
if isinstance(o, np.ndarray):
return o.tolist()
return super().default(o)
+7 -7
View File
@@ -1475,10 +1475,10 @@ def get(object_ids, timeout=None):
"core_worker") and worker.core_worker.current_actor_is_asyncio():
global blocking_get_inside_async_warned
if not blocking_get_inside_async_warned:
logger.warning("Using blocking ray.get inside async actor. "
"This blocks the event loop. Please use `await` "
"on object id with asyncio.gather if you want to "
"yield execution to the event loop instead.")
logger.debug("Using blocking ray.get inside async actor. "
"This blocks the event loop. Please use `await` "
"on object id with asyncio.gather if you want to "
"yield execution to the event loop instead.")
blocking_get_inside_async_warned = True
with profiling.profile("ray.get"):
@@ -1586,9 +1586,9 @@ def wait(object_ids, num_returns=1, timeout=None):
) and timeout != 0:
global blocking_wait_inside_async_warned
if not blocking_wait_inside_async_warned:
logger.warning("Using blocking ray.wait inside async method. "
"This blocks the event loop. Please use `await` "
"on object id with asyncio.wait. ")
logger.debug("Using blocking ray.wait inside async method. "
"This blocks the event loop. Please use `await` "
"on object id with asyncio.wait. ")
blocking_wait_inside_async_warned = True
if isinstance(object_ids, ObjectID):