mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 07:07:00 +08:00
[Serve] Enhancement in HTTP Methods and Multi-route support (#7709)
This commit is contained in:
+63
-43
@@ -9,15 +9,15 @@ import numpy as np
|
||||
import ray
|
||||
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
|
||||
SERVE_NURSERY_NAME)
|
||||
from ray.serve.global_state import (GlobalState, start_initial_state)
|
||||
from ray.serve.global_state import GlobalState, start_initial_state
|
||||
from ray.serve.kv_store_service import SQLiteKVStore
|
||||
from ray.serve.task_runner import RayServeMixin, TaskRunnerActor
|
||||
from ray.serve.utils import (block_until_http_ready, get_random_letters,
|
||||
expand)
|
||||
from ray.serve.utils import block_until_http_ready, get_random_letters, expand
|
||||
from ray.serve.exceptions import RayServeException, batch_annotation_not_found
|
||||
from ray.serve.backend_config import BackendConfig
|
||||
from ray.serve.policy import RoutePolicy
|
||||
from ray.serve.queues import Query
|
||||
|
||||
global_state = None
|
||||
|
||||
|
||||
@@ -61,19 +61,21 @@ def accept_batch(f):
|
||||
return f
|
||||
|
||||
|
||||
def init(kv_store_connector=None,
|
||||
kv_store_path=None,
|
||||
blocking=False,
|
||||
start_server=True,
|
||||
http_host=DEFAULT_HTTP_HOST,
|
||||
http_port=DEFAULT_HTTP_PORT,
|
||||
ray_init_kwargs={
|
||||
"object_store_memory": int(1e8),
|
||||
"num_cpus": max(cpu_count(), 8)
|
||||
},
|
||||
gc_window_seconds=3600,
|
||||
queueing_policy=RoutePolicy.Random,
|
||||
policy_kwargs={}):
|
||||
def init(
|
||||
kv_store_connector=None,
|
||||
kv_store_path=None,
|
||||
blocking=False,
|
||||
start_server=True,
|
||||
http_host=DEFAULT_HTTP_HOST,
|
||||
http_port=DEFAULT_HTTP_PORT,
|
||||
ray_init_kwargs={
|
||||
"object_store_memory": int(1e8),
|
||||
"num_cpus": max(cpu_count(), 8)
|
||||
},
|
||||
gc_window_seconds=3600,
|
||||
queueing_policy=RoutePolicy.Random,
|
||||
policy_kwargs={},
|
||||
):
|
||||
"""Initialize a serve cluster.
|
||||
|
||||
If serve cluster has already initialized, this function will just return.
|
||||
@@ -127,7 +129,7 @@ def init(kv_store_connector=None,
|
||||
_, kv_store_path = mkstemp()
|
||||
|
||||
# Serve has not been initialized, perform init sequence
|
||||
# Todo, move the db to session_dir
|
||||
# TODO move the db to session_dir
|
||||
# ray.worker._global_node.address_info["session_dir"]
|
||||
def kv_store_connector(namespace):
|
||||
return SQLiteKVStore(namespace, db_path=kv_store_path)
|
||||
@@ -143,11 +145,12 @@ def init(kv_store_connector=None,
|
||||
gc_window_seconds=gc_window_seconds)
|
||||
|
||||
if start_server and blocking:
|
||||
block_until_http_ready("http://{}:{}".format(http_host, http_port))
|
||||
block_until_http_ready("http://{}:{}/-/routes".format(
|
||||
http_host, http_port))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def create_endpoint(endpoint_name, route=None, blocking=True):
|
||||
def create_endpoint(endpoint_name, route=None, methods=["GET"]):
|
||||
"""Create a service endpoint given route_expression.
|
||||
|
||||
Args:
|
||||
@@ -158,7 +161,9 @@ def create_endpoint(endpoint_name, route=None, blocking=True):
|
||||
blocking (bool): If true, the function will wait for service to be
|
||||
registered before returning
|
||||
"""
|
||||
global_state.route_table.register_service(route, endpoint_name)
|
||||
methods = [m.upper() for m in methods]
|
||||
global_state.route_table.register_service(
|
||||
route, endpoint_name, methods=methods)
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
@@ -169,8 +174,8 @@ def set_backend_config(backend_tag, backend_config):
|
||||
backend_tag(str): A registered backend.
|
||||
backend_config(BackendConfig) : Desired backend configuration.
|
||||
"""
|
||||
assert backend_tag in global_state.backend_table.list_backends(), (
|
||||
"Backend {} is not registered.".format(backend_tag))
|
||||
assert (backend_tag in global_state.backend_table.list_backends()
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert isinstance(backend_config,
|
||||
BackendConfig), ("backend_config must be"
|
||||
" of instance BackendConfig")
|
||||
@@ -211,8 +216,8 @@ def get_backend_config(backend_tag):
|
||||
Args:
|
||||
backend_tag(str): A registered backend.
|
||||
"""
|
||||
assert backend_tag in global_state.backend_table.list_backends(), (
|
||||
"Backend {} is not registered.".format(backend_tag))
|
||||
assert (backend_tag in global_state.backend_table.list_backends()
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
backend_config_dict = global_state.backend_table.get_info(backend_tag)
|
||||
return BackendConfig(**backend_config_dict)
|
||||
|
||||
@@ -249,8 +254,7 @@ def create_backend(func_or_class,
|
||||
" of instance BackendConfig")
|
||||
|
||||
# Make sure the batch size is correct
|
||||
should_accept_batch = (True if backend_config.max_batch_size is not None
|
||||
else False)
|
||||
should_accept_batch = backend_config.max_batch_size is not None
|
||||
if should_accept_batch and not _backend_accept_batch(func_or_class):
|
||||
raise batch_annotation_not_found
|
||||
if _backend_accept_batch(func_or_class):
|
||||
@@ -267,7 +271,10 @@ def create_backend(func_or_class,
|
||||
# on the left to make sure its methods are not overriden.
|
||||
@ray.remote
|
||||
class CustomActor(RayServeMixin, func_or_class):
|
||||
pass
|
||||
@wraps(func_or_class.__init__)
|
||||
def __init__(self, *args, **kwargs):
|
||||
init() # serve init
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
arg_list = actor_init_args
|
||||
# ignore lint on lambda expression
|
||||
@@ -296,8 +303,8 @@ def create_backend(func_or_class,
|
||||
|
||||
|
||||
def _start_replica(backend_tag):
|
||||
assert backend_tag in global_state.backend_table.list_backends(), (
|
||||
"Backend {} is not registered.".format(backend_tag))
|
||||
assert (backend_tag in global_state.backend_table.list_backends()
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
|
||||
replica_tag = "{}#{}".format(backend_tag, get_random_letters(length=6))
|
||||
|
||||
@@ -327,11 +334,12 @@ def _start_replica(backend_tag):
|
||||
|
||||
|
||||
def _remove_replica(backend_tag):
|
||||
assert backend_tag in global_state.backend_table.list_backends(), (
|
||||
"Backend {} is not registered.".format(backend_tag))
|
||||
assert len(global_state.backend_table.list_replicas(backend_tag)) > 0, (
|
||||
"Backend {} does not have enough replicas to be removed.".format(
|
||||
backend_tag))
|
||||
assert (backend_tag in global_state.backend_table.list_backends()
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert (
|
||||
len(global_state.backend_table.list_replicas(backend_tag)) >
|
||||
0), "Backend {} does not have enough replicas to be removed.".format(
|
||||
backend_tag)
|
||||
|
||||
replica_tag = global_state.backend_table.remove_replica(backend_tag)
|
||||
[replica_handle] = ray.get(
|
||||
@@ -347,8 +355,9 @@ def _remove_replica(backend_tag):
|
||||
|
||||
# Remove the replica from router.
|
||||
# This will also destory the actor handle.
|
||||
ray.get(global_state.init_or_get_router()
|
||||
.remove_and_destory_replica.remote(backend_tag, replica_handle))
|
||||
ray.get(
|
||||
global_state.init_or_get_router().remove_and_destory_replica.remote(
|
||||
backend_tag, replica_handle))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
@@ -359,8 +368,8 @@ def _scale(backend_tag, num_replicas):
|
||||
backend_tag (str): A registered backend.
|
||||
num_replicas (int): Desired number of replicas
|
||||
"""
|
||||
assert backend_tag in global_state.backend_table.list_backends(), (
|
||||
"Backend {} is not registered.".format(backend_tag))
|
||||
assert (backend_tag in global_state.backend_table.list_backends()
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert num_replicas >= 0, ("Number of replicas must be"
|
||||
" greater than or equal to 0.")
|
||||
|
||||
@@ -430,7 +439,10 @@ def split(endpoint_name, traffic_policy_dictionary):
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def get_handle(endpoint_name, relative_slo_ms=None, absolute_slo_ms=None):
|
||||
def get_handle(endpoint_name,
|
||||
relative_slo_ms=None,
|
||||
absolute_slo_ms=None,
|
||||
missing_ok=False):
|
||||
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
|
||||
|
||||
Args:
|
||||
@@ -439,18 +451,26 @@ def get_handle(endpoint_name, relative_slo_ms=None, absolute_slo_ms=None):
|
||||
queries fired using this handle. (Default: None)
|
||||
absolute_slo_ms(float): Specify absolute deadline in milliseconds for
|
||||
queries fired using this handle. (Default: None)
|
||||
missing_ok (bool): If true, skip the check for the endpoint existence.
|
||||
It can be useful when the endpoint has not been registered.
|
||||
|
||||
Returns:
|
||||
RayServeHandle
|
||||
"""
|
||||
assert endpoint_name in expand(
|
||||
global_state.route_table.list_service(include_headless=True).values())
|
||||
if not missing_ok:
|
||||
assert endpoint_name in expand(
|
||||
global_state.route_table.list_service(
|
||||
include_headless=True).values())
|
||||
|
||||
# Delay import due to it's dependency on global_state
|
||||
from ray.serve.handle import RayServeHandle
|
||||
|
||||
return RayServeHandle(global_state.init_or_get_router(), endpoint_name,
|
||||
relative_slo_ms, absolute_slo_ms)
|
||||
return RayServeHandle(
|
||||
global_state.init_or_get_router(),
|
||||
endpoint_name,
|
||||
relative_slo_ms,
|
||||
absolute_slo_ms,
|
||||
)
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
|
||||
@@ -16,7 +16,7 @@ def echo(flask_request):
|
||||
|
||||
serve.init(blocking=True)
|
||||
|
||||
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
|
||||
serve.create_endpoint("my_endpoint", "/echo")
|
||||
serve.create_backend(echo, "echo:v1")
|
||||
serve.link("my_endpoint", "echo:v1")
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class MagicCounter:
|
||||
|
||||
|
||||
serve.init(blocking=True)
|
||||
serve.create_endpoint("magic_counter", "/counter", blocking=True)
|
||||
serve.create_endpoint("magic_counter", "/counter")
|
||||
serve.create_backend(MagicCounter, "counter:v1", 42) # increment=42
|
||||
serve.link("magic_counter", "counter:v1")
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class MagicCounter:
|
||||
|
||||
|
||||
serve.init(blocking=True)
|
||||
serve.create_endpoint("magic_counter", "/counter", blocking=True)
|
||||
serve.create_endpoint("magic_counter", "/counter")
|
||||
b_config = BackendConfig(max_batch_size=5)
|
||||
serve.create_backend(
|
||||
MagicCounter, "counter:v1", 42, backend_config=b_config) # increment=42
|
||||
|
||||
@@ -28,7 +28,7 @@ class MagicCounter:
|
||||
|
||||
|
||||
serve.init(blocking=True)
|
||||
serve.create_endpoint("magic_counter", "/counter", blocking=True)
|
||||
serve.create_endpoint("magic_counter", "/counter")
|
||||
# specify max_batch_size in BackendConfig
|
||||
b_config = BackendConfig(max_batch_size=5)
|
||||
serve.create_backend(
|
||||
|
||||
@@ -28,7 +28,7 @@ def echo(_):
|
||||
|
||||
serve.init(blocking=True)
|
||||
|
||||
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
|
||||
serve.create_endpoint("my_endpoint", "/echo")
|
||||
serve.create_backend(echo, "echo:v1")
|
||||
serve.link("my_endpoint", "echo:v1")
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ serve.init(
|
||||
policy_kwargs={"packing_num": 5})
|
||||
|
||||
# create a service
|
||||
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
|
||||
serve.create_endpoint("my_endpoint", "/echo")
|
||||
|
||||
# create first backend
|
||||
serve.create_backend(echo_v1, "echo:v1")
|
||||
|
||||
@@ -33,7 +33,7 @@ backend_config_v1 = serve.get_backend_config("echo:v1")
|
||||
# goes to my_endpoint will now goes to echo:v1 backend.
|
||||
serve.link("my_endpoint", "echo:v1")
|
||||
|
||||
print(requests.get("http://127.0.0.1:8000/echo", timeout=0.5).json())
|
||||
print(requests.get("http://127.0.0.1:8000/echo", timeout=0.5).text)
|
||||
# The service will be reachable from http
|
||||
|
||||
print(ray.get(serve.get_handle("my_endpoint").remote(response="hello")))
|
||||
@@ -55,7 +55,7 @@ serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
|
||||
|
||||
# Observe requests are now split between two backends.
|
||||
for _ in range(10):
|
||||
print(requests.get("http://127.0.0.1:8000/echo").json())
|
||||
print(requests.get("http://127.0.0.1:8000/echo").text)
|
||||
time.sleep(0.5)
|
||||
|
||||
# You can also change number of replicas
|
||||
|
||||
@@ -22,7 +22,7 @@ def echo_v2(_):
|
||||
serve.init(blocking=True, queueing_policy=serve.RoutePolicy.RoundRobin)
|
||||
|
||||
# create a service
|
||||
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
|
||||
serve.create_endpoint("my_endpoint", "/echo")
|
||||
|
||||
# create first backend
|
||||
serve.create_backend(echo_v1, "echo:v1")
|
||||
|
||||
@@ -20,7 +20,7 @@ def echo_v2(_):
|
||||
|
||||
serve.init(blocking=True)
|
||||
|
||||
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
|
||||
serve.create_endpoint("my_endpoint", "/echo")
|
||||
serve.create_backend(echo_v1, "echo:v1")
|
||||
serve.link("my_endpoint", "echo:v1")
|
||||
|
||||
|
||||
+51
-24
@@ -27,19 +27,23 @@ class RayServeHandle:
|
||||
# raises RayTaskError Exception
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
router_handle,
|
||||
endpoint_name,
|
||||
relative_slo_ms=None,
|
||||
absolute_slo_ms=None):
|
||||
def __init__(
|
||||
self,
|
||||
router_handle,
|
||||
endpoint_name,
|
||||
relative_slo_ms=None,
|
||||
absolute_slo_ms=None,
|
||||
method_name=None,
|
||||
):
|
||||
self.router_handle = router_handle
|
||||
self.endpoint_name = endpoint_name
|
||||
assert (relative_slo_ms is None
|
||||
or absolute_slo_ms is None), ("Can't specify both "
|
||||
"relative and absolute "
|
||||
"slo's together!")
|
||||
assert relative_slo_ms is None or absolute_slo_ms is None, (
|
||||
"Can't specify both "
|
||||
"relative and absolute "
|
||||
"slo's together!")
|
||||
self.relative_slo_ms = self._check_slo_ms(relative_slo_ms)
|
||||
self.absolute_slo_ms = self._check_slo_ms(absolute_slo_ms)
|
||||
self.method_name = method_name
|
||||
|
||||
def _check_slo_ms(self, slo_value):
|
||||
if slo_value is not None:
|
||||
@@ -59,23 +63,44 @@ class RayServeHandle:
|
||||
raise RayServeException(
|
||||
"handle.remote must be invoked with keyword arguments.")
|
||||
|
||||
method_name = self.method_name
|
||||
if method_name is None:
|
||||
method_name = "__call__"
|
||||
|
||||
# create RequestMetadata instance
|
||||
request_in_object = RequestMetadata(
|
||||
self.endpoint_name, TaskContext.Python, self.relative_slo_ms,
|
||||
self.absolute_slo_ms)
|
||||
self.endpoint_name,
|
||||
TaskContext.Python,
|
||||
self.relative_slo_ms,
|
||||
self.absolute_slo_ms,
|
||||
call_method=method_name,
|
||||
)
|
||||
return self.router_handle.enqueue_request.remote(
|
||||
request_in_object, **kwargs)
|
||||
|
||||
def options(self, relative_slo_ms=None, absolute_slo_ms=None):
|
||||
def options(self,
|
||||
method_name=None,
|
||||
relative_slo_ms=None,
|
||||
absolute_slo_ms=None):
|
||||
# If both the slo's are None then then we use a high default
|
||||
# value so other queries can be prioritize and put in front of these
|
||||
# queries.
|
||||
assert (relative_slo_ms is None
|
||||
or absolute_slo_ms is None), ("Can't specify both "
|
||||
"relative and absolute "
|
||||
"slo's together!")
|
||||
return RayServeHandle(self.router_handle, self.endpoint_name,
|
||||
relative_slo_ms, absolute_slo_ms)
|
||||
assert not all(absolute_slo_ms,
|
||||
relative_slo_ms), ("Can't specify both "
|
||||
"relative and absolute "
|
||||
"slo's together!")
|
||||
|
||||
# Don't override existing method
|
||||
if method_name is None and self.method_name is not None:
|
||||
method_name = self.method_name
|
||||
|
||||
return RayServeHandle(
|
||||
self.router_handle,
|
||||
self.endpoint_name,
|
||||
relative_slo_ms,
|
||||
absolute_slo_ms,
|
||||
method_name=method_name,
|
||||
)
|
||||
|
||||
def get_traffic_policy(self):
|
||||
policy_table = serve.api._get_global_state().policy_table
|
||||
@@ -94,9 +119,9 @@ class RayServeHandle:
|
||||
backends = set(traffic_policy.keys())
|
||||
return backends.pop()
|
||||
else:
|
||||
assert backend_tag in traffic_policy, (
|
||||
"Backend {} not found in avaiable backends: {}.".format(
|
||||
backend_tag, list(traffic_policy.keys())))
|
||||
assert (backend_tag in traffic_policy
|
||||
), "Backend {} not found in avaiable backends: {}.".format(
|
||||
backend_tag, list(traffic_policy.keys()))
|
||||
return backend_tag
|
||||
|
||||
def scale(self, new_num_replicas, backend_tag=None):
|
||||
@@ -118,9 +143,11 @@ RayServeHandle(
|
||||
URL="{http_endpoint}/{endpoint_name}",
|
||||
Traffic={traffic_policy}
|
||||
)
|
||||
""".format(endpoint_name=self.endpoint_name,
|
||||
http_endpoint=self.get_http_endpoint(),
|
||||
traffic_policy=self.get_traffic_policy())
|
||||
""".format(
|
||||
endpoint_name=self.endpoint_name,
|
||||
http_endpoint=self.get_http_endpoint(),
|
||||
traffic_policy=self.get_traffic_policy(),
|
||||
)
|
||||
|
||||
# TODO(simon): a convenience function that dumps equivalent requests
|
||||
# code for a given call.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import io
|
||||
import json
|
||||
|
||||
import flask
|
||||
|
||||
@@ -67,3 +68,58 @@ def build_wsgi_environ(scope, body):
|
||||
|
||||
environ[corrected_name] = value
|
||||
return environ
|
||||
|
||||
|
||||
class Response:
|
||||
"""ASGI compliant response class.
|
||||
|
||||
It is expected to be called in async context and pass along
|
||||
`scope, receive, send` as in ASGI spec.
|
||||
|
||||
>>> await Response({"k": "v"}).send(scope, receive, send)
|
||||
"""
|
||||
|
||||
def __init__(self, content=None, status_code=200):
|
||||
"""Construct a HTTP Response based on input type.
|
||||
|
||||
Args:
|
||||
content (optional): Any JSON serializable object.
|
||||
status_code (int, optional): Default status code is 200.
|
||||
"""
|
||||
self.status_code = status_code
|
||||
self.raw_headers = []
|
||||
|
||||
if content is None:
|
||||
self.body = b""
|
||||
self.set_content_type("text")
|
||||
elif isinstance(content, bytes):
|
||||
self.body = content
|
||||
self.set_content_type("text")
|
||||
elif isinstance(content, str):
|
||||
self.body = content.encode("utf-8")
|
||||
self.set_content_type("text-utf8")
|
||||
else:
|
||||
# Delayed import since utils depends on http_util
|
||||
from ray.serve.utils import ServeEncoder
|
||||
self.body = json.dumps(
|
||||
content, cls=ServeEncoder, indent=2).encode()
|
||||
self.set_content_type("json")
|
||||
|
||||
def set_content_type(self, content_type):
|
||||
if content_type == "text":
|
||||
self.raw_headers.append([b"content-type", b"text/plain"])
|
||||
elif content_type == "text-utf8":
|
||||
self.raw_headers.append(
|
||||
[b"content-type", b"text/plain; charset=utf-8"])
|
||||
elif content_type == "json":
|
||||
self.raw_headers.append([b"content-type", b"application/json"])
|
||||
else:
|
||||
raise ValueError("Invalid content type {}".foramt(content_type))
|
||||
|
||||
async def send(self, scope, receive, send):
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
})
|
||||
await send({"type": "http.response.body", "body": self.body})
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from abc import ABC
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
|
||||
from ray import cloudpickle as pickle
|
||||
import ray.experimental.internal_kv as ray_kv
|
||||
@@ -173,9 +173,11 @@ class SQLiteKVStore(NamespacedKVStore):
|
||||
class RoutingTable:
|
||||
def __init__(self, kv_connector):
|
||||
self.routing_table = kv_connector("routing_table")
|
||||
self.methods_table = kv_connector("methods_table")
|
||||
self.request_count = 0
|
||||
|
||||
def register_service(self, route: Union[str, None], service: str):
|
||||
def register_service(self, route: Union[str, None], service: str,
|
||||
methods: List[str]):
|
||||
"""Create an entry in the routing table
|
||||
|
||||
Args:
|
||||
@@ -183,8 +185,9 @@ class RoutingTable:
|
||||
service: service name. This is the name http actor will push
|
||||
the request to.
|
||||
"""
|
||||
logger.debug("[KV] Registering route {} to service {}.".format(
|
||||
route, service))
|
||||
logger.debug(
|
||||
"[KV] Registering route {} to service {} with methods {}.".format(
|
||||
route, service, methods))
|
||||
|
||||
# put no route services in default key
|
||||
if route is None:
|
||||
@@ -194,14 +197,23 @@ class RoutingTable:
|
||||
self.routing_table.put(NO_ROUTE_KEY, json.dumps(no_http_services))
|
||||
else:
|
||||
self.routing_table.put(route, service)
|
||||
self.methods_table.put(route, json.dumps(methods))
|
||||
|
||||
def list_service(self, include_headless=False):
|
||||
def list_service(self, include_headless=False, include_methods=False):
|
||||
"""Returns the routing table.
|
||||
Args:
|
||||
include_headless: If True, returns a no route services (headless)
|
||||
services with normal services. (Default: False)
|
||||
include_methods: If True, returns a mapping include the methods
|
||||
list for each route.
|
||||
"""
|
||||
table = self.routing_table.as_dict()
|
||||
if include_methods:
|
||||
methods_table = self.methods_table.as_dict()
|
||||
for route, methods in methods_table.items():
|
||||
if route in table:
|
||||
table[route] = (table[route], json.loads(methods))
|
||||
|
||||
if include_headless:
|
||||
table[NO_ROUTE_KEY] = json.loads(table.get(NO_ROUTE_KEY, "[]"))
|
||||
else:
|
||||
|
||||
+66
-84
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import uvicorn
|
||||
|
||||
@@ -7,48 +6,12 @@ import ray
|
||||
from ray.experimental.async_api import _async_init
|
||||
from ray.serve.constants import HTTP_ROUTER_CHECKER_INTERVAL_S
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.utils import BytesEncoder
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.http_util import Response
|
||||
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
|
||||
class JSONResponse:
|
||||
"""ASGI compliant response class.
|
||||
|
||||
It is expected to be called in async context and pass along
|
||||
`scope, receive, send` as in ASGI spec.
|
||||
|
||||
>>> await JSONResponse({"k": "v"})(scope, receive, send)
|
||||
"""
|
||||
|
||||
def __init__(self, content=None, status_code=200):
|
||||
"""Construct a JSON HTTP Response.
|
||||
|
||||
Args:
|
||||
content (optional): Any JSON serializable object.
|
||||
status_code (int, optional): Default status code is 200.
|
||||
"""
|
||||
self.body = self.render(content)
|
||||
self.status_code = status_code
|
||||
self.raw_headers = [[b"content-type", b"application/json"]]
|
||||
|
||||
def render(self, content):
|
||||
if content is None:
|
||||
return b""
|
||||
if isinstance(content, bytes):
|
||||
return content
|
||||
return json.dumps(content, cls=BytesEncoder, indent=2).encode()
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
})
|
||||
await send({"type": "http.response.body", "body": self.body})
|
||||
|
||||
|
||||
class HTTPProxy:
|
||||
"""
|
||||
This class should be instantiated and ran by ASGI server.
|
||||
@@ -63,6 +26,7 @@ class HTTPProxy:
|
||||
|
||||
# Delay import due to GlobalState depends on HTTP actor
|
||||
from ray.serve.global_state import GlobalState
|
||||
|
||||
self.serve_global_state = GlobalState()
|
||||
self.route_table_cache = dict()
|
||||
|
||||
@@ -75,7 +39,8 @@ class HTTPProxy:
|
||||
return
|
||||
|
||||
self.route_table_cache = (
|
||||
self.serve_global_state.route_table.list_service())
|
||||
self.serve_global_state.route_table.list_service(
|
||||
include_methods=True, include_headless=False))
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
@@ -105,19 +70,38 @@ class HTTPProxy:
|
||||
|
||||
return b"".join(body_buffer)
|
||||
|
||||
def _check_slo_ms(self, request_slo_ms):
|
||||
if request_slo_ms is not None:
|
||||
if len(request_slo_ms) != 1:
|
||||
raise ValueError(
|
||||
"Multiple SLO specified, please specific only one.")
|
||||
request_slo_ms = request_slo_ms[0]
|
||||
request_slo_ms = float(request_slo_ms)
|
||||
if request_slo_ms < 0:
|
||||
raise ValueError(
|
||||
"Request SLO must be positive, it is {}".format(
|
||||
request_slo_ms))
|
||||
return request_slo_ms
|
||||
return None
|
||||
def _parse_latency_slo(self, scope):
|
||||
query_string = scope["query_string"].decode("ascii")
|
||||
query_kwargs = parse_qs(query_string)
|
||||
|
||||
relative_slo_ms = query_kwargs.pop("relative_slo_ms", None)
|
||||
absolute_slo_ms = query_kwargs.pop("absolute_slo_ms", None)
|
||||
relative_slo_ms = self._validate_slo_ms(relative_slo_ms)
|
||||
absolute_slo_ms = self._validate_slo_ms(absolute_slo_ms)
|
||||
if relative_slo_ms is not None and absolute_slo_ms is not None:
|
||||
raise ValueError("Both relative and absolute slo's"
|
||||
"cannot be specified.")
|
||||
return relative_slo_ms, absolute_slo_ms
|
||||
|
||||
def _validate_slo_ms(self, request_slo_ms):
|
||||
if request_slo_ms is None:
|
||||
return None
|
||||
if len(request_slo_ms) != 1:
|
||||
raise ValueError(
|
||||
"Multiple SLO specified, please specific only one.")
|
||||
request_slo_ms = request_slo_ms[0]
|
||||
request_slo_ms = float(request_slo_ms)
|
||||
if request_slo_ms < 0:
|
||||
raise ValueError("Request SLO must be positive, it is {}".format(
|
||||
request_slo_ms))
|
||||
return request_slo_ms
|
||||
|
||||
def _make_error_sender(self, scope, receive, send):
|
||||
async def sender(error_message, status_code):
|
||||
response = Response(error_message, status_code=status_code)
|
||||
await response.send(scope, receive, send)
|
||||
|
||||
return sender
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
# NOTE: This implements ASGI protocol specified in
|
||||
@@ -127,39 +111,39 @@ class HTTPProxy:
|
||||
await self.handle_lifespan_message(scope, receive, send)
|
||||
return
|
||||
|
||||
error_sender = self._make_error_sender(scope, receive, send)
|
||||
|
||||
assert scope["type"] == "http"
|
||||
current_path = scope["path"]
|
||||
if current_path == "/":
|
||||
await JSONResponse(self.route_table_cache)(scope, receive, send)
|
||||
if current_path == "/-/routes":
|
||||
await Response(self.route_table_cache).send(scope, receive, send)
|
||||
return
|
||||
|
||||
# TODO(simon): Use werkzeug route mapper to support variable path
|
||||
if current_path not in self.route_table_cache:
|
||||
error_message = ("Path {} not found. "
|
||||
"Please ping http://.../ for routing table"
|
||||
).format(current_path)
|
||||
await JSONResponse(
|
||||
{
|
||||
"error": error_message
|
||||
}, status_code=404)(scope, receive, send)
|
||||
error_message = (
|
||||
"Path {} not found. "
|
||||
"Please ping http://.../-/routes for routing table"
|
||||
).format(current_path)
|
||||
await error_sender(error_message, 404)
|
||||
return
|
||||
|
||||
endpoint_name, methods_allowed = self.route_table_cache[current_path]
|
||||
|
||||
if scope["method"] not in methods_allowed:
|
||||
error_message = ("Methods {} not allowed. "
|
||||
"Avaiable HTTP methods are {}.").format(
|
||||
scope["method"], methods_allowed)
|
||||
await error_sender(error_message, 405)
|
||||
return
|
||||
|
||||
endpoint_name = self.route_table_cache[current_path]
|
||||
http_body_bytes = await self.receive_http_body(scope, receive, send)
|
||||
|
||||
# get slo_ms before enqueuing the query
|
||||
query_string = scope["query_string"].decode("ascii")
|
||||
query_kwargs = parse_qs(query_string)
|
||||
relative_slo_ms = query_kwargs.pop("relative_slo_ms", None)
|
||||
absolute_slo_ms = query_kwargs.pop("absolute_slo_ms", None)
|
||||
try:
|
||||
relative_slo_ms = self._check_slo_ms(relative_slo_ms)
|
||||
absolute_slo_ms = self._check_slo_ms(absolute_slo_ms)
|
||||
if relative_slo_ms is not None and absolute_slo_ms is not None:
|
||||
raise ValueError("Both relative and absolute slo's"
|
||||
"cannot be specified.")
|
||||
relative_slo_ms, absolute_slo_ms = self._parse_latency_slo(scope)
|
||||
except ValueError as e:
|
||||
await JSONResponse({"error": str(e)})(scope, receive, send)
|
||||
await error_sender(str(e), 400)
|
||||
return
|
||||
|
||||
# create objects necessary for enqueue
|
||||
@@ -167,23 +151,21 @@ class HTTPProxy:
|
||||
# https://github.com/ray-project/ray/issues/6944
|
||||
# TODO(alind): remove list enclosing after issue is fixed
|
||||
args = (scope, [http_body_bytes])
|
||||
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
|
||||
request_in_object = RequestMetadata(
|
||||
endpoint_name,
|
||||
TaskContext.Web,
|
||||
relative_slo_ms=relative_slo_ms,
|
||||
absolute_slo_ms=absolute_slo_ms)
|
||||
absolute_slo_ms=absolute_slo_ms,
|
||||
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"))
|
||||
|
||||
actual_result = await (self.serve_global_state.init_or_get_router()
|
||||
.enqueue_request.remote(request_in_object,
|
||||
*args))
|
||||
result = actual_result
|
||||
|
||||
if isinstance(result, ray.exceptions.RayTaskError):
|
||||
await JSONResponse({
|
||||
"error": "internal error, please use python API to debug"
|
||||
})(scope, receive, send)
|
||||
else:
|
||||
await JSONResponse({"result": result})(scope, receive, send)
|
||||
try:
|
||||
result = await (self.serve_global_state.init_or_get_router()
|
||||
.enqueue_request.remote(request_in_object, *args))
|
||||
await Response(result).send(scope, receive, send)
|
||||
except Exception as e:
|
||||
error_message = "Internal Error. Traceback: {}.".format(e)
|
||||
await error_sender(error_message, 500)
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
||||
@@ -21,6 +21,9 @@ class TaskRunner:
|
||||
def __init__(self, func_to_run):
|
||||
self.func = func_to_run
|
||||
|
||||
# This parameter let argument inspection work with inner function.
|
||||
self.__wrapped__ = func_to_run
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
@@ -111,17 +114,33 @@ class RayServeMixin:
|
||||
"The avaiable methods are {}".format(
|
||||
method_name, dir(self)))
|
||||
|
||||
return getattr(self, method_name)
|
||||
if method_name != "__call__":
|
||||
return getattr(self, method_name)
|
||||
else:
|
||||
# For simple callables, we should just return the object so
|
||||
# signature recoding will continue to funciton.
|
||||
return self
|
||||
|
||||
def _ray_serve_count_num_positional(self, f):
|
||||
signature = inspect.signature(f)
|
||||
counter = 0
|
||||
for param in signature.parameters.values():
|
||||
if (param.kind == param.POSITIONAL_OR_KEYWORD
|
||||
and param.default is param.empty):
|
||||
counter += 1
|
||||
return counter
|
||||
|
||||
async def invoke_single(self, request_item):
|
||||
args, kwargs, is_web_context = parse_request_item(request_item)
|
||||
serve_context.web = is_web_context
|
||||
start_timestamp = time.time()
|
||||
|
||||
method_to_call = self._ray_serve_get_runner_method(request_item)
|
||||
args = args if self._ray_serve_count_num_positional(
|
||||
method_to_call) else []
|
||||
method_to_call = ensure_async(method_to_call)
|
||||
try:
|
||||
result = await ensure_async(
|
||||
self._ray_serve_get_runner_method(request_item))(*args,
|
||||
**kwargs)
|
||||
result = await method_to_call(*args, **kwargs)
|
||||
except Exception as e:
|
||||
result = wrap_to_ray_error(e)
|
||||
self._serve_metric_error_counter += 1
|
||||
@@ -154,7 +173,8 @@ class RayServeMixin:
|
||||
args, kwargs, is_web_context = parse_request_item(item)
|
||||
context_flags.add(is_web_context)
|
||||
|
||||
call_methods.add(self._ray_serve_get_runner_method(item))
|
||||
call_method = self._ray_serve_get_runner_method(item)
|
||||
call_methods.add(call_method)
|
||||
|
||||
if is_web_context:
|
||||
# Python context only have kwargs
|
||||
@@ -168,7 +188,8 @@ class RayServeMixin:
|
||||
# Set the flask request as a list to conform
|
||||
# with batching semantics: when in batching
|
||||
# mode, each argument it turned into list.
|
||||
arg_list.append(FakeFlaskRequest())
|
||||
if self._ray_serve_count_num_positional(call_method):
|
||||
arg_list.append(FakeFlaskRequest())
|
||||
|
||||
try:
|
||||
# check mixing of query context
|
||||
|
||||
@@ -10,6 +10,41 @@ from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.handle import RayServeHandle
|
||||
|
||||
|
||||
def test_e2e(serve_instance):
|
||||
serve.init() # so we have access to global state
|
||||
serve.create_endpoint("endpoint", "/api", methods=["GET", "POST"])
|
||||
result = serve.api._get_global_state().route_table.list_service()
|
||||
assert result["/api"] == "endpoint"
|
||||
|
||||
retry_count = 5
|
||||
timeout_sleep = 0.5
|
||||
while True:
|
||||
try:
|
||||
resp = requests.get(
|
||||
"http://127.0.0.1:8000/-/routes", timeout=0.5).json()
|
||||
assert resp == {"/api": ["endpoint", ["GET", "POST"]]}
|
||||
break
|
||||
except Exception as e:
|
||||
time.sleep(timeout_sleep)
|
||||
timeout_sleep *= 2
|
||||
retry_count -= 1
|
||||
if retry_count == 0:
|
||||
assert False, ("Route table hasn't been updated after 3 tries."
|
||||
"The latest error was {}").format(e)
|
||||
|
||||
def function(flask_request):
|
||||
return {"method": flask_request.method}
|
||||
|
||||
serve.create_backend(function, "echo:v1")
|
||||
serve.link("endpoint", "echo:v1")
|
||||
|
||||
resp = requests.get("http://127.0.0.1:8000/api").json()["method"]
|
||||
assert resp == "GET"
|
||||
|
||||
resp = requests.post("http://127.0.0.1:8000/api").json()["method"]
|
||||
assert resp == "POST"
|
||||
|
||||
|
||||
def test_route_decorator(serve_instance):
|
||||
@serve.route("/hello_world")
|
||||
def hello_world(_):
|
||||
@@ -25,38 +60,8 @@ def test_route_decorator(serve_instance):
|
||||
hello_world.set_max_batch_size(2)
|
||||
|
||||
|
||||
def test_e2e(serve_instance):
|
||||
serve.init() # so we have access to global state
|
||||
serve.create_endpoint("endpoint", "/api", blocking=True)
|
||||
result = serve.api._get_global_state().route_table.list_service()
|
||||
assert result["/api"] == "endpoint"
|
||||
|
||||
retry_count = 5
|
||||
timeout_sleep = 0.5
|
||||
while True:
|
||||
try:
|
||||
resp = requests.get("http://127.0.0.1:8000/", timeout=0.5).json()
|
||||
assert resp == result
|
||||
break
|
||||
except Exception:
|
||||
time.sleep(timeout_sleep)
|
||||
timeout_sleep *= 2
|
||||
retry_count -= 1
|
||||
if retry_count == 0:
|
||||
assert False, "Route table hasn't been updated after 3 tries."
|
||||
|
||||
def function(flask_request):
|
||||
return "OK"
|
||||
|
||||
serve.create_backend(function, "echo:v1")
|
||||
serve.link("endpoint", "echo:v1")
|
||||
|
||||
resp = requests.get("http://127.0.0.1:8000/api").json()["result"]
|
||||
assert resp == "OK"
|
||||
|
||||
|
||||
def test_no_route(serve_instance):
|
||||
serve.create_endpoint("noroute-endpoint", blocking=True)
|
||||
serve.create_endpoint("noroute-endpoint")
|
||||
global_state = serve.api._get_global_state()
|
||||
|
||||
result = global_state.route_table.list_service(include_headless=True)
|
||||
@@ -87,7 +92,8 @@ def test_scaling_replicas(serve_instance):
|
||||
serve.create_endpoint("counter", "/increment")
|
||||
|
||||
# Keep checking the routing table until /increment is populated
|
||||
while "/increment" not in requests.get("http://127.0.0.1:8000/").json():
|
||||
while "/increment" not in requests.get(
|
||||
"http://127.0.0.1:8000/-/routes").json():
|
||||
time.sleep(0.2)
|
||||
|
||||
b_config = BackendConfig(num_replicas=2)
|
||||
@@ -96,7 +102,7 @@ def test_scaling_replicas(serve_instance):
|
||||
|
||||
counter_result = []
|
||||
for _ in range(10):
|
||||
resp = requests.get("http://127.0.0.1:8000/increment").json()["result"]
|
||||
resp = requests.get("http://127.0.0.1:8000/increment").json()
|
||||
counter_result.append(resp)
|
||||
|
||||
# If the load is shared among two replicas. The max result cannot be 10.
|
||||
@@ -108,7 +114,7 @@ def test_scaling_replicas(serve_instance):
|
||||
|
||||
counter_result = []
|
||||
for _ in range(10):
|
||||
resp = requests.get("http://127.0.0.1:8000/increment").json()["result"]
|
||||
resp = requests.get("http://127.0.0.1:8000/increment").json()
|
||||
counter_result.append(resp)
|
||||
# Give some time for a replica to spin down. But majority of the request
|
||||
# should be served by the only remaining replica.
|
||||
@@ -129,7 +135,8 @@ def test_batching(serve_instance):
|
||||
serve.create_endpoint("counter1", "/increment")
|
||||
|
||||
# Keep checking the routing table until /increment is populated
|
||||
while "/increment" not in requests.get("http://127.0.0.1:8000/").json():
|
||||
while "/increment" not in requests.get(
|
||||
"http://127.0.0.1:8000/-/routes").json():
|
||||
time.sleep(0.2)
|
||||
|
||||
# set the max batch size
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
|
||||
from ray.serve.utils import BytesEncoder
|
||||
from ray.serve.utils import ServeEncoder
|
||||
|
||||
|
||||
def test_bytes_encoder():
|
||||
data_before = {"inp": {"nest": b"bytes"}}
|
||||
data_after = {"inp": {"nest": "bytes"}}
|
||||
assert json.loads(json.dumps(data_before, cls=BytesEncoder)) == data_after
|
||||
assert json.loads(json.dumps(data_before, cls=ServeEncoder)) == data_after
|
||||
|
||||
@@ -11,6 +11,12 @@ from pygments import formatters, highlight, lexers
|
||||
from ray.serve.context import FakeFlaskRequest, TaskContext
|
||||
from ray.serve.http_util import build_flask_request
|
||||
import itertools
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import pydantic
|
||||
except ImportError:
|
||||
pydantic = None
|
||||
|
||||
|
||||
def expand(l):
|
||||
@@ -60,19 +66,23 @@ def _get_logger():
|
||||
logger = _get_logger()
|
||||
|
||||
|
||||
class BytesEncoder(json.JSONEncoder):
|
||||
"""Allow bytes to be part of the JSON document.
|
||||
|
||||
BytesEncoder will walk the JSON tree and decode bytes with utf-8 codec.
|
||||
|
||||
Example:
|
||||
>>> json.dumps({b'a': b'c'}, cls=BytesEncoder)
|
||||
'{"a":"c"}'
|
||||
class ServeEncoder(json.JSONEncoder):
|
||||
"""Ray.Serve's utility JSON encoder. Adds support for:
|
||||
- bytes
|
||||
- Pydantic types
|
||||
- Exceptions
|
||||
- numpy.ndarray
|
||||
"""
|
||||
|
||||
def default(self, o): # pylint: disable=E0202
|
||||
if isinstance(o, bytes):
|
||||
return o.decode("utf-8")
|
||||
if pydantic is not None and isinstance(o, pydantic.BaseModel):
|
||||
return o.dict()
|
||||
if isinstance(o, Exception):
|
||||
return str(o)
|
||||
if isinstance(o, np.ndarray):
|
||||
return o.tolist()
|
||||
return super().default(o)
|
||||
|
||||
|
||||
|
||||
@@ -1475,10 +1475,10 @@ def get(object_ids, timeout=None):
|
||||
"core_worker") and worker.core_worker.current_actor_is_asyncio():
|
||||
global blocking_get_inside_async_warned
|
||||
if not blocking_get_inside_async_warned:
|
||||
logger.warning("Using blocking ray.get inside async actor. "
|
||||
"This blocks the event loop. Please use `await` "
|
||||
"on object id with asyncio.gather if you want to "
|
||||
"yield execution to the event loop instead.")
|
||||
logger.debug("Using blocking ray.get inside async actor. "
|
||||
"This blocks the event loop. Please use `await` "
|
||||
"on object id with asyncio.gather if you want to "
|
||||
"yield execution to the event loop instead.")
|
||||
blocking_get_inside_async_warned = True
|
||||
|
||||
with profiling.profile("ray.get"):
|
||||
@@ -1586,9 +1586,9 @@ def wait(object_ids, num_returns=1, timeout=None):
|
||||
) and timeout != 0:
|
||||
global blocking_wait_inside_async_warned
|
||||
if not blocking_wait_inside_async_warned:
|
||||
logger.warning("Using blocking ray.wait inside async method. "
|
||||
"This blocks the event loop. Please use `await` "
|
||||
"on object id with asyncio.wait. ")
|
||||
logger.debug("Using blocking ray.wait inside async method. "
|
||||
"This blocks the event loop. Please use `await` "
|
||||
"on object id with asyncio.wait. ")
|
||||
blocking_wait_inside_async_warned = True
|
||||
|
||||
if isinstance(object_ids, ObjectID):
|
||||
|
||||
Reference in New Issue
Block a user