mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
[serve] Serve client refactor (#10409)
This commit is contained in:
+2
-12
@@ -6,8 +6,7 @@ py_library(
|
||||
)
|
||||
|
||||
serve_tests_srcs = glob(["tests/*.py"],
|
||||
exclude=["tests/test_nonblocking.py",
|
||||
"tests/test_controller_crashes.py",
|
||||
exclude=["tests/test_controller_crashes.py",
|
||||
"tests/test_serve.py",
|
||||
])
|
||||
|
||||
@@ -115,8 +114,7 @@ py_test(
|
||||
# srcs = glob(["tests/test_controller_crashes.py",
|
||||
# "tests/test_api.py",
|
||||
# "tests/test_failure.py"],
|
||||
# exclude=["tests/test_nonblocking.py",
|
||||
# "tests/test_serve.py"]),
|
||||
# exclude=["tests/test_serve.py"]),
|
||||
# )
|
||||
|
||||
py_test(
|
||||
@@ -127,14 +125,6 @@ py_test(
|
||||
deps = [":serve_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_nonblocking",
|
||||
size = "small",
|
||||
srcs = glob(["tests/test_nonblocking.py"]),
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
# Make sure the example showing in doc is tested
|
||||
py_test(
|
||||
name = "quickstart_class",
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from ray.serve.api import (init, create_backend, delete_backend,
|
||||
create_endpoint, delete_endpoint, set_traffic,
|
||||
shadow_traffic, get_handle, update_backend_config,
|
||||
get_backend_config, accept_batch, list_backends,
|
||||
list_endpoints, shutdown) # noqa: E402
|
||||
from ray.serve.api import (accept_batch, Client, connect, start) # noqa: F401
|
||||
from ray.serve.config import BackendConfig
|
||||
|
||||
__all__ = [
|
||||
"init", "create_backend", "delete_backend", "create_endpoint",
|
||||
"delete_endpoint", "set_traffic", "shadow_traffic", "get_handle",
|
||||
"update_backend_config", "get_backend_config", "accept_batch",
|
||||
"list_backends", "list_endpoints", "shutdown", "BackendConfig"
|
||||
"accept_batch",
|
||||
"BackendConfig",
|
||||
"connect"
|
||||
"Client",
|
||||
"start",
|
||||
]
|
||||
|
||||
+386
-319
@@ -1,3 +1,4 @@
|
||||
import atexit
|
||||
from functools import wraps
|
||||
|
||||
import ray
|
||||
@@ -5,99 +6,383 @@ from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
|
||||
SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT)
|
||||
from ray.serve.controller import ServeController
|
||||
from ray.serve.handle import RayServeHandle
|
||||
from ray.serve.utils import (block_until_http_ready, format_actor_name)
|
||||
from ray.serve.utils import (block_until_http_ready, format_actor_name,
|
||||
get_random_letters)
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
|
||||
from ray.actor import ActorHandle
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
controller = None
|
||||
_INTERNAL_CONTROLLER_NAME = None
|
||||
|
||||
|
||||
def _get_controller() -> ActorHandle:
|
||||
"""Used for internal purpose because using just import serve.global_state
|
||||
will always reference the original None object.
|
||||
"""
|
||||
global controller
|
||||
if controller is None:
|
||||
raise RayServeException("Please run serve.init to initialize or "
|
||||
"connect to existing ray serve cluster.")
|
||||
return controller
|
||||
def _set_internal_controller_name(name):
|
||||
global _INTERNAL_CONTROLLER_NAME
|
||||
_INTERNAL_CONTROLLER_NAME = name
|
||||
|
||||
|
||||
def _ensure_connected(f: Callable) -> Callable:
|
||||
@wraps(f)
|
||||
def check(*args, **kwargs):
|
||||
_get_controller()
|
||||
return f(*args, **kwargs)
|
||||
def check(self, *args, **kwargs):
|
||||
if self._shutdown:
|
||||
raise RayServeException("Client has already been shut down.")
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return check
|
||||
|
||||
|
||||
def accept_batch(f: Callable) -> Callable:
|
||||
"""Annotation to mark a serving function that batch is accepted.
|
||||
class Client:
|
||||
def __init__(self,
|
||||
controller: ActorHandle,
|
||||
controller_name: str,
|
||||
detached: bool = False):
|
||||
self._controller = controller
|
||||
self._controller_name = controller_name
|
||||
self._detached = detached
|
||||
self._shutdown = False
|
||||
|
||||
This annotation need to be used to mark a function expect all arguments
|
||||
to be passed into a list.
|
||||
# NOTE(edoakes): Need this because the shutdown order isn't guaranteed
|
||||
# when the interpreter is exiting so we can't rely on __del__ (it
|
||||
# throws a nasty stacktrace).
|
||||
if not self._detached:
|
||||
|
||||
Example:
|
||||
def shutdown_serve_client():
|
||||
self.shutdown()
|
||||
|
||||
>>> @serve.accept_batch
|
||||
def serving_func(flask_request):
|
||||
assert isinstance(flask_request, list)
|
||||
...
|
||||
atexit.register(shutdown_serve_client)
|
||||
|
||||
>>> class ServingActor:
|
||||
@serve.accept_batch
|
||||
def __call__(self, *, python_arg=None):
|
||||
assert isinstance(python_arg, list)
|
||||
"""
|
||||
f._serve_accept_batch = True
|
||||
return f
|
||||
def __del__(self):
|
||||
if not self._detached:
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Completely shut down the connected Serve instance.
|
||||
|
||||
Shuts down all processes and deletes all state associated with the
|
||||
instance.
|
||||
"""
|
||||
if not self._shutdown:
|
||||
ray.get(self._controller.shutdown.remote())
|
||||
ray.kill(self._controller, no_restart=True)
|
||||
self._shutdown = True
|
||||
|
||||
@_ensure_connected
|
||||
def create_endpoint(self,
|
||||
endpoint_name: str,
|
||||
*,
|
||||
backend: str = None,
|
||||
route: Optional[str] = None,
|
||||
methods: List[str] = ["GET"]) -> None:
|
||||
"""Create a service endpoint given route_expression.
|
||||
|
||||
Args:
|
||||
endpoint_name (str): A name to associate to with the endpoint.
|
||||
backend (str, required): The backend that will serve requests to
|
||||
this endpoint. To change this or split traffic among backends,
|
||||
use `serve.set_traffic`.
|
||||
route (str, optional): A string begin with "/". HTTP server will
|
||||
use the string to match the path.
|
||||
methods(List[str], optional): The HTTP methods that are valid for
|
||||
this endpoint.
|
||||
"""
|
||||
if backend is None:
|
||||
raise TypeError("backend must be specified when creating "
|
||||
"an endpoint.")
|
||||
elif not isinstance(backend, str):
|
||||
raise TypeError("backend must be a string, got {}.".format(
|
||||
type(backend)))
|
||||
|
||||
if route is not None:
|
||||
if not isinstance(route, str) or not route.startswith("/"):
|
||||
raise TypeError("route must be a string starting with '/'.")
|
||||
|
||||
if not isinstance(methods, list):
|
||||
raise TypeError(
|
||||
"methods must be a list of strings, but got type {}".format(
|
||||
type(methods)))
|
||||
|
||||
endpoints = self.list_endpoints()
|
||||
if endpoint_name in endpoints:
|
||||
methods_old = endpoints[endpoint_name]["methods"]
|
||||
route_old = endpoints[endpoint_name]["route"]
|
||||
if methods_old.sort() == methods.sort() and route_old == route:
|
||||
raise ValueError(
|
||||
"Route '{}' is already registered to endpoint '{}' "
|
||||
"with methods '{}'. To set the backend for this "
|
||||
"endpoint, please use serve.set_traffic().".format(
|
||||
route, endpoint_name, methods))
|
||||
|
||||
upper_methods = []
|
||||
for method in methods:
|
||||
if not isinstance(method, str):
|
||||
raise TypeError(
|
||||
"methods must be a list of strings, but contained "
|
||||
"an element of type {}".format(type(method)))
|
||||
upper_methods.append(method.upper())
|
||||
|
||||
ray.get(
|
||||
self._controller.create_endpoint.remote(
|
||||
endpoint_name, {backend: 1.0}, route, upper_methods))
|
||||
|
||||
@_ensure_connected
|
||||
def delete_endpoint(self, endpoint: str) -> None:
|
||||
"""Delete the given endpoint.
|
||||
|
||||
Does not delete any associated backends.
|
||||
"""
|
||||
ray.get(self._controller.delete_endpoint.remote(endpoint))
|
||||
|
||||
@_ensure_connected
|
||||
def list_endpoints(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Returns a dictionary of all registered endpoints.
|
||||
|
||||
The dictionary keys are endpoint names and values are dictionaries
|
||||
of the form: {"methods": List[str], "traffic": Dict[str, float]}.
|
||||
"""
|
||||
return ray.get(self._controller.get_all_endpoints.remote())
|
||||
|
||||
@_ensure_connected
|
||||
def update_backend_config(
|
||||
self, backend_tag: str,
|
||||
config_options: Union[BackendConfig, Dict[str, Any]]) -> None:
|
||||
"""Update a backend configuration for a backend tag.
|
||||
|
||||
Keys not specified in the passed will be left unchanged.
|
||||
|
||||
Args:
|
||||
backend_tag(str): A registered backend.
|
||||
config_options(dict, serve.BackendConfig): Backend config options
|
||||
to update. Either a BackendConfig object or a dict mapping
|
||||
strings to values for the following supported options:
|
||||
- "num_replicas": number of worker processes to start up that
|
||||
will handle requests to this backend.
|
||||
- "max_batch_size": the maximum number of requests that will
|
||||
be processed in one batch by this backend.
|
||||
- "batch_wait_timeout": time in seconds that backend replicas
|
||||
will wait for a full batch of requests before
|
||||
processing a partial batch.
|
||||
- "max_concurrent_queries": the maximum number of queries
|
||||
that will be sent to a replica of this backend
|
||||
without receiving a response.
|
||||
"""
|
||||
|
||||
if not isinstance(config_options, (BackendConfig, dict)):
|
||||
raise TypeError(
|
||||
"config_options must be a BackendConfig or dictionary.")
|
||||
ray.get(
|
||||
self._controller.update_backend_config.remote(
|
||||
backend_tag, config_options))
|
||||
|
||||
@_ensure_connected
|
||||
def get_backend_config(self, backend_tag: str) -> BackendConfig:
|
||||
"""Get the backend configuration for a backend tag.
|
||||
|
||||
Args:
|
||||
backend_tag(str): A registered backend.
|
||||
"""
|
||||
return ray.get(self._controller.get_backend_config.remote(backend_tag))
|
||||
|
||||
@_ensure_connected
|
||||
def create_backend(
|
||||
self,
|
||||
backend_tag: str,
|
||||
func_or_class: Union[Callable, Type[Callable]],
|
||||
*actor_init_args: Any,
|
||||
ray_actor_options: Optional[Dict] = None,
|
||||
config: Optional[Union[BackendConfig, Dict[str, Any]]] = None
|
||||
) -> None:
|
||||
"""Create a backend with the provided tag.
|
||||
|
||||
The backend will serve requests with func_or_class.
|
||||
|
||||
Args:
|
||||
backend_tag (str): a unique tag assign to identify this backend.
|
||||
func_or_class (callable, class): a function or a class implementing
|
||||
__call__.
|
||||
actor_init_args (optional): the arguments to pass to the class.
|
||||
initialization method.
|
||||
ray_actor_options (optional): options to be passed into the
|
||||
@ray.remote decorator for the backend actor.
|
||||
config (dict, serve.BackendConfig, optional): configuration options
|
||||
for this backend. Either a BackendConfig, or a dictionary
|
||||
mapping strings to values for the following supported options:
|
||||
- "num_replicas": number of worker processes to start up that
|
||||
will handle requests to this backend.
|
||||
- "max_batch_size": the maximum number of requests that will
|
||||
be processed in one batch by this backend.
|
||||
- "batch_wait_timeout": time in seconds that backend replicas
|
||||
will wait for a full batch of requests before processing a
|
||||
partial batch.
|
||||
- "max_concurrent_queries": the maximum number of queries that
|
||||
will be sent to a replica of this backend without receiving a
|
||||
response.
|
||||
"""
|
||||
if backend_tag in self.list_backends():
|
||||
raise ValueError(
|
||||
"Cannot create backend. "
|
||||
"Backend '{}' is already registered.".format(backend_tag))
|
||||
|
||||
if config is None:
|
||||
config = {}
|
||||
replica_config = ReplicaConfig(
|
||||
func_or_class,
|
||||
*actor_init_args,
|
||||
ray_actor_options=ray_actor_options)
|
||||
metadata = BackendMetadata(
|
||||
accepts_batches=replica_config.accepts_batches,
|
||||
is_blocking=replica_config.is_blocking)
|
||||
if isinstance(config, dict):
|
||||
backend_config = BackendConfig.parse_obj({
|
||||
**config, "internal_metadata": metadata
|
||||
})
|
||||
elif isinstance(config, BackendConfig):
|
||||
backend_config = config.copy(
|
||||
update={"internal_metadata": metadata})
|
||||
else:
|
||||
raise TypeError("config must be a BackendConfig or a dictionary.")
|
||||
backend_config._validate_complete()
|
||||
ray.get(
|
||||
self._controller.create_backend.remote(backend_tag, backend_config,
|
||||
replica_config))
|
||||
|
||||
@_ensure_connected
|
||||
def list_backends(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Returns a dictionary of all registered backends.
|
||||
|
||||
Dictionary maps backend tags to backend configs.
|
||||
"""
|
||||
return ray.get(self._controller.get_all_backends.remote())
|
||||
|
||||
@_ensure_connected
|
||||
def delete_backend(self, backend_tag: str) -> None:
|
||||
"""Delete the given backend.
|
||||
|
||||
The backend must not currently be used by any endpoints.
|
||||
"""
|
||||
ray.get(self._controller.delete_backend.remote(backend_tag))
|
||||
|
||||
@_ensure_connected
|
||||
def set_traffic(self, endpoint_name: str,
|
||||
traffic_policy_dictionary: Dict[str, float]) -> None:
|
||||
"""Associate a service endpoint with traffic policy.
|
||||
|
||||
Example:
|
||||
|
||||
>>> serve.set_traffic("service-name", {
|
||||
"backend:v1": 0.5,
|
||||
"backend:v2": 0.5
|
||||
})
|
||||
|
||||
Args:
|
||||
endpoint_name (str): A registered service endpoint.
|
||||
traffic_policy_dictionary (dict): a dictionary maps backend names
|
||||
to their traffic weights. The weights must sum to 1.
|
||||
"""
|
||||
ray.get(
|
||||
self._controller.set_traffic.remote(endpoint_name,
|
||||
traffic_policy_dictionary))
|
||||
|
||||
@_ensure_connected
|
||||
def shadow_traffic(self, endpoint_name: str, backend_tag: str,
|
||||
proportion: float) -> None:
|
||||
"""Shadow traffic from an endpoint to a backend.
|
||||
|
||||
The specified proportion of requests will be duplicated and sent to the
|
||||
backend. Responses of the duplicated traffic will be ignored.
|
||||
The backend must not already be in use.
|
||||
|
||||
To stop shadowing traffic to a backend, call `shadow_traffic` with
|
||||
proportion equal to 0.
|
||||
|
||||
Args:
|
||||
endpoint_name (str): A registered service endpoint.
|
||||
backend_tag (str): A registered backend.
|
||||
proportion (float): The proportion of traffic from 0 to 1.
|
||||
"""
|
||||
|
||||
if not isinstance(proportion,
|
||||
(float, int)) or not 0 <= proportion <= 1:
|
||||
raise TypeError("proportion must be a float from 0 to 1.")
|
||||
|
||||
ray.get(
|
||||
self._controller.shadow_traffic.remote(endpoint_name, backend_tag,
|
||||
proportion))
|
||||
|
||||
@_ensure_connected
|
||||
def get_handle(self, endpoint_name: str) -> RayServeHandle:
|
||||
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
|
||||
|
||||
Args:
|
||||
endpoint_name (str): A registered service endpoint.
|
||||
|
||||
Returns:
|
||||
RayServeHandle
|
||||
"""
|
||||
if endpoint_name not in ray.get(
|
||||
self._controller.get_all_endpoints.remote()):
|
||||
raise KeyError(f"Endpoint '{endpoint_name}' does not exist.")
|
||||
|
||||
# TODO(edoakes): we should choose the router on the same node.
|
||||
routers = ray.get(self._controller.get_routers.remote())
|
||||
return RayServeHandle(
|
||||
list(routers.values())[0],
|
||||
endpoint_name,
|
||||
)
|
||||
|
||||
|
||||
def init(name: Optional[str] = None,
|
||||
http_host: str = DEFAULT_HTTP_HOST,
|
||||
http_port: int = DEFAULT_HTTP_PORT,
|
||||
http_middlewares: List[Any] = []) -> None:
|
||||
"""Initialize or connect to a serve cluster.
|
||||
def start(detached: bool = False,
|
||||
http_host: str = DEFAULT_HTTP_HOST,
|
||||
http_port: int = DEFAULT_HTTP_PORT,
|
||||
http_middlewares: List[Any] = []) -> Client:
|
||||
"""Initialize a serve instance.
|
||||
|
||||
If serve cluster is already initialized, this function will just return.
|
||||
|
||||
If `ray.init` has not been called in this process, it will be called with
|
||||
no arguments. To specify kwargs to `ray.init`, it should be called
|
||||
separately before calling `serve.init`.
|
||||
By default, the instance will be scoped to the lifetime of the returned
|
||||
Client object (or when the script exits). If detached is set to True, the
|
||||
instance will instead persist until client.shutdown() is called and clients
|
||||
to it can be connected using serve.connect(). This is only relevant if
|
||||
connecting to a long-running Ray cluster (e.g., with address="auto").
|
||||
|
||||
Args:
|
||||
name (str): A unique name for this serve instance. This allows
|
||||
multiple serve instances to run on the same ray cluster. Must be
|
||||
specified in all subsequent serve.init() calls.
|
||||
http_host (str): Host for HTTP servers. Default to "0.0.0.0". Serve
|
||||
starts one HTTP server per node in the Ray cluster.
|
||||
http_port (int, List[int]): Port for HTTP server. Default to 8000.
|
||||
detached (bool): Whether not the instance should be detached from this
|
||||
script.
|
||||
http_host (str): Host for HTTP servers to listen on. Defaults to
|
||||
"127.0.0.1". To expose Serve publicly, you probably want to set
|
||||
this to "0.0.0.0". One HTTP server will be started on each node in
|
||||
the Ray cluster.
|
||||
http_port (int): Port for HTTP server. Defaults to 8000.
|
||||
http_middleswares (list): A list of Starlette middlewares that will be
|
||||
applied to the HTTP servers in the cluster.
|
||||
"""
|
||||
if name is not None and not isinstance(name, str):
|
||||
raise TypeError("name must be a string.")
|
||||
|
||||
# Initialize ray if needed.
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
|
||||
# Try to get serve controller if it exists
|
||||
global controller
|
||||
controller_name = format_actor_name(SERVE_CONTROLLER_NAME, name)
|
||||
try:
|
||||
controller = ray.get_actor(controller_name)
|
||||
return
|
||||
except ValueError:
|
||||
pass
|
||||
if detached:
|
||||
controller_name = SERVE_CONTROLLER_NAME
|
||||
try:
|
||||
ray.get_actor(controller_name)
|
||||
raise RayServeException("Called serve.start(detached=True) but a "
|
||||
"detached instance is already running. "
|
||||
"Please use serve.connect() to connect to "
|
||||
"the running instance instead.")
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
controller_name = format_actor_name(SERVE_CONTROLLER_NAME,
|
||||
get_random_letters())
|
||||
|
||||
controller = ServeController.options(
|
||||
name=controller_name,
|
||||
lifetime="detached",
|
||||
lifetime="detached" if detached else None,
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
).remote(name, http_host, http_port, http_middlewares)
|
||||
).remote(
|
||||
controller_name,
|
||||
http_host,
|
||||
http_port,
|
||||
http_middlewares,
|
||||
detached=detached)
|
||||
|
||||
futures = []
|
||||
for node_id in ray.state.node_ids():
|
||||
@@ -110,281 +395,63 @@ def init(name: Optional[str] = None,
|
||||
futures.append(future)
|
||||
ray.get(futures)
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def shutdown() -> None:
|
||||
"""Completely shut down the connected Serve instance.
|
||||
|
||||
Shuts down all processes and deletes all state associated with the Serve
|
||||
instance that's currently connected to (via serve.init).
|
||||
"""
|
||||
global controller
|
||||
ray.get(controller.shutdown.remote())
|
||||
ray.kill(controller, no_restart=True)
|
||||
controller = None
|
||||
return Client(controller, controller_name, detached=detached)
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def create_endpoint(endpoint_name: str,
|
||||
*,
|
||||
backend: str = None,
|
||||
route: Optional[str] = None,
|
||||
methods: List[str] = ["GET"]) -> None:
|
||||
"""Create a service endpoint given route_expression.
|
||||
def connect() -> Client:
|
||||
"""Connect to an existing Serve instance on this Ray cluster.
|
||||
|
||||
Args:
|
||||
endpoint_name (str): A name to associate to with the endpoint.
|
||||
backend (str, required): The backend that will serve requests to
|
||||
this endpoint. To change this or split traffic among backends, use
|
||||
`serve.set_traffic`.
|
||||
route (str, optional): A string begin with "/". HTTP server will use
|
||||
the string to match the path.
|
||||
methods(List[str], optional): The HTTP methods that are valid for this
|
||||
endpoint.
|
||||
"""
|
||||
if backend is None:
|
||||
raise TypeError("backend must be specified when creating "
|
||||
"an endpoint.")
|
||||
elif not isinstance(backend, str):
|
||||
raise TypeError("backend must be a string, got {}.".format(
|
||||
type(backend)))
|
||||
If calling from the driver program, the Serve instance on this Ray cluster
|
||||
must first have been initialized using `serve.start(detached=True)`.
|
||||
|
||||
if route is not None:
|
||||
if not isinstance(route, str) or not route.startswith("/"):
|
||||
raise TypeError("route must be a string starting with '/'.")
|
||||
|
||||
if not isinstance(methods, list):
|
||||
raise TypeError(
|
||||
"methods must be a list of strings, but got type {}".format(
|
||||
type(methods)))
|
||||
|
||||
endpoints = list_endpoints()
|
||||
if endpoint_name in endpoints:
|
||||
methods_old = endpoints[endpoint_name]["methods"]
|
||||
route_old = endpoints[endpoint_name]["route"]
|
||||
if methods_old.sort() == methods.sort() and route_old == route:
|
||||
raise ValueError(
|
||||
"Route '{}' is already registered to endpoint '{}' "
|
||||
"with methods '{}'. To set the backend for this "
|
||||
"endpoint, please use serve.set_traffic().".format(
|
||||
route, endpoint_name, methods))
|
||||
|
||||
upper_methods = []
|
||||
for method in methods:
|
||||
if not isinstance(method, str):
|
||||
raise TypeError("methods must be a list of strings, but contained "
|
||||
"an element of type {}".format(type(method)))
|
||||
upper_methods.append(method.upper())
|
||||
|
||||
ray.get(
|
||||
controller.create_endpoint.remote(endpoint_name, {backend: 1.0}, route,
|
||||
upper_methods))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def delete_endpoint(endpoint: str) -> None:
|
||||
"""Delete the given endpoint.
|
||||
|
||||
Does not delete any associated backends.
|
||||
"""
|
||||
ray.get(controller.delete_endpoint.remote(endpoint))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def list_endpoints() -> Dict[str, Dict[str, Any]]:
|
||||
"""Returns a dictionary of all registered endpoints.
|
||||
|
||||
The dictionary keys are endpoint names and values are dictionaries
|
||||
of the form: {"methods": List[str], "traffic": Dict[str, float]}.
|
||||
"""
|
||||
return ray.get(controller.get_all_endpoints.remote())
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def update_backend_config(
|
||||
backend_tag: str,
|
||||
config_options: Union[BackendConfig, Dict[str, Any]]) -> None:
|
||||
"""Update a backend configuration for a backend tag.
|
||||
|
||||
Keys not specified in the passed will be left unchanged.
|
||||
|
||||
Args:
|
||||
backend_tag(str): A registered backend.
|
||||
config_options(dict, serve.BackendConfig): Backend config options to
|
||||
update. Either a BackendConfig object or a dict mapping strings to
|
||||
values for the following supported options:
|
||||
- "num_replicas": number of worker processes to start up that
|
||||
will handle requests to this backend.
|
||||
- "max_batch_size": the maximum number of requests that will
|
||||
be processed in one batch by this backend.
|
||||
- "batch_wait_timeout": time in seconds that backend replicas
|
||||
will wait for a full batch of requests before
|
||||
processing a partial batch.
|
||||
- "max_concurrent_queries": the maximum number of queries
|
||||
that will be sent to a replica of this backend
|
||||
without receiving a response.
|
||||
If called from within a backend, will connect to the same Serve instance
|
||||
that the backend is running in.
|
||||
"""
|
||||
|
||||
if not isinstance(config_options, (BackendConfig, dict)):
|
||||
raise TypeError(
|
||||
"config_options must be a BackendConfig or dictionary.")
|
||||
ray.get(
|
||||
controller.update_backend_config.remote(backend_tag, config_options))
|
||||
# Initialize ray if needed.
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def get_backend_config(backend_tag: str) -> BackendConfig:
|
||||
"""Get the backend configuration for a backend tag.
|
||||
|
||||
Args:
|
||||
backend_tag(str): A registered backend.
|
||||
"""
|
||||
return ray.get(controller.get_backend_config.remote(backend_tag))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def create_backend(
|
||||
backend_tag: str,
|
||||
func_or_class: Union[Callable, Type[Callable]],
|
||||
*actor_init_args: Any,
|
||||
ray_actor_options: Optional[Dict] = None,
|
||||
config: Optional[Union[BackendConfig, Dict[str, Any]]] = None) -> None:
|
||||
"""Create a backend with the provided tag.
|
||||
|
||||
The backend will serve requests with func_or_class.
|
||||
|
||||
Args:
|
||||
backend_tag (str): a unique tag assign to identify this backend.
|
||||
func_or_class (callable, class): a function or a class implementing
|
||||
__call__.
|
||||
actor_init_args (optional): the arguments to pass to the class.
|
||||
initialization method.
|
||||
ray_actor_options (optional): options to be passed into the
|
||||
@ray.remote decorator for the backend actor.
|
||||
config (dict, serve.BackendConfig, optional): configuration options
|
||||
for this backend. Either a BackendConfig, or a dictionary mapping
|
||||
strings to values for the following supported options:
|
||||
- "num_replicas": number of worker processes to start up that will
|
||||
handle requests to this backend.
|
||||
- "max_batch_size": the maximum number of requests that will
|
||||
be processed in one batch by this backend.
|
||||
- "batch_wait_timeout": time in seconds that backend replicas
|
||||
will wait for a full batch of requests before processing a
|
||||
partial batch.
|
||||
- "max_concurrent_queries": the maximum number of queries that will
|
||||
be sent to a replica of this backend without receiving a
|
||||
response.
|
||||
"""
|
||||
if backend_tag in list_backends():
|
||||
raise ValueError(
|
||||
"Cannot create backend. "
|
||||
"Backend '{}' is already registered.".format(backend_tag))
|
||||
|
||||
if config is None:
|
||||
config = {}
|
||||
replica_config = ReplicaConfig(
|
||||
func_or_class, *actor_init_args, ray_actor_options=ray_actor_options)
|
||||
metadata = BackendMetadata(
|
||||
accepts_batches=replica_config.accepts_batches,
|
||||
is_blocking=replica_config.is_blocking)
|
||||
if isinstance(config, dict):
|
||||
backend_config = BackendConfig.parse_obj({
|
||||
**config, "internal_metadata": metadata
|
||||
})
|
||||
elif isinstance(config, BackendConfig):
|
||||
backend_config = config.copy(update={"internal_metadata": metadata})
|
||||
# When running inside of a backend, _INTERNAL_CONTROLLER_NAME is set to
|
||||
# ensure that the correct instance is connected to.
|
||||
if _INTERNAL_CONTROLLER_NAME is None:
|
||||
controller_name = SERVE_CONTROLLER_NAME
|
||||
else:
|
||||
raise TypeError("config must be a BackendConfig or a dictionary.")
|
||||
backend_config._validate_complete()
|
||||
ray.get(
|
||||
controller.create_backend.remote(backend_tag, backend_config,
|
||||
replica_config))
|
||||
controller_name = _INTERNAL_CONTROLLER_NAME
|
||||
|
||||
# Try to get serve controller if it exists
|
||||
try:
|
||||
controller = ray.get_actor(controller_name)
|
||||
except ValueError:
|
||||
raise RayServeException("Called `serve.connect()` but there is no "
|
||||
"instance running on this Ray cluster. Please "
|
||||
"call `serve.start(detached=True) to start "
|
||||
"one.")
|
||||
|
||||
return Client(controller, controller_name, detached=True)
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def list_backends() -> Dict[str, Dict[str, Any]]:
|
||||
"""Returns a dictionary of all registered backends.
|
||||
def accept_batch(f: Callable) -> Callable:
|
||||
"""Annotation to mark that a serving function accepts batches of requests.
|
||||
|
||||
Dictionary maps backend tags to backend configs.
|
||||
"""
|
||||
return ray.get(controller.get_all_backends.remote())
|
||||
In order to accept batches of requests as input, the implementation must
|
||||
handle a list of requests being passed in rather than just a single
|
||||
request.
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def delete_backend(backend_tag: str) -> None:
|
||||
"""Delete the given backend.
|
||||
|
||||
The backend must not currently be used by any endpoints.
|
||||
"""
|
||||
ray.get(controller.delete_backend.remote(backend_tag))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def set_traffic(endpoint_name: str,
|
||||
traffic_policy_dictionary: Dict[str, float]) -> None:
|
||||
"""Associate a service endpoint with traffic policy.
|
||||
This must be set on any backend implementation that will have
|
||||
max_batch_size set to greater than 1.
|
||||
|
||||
Example:
|
||||
|
||||
>>> serve.set_traffic("service-name", {
|
||||
"backend:v1": 0.5,
|
||||
"backend:v2": 0.5
|
||||
})
|
||||
>>> @serve.accept_batch
|
||||
def serving_func(requests):
|
||||
assert isinstance(requests, list)
|
||||
...
|
||||
|
||||
Args:
|
||||
endpoint_name (str): A registered service endpoint.
|
||||
traffic_policy_dictionary (dict): a dictionary maps backend names
|
||||
to their traffic weights. The weights must sum to 1.
|
||||
>>> class ServingActor:
|
||||
@serve.accept_batch
|
||||
def __call__(self, requests):
|
||||
assert isinstance(requests, list)
|
||||
"""
|
||||
ray.get(
|
||||
controller.set_traffic.remote(endpoint_name,
|
||||
traffic_policy_dictionary))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def shadow_traffic(endpoint_name: str, backend_tag: str,
|
||||
proportion: float) -> None:
|
||||
"""Shadow traffic from an endpoint to a backend.
|
||||
|
||||
The specified proportion of requests will be duplicated and sent to the
|
||||
backend. Responses of the duplicated traffic will be ignored.
|
||||
The backend must not already be in use.
|
||||
|
||||
To stop shadowing traffic to a backend, call `shadow_traffic` with
|
||||
proportion equal to 0.
|
||||
|
||||
Args:
|
||||
endpoint_name (str): A registered service endpoint.
|
||||
backend_tag (str): A registered backend.
|
||||
proportion (float): The proportion of traffic from 0 to 1.
|
||||
"""
|
||||
|
||||
if not isinstance(proportion, (float, int)) or not 0 <= proportion <= 1:
|
||||
raise TypeError("proportion must be a float from 0 to 1.")
|
||||
|
||||
ray.get(
|
||||
controller.shadow_traffic.remote(endpoint_name, backend_tag,
|
||||
proportion))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def get_handle(endpoint_name: str, missing_ok: bool = False) -> RayServeHandle:
|
||||
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
|
||||
|
||||
Args:
|
||||
endpoint_name (str): A registered service endpoint.
|
||||
missing_ok (bool): If true, skip the check for the endpoint existence.
|
||||
It can be useful when the endpoint has not been registered.
|
||||
|
||||
Returns:
|
||||
RayServeHandle
|
||||
"""
|
||||
if not missing_ok:
|
||||
assert endpoint_name in ray.get(controller.get_all_endpoints.remote())
|
||||
|
||||
# TODO(edoakes): we should choose the router on the same node.
|
||||
routers = ray.get(controller.get_routers.remote())
|
||||
return RayServeHandle(
|
||||
list(routers.values())[0],
|
||||
endpoint_name,
|
||||
)
|
||||
f._serve_accept_batch = True
|
||||
return f
|
||||
|
||||
@@ -10,7 +10,6 @@ import time
|
||||
import ray
|
||||
from ray.async_compat import sync_to_async
|
||||
|
||||
from ray import serve
|
||||
from ray.serve import context as serve_context
|
||||
from ray.serve.context import FakeFlaskRequest
|
||||
from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
|
||||
@@ -102,14 +101,11 @@ def create_backend_worker(func_or_class: Union[Callable, Type[Callable]]):
|
||||
|
||||
# TODO(architkulkarni): Add type hints after upgrading cloudpickle
|
||||
class RayServeWrappedWorker(object):
|
||||
def __init__(self,
|
||||
backend_tag,
|
||||
replica_tag,
|
||||
init_args,
|
||||
backend_config: BackendConfig,
|
||||
instance_name=None):
|
||||
serve.init(name=instance_name)
|
||||
|
||||
def __init__(self, backend_tag, replica_tag, init_args,
|
||||
backend_config: BackendConfig, controller_name: str):
|
||||
# Set the controller name so that serve.connect() will connect to
|
||||
# the instance that this backend is running in.
|
||||
ray.serve.api._set_internal_controller_name(controller_name)
|
||||
if is_function:
|
||||
_callable = func_or_class
|
||||
else:
|
||||
|
||||
@@ -78,10 +78,10 @@ async def trial(actors, session, data_size):
|
||||
|
||||
async def main():
|
||||
ray.init(log_to_driver=False)
|
||||
serve.init()
|
||||
client = serve.start()
|
||||
|
||||
serve.create_backend("backend", backend)
|
||||
serve.create_endpoint("endpoint", backend="backend", route="/api")
|
||||
client.create_backend("backend", backend)
|
||||
client.create_endpoint("endpoint", backend="backend", route="/api")
|
||||
|
||||
actors = [Client.remote() for _ in range(NUM_CLIENTS)]
|
||||
for num_replicas in [1, 8]:
|
||||
@@ -100,7 +100,7 @@ async def main():
|
||||
},
|
||||
]:
|
||||
backend_config["num_replicas"] = num_replicas
|
||||
serve.update_backend_config("backend", backend_config)
|
||||
client.update_backend_config("backend", backend_config)
|
||||
print(repr(backend_config) + ":")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# TODO(edoakes): large data causes broken pipe errors.
|
||||
|
||||
@@ -42,7 +42,7 @@ def run_http_benchmark(url, num_queries):
|
||||
@click.option("--max-concurrent-queries", type=int, required=False)
|
||||
def main(num_replicas: int, num_queries: Optional[int],
|
||||
max_concurrent_queries: Optional[int], blocking: bool):
|
||||
serve.init()
|
||||
client = serve.start()
|
||||
|
||||
def noop(_):
|
||||
return "hello world"
|
||||
@@ -52,8 +52,8 @@ def main(num_replicas: int, num_queries: Optional[int],
|
||||
"max_concurrent_queries": max_concurrent_queries
|
||||
}
|
||||
print("Using config", config)
|
||||
serve.create_backend("noop", noop, config=config)
|
||||
serve.create_endpoint("noop", backend="noop", route="/noop")
|
||||
client.create_backend("noop", noop, config=config)
|
||||
client.create_endpoint("noop", backend="noop", route="/noop")
|
||||
|
||||
url = "{}/noop".format(DEFAULT_HTTP_ADDRESS)
|
||||
block_until_ready(url)
|
||||
|
||||
@@ -16,10 +16,10 @@ DEFAULT_HTTP_PORT = 8000
|
||||
#: Max concurrency
|
||||
ASYNC_CONCURRENCY = int(1e6)
|
||||
|
||||
#: Time to wait for HTTP proxy in `serve.init()`
|
||||
#: Max time to wait for HTTP proxy in `serve.start()`.
|
||||
HTTP_PROXY_TIMEOUT = 60
|
||||
|
||||
#: Default histogram buckets for latency tracker
|
||||
#: Default histogram buckets for latency tracker.
|
||||
DEFAULT_LATENCY_BUCKET_MS = [
|
||||
1,
|
||||
2,
|
||||
|
||||
@@ -102,13 +102,17 @@ class ServeController:
|
||||
requires all implementations here to be idempotent.
|
||||
"""
|
||||
|
||||
async def __init__(self, instance_name: str, http_host: str,
|
||||
http_port: str, http_middlewares: List[Any]) -> None:
|
||||
# Unique name of the serve instance managed by this actor. Used to
|
||||
# namespace child actors and checkpoints.
|
||||
self.instance_name = instance_name
|
||||
async def __init__(self,
|
||||
controller_name: str,
|
||||
http_host: str,
|
||||
http_port: str,
|
||||
http_middlewares: List[Any],
|
||||
detached: bool = False):
|
||||
self.detached = detached
|
||||
# Name of this controller actor.
|
||||
self.controller_name = controller_name
|
||||
# Used to read/write checkpoints.
|
||||
self.kv_store = RayInternalKVStore(namespace=instance_name)
|
||||
self.kv_store = RayInternalKVStore(namespace=controller_name)
|
||||
# path -> (endpoint, methods).
|
||||
self.routes = dict()
|
||||
# backend -> BackendInfo.
|
||||
@@ -180,7 +184,7 @@ class ServeController:
|
||||
continue
|
||||
|
||||
router_name = format_actor_name(SERVE_PROXY_NAME,
|
||||
self.instance_name, node_id)
|
||||
self.controller_name, node_id)
|
||||
try:
|
||||
router = ray.get_actor(router_name)
|
||||
except ValueError:
|
||||
@@ -190,7 +194,7 @@ class ServeController:
|
||||
self.http_port))
|
||||
router = HTTPProxyActor.options(
|
||||
name=router_name,
|
||||
lifetime="detached",
|
||||
lifetime="detached" if self.detached else None,
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
@@ -201,7 +205,7 @@ class ServeController:
|
||||
node_id,
|
||||
self.http_host,
|
||||
self.http_port,
|
||||
instance_name=self.instance_name,
|
||||
controller_name=self.controller_name,
|
||||
http_middlewares=self.http_middlewares)
|
||||
|
||||
self.routers[node_id] = router
|
||||
@@ -287,7 +291,7 @@ class ServeController:
|
||||
|
||||
for node_id in router_node_ids:
|
||||
router_name = format_actor_name(SERVE_PROXY_NAME,
|
||||
self.instance_name, node_id)
|
||||
self.controller_name, node_id)
|
||||
self.routers[node_id] = ray.get_actor(router_name)
|
||||
|
||||
# Fetch actor handles for all of the backend replicas in the system.
|
||||
@@ -297,7 +301,7 @@ class ServeController:
|
||||
for backend_tag, replica_tags in self.replicas.items():
|
||||
for replica_tag in replica_tags:
|
||||
replica_name = format_actor_name(replica_tag,
|
||||
self.instance_name)
|
||||
self.controller_name)
|
||||
self.workers[backend_tag][replica_tag] = ray.get_actor(
|
||||
replica_name)
|
||||
|
||||
@@ -389,8 +393,8 @@ class ServeController:
|
||||
"""Fetched by serve handles."""
|
||||
return self.traffic_policies[endpoint]
|
||||
|
||||
async def _start_backend_worker(self, backend_tag: str,
|
||||
replica_tag: str) -> ActorHandle:
|
||||
async def _start_backend_worker(self, backend_tag: str, replica_tag: str,
|
||||
replica_name: str) -> ActorHandle:
|
||||
"""Creates a backend worker and waits for it to start up.
|
||||
|
||||
Assumes that the backend configuration has already been registered
|
||||
@@ -400,18 +404,15 @@ class ServeController:
|
||||
replica_tag, backend_tag))
|
||||
backend_info = self.backends[backend_tag]
|
||||
|
||||
replica_name = format_actor_name(replica_tag, self.instance_name)
|
||||
worker_handle = ray.remote(backend_info.worker_class).options(
|
||||
name=replica_name,
|
||||
lifetime="detached",
|
||||
lifetime="detached" if self.detached else None,
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
**backend_info.replica_config.ray_actor_options).remote(
|
||||
backend_tag,
|
||||
replica_tag,
|
||||
backend_tag, replica_tag,
|
||||
backend_info.replica_config.actor_init_args,
|
||||
backend_info.backend_config,
|
||||
instance_name=self.instance_name)
|
||||
backend_info.backend_config, self.controller_name)
|
||||
# TODO(edoakes): we should probably have a timeout here.
|
||||
await worker_handle.ready.remote()
|
||||
return worker_handle
|
||||
@@ -420,11 +421,12 @@ class ServeController:
|
||||
# NOTE(edoakes): the replicas may already be created if we
|
||||
# failed after creating them but before writing a
|
||||
# checkpoint.
|
||||
replica_name = format_actor_name(replica_tag, self.controller_name)
|
||||
try:
|
||||
worker_handle = ray.get_actor(replica_tag)
|
||||
worker_handle = ray.get_actor(replica_name)
|
||||
except ValueError:
|
||||
worker_handle = await self._start_backend_worker(
|
||||
backend_tag, replica_tag)
|
||||
backend_tag, replica_tag, replica_name)
|
||||
|
||||
self.replicas[backend_tag].append(replica_tag)
|
||||
self.workers[backend_tag][replica_tag] = worker_handle
|
||||
@@ -466,8 +468,10 @@ class ServeController:
|
||||
for replica_tag in replicas_to_stop:
|
||||
# NOTE(edoakes): the replicas may already be stopped if we
|
||||
# failed after stopping them but before writing a checkpoint.
|
||||
replica_name = format_actor_name(replica_tag,
|
||||
self.controller_name)
|
||||
try:
|
||||
replica = ray.get_actor(replica_tag)
|
||||
replica = ray.get_actor(replica_name)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
@@ -556,7 +560,7 @@ class ServeController:
|
||||
self.replicas_to_start[backend_tag].append(replica_tag)
|
||||
|
||||
elif delta_num_replicas < 0:
|
||||
logger.debug("Removing {} replicas from backend {}".format(
|
||||
logger.debug("Removing {} replicas from backend '{}'".format(
|
||||
-delta_num_replicas, backend_tag))
|
||||
assert len(self.replicas[backend_tag]) >= delta_num_replicas
|
||||
for _ in range(-delta_num_replicas):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from ray import serve
|
||||
import requests
|
||||
|
||||
serve.init()
|
||||
client = serve.start()
|
||||
|
||||
|
||||
class Counter:
|
||||
@@ -12,8 +12,8 @@ class Counter:
|
||||
return {"current_counter": self.count}
|
||||
|
||||
|
||||
serve.create_backend("counter", Counter)
|
||||
serve.create_endpoint("counter", backend="counter", route="/counter")
|
||||
client.create_backend("counter", Counter)
|
||||
client.create_endpoint("counter", backend="counter", route="/counter")
|
||||
|
||||
requests.get("http://127.0.0.1:8000/counter").json()
|
||||
# > {"current_counter": self.count}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from ray import serve
|
||||
import requests
|
||||
|
||||
serve.init()
|
||||
client = serve.start()
|
||||
|
||||
|
||||
def echo(flask_request):
|
||||
return "hello " + flask_request.args.get("name", "serve!")
|
||||
|
||||
|
||||
serve.create_backend("hello", echo)
|
||||
serve.create_endpoint("hello", backend="hello", route="/hello")
|
||||
client.create_backend("hello", echo)
|
||||
client.create_endpoint("hello", backend="hello", route="/hello")
|
||||
|
||||
requests.get("http://127.0.0.1:8000/hello").text
|
||||
# > "hello serve!"
|
||||
|
||||
@@ -4,7 +4,7 @@ import ray
|
||||
from ray import serve
|
||||
|
||||
ray.init(num_cpus=10)
|
||||
serve.init()
|
||||
client = serve.start()
|
||||
|
||||
# Our pipeline will be structured as follows:
|
||||
# - Input comes in, the composed model sends it to model_one
|
||||
@@ -27,8 +27,9 @@ def model_two(_unused_flask_request, data=None):
|
||||
|
||||
class ComposedModel:
|
||||
def __init__(self):
|
||||
self.model_one = serve.get_handle("model_one")
|
||||
self.model_two = serve.get_handle("model_two")
|
||||
client = serve.connect()
|
||||
self.model_one = client.get_handle("model_one")
|
||||
self.model_two = client.get_handle("model_two")
|
||||
|
||||
# This method can be called concurrently!
|
||||
async def __call__(self, flask_request):
|
||||
@@ -44,17 +45,17 @@ class ComposedModel:
|
||||
return result
|
||||
|
||||
|
||||
serve.create_backend("model_one", model_one)
|
||||
serve.create_endpoint("model_one", backend="model_one")
|
||||
client.create_backend("model_one", model_one)
|
||||
client.create_endpoint("model_one", backend="model_one")
|
||||
|
||||
serve.create_backend("model_two", model_two)
|
||||
serve.create_endpoint("model_two", backend="model_two")
|
||||
client.create_backend("model_two", model_two)
|
||||
client.create_endpoint("model_two", backend="model_two")
|
||||
|
||||
# max_concurrent_queries is optional. By default, if you pass in an async
|
||||
# function, Ray Serve sets the limit to a high number.
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"composed_backend", ComposedModel, config={"max_concurrent_queries": 10})
|
||||
serve.create_endpoint(
|
||||
client.create_endpoint(
|
||||
"composed", backend="composed_backend", route="/composed")
|
||||
|
||||
for _ in range(5):
|
||||
|
||||
@@ -30,9 +30,9 @@ def batch_adder_v0(flask_requests: List):
|
||||
|
||||
# __doc_deploy_begin__
|
||||
ray.init(num_cpus=10)
|
||||
serve.init()
|
||||
serve.create_backend("adder:v0", batch_adder_v0, config={"max_batch_size": 4})
|
||||
serve.create_endpoint(
|
||||
client = serve.start()
|
||||
client.create_backend("adder:v0", batch_adder_v0, config={"max_batch_size": 4})
|
||||
client.create_endpoint(
|
||||
"adder", backend="adder:v0", route="/adder", methods=["GET"])
|
||||
# __doc_deploy_end__
|
||||
|
||||
@@ -80,12 +80,12 @@ def batch_adder_v1(flask_requests: List, *, numbers: List = []):
|
||||
# __doc_define_servable_v1_end__
|
||||
|
||||
# __doc_deploy_v1_begin__
|
||||
serve.create_backend("adder:v1", batch_adder_v1, config={"max_batch_size": 4})
|
||||
serve.set_traffic("adder", {"adder:v1": 1})
|
||||
client.create_backend("adder:v1", batch_adder_v1, config={"max_batch_size": 4})
|
||||
client.set_traffic("adder", {"adder:v1": 1})
|
||||
# __doc_deploy_v1_end__
|
||||
|
||||
# __doc_query_handle_begin__
|
||||
handle = serve.get_handle("adder")
|
||||
handle = client.get_handle("adder")
|
||||
print(handle)
|
||||
# Output
|
||||
# RayServeHandle(
|
||||
|
||||
@@ -69,9 +69,9 @@ ray.init(address="auto")
|
||||
# now we initialize /connect to the Ray service
|
||||
|
||||
# listen on 0.0.0.0 to make the HTTP server accessible from other machines.
|
||||
serve.init(http_host="0.0.0.0")
|
||||
serve.create_backend("lr:v1", BoostingModel)
|
||||
serve.create_endpoint("iris_classifier", backend="lr:v1", route="/regressor")
|
||||
client = serve.start(http_host="0.0.0.0")
|
||||
client.create_backend("lr:v1", BoostingModel)
|
||||
client.create_endpoint("iris_classifier", backend="lr:v1", route="/regressor")
|
||||
# __doc_create_deploy_end__
|
||||
|
||||
# __doc_query_begin__
|
||||
@@ -163,7 +163,7 @@ class BoostingModelv2:
|
||||
# now we initialize /connect to the Ray service
|
||||
|
||||
|
||||
serve.init()
|
||||
serve.create_backend("lr:v2", BoostingModelv2)
|
||||
serve.set_traffic("iris_classifier", {"lr:v2": 0.25, "lr:v1": 0.75})
|
||||
client = serve.connect()
|
||||
client.create_backend("lr:v2", BoostingModelv2)
|
||||
client.set_traffic("iris_classifier", {"lr:v2": 0.25, "lr:v1": 0.75})
|
||||
# __doc_create_deploy_2_end__
|
||||
|
||||
@@ -46,9 +46,9 @@ class ImageModel:
|
||||
# __doc_define_servable_end__
|
||||
|
||||
# __doc_deploy_begin__
|
||||
serve.init()
|
||||
serve.create_backend("resnet18:v0", ImageModel)
|
||||
serve.create_endpoint(
|
||||
client = serve.start()
|
||||
client.create_backend("resnet18:v0", ImageModel)
|
||||
client.create_endpoint(
|
||||
"predictor",
|
||||
backend="resnet18:v0",
|
||||
route="/image_predict",
|
||||
|
||||
@@ -65,9 +65,9 @@ class BoostingModel:
|
||||
# __doc_define_servable_end__
|
||||
|
||||
# __doc_deploy_begin__
|
||||
serve.init()
|
||||
serve.create_backend("lr:v1", BoostingModel)
|
||||
serve.create_endpoint("iris_classifier", backend="lr:v1", route="/regressor")
|
||||
client = serve.start()
|
||||
client.create_backend("lr:v1", BoostingModel)
|
||||
client.create_endpoint("iris_classifier", backend="lr:v1", route="/regressor")
|
||||
# __doc_deploy_end__
|
||||
|
||||
# __doc_query_begin__
|
||||
|
||||
@@ -68,9 +68,9 @@ class TFMnistModel:
|
||||
# __doc_define_servable_end__
|
||||
|
||||
# __doc_deploy_begin__
|
||||
serve.init()
|
||||
serve.create_backend("tf:v1", TFMnistModel, "/tmp/mnist_model.h5")
|
||||
serve.create_endpoint("tf_classifier", backend="tf:v1", route="/mnist")
|
||||
client = serve.start()
|
||||
client.create_backend("tf:v1", TFMnistModel, "/tmp/mnist_model.h5")
|
||||
client.create_endpoint("tf_classifier", backend="tf:v1", route="/mnist")
|
||||
# __doc_deploy_end__
|
||||
|
||||
# __doc_query_begin__
|
||||
|
||||
@@ -13,10 +13,9 @@ def echo(flask_request):
|
||||
return ["hello " + flask_request.args.get("name", "serve!")]
|
||||
|
||||
|
||||
serve.init()
|
||||
|
||||
serve.create_backend("echo:v1", echo)
|
||||
serve.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
|
||||
client = serve.start()
|
||||
client.create_backend("echo:v1", echo)
|
||||
client.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
|
||||
|
||||
while True:
|
||||
resp = requests.get("http://127.0.0.1:8000/echo").json()
|
||||
|
||||
@@ -36,9 +36,9 @@ class MagicCounter:
|
||||
return base_number + self.increment
|
||||
|
||||
|
||||
serve.init()
|
||||
serve.create_backend("counter:v1", MagicCounter, 42) # increment=42
|
||||
serve.create_endpoint("magic_counter", backend="counter:v1", route="/counter")
|
||||
client = serve.start()
|
||||
client.create_backend("counter:v1", MagicCounter, 42) # increment=42
|
||||
client.create_endpoint("magic_counter", backend="counter:v1", route="/counter")
|
||||
|
||||
print("Sending ten queries via HTTP")
|
||||
for i in range(10):
|
||||
@@ -50,7 +50,7 @@ for i in range(10):
|
||||
time.sleep(0.2)
|
||||
|
||||
print("Sending ten queries via Python")
|
||||
handle = serve.get_handle("magic_counter")
|
||||
handle = client.get_handle("magic_counter")
|
||||
for i in range(10):
|
||||
print("> Pinging handle.remote(base_number={})".format(i))
|
||||
result = ray.get(handle.remote(base_number=i))
|
||||
|
||||
@@ -47,11 +47,11 @@ class MagicCounter:
|
||||
return result
|
||||
|
||||
|
||||
serve.init()
|
||||
serve.create_backend(
|
||||
client = serve.start()
|
||||
client.create_backend(
|
||||
"counter:v1", MagicCounter, 42,
|
||||
config={"max_batch_size": 5}) # increment=42
|
||||
serve.create_endpoint("magic_counter", backend="counter:v1", route="/counter")
|
||||
client.create_endpoint("magic_counter", backend="counter:v1", route="/counter")
|
||||
|
||||
print("Sending ten queries via HTTP")
|
||||
for i in range(10):
|
||||
@@ -63,7 +63,7 @@ for i in range(10):
|
||||
time.sleep(0.2)
|
||||
|
||||
print("Sending ten queries via Python")
|
||||
handle = serve.get_handle("magic_counter")
|
||||
handle = client.get_handle("magic_counter")
|
||||
for i in range(10):
|
||||
print("> Pinging handle.remote(base_number={})".format(i))
|
||||
result = ray.get(handle.remote(base_number=i))
|
||||
|
||||
@@ -26,16 +26,16 @@ class MagicCounter:
|
||||
return ""
|
||||
|
||||
|
||||
serve.init()
|
||||
client = serve.start()
|
||||
# specify max_batch_size in BackendConfig
|
||||
backend_config = {"max_batch_size": 5}
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"counter:v1", MagicCounter, 42, config=backend_config) # increment=42
|
||||
print("Backend Config for backend: 'counter:v1'")
|
||||
print(backend_config)
|
||||
serve.create_endpoint("magic_counter", backend="counter:v1", route="/counter")
|
||||
client.create_endpoint("magic_counter", backend="counter:v1", route="/counter")
|
||||
|
||||
handle = serve.get_handle("magic_counter")
|
||||
handle = client.get_handle("magic_counter")
|
||||
future_list = []
|
||||
|
||||
# fire 30 requests
|
||||
|
||||
@@ -38,9 +38,9 @@ def echo(_):
|
||||
raise Exception("Something went wrong...")
|
||||
|
||||
|
||||
serve.init()
|
||||
client = serve.start()
|
||||
|
||||
serve.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
|
||||
client.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
|
||||
|
||||
for _ in range(2):
|
||||
resp = requests.get("http://127.0.0.1:8000/echo").json()
|
||||
@@ -49,6 +49,6 @@ for _ in range(2):
|
||||
print("...Sleeping for 2 seconds...")
|
||||
time.sleep(2)
|
||||
|
||||
handle = serve.get_handle("my_endpoint")
|
||||
handle = client.get_handle("my_endpoint")
|
||||
print("Invoke from python will raise exception with traceback:")
|
||||
ray.get(handle.remote())
|
||||
|
||||
@@ -7,7 +7,7 @@ import ray.serve as serve
|
||||
|
||||
# initialize ray serve system.
|
||||
ray.init(num_cpus=10)
|
||||
serve.init()
|
||||
client = serve.start()
|
||||
|
||||
|
||||
# a backend can be a function or class.
|
||||
@@ -18,16 +18,16 @@ def echo_v1(flask_request, response="hello from python!"):
|
||||
return response
|
||||
|
||||
|
||||
serve.create_backend("echo:v1", echo_v1)
|
||||
client.create_backend("echo:v1", echo_v1)
|
||||
|
||||
# An endpoint is associated with an HTTP path and traffic to the endpoint
|
||||
# will be serviced by the echo:v1 backend.
|
||||
serve.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
|
||||
client.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
|
||||
|
||||
print(requests.get("http://127.0.0.1:8000/echo", timeout=0.5).text)
|
||||
# The service will be reachable from http
|
||||
|
||||
print(ray.get(serve.get_handle("my_endpoint").remote(response="hello")))
|
||||
print(ray.get(client.get_handle("my_endpoint").remote(response="hello")))
|
||||
|
||||
# as well as within the ray system.
|
||||
|
||||
@@ -38,10 +38,10 @@ def echo_v2(flask_request):
|
||||
return "something new"
|
||||
|
||||
|
||||
serve.create_backend("echo:v2", echo_v2)
|
||||
client.create_backend("echo:v2", echo_v2)
|
||||
|
||||
# The two backend will now split the traffic 50%-50%.
|
||||
serve.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
|
||||
client.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
|
||||
|
||||
# Observe requests are now split between two backends.
|
||||
for _ in range(10):
|
||||
@@ -49,5 +49,5 @@ for _ in range(10):
|
||||
time.sleep(0.5)
|
||||
|
||||
# You can also change number of replicas for each backend independently.
|
||||
serve.update_backend_config("echo:v1", {"num_replicas": 2})
|
||||
serve.update_backend_config("echo:v2", {"num_replicas": 2})
|
||||
client.update_backend_config("echo:v1", {"num_replicas": 2})
|
||||
client.update_backend_config("echo:v2", {"num_replicas": 2})
|
||||
|
||||
@@ -5,42 +5,42 @@ import ray
|
||||
import ray.serve as serve
|
||||
import time
|
||||
|
||||
# initialize ray serve system.
|
||||
serve.init()
|
||||
# Initialize ray serve instance.
|
||||
client = serve.start()
|
||||
|
||||
|
||||
# a backend can be a function or class.
|
||||
# it can be made to be invoked from web as well as python.
|
||||
# A backend can be a function or class.
|
||||
# It can be made to be invoked via HTTP as well as python.
|
||||
def echo_v1(_, response="hello from python!"):
|
||||
return f"echo_v1({response})"
|
||||
|
||||
|
||||
serve.create_backend("echo_v1", echo_v1)
|
||||
serve.create_endpoint("echo_v1", backend="echo_v1", route="/echo_v1")
|
||||
client.create_backend("echo_v1", echo_v1)
|
||||
client.create_endpoint("echo_v1", backend="echo_v1", route="/echo_v1")
|
||||
|
||||
|
||||
def echo_v2(_, relay=""):
|
||||
return f"echo_v2({relay})"
|
||||
|
||||
|
||||
serve.create_backend("echo_v2", echo_v2)
|
||||
serve.create_endpoint("echo_v2", backend="echo_v2", route="/echo_v2")
|
||||
client.create_backend("echo_v2", echo_v2)
|
||||
client.create_endpoint("echo_v2", backend="echo_v2", route="/echo_v2")
|
||||
|
||||
|
||||
def echo_v3(_, relay=""):
|
||||
return f"echo_v3({relay})"
|
||||
|
||||
|
||||
serve.create_backend("echo_v3", echo_v3)
|
||||
serve.create_endpoint("echo_v3", backend="echo_v3", route="/echo_v3")
|
||||
client.create_backend("echo_v3", echo_v3)
|
||||
client.create_endpoint("echo_v3", backend="echo_v3", route="/echo_v3")
|
||||
|
||||
|
||||
def echo_v4(_, relay1="", relay2=""):
|
||||
return f"echo_v4({relay1} , {relay2})"
|
||||
|
||||
|
||||
serve.create_backend("echo_v4", echo_v4)
|
||||
serve.create_endpoint("echo_v4", backend="echo_v4", route="/echo_v4")
|
||||
client.create_backend("echo_v4", echo_v4)
|
||||
client.create_endpoint("echo_v4", backend="echo_v4", route="/echo_v4")
|
||||
"""
|
||||
The pipeline created is as follows -
|
||||
"my_endpoint1"
|
||||
@@ -62,10 +62,10 @@ The pipeline created is as follows -
|
||||
"""
|
||||
|
||||
# get the handle of the endpoints
|
||||
handle1 = serve.get_handle("echo_v1")
|
||||
handle2 = serve.get_handle("echo_v2")
|
||||
handle3 = serve.get_handle("echo_v3")
|
||||
handle4 = serve.get_handle("echo_v4")
|
||||
handle1 = client.get_handle("echo_v1")
|
||||
handle2 = client.get_handle("echo_v2")
|
||||
handle3 = client.get_handle("echo_v3")
|
||||
handle4 = client.get_handle("echo_v4")
|
||||
|
||||
start = time.time()
|
||||
print("Start firing to the pipeline: {} s".format(time.time()))
|
||||
|
||||
@@ -30,10 +30,10 @@ def echo_v2(_):
|
||||
return "v2"
|
||||
|
||||
|
||||
serve.init()
|
||||
client = serve.start()
|
||||
|
||||
serve.create_backend("echo:v1", echo_v1)
|
||||
serve.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
|
||||
client.create_backend("echo:v1", echo_v1)
|
||||
client.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
|
||||
|
||||
for _ in range(3):
|
||||
resp = requests.get("http://127.0.0.1:8000/echo").json()
|
||||
@@ -42,8 +42,8 @@ for _ in range(3):
|
||||
print("...Sleeping for 2 seconds...")
|
||||
time.sleep(2)
|
||||
|
||||
serve.create_backend("echo:v2", echo_v2)
|
||||
serve.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
|
||||
client.create_backend("echo:v2", echo_v2)
|
||||
client.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
|
||||
while True:
|
||||
resp = requests.get("http://127.0.0.1:8000/echo").json()
|
||||
print(pformat_color_json(resp))
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
|
||||
@@ -42,6 +40,14 @@ class RayServeHandle:
|
||||
self.shard_key = shard_key
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
"""Invoke a request on the endpoint.
|
||||
|
||||
Returns a Ray ObjectRef whose result can be waited for or retrieved
|
||||
using `ray.wait` or `ray.get`, respectively.
|
||||
|
||||
Returns:
|
||||
ray.ObjectRef
|
||||
"""
|
||||
if len(args) > 0:
|
||||
raise ValueError(
|
||||
"handle.remote must be invoked with keyword arguments.")
|
||||
@@ -59,6 +65,14 @@ class RayServeHandle:
|
||||
method_name: Optional[str] = None,
|
||||
http_method: Optional[str] = None,
|
||||
shard_key: Optional[str] = None):
|
||||
"""Set options for this handle.
|
||||
|
||||
Args:
|
||||
method_name(str): The method to invoke on the backend.
|
||||
http_method(str): The HTTP method to use for the request.
|
||||
shard_key(str): A string to use to deterministically map this
|
||||
request to a backend if there are multiple for this endpoint.
|
||||
"""
|
||||
return RayServeHandle(
|
||||
self.router_handle,
|
||||
self.endpoint_name,
|
||||
@@ -68,11 +82,5 @@ class RayServeHandle:
|
||||
shard_key=self.shard_key or shard_key,
|
||||
)
|
||||
|
||||
def _get_traffic_policy(self):
|
||||
controller = serve.api._get_controller()
|
||||
return ray.get(
|
||||
controller.get_traffic_policy.remote(self.endpoint_name))
|
||||
|
||||
def __repr__(self):
|
||||
return (f"RayServeHandle(Endpoint='{self.endpoint_name}', "
|
||||
f"Traffic={self._get_traffic_policy()})")
|
||||
return f"RayServeHandle(endpoint='{self.endpoint_name}')"
|
||||
|
||||
@@ -6,7 +6,6 @@ import uvicorn
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray import serve
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.experimental import metrics
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
@@ -27,9 +26,9 @@ class HTTPProxy:
|
||||
# blocks forever
|
||||
"""
|
||||
|
||||
async def fetch_config_from_controller(self, name, instance_name=None):
|
||||
async def fetch_config_from_controller(self, name, controller_name):
|
||||
assert ray.is_initialized()
|
||||
controller = serve.api._get_controller()
|
||||
controller = ray.get_actor(controller_name)
|
||||
|
||||
self.route_table = await controller.get_router_config.remote()
|
||||
|
||||
@@ -38,7 +37,7 @@ class HTTPProxy:
|
||||
"requests", ["route"])
|
||||
|
||||
self.router = Router()
|
||||
await self.router.setup(name, instance_name)
|
||||
await self.router.setup(name, controller_name)
|
||||
|
||||
def set_route_table(self, route_table):
|
||||
self.route_table = route_table
|
||||
@@ -133,15 +132,14 @@ class HTTPProxyActor:
|
||||
name,
|
||||
host,
|
||||
port,
|
||||
instance_name=None,
|
||||
controller_name,
|
||||
http_middlewares: List["starlette.middleware.Middleware"] = []):
|
||||
serve.init(name=instance_name)
|
||||
self.app = HTTPProxy()
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self.app = HTTPProxy()
|
||||
await self.app.fetch_config_from_controller(name, instance_name)
|
||||
await self.app.fetch_config_from_controller(name, controller_name)
|
||||
|
||||
self.wrapped_app = self.app
|
||||
for middleware in http_middlewares:
|
||||
|
||||
@@ -9,7 +9,6 @@ from dataclasses import dataclass
|
||||
from ray.exceptions import RayTaskError
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.experimental import metrics
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.endpoint_policy import RandomEndpointPolicy
|
||||
@@ -53,7 +52,7 @@ class Query:
|
||||
class Router:
|
||||
"""A router that routes request to available workers."""
|
||||
|
||||
async def setup(self, name, instance_name=None):
|
||||
async def setup(self, name, controller_name):
|
||||
# Note: Several queues are used in the router
|
||||
# - When a request come in, it's placed inside its corresponding
|
||||
# endpoint_queue.
|
||||
@@ -104,8 +103,7 @@ class Router:
|
||||
# the controller. We use a "pull-based" approach instead of pushing
|
||||
# them from the controller so that the router can transparently recover
|
||||
# from failure.
|
||||
serve.init(name=instance_name)
|
||||
self.controller = serve.api._get_controller()
|
||||
self.controller = ray.get_actor(controller_name)
|
||||
|
||||
traffic_policies = ray.get(
|
||||
self.controller.get_traffic_policies.remote())
|
||||
|
||||
@@ -15,19 +15,15 @@ def _shared_serve_instance():
|
||||
num_cpus=36,
|
||||
_metrics_export_port=9999,
|
||||
_system_config={"metrics_report_interval_ms": 1000})
|
||||
serve.init()
|
||||
yield
|
||||
yield serve.start(detached=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serve_instance(_shared_serve_instance):
|
||||
serve.init()
|
||||
yield
|
||||
# Re-init if necessary.
|
||||
serve.init()
|
||||
controller = serve.api._get_controller()
|
||||
yield _shared_serve_instance
|
||||
controller = _shared_serve_instance._controller
|
||||
# Clear all state between tests to avoid naming collisions.
|
||||
for endpoint in ray.get(controller.get_all_endpoints.remote()):
|
||||
serve.delete_endpoint(endpoint)
|
||||
_shared_serve_instance.delete_endpoint(endpoint)
|
||||
for backend in ray.get(controller.get_all_backends.remote()):
|
||||
serve.delete_backend(backend)
|
||||
_shared_serve_instance.delete_backend(backend)
|
||||
|
||||
+234
-161
@@ -8,7 +8,7 @@ import requests
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.test_utils import wait_for_condition
|
||||
from ray.serve import constants
|
||||
from ray.serve.constants import SERVE_PROXY_NAME
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.config import BackendConfig
|
||||
from ray.serve.utils import (block_until_http_ready, format_actor_name,
|
||||
@@ -16,13 +16,13 @@ from ray.serve.utils import (block_until_http_ready, format_actor_name,
|
||||
|
||||
|
||||
def test_e2e(serve_instance):
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
def function(flask_request):
|
||||
return {"method": flask_request.method}
|
||||
|
||||
serve.create_backend("echo:v1", function)
|
||||
serve.create_endpoint(
|
||||
client.create_backend("echo:v1", function)
|
||||
client.create_endpoint(
|
||||
"endpoint", backend="echo:v1", route="/api", methods=["GET", "POST"])
|
||||
|
||||
retry_count = 5
|
||||
@@ -49,12 +49,14 @@ def test_e2e(serve_instance):
|
||||
|
||||
|
||||
def test_call_method(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class CallMethod:
|
||||
def method(self, request):
|
||||
return "hello"
|
||||
|
||||
serve.create_backend("backend", CallMethod)
|
||||
serve.create_endpoint("endpoint", backend="backend", route="/api")
|
||||
client.create_backend("backend", CallMethod)
|
||||
client.create_endpoint("endpoint", backend="backend", route="/api")
|
||||
|
||||
# Test HTTP path.
|
||||
resp = requests.get(
|
||||
@@ -64,59 +66,69 @@ def test_call_method(serve_instance):
|
||||
assert resp.text == "hello"
|
||||
|
||||
# Test serve handle path.
|
||||
handle = serve.get_handle("endpoint")
|
||||
handle = client.get_handle("endpoint")
|
||||
assert ray.get(handle.options("method").remote()) == "hello"
|
||||
|
||||
|
||||
def test_no_route(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
def func(_, i=1):
|
||||
return 1
|
||||
|
||||
serve.create_backend("backend:1", func)
|
||||
serve.create_endpoint("noroute-endpoint", backend="backend:1")
|
||||
service_handle = serve.get_handle("noroute-endpoint")
|
||||
client.create_backend("backend:1", func)
|
||||
client.create_endpoint("noroute-endpoint", backend="backend:1")
|
||||
service_handle = client.get_handle("noroute-endpoint")
|
||||
result = ray.get(service_handle.remote(i=1))
|
||||
assert result == 1
|
||||
|
||||
|
||||
def test_reject_duplicate_backend(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
def f():
|
||||
pass
|
||||
|
||||
def g():
|
||||
pass
|
||||
|
||||
serve.create_backend("backend", f)
|
||||
client.create_backend("backend", f)
|
||||
with pytest.raises(ValueError):
|
||||
serve.create_backend("backend", g)
|
||||
client.create_backend("backend", g)
|
||||
|
||||
|
||||
def test_reject_duplicate_route(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
def f():
|
||||
pass
|
||||
|
||||
serve.create_backend("backend", f)
|
||||
client.create_backend("backend", f)
|
||||
|
||||
route = "/foo"
|
||||
serve.create_endpoint("bar", backend="backend", route=route)
|
||||
client.create_endpoint("bar", backend="backend", route=route)
|
||||
with pytest.raises(ValueError):
|
||||
serve.create_endpoint("foo", backend="backend", route=route)
|
||||
client.create_endpoint("foo", backend="backend", route=route)
|
||||
|
||||
|
||||
def test_reject_duplicate_endpoint(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
def f():
|
||||
pass
|
||||
|
||||
serve.create_backend("backend", f)
|
||||
client.create_backend("backend", f)
|
||||
|
||||
endpoint_name = "foo"
|
||||
serve.create_endpoint(endpoint_name, backend="backend", route="/ok")
|
||||
client.create_endpoint(endpoint_name, backend="backend", route="/ok")
|
||||
with pytest.raises(ValueError):
|
||||
serve.create_endpoint(
|
||||
client.create_endpoint(
|
||||
endpoint_name, backend="backend", route="/different")
|
||||
|
||||
|
||||
def test_reject_duplicate_endpoint_and_route(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class SimpleBackend(object):
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
@@ -124,26 +136,30 @@ def test_reject_duplicate_endpoint_and_route(serve_instance):
|
||||
def __call__(self, *args, **kwargs):
|
||||
return {"message": self.message}
|
||||
|
||||
serve.create_backend("backend1", SimpleBackend, "First")
|
||||
serve.create_backend("backend2", SimpleBackend, "Second")
|
||||
client.create_backend("backend1", SimpleBackend, "First")
|
||||
client.create_backend("backend2", SimpleBackend, "Second")
|
||||
|
||||
serve.create_endpoint("test", backend="backend1", route="/test")
|
||||
client.create_endpoint("test", backend="backend1", route="/test")
|
||||
with pytest.raises(ValueError):
|
||||
serve.create_endpoint("test", backend="backend2", route="/test")
|
||||
client.create_endpoint("test", backend="backend2", route="/test")
|
||||
|
||||
|
||||
def test_set_traffic_missing_data(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
endpoint_name = "foobar"
|
||||
backend_name = "foo_backend"
|
||||
serve.create_backend(backend_name, lambda: 5)
|
||||
serve.create_endpoint(endpoint_name, backend=backend_name)
|
||||
client.create_backend(backend_name, lambda: 5)
|
||||
client.create_endpoint(endpoint_name, backend=backend_name)
|
||||
with pytest.raises(ValueError):
|
||||
serve.set_traffic(endpoint_name, {"nonexistent_backend": 1.0})
|
||||
client.set_traffic(endpoint_name, {"nonexistent_backend": 1.0})
|
||||
with pytest.raises(ValueError):
|
||||
serve.set_traffic("nonexistent_endpoint_name", {backend_name: 1.0})
|
||||
client.set_traffic("nonexistent_endpoint_name", {backend_name: 1.0})
|
||||
|
||||
|
||||
def test_scaling_replicas(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class Counter:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
@@ -152,9 +168,9 @@ def test_scaling_replicas(serve_instance):
|
||||
self.count += 1
|
||||
return self.count
|
||||
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"counter:v1", Counter, config=BackendConfig(num_replicas=2))
|
||||
serve.create_endpoint("counter", backend="counter:v1", route="/increment")
|
||||
client.create_endpoint("counter", backend="counter:v1", route="/increment")
|
||||
|
||||
# Keep checking the routing table until /increment is populated
|
||||
while "/increment" not in requests.get(
|
||||
@@ -169,7 +185,7 @@ def test_scaling_replicas(serve_instance):
|
||||
# If the load is shared among two replicas. The max result cannot be 10.
|
||||
assert max(counter_result) < 10
|
||||
|
||||
serve.update_backend_config("counter:v1", {"num_replicas": 1})
|
||||
client.update_backend_config("counter:v1", {"num_replicas": 1})
|
||||
|
||||
counter_result = []
|
||||
for _ in range(10):
|
||||
@@ -181,6 +197,8 @@ def test_scaling_replicas(serve_instance):
|
||||
|
||||
|
||||
def test_scaling_replicas_legacy(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class Counter:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
@@ -189,8 +207,8 @@ def test_scaling_replicas_legacy(serve_instance):
|
||||
self.count += 1
|
||||
return self.count
|
||||
|
||||
serve.create_backend("counter:v1", Counter, config={"num_replicas": 2})
|
||||
serve.create_endpoint("counter", backend="counter:v1", route="/increment")
|
||||
client.create_backend("counter:v1", Counter, config={"num_replicas": 2})
|
||||
client.create_endpoint("counter", backend="counter:v1", route="/increment")
|
||||
|
||||
# Keep checking the routing table until /increment is populated
|
||||
while "/increment" not in requests.get(
|
||||
@@ -205,7 +223,7 @@ def test_scaling_replicas_legacy(serve_instance):
|
||||
# If the load is shared among two replicas. The max result cannot be 10.
|
||||
assert max(counter_result) < 10
|
||||
|
||||
serve.update_backend_config("counter:v1", {"num_replicas": 1})
|
||||
client.update_backend_config("counter:v1", {"num_replicas": 1})
|
||||
|
||||
counter_result = []
|
||||
for _ in range(10):
|
||||
@@ -217,6 +235,8 @@ def test_scaling_replicas_legacy(serve_instance):
|
||||
|
||||
|
||||
def test_batching(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class BatchingExample:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
@@ -228,11 +248,11 @@ def test_batching(serve_instance):
|
||||
return [self.count] * batch_size
|
||||
|
||||
# set the max batch size
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"counter:v11",
|
||||
BatchingExample,
|
||||
config=BackendConfig(max_batch_size=5, batch_wait_timeout=1))
|
||||
serve.create_endpoint(
|
||||
client.create_endpoint(
|
||||
"counter1", backend="counter:v11", route="/increment2")
|
||||
|
||||
# Keep checking the routing table until /increment is populated
|
||||
@@ -241,7 +261,7 @@ def test_batching(serve_instance):
|
||||
time.sleep(0.2)
|
||||
|
||||
future_list = []
|
||||
handle = serve.get_handle("counter1")
|
||||
handle = client.get_handle("counter1")
|
||||
for _ in range(20):
|
||||
f = handle.remote(temp=1)
|
||||
future_list.append(f)
|
||||
@@ -254,6 +274,8 @@ def test_batching(serve_instance):
|
||||
|
||||
|
||||
def test_batching_legacy(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class BatchingExample:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
@@ -265,14 +287,14 @@ def test_batching_legacy(serve_instance):
|
||||
return [self.count] * batch_size
|
||||
|
||||
# set the max batch size
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"counter:v11",
|
||||
BatchingExample,
|
||||
config={
|
||||
"max_batch_size": 5,
|
||||
"batch_wait_timeout": 1
|
||||
})
|
||||
serve.create_endpoint(
|
||||
client.create_endpoint(
|
||||
"counter1", backend="counter:v11", route="/increment2")
|
||||
|
||||
# Keep checking the routing table until /increment is populated
|
||||
@@ -281,7 +303,7 @@ def test_batching_legacy(serve_instance):
|
||||
time.sleep(0.2)
|
||||
|
||||
future_list = []
|
||||
handle = serve.get_handle("counter1")
|
||||
handle = client.get_handle("counter1")
|
||||
for _ in range(20):
|
||||
f = handle.remote(temp=1)
|
||||
future_list.append(f)
|
||||
@@ -294,6 +316,8 @@ def test_batching_legacy(serve_instance):
|
||||
|
||||
|
||||
def test_batching_exception(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class NoListReturned:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
@@ -304,17 +328,19 @@ def test_batching_exception(serve_instance):
|
||||
return batch_size
|
||||
|
||||
# set the max batch size
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"exception:v1", NoListReturned, config=BackendConfig(max_batch_size=5))
|
||||
serve.create_endpoint(
|
||||
client.create_endpoint(
|
||||
"exception-test", backend="exception:v1", route="/noListReturned")
|
||||
|
||||
handle = serve.get_handle("exception-test")
|
||||
handle = client.get_handle("exception-test")
|
||||
with pytest.raises(ray.exceptions.RayTaskError):
|
||||
assert ray.get(handle.remote(temp=1))
|
||||
|
||||
|
||||
def test_batching_exception_legacy(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class NoListReturned:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
@@ -325,17 +351,19 @@ def test_batching_exception_legacy(serve_instance):
|
||||
return batch_size
|
||||
|
||||
# set the max batch size
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"exception:v1", NoListReturned, config={"max_batch_size": 5})
|
||||
serve.create_endpoint(
|
||||
client.create_endpoint(
|
||||
"exception-test", backend="exception:v1", route="/noListReturned")
|
||||
|
||||
handle = serve.get_handle("exception-test")
|
||||
handle = client.get_handle("exception-test")
|
||||
with pytest.raises(ray.exceptions.RayTaskError):
|
||||
assert ray.get(handle.remote(temp=1))
|
||||
|
||||
|
||||
def test_updating_config(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class BatchSimple:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
@@ -345,17 +373,17 @@ def test_updating_config(serve_instance):
|
||||
batch_size = serve.context.batch_size
|
||||
return [1] * batch_size
|
||||
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"bsimple:v1",
|
||||
BatchSimple,
|
||||
config=BackendConfig(max_batch_size=2, num_replicas=3))
|
||||
serve.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")
|
||||
client.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")
|
||||
|
||||
controller = serve.api._get_controller()
|
||||
controller = client._controller
|
||||
old_replica_tag_list = ray.get(
|
||||
controller._list_replicas.remote("bsimple:v1"))
|
||||
|
||||
serve.update_backend_config("bsimple:v1", BackendConfig(max_batch_size=5))
|
||||
client.update_backend_config("bsimple:v1", BackendConfig(max_batch_size=5))
|
||||
new_replica_tag_list = ray.get(
|
||||
controller._list_replicas.remote("bsimple:v1"))
|
||||
new_all_tag_list = []
|
||||
@@ -370,6 +398,8 @@ def test_updating_config(serve_instance):
|
||||
|
||||
|
||||
def test_updating_config_legacy(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class BatchSimple:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
@@ -379,20 +409,20 @@ def test_updating_config_legacy(serve_instance):
|
||||
batch_size = serve.context.batch_size
|
||||
return [1] * batch_size
|
||||
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"bsimple:v1",
|
||||
BatchSimple,
|
||||
config={
|
||||
"max_batch_size": 2,
|
||||
"num_replicas": 3
|
||||
})
|
||||
serve.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")
|
||||
client.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")
|
||||
|
||||
controller = serve.api._get_controller()
|
||||
controller = client._controller
|
||||
old_replica_tag_list = ray.get(
|
||||
controller._list_replicas.remote("bsimple:v1"))
|
||||
|
||||
serve.update_backend_config("bsimple:v1", {"max_batch_size": 5})
|
||||
client.update_backend_config("bsimple:v1", {"max_batch_size": 5})
|
||||
new_replica_tag_list = ray.get(
|
||||
controller._list_replicas.remote("bsimple:v1"))
|
||||
new_all_tag_list = []
|
||||
@@ -407,79 +437,85 @@ def test_updating_config_legacy(serve_instance):
|
||||
|
||||
|
||||
def test_delete_backend(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
def function():
|
||||
return "hello"
|
||||
|
||||
serve.create_backend("delete:v1", function)
|
||||
serve.create_endpoint(
|
||||
client.create_backend("delete:v1", function)
|
||||
client.create_endpoint(
|
||||
"delete_backend", backend="delete:v1", route="/delete-backend")
|
||||
|
||||
assert requests.get("http://127.0.0.1:8000/delete-backend").text == "hello"
|
||||
|
||||
# Check that we can't delete the backend while it's in use.
|
||||
with pytest.raises(ValueError):
|
||||
serve.delete_backend("delete:v1")
|
||||
client.delete_backend("delete:v1")
|
||||
|
||||
serve.create_backend("delete:v2", function)
|
||||
serve.set_traffic("delete_backend", {"delete:v1": 0.5, "delete:v2": 0.5})
|
||||
client.create_backend("delete:v2", function)
|
||||
client.set_traffic("delete_backend", {"delete:v1": 0.5, "delete:v2": 0.5})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
serve.delete_backend("delete:v1")
|
||||
client.delete_backend("delete:v1")
|
||||
|
||||
# Check that the backend can be deleted once it's no longer in use.
|
||||
serve.set_traffic("delete_backend", {"delete:v2": 1.0})
|
||||
serve.delete_backend("delete:v1")
|
||||
client.set_traffic("delete_backend", {"delete:v2": 1.0})
|
||||
client.delete_backend("delete:v1")
|
||||
|
||||
# Check that we can no longer use the previously deleted backend.
|
||||
with pytest.raises(ValueError):
|
||||
serve.set_traffic("delete_backend", {"delete:v1": 1.0})
|
||||
client.set_traffic("delete_backend", {"delete:v1": 1.0})
|
||||
|
||||
def function2():
|
||||
return "olleh"
|
||||
|
||||
# Check that we can now reuse the previously delete backend's tag.
|
||||
serve.create_backend("delete:v1", function2)
|
||||
serve.set_traffic("delete_backend", {"delete:v1": 1.0})
|
||||
client.create_backend("delete:v1", function2)
|
||||
client.set_traffic("delete_backend", {"delete:v1": 1.0})
|
||||
|
||||
assert requests.get("http://127.0.0.1:8000/delete-backend").text == "olleh"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("route", [None, "/delete-endpoint"])
|
||||
def test_delete_endpoint(serve_instance, route):
|
||||
client = serve_instance
|
||||
|
||||
def function():
|
||||
return "hello"
|
||||
|
||||
backend_name = "delete-endpoint:v1"
|
||||
serve.create_backend(backend_name, function)
|
||||
client.create_backend(backend_name, function)
|
||||
|
||||
endpoint_name = "delete_endpoint" + str(route)
|
||||
serve.create_endpoint(endpoint_name, backend=backend_name, route=route)
|
||||
serve.delete_endpoint(endpoint_name)
|
||||
client.create_endpoint(endpoint_name, backend=backend_name, route=route)
|
||||
client.delete_endpoint(endpoint_name)
|
||||
|
||||
# Check that we can reuse a deleted endpoint name and route.
|
||||
serve.create_endpoint(endpoint_name, backend=backend_name, route=route)
|
||||
client.create_endpoint(endpoint_name, backend=backend_name, route=route)
|
||||
|
||||
if route is not None:
|
||||
assert requests.get(
|
||||
"http://127.0.0.1:8000/delete-endpoint").text == "hello"
|
||||
else:
|
||||
handle = serve.get_handle(endpoint_name)
|
||||
handle = client.get_handle(endpoint_name)
|
||||
assert ray.get(handle.remote()) == "hello"
|
||||
|
||||
# Check that deleting the endpoint doesn't delete the backend.
|
||||
serve.delete_endpoint(endpoint_name)
|
||||
serve.create_endpoint(endpoint_name, backend=backend_name, route=route)
|
||||
client.delete_endpoint(endpoint_name)
|
||||
client.create_endpoint(endpoint_name, backend=backend_name, route=route)
|
||||
|
||||
if route is not None:
|
||||
assert requests.get(
|
||||
"http://127.0.0.1:8000/delete-endpoint").text == "hello"
|
||||
else:
|
||||
handle = serve.get_handle(endpoint_name)
|
||||
handle = client.get_handle(endpoint_name)
|
||||
assert ray.get(handle.remote()) == "hello"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("route", [None, "/shard"])
|
||||
def test_shard_key(serve_instance, route):
|
||||
client = serve_instance
|
||||
|
||||
# Create five backends that return different integers.
|
||||
num_backends = 5
|
||||
traffic_dict = {}
|
||||
@@ -490,11 +526,11 @@ def test_shard_key(serve_instance, route):
|
||||
|
||||
backend_name = "backend-split-" + str(i)
|
||||
traffic_dict[backend_name] = 1.0 / num_backends
|
||||
serve.create_backend(backend_name, function)
|
||||
client.create_backend(backend_name, function)
|
||||
|
||||
serve.create_endpoint(
|
||||
client.create_endpoint(
|
||||
"endpoint", backend=list(traffic_dict.keys())[0], route=route)
|
||||
serve.set_traffic("endpoint", traffic_dict)
|
||||
client.set_traffic("endpoint", traffic_dict)
|
||||
|
||||
def do_request(shard_key):
|
||||
if route is not None:
|
||||
@@ -502,7 +538,7 @@ def test_shard_key(serve_instance, route):
|
||||
headers = {"X-SERVE-SHARD-KEY": shard_key}
|
||||
result = requests.get(url, headers=headers).text
|
||||
else:
|
||||
handle = serve.get_handle("endpoint").options(shard_key=shard_key)
|
||||
handle = client.get_handle("endpoint").options(shard_key=shard_key)
|
||||
result = ray.get(handle.options(shard_key=shard_key).remote())
|
||||
return result
|
||||
|
||||
@@ -517,49 +553,47 @@ def test_shard_key(serve_instance, route):
|
||||
assert do_request(shard_key) == results[shard_key]
|
||||
|
||||
|
||||
def test_name():
|
||||
with pytest.raises(TypeError):
|
||||
serve.init(name=1)
|
||||
|
||||
def test_multiple_instances():
|
||||
route = "/api"
|
||||
backend = "backend"
|
||||
endpoint = "endpoint"
|
||||
|
||||
serve.init(name="cluster1", http_port=8001)
|
||||
client1 = serve.start(http_port=8001)
|
||||
|
||||
def function():
|
||||
return "hello1"
|
||||
|
||||
serve.create_backend(backend, function)
|
||||
serve.create_endpoint(endpoint, backend=backend, route=route)
|
||||
client1.create_backend(backend, function)
|
||||
client1.create_endpoint(endpoint, backend=backend, route=route)
|
||||
|
||||
assert requests.get("http://127.0.0.1:8001" + route).text == "hello1"
|
||||
|
||||
# Create a second cluster on port 8002. Create an endpoint and backend with
|
||||
# the same names and check that they don't collide.
|
||||
serve.init(name="cluster2", http_port=8002)
|
||||
client2 = serve.start(http_port=8002)
|
||||
|
||||
def function():
|
||||
return "hello2"
|
||||
|
||||
serve.create_backend(backend, function)
|
||||
serve.create_endpoint(endpoint, backend=backend, route=route)
|
||||
client2.create_backend(backend, function)
|
||||
client2.create_endpoint(endpoint, backend=backend, route=route)
|
||||
|
||||
assert requests.get("http://127.0.0.1:8001" + route).text == "hello1"
|
||||
assert requests.get("http://127.0.0.1:8002" + route).text == "hello2"
|
||||
|
||||
# Check that deleting the backend in the current cluster doesn't.
|
||||
serve.delete_endpoint(endpoint)
|
||||
serve.delete_backend(backend)
|
||||
client2.delete_endpoint(endpoint)
|
||||
client2.delete_backend(backend)
|
||||
assert requests.get("http://127.0.0.1:8001" + route).text == "hello1"
|
||||
|
||||
# Check that we can re-connect to the first cluster.
|
||||
serve.init(name="cluster1")
|
||||
serve.delete_endpoint(endpoint)
|
||||
serve.delete_backend(backend)
|
||||
# Check that the first client still works.
|
||||
client1.delete_endpoint(endpoint)
|
||||
client1.delete_backend(backend)
|
||||
|
||||
|
||||
def test_parallel_start(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
# Test the ability to start multiple replicas in parallel.
|
||||
# In the past, when Serve scale up a backend, it does so one by one and
|
||||
# wait for each replica to initialize. This test avoid this by preventing
|
||||
@@ -588,15 +622,17 @@ def test_parallel_start(serve_instance):
|
||||
def __call__(self, _):
|
||||
return "Ready"
|
||||
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"p:v0", LongStartingServable, config=BackendConfig(num_replicas=2))
|
||||
serve.create_endpoint("test-parallel", backend="p:v0")
|
||||
handle = serve.get_handle("test-parallel")
|
||||
client.create_endpoint("test-parallel", backend="p:v0")
|
||||
handle = client.get_handle("test-parallel")
|
||||
|
||||
ray.get(handle.remote(), timeout=10)
|
||||
|
||||
|
||||
def test_parallel_start_legacy(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
# Test the ability to start multiple replicas in parallel.
|
||||
# In the past, when Serve scale up a backend, it does so one by one and
|
||||
# wait for each replica to initialize. This test avoid this by preventing
|
||||
@@ -625,29 +661,29 @@ def test_parallel_start_legacy(serve_instance):
|
||||
def __call__(self, _):
|
||||
return "Ready"
|
||||
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"p:v0", LongStartingServable, config={"num_replicas": 2})
|
||||
serve.create_endpoint("test-parallel", backend="p:v0")
|
||||
handle = serve.get_handle("test-parallel")
|
||||
client.create_endpoint("test-parallel", backend="p:v0")
|
||||
handle = client.get_handle("test-parallel")
|
||||
|
||||
ray.get(handle.remote(), timeout=10)
|
||||
|
||||
|
||||
def test_list_endpoints(serve_instance):
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
def f():
|
||||
pass
|
||||
|
||||
serve.create_backend("backend", f)
|
||||
serve.create_backend("backend2", f)
|
||||
serve.create_backend("backend3", f)
|
||||
serve.create_endpoint(
|
||||
client.create_backend("backend", f)
|
||||
client.create_backend("backend2", f)
|
||||
client.create_backend("backend3", f)
|
||||
client.create_endpoint(
|
||||
"endpoint", backend="backend", route="/api", methods=["GET", "POST"])
|
||||
serve.create_endpoint("endpoint2", backend="backend2", methods=["POST"])
|
||||
serve.shadow_traffic("endpoint", "backend3", 0.5)
|
||||
client.create_endpoint("endpoint2", backend="backend2", methods=["POST"])
|
||||
client.shadow_traffic("endpoint", "backend3", 0.5)
|
||||
|
||||
endpoints = serve.list_endpoints()
|
||||
endpoints = client.list_endpoints()
|
||||
assert "endpoint" in endpoints
|
||||
assert endpoints["endpoint"] == {
|
||||
"route": "/api",
|
||||
@@ -670,92 +706,93 @@ def test_list_endpoints(serve_instance):
|
||||
"shadows": {}
|
||||
}
|
||||
|
||||
serve.delete_endpoint("endpoint")
|
||||
assert "endpoint2" in serve.list_endpoints()
|
||||
client.delete_endpoint("endpoint")
|
||||
assert "endpoint2" in client.list_endpoints()
|
||||
|
||||
serve.delete_endpoint("endpoint2")
|
||||
assert len(serve.list_endpoints()) == 0
|
||||
client.delete_endpoint("endpoint2")
|
||||
assert len(client.list_endpoints()) == 0
|
||||
|
||||
|
||||
def test_list_backends(serve_instance):
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
@serve.accept_batch
|
||||
def f():
|
||||
pass
|
||||
|
||||
serve.create_backend("backend", f, config=BackendConfig(max_batch_size=10))
|
||||
backends = serve.list_backends()
|
||||
client.create_backend(
|
||||
"backend", f, config=BackendConfig(max_batch_size=10))
|
||||
backends = client.list_backends()
|
||||
assert len(backends) == 1
|
||||
assert "backend" in backends
|
||||
assert backends["backend"]["max_batch_size"] == 10
|
||||
|
||||
serve.create_backend("backend2", f, config=BackendConfig(num_replicas=10))
|
||||
backends = serve.list_backends()
|
||||
client.create_backend("backend2", f, config=BackendConfig(num_replicas=10))
|
||||
backends = client.list_backends()
|
||||
assert len(backends) == 2
|
||||
assert backends["backend2"]["num_replicas"] == 10
|
||||
|
||||
serve.delete_backend("backend")
|
||||
backends = serve.list_backends()
|
||||
client.delete_backend("backend")
|
||||
backends = client.list_backends()
|
||||
assert len(backends) == 1
|
||||
assert "backend2" in backends
|
||||
|
||||
serve.delete_backend("backend2")
|
||||
assert len(serve.list_backends()) == 0
|
||||
client.delete_backend("backend2")
|
||||
assert len(client.list_backends()) == 0
|
||||
|
||||
|
||||
def test_list_backends_legacy(serve_instance):
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
@serve.accept_batch
|
||||
def f():
|
||||
pass
|
||||
|
||||
serve.create_backend("backend", f, config={"max_batch_size": 10})
|
||||
backends = serve.list_backends()
|
||||
client.create_backend("backend", f, config={"max_batch_size": 10})
|
||||
backends = client.list_backends()
|
||||
assert len(backends) == 1
|
||||
assert "backend" in backends
|
||||
assert backends["backend"]["max_batch_size"] == 10
|
||||
|
||||
serve.create_backend("backend2", f, config={"num_replicas": 10})
|
||||
backends = serve.list_backends()
|
||||
client.create_backend("backend2", f, config={"num_replicas": 10})
|
||||
backends = client.list_backends()
|
||||
assert len(backends) == 2
|
||||
assert backends["backend2"]["num_replicas"] == 10
|
||||
|
||||
serve.delete_backend("backend")
|
||||
backends = serve.list_backends()
|
||||
client.delete_backend("backend")
|
||||
backends = client.list_backends()
|
||||
assert len(backends) == 1
|
||||
assert "backend2" in backends
|
||||
|
||||
serve.delete_backend("backend2")
|
||||
assert len(serve.list_backends()) == 0
|
||||
client.delete_backend("backend2")
|
||||
assert len(client.list_backends()) == 0
|
||||
|
||||
|
||||
def test_endpoint_input_validation(serve_instance):
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
def f():
|
||||
pass
|
||||
|
||||
serve.create_backend("backend", f)
|
||||
client.create_backend("backend", f)
|
||||
with pytest.raises(TypeError):
|
||||
serve.create_endpoint("endpoint")
|
||||
client.create_endpoint("endpoint")
|
||||
with pytest.raises(TypeError):
|
||||
serve.create_endpoint("endpoint", route="/hello")
|
||||
client.create_endpoint("endpoint", route="/hello")
|
||||
with pytest.raises(TypeError):
|
||||
serve.create_endpoint("endpoint", backend=2)
|
||||
serve.create_endpoint("endpoint", backend="backend")
|
||||
client.create_endpoint("endpoint", backend=2)
|
||||
client.create_endpoint("endpoint", backend="backend")
|
||||
|
||||
|
||||
def test_create_infeasible_error(serve_instance):
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
def f():
|
||||
pass
|
||||
|
||||
# Non existent resource should be infeasible.
|
||||
with pytest.raises(RayServeException, match="Cannot scale backend"):
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"f:1",
|
||||
f,
|
||||
ray_actor_options={"resources": {
|
||||
@@ -765,7 +802,7 @@ def test_create_infeasible_error(serve_instance):
|
||||
# Even each replica might be feasible, the total might not be.
|
||||
current_cpus = int(ray.nodes()[0]["Resources"]["CPU"])
|
||||
with pytest.raises(RayServeException, match="Cannot scale backend"):
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"f:1",
|
||||
f,
|
||||
ray_actor_options={"resources": {
|
||||
@@ -774,19 +811,19 @@ def test_create_infeasible_error(serve_instance):
|
||||
config=BackendConfig(num_replicas=(current_cpus + 20)))
|
||||
|
||||
# No replica should be created!
|
||||
replicas = ray.get(serve.api.controller._list_replicas.remote("f1"))
|
||||
replicas = ray.get(client._controller._list_replicas.remote("f1"))
|
||||
assert len(replicas) == 0
|
||||
|
||||
|
||||
def test_create_infeasible_error_legacy(serve_instance):
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
def f():
|
||||
pass
|
||||
|
||||
# Non existent resource should be infeasible.
|
||||
with pytest.raises(RayServeException, match="Cannot scale backend"):
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"f:1",
|
||||
f,
|
||||
ray_actor_options={"resources": {
|
||||
@@ -796,7 +833,7 @@ def test_create_infeasible_error_legacy(serve_instance):
|
||||
# Even each replica might be feasible, the total might not be.
|
||||
current_cpus = int(ray.nodes()[0]["Resources"]["CPU"])
|
||||
with pytest.raises(RayServeException, match="Cannot scale backend"):
|
||||
serve.create_backend(
|
||||
client.create_backend(
|
||||
"f:1",
|
||||
f,
|
||||
ray_actor_options={"resources": {
|
||||
@@ -805,29 +842,29 @@ def test_create_infeasible_error_legacy(serve_instance):
|
||||
config={"num_replicas": current_cpus + 20})
|
||||
|
||||
# No replica should be created!
|
||||
replicas = ray.get(serve.api.controller._list_replicas.remote("f1"))
|
||||
replicas = ray.get(client._controller._list_replicas.remote("f1"))
|
||||
assert len(replicas) == 0
|
||||
|
||||
|
||||
def test_shutdown(serve_instance):
|
||||
def test_shutdown():
|
||||
def f():
|
||||
pass
|
||||
|
||||
instance_name = "shutdown"
|
||||
serve.init(name=instance_name, http_port=8003)
|
||||
serve.create_backend("backend", f)
|
||||
serve.create_endpoint("endpoint", backend="backend")
|
||||
client = serve.start(http_port=8003)
|
||||
client.create_backend("backend", f)
|
||||
client.create_endpoint("endpoint", backend="backend")
|
||||
|
||||
serve.shutdown()
|
||||
with pytest.raises(RayServeException, match="Please run serve.init"):
|
||||
serve.list_backends()
|
||||
client.shutdown()
|
||||
with pytest.raises(RayServeException):
|
||||
client.list_backends()
|
||||
|
||||
def check_dead():
|
||||
for actor_name in [
|
||||
constants.SERVE_CONTROLLER_NAME, constants.SERVE_PROXY_NAME
|
||||
client._controller_name,
|
||||
format_actor_name(SERVE_PROXY_NAME, client._controller_name)
|
||||
]:
|
||||
try:
|
||||
ray.get_actor(format_actor_name(actor_name, instance_name))
|
||||
ray.get_actor(actor_name)
|
||||
return False
|
||||
except ValueError:
|
||||
pass
|
||||
@@ -837,6 +874,8 @@ def test_shutdown(serve_instance):
|
||||
|
||||
|
||||
def test_shadow_traffic(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
@ray.remote
|
||||
class RequestCounter:
|
||||
def __init__(self):
|
||||
@@ -866,15 +905,15 @@ def test_shadow_traffic(serve_instance):
|
||||
ray.get(counter.record.remote("backend4"))
|
||||
return "oops"
|
||||
|
||||
serve.create_backend("backend1", f)
|
||||
serve.create_backend("backend2", f_shadow_1)
|
||||
serve.create_backend("backend3", f_shadow_2)
|
||||
serve.create_backend("backend4", f_shadow_3)
|
||||
client.create_backend("backend1", f)
|
||||
client.create_backend("backend2", f_shadow_1)
|
||||
client.create_backend("backend3", f_shadow_2)
|
||||
client.create_backend("backend4", f_shadow_3)
|
||||
|
||||
serve.create_endpoint("endpoint", backend="backend1", route="/api")
|
||||
serve.shadow_traffic("endpoint", "backend2", 1.0)
|
||||
serve.shadow_traffic("endpoint", "backend3", 0.5)
|
||||
serve.shadow_traffic("endpoint", "backend4", 0.1)
|
||||
client.create_endpoint("endpoint", backend="backend1", route="/api")
|
||||
client.shadow_traffic("endpoint", "backend2", 1.0)
|
||||
client.shadow_traffic("endpoint", "backend3", 0.5)
|
||||
client.shadow_traffic("endpoint", "backend4", 0.1)
|
||||
|
||||
start = time.time()
|
||||
num_requests = 100
|
||||
@@ -897,13 +936,47 @@ def test_shadow_traffic(serve_instance):
|
||||
wait_for_condition(check_requests)
|
||||
|
||||
|
||||
def test_connect(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
# Check that you can have multiple clients to the same detached instance.
|
||||
client2 = serve.connect()
|
||||
assert client._controller_name == client2._controller_name
|
||||
|
||||
# Check that you can have detached and non-detached instances.
|
||||
client3 = serve.start(http_port=8004)
|
||||
assert client3._controller_name != client._controller_name
|
||||
|
||||
# Check that you can call serve.connect() from within a backend for both
|
||||
# detached and non-detached instances.
|
||||
|
||||
def connect_in_backend():
|
||||
client = serve.connect()
|
||||
client.create_backend("backend-ception", connect_in_backend)
|
||||
return client._controller_name
|
||||
|
||||
client.create_backend("connect_in_backend", connect_in_backend)
|
||||
client.create_endpoint("endpoint", backend="connect_in_backend")
|
||||
handle = client.get_handle("endpoint")
|
||||
assert ray.get(handle.remote()) == client._controller_name
|
||||
assert "backend-ception" in client.list_backends()
|
||||
|
||||
client3.create_backend("connect_in_backend", connect_in_backend)
|
||||
client3.create_endpoint("endpoint", backend="connect_in_backend")
|
||||
handle = client3.get_handle("endpoint")
|
||||
assert ray.get(handle.remote()) == client3._controller_name
|
||||
assert "backend-ception" in client3.list_backends()
|
||||
|
||||
|
||||
def test_serve_metrics(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
@serve.accept_batch
|
||||
def batcher(flask_requests):
|
||||
return ["hello"] * len(flask_requests)
|
||||
|
||||
serve.create_backend("metrics", batcher)
|
||||
serve.create_endpoint("metrics", backend="metrics", route="/metrics")
|
||||
client.create_backend("metrics", batcher)
|
||||
client.create_endpoint("metrics", backend="metrics", route="/metrics")
|
||||
# send 10 concurrent requests
|
||||
url = "http://127.0.0.1:8000/metrics"
|
||||
ray.get([block_until_http_ready.remote(url) for _ in range(10)])
|
||||
|
||||
@@ -19,7 +19,8 @@ pytestmark = pytest.mark.asyncio
|
||||
def setup_worker(name,
|
||||
func_or_class,
|
||||
init_args=None,
|
||||
backend_config=BackendConfig()):
|
||||
backend_config=BackendConfig(),
|
||||
controller_name=""):
|
||||
if init_args is None:
|
||||
init_args = ()
|
||||
|
||||
@@ -27,7 +28,8 @@ def setup_worker(name,
|
||||
class WorkerActor:
|
||||
def __init__(self):
|
||||
self.worker = create_backend_worker(func_or_class)(
|
||||
name, name + ":tag", init_args, backend_config)
|
||||
name, name + ":tag", init_args, backend_config,
|
||||
controller_name)
|
||||
|
||||
def ready(self):
|
||||
pass
|
||||
@@ -50,7 +52,7 @@ async def test_runner_wraps_error():
|
||||
|
||||
async def test_runner_actor(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("")
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
|
||||
def echo(flask_request, i=None):
|
||||
return i
|
||||
@@ -72,7 +74,7 @@ async def test_runner_actor(serve_instance):
|
||||
|
||||
async def test_ray_serve_mixin(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("")
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
|
||||
CONSUMER_NAME = "runner-cls"
|
||||
PRODUCER_NAME = "prod-cls"
|
||||
@@ -98,7 +100,7 @@ async def test_ray_serve_mixin(serve_instance):
|
||||
|
||||
async def test_task_runner_check_context(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("")
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
|
||||
def echo(flask_request, i=None):
|
||||
# Accessing the flask_request without web context should throw.
|
||||
@@ -120,7 +122,7 @@ async def test_task_runner_check_context(serve_instance):
|
||||
|
||||
async def test_task_runner_custom_method_single(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("")
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
|
||||
class NonBatcher:
|
||||
def a(self, _):
|
||||
@@ -155,7 +157,7 @@ async def test_task_runner_custom_method_single(serve_instance):
|
||||
|
||||
async def test_task_runner_custom_method_batch(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("")
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
|
||||
@serve.accept_batch
|
||||
class Batcher:
|
||||
@@ -220,7 +222,7 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||
|
||||
async def test_task_runner_perform_batch(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("")
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
|
||||
def batcher(*args, **kwargs):
|
||||
return [serve.context.batch_size] * serve.context.batch_size
|
||||
@@ -250,7 +252,7 @@ async def test_task_runner_perform_batch(serve_instance):
|
||||
|
||||
async def test_task_runner_perform_async(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("")
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
|
||||
@ray.remote
|
||||
class Barrier:
|
||||
|
||||
@@ -21,13 +21,13 @@ def request_with_retries(endpoint, timeout=30):
|
||||
|
||||
|
||||
def test_controller_failure(serve_instance):
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
def function():
|
||||
return "hello1"
|
||||
|
||||
serve.create_backend("controller_failure:v1", function)
|
||||
serve.create_endpoint(
|
||||
client.create_backend("controller_failure:v1", function)
|
||||
client.create_endpoint(
|
||||
"controller_failure",
|
||||
backend="controller_failure:v1",
|
||||
route="/controller_failure")
|
||||
@@ -39,7 +39,7 @@ def test_controller_failure(serve_instance):
|
||||
response = request_with_retries("/controller_failure", timeout=30)
|
||||
assert response.text == "hello1"
|
||||
|
||||
ray.kill(serve.api._get_controller(), no_restart=False)
|
||||
ray.kill(client._controller, no_restart=False)
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/controller_failure", timeout=30)
|
||||
@@ -48,10 +48,10 @@ def test_controller_failure(serve_instance):
|
||||
def function():
|
||||
return "hello2"
|
||||
|
||||
ray.kill(serve.api._get_controller(), no_restart=False)
|
||||
ray.kill(client._controller, no_restart=False)
|
||||
|
||||
serve.create_backend("controller_failure:v2", function)
|
||||
serve.set_traffic("controller_failure", {"controller_failure:v2": 1.0})
|
||||
client.create_backend("controller_failure:v2", function)
|
||||
client.set_traffic("controller_failure", {"controller_failure:v2": 1.0})
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/controller_failure", timeout=30)
|
||||
@@ -60,14 +60,14 @@ def test_controller_failure(serve_instance):
|
||||
def function():
|
||||
return "hello3"
|
||||
|
||||
ray.kill(serve.api._get_controller(), no_restart=False)
|
||||
serve.create_backend("controller_failure_2", function)
|
||||
ray.kill(serve.api._get_controller(), no_restart=False)
|
||||
serve.create_endpoint(
|
||||
ray.kill(client._controller, no_restart=False)
|
||||
client.create_backend("controller_failure_2", function)
|
||||
ray.kill(client._controller, no_restart=False)
|
||||
client.create_endpoint(
|
||||
"controller_failure_2",
|
||||
backend="controller_failure_2",
|
||||
route="/controller_failure_2")
|
||||
ray.kill(serve.api._get_controller(), no_restart=False)
|
||||
ray.kill(client._controller, no_restart=False)
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/controller_failure", timeout=30)
|
||||
@@ -76,20 +76,20 @@ def test_controller_failure(serve_instance):
|
||||
assert response.text == "hello3"
|
||||
|
||||
|
||||
def _kill_routers():
|
||||
routers = ray.get(serve.api._get_controller().get_routers.remote())
|
||||
def _kill_routers(client):
|
||||
routers = ray.get(client._controller.get_routers.remote())
|
||||
for router in routers.values():
|
||||
ray.kill(router, no_restart=False)
|
||||
|
||||
|
||||
def test_http_proxy_failure(serve_instance):
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
def function():
|
||||
return "hello1"
|
||||
|
||||
serve.create_backend("proxy_failure:v1", function)
|
||||
serve.create_endpoint(
|
||||
client.create_backend("proxy_failure:v1", function)
|
||||
client.create_endpoint(
|
||||
"proxy_failure", backend="proxy_failure:v1", route="/proxy_failure")
|
||||
|
||||
assert request_with_retries("/proxy_failure", timeout=1.0).text == "hello1"
|
||||
@@ -98,21 +98,21 @@ def test_http_proxy_failure(serve_instance):
|
||||
response = request_with_retries("/proxy_failure", timeout=30)
|
||||
assert response.text == "hello1"
|
||||
|
||||
_kill_routers()
|
||||
_kill_routers(client)
|
||||
|
||||
def function():
|
||||
return "hello2"
|
||||
|
||||
serve.create_backend("proxy_failure:v2", function)
|
||||
serve.set_traffic("proxy_failure", {"proxy_failure:v2": 1.0})
|
||||
client.create_backend("proxy_failure:v2", function)
|
||||
client.set_traffic("proxy_failure", {"proxy_failure:v2": 1.0})
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/proxy_failure", timeout=30)
|
||||
assert response.text == "hello2"
|
||||
|
||||
|
||||
def _get_worker_handles(backend):
|
||||
controller = serve.api._get_controller()
|
||||
def _get_worker_handles(client, backend):
|
||||
controller = client._controller
|
||||
backend_dict = ray.get(controller.get_all_worker_handles.remote())
|
||||
|
||||
return list(backend_dict[backend].values())
|
||||
@@ -121,21 +121,21 @@ def _get_worker_handles(backend):
|
||||
# Test that a worker dying unexpectedly causes it to restart and continue
|
||||
# serving requests.
|
||||
def test_worker_restart(serve_instance):
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
class Worker1:
|
||||
def __call__(self):
|
||||
return os.getpid()
|
||||
|
||||
serve.create_backend("worker_failure:v1", Worker1)
|
||||
serve.create_endpoint(
|
||||
client.create_backend("worker_failure:v1", Worker1)
|
||||
client.create_endpoint(
|
||||
"worker_failure", backend="worker_failure:v1", route="/worker_failure")
|
||||
|
||||
# Get the PID of the worker.
|
||||
old_pid = request_with_retries("/worker_failure", timeout=1).text
|
||||
|
||||
# Kill the worker.
|
||||
handles = _get_worker_handles("worker_failure:v1")
|
||||
handles = _get_worker_handles(client, "worker_failure:v1")
|
||||
assert len(handles) == 1
|
||||
ray.kill(handles[0], no_restart=False)
|
||||
|
||||
@@ -152,8 +152,7 @@ def test_worker_restart(serve_instance):
|
||||
# Test that if there are multiple replicas for a worker and one dies
|
||||
# unexpectedly, the others continue to serve requests.
|
||||
def test_worker_replica_failure(serve_instance):
|
||||
serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
class Worker:
|
||||
# Assumes that two replicas are started. Will hang forever in the
|
||||
@@ -182,10 +181,10 @@ def test_worker_replica_failure(serve_instance):
|
||||
|
||||
temp_path = os.path.join(tempfile.gettempdir(),
|
||||
serve.utils.get_random_letters())
|
||||
serve.create_backend("replica_failure", Worker, temp_path)
|
||||
serve.update_backend_config(
|
||||
client.create_backend("replica_failure", Worker, temp_path)
|
||||
client.update_backend_config(
|
||||
"replica_failure", BackendConfig(num_replicas=2))
|
||||
serve.create_endpoint(
|
||||
client.create_endpoint(
|
||||
"replica_failure", backend="replica_failure", route="/replica_failure")
|
||||
|
||||
# Wait until both replicas have been started.
|
||||
@@ -195,7 +194,7 @@ def test_worker_replica_failure(serve_instance):
|
||||
time.sleep(0.1)
|
||||
|
||||
# Kill one of the replicas.
|
||||
handles = _get_worker_handles("replica_failure")
|
||||
handles = _get_worker_handles(client, "replica_failure")
|
||||
assert len(handles) == 2
|
||||
ray.kill(handles[0], no_restart=False)
|
||||
|
||||
@@ -212,12 +211,12 @@ def test_worker_replica_failure(serve_instance):
|
||||
|
||||
|
||||
def test_create_backend_idempotent(serve_instance):
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
def f():
|
||||
return "hello"
|
||||
|
||||
controller = serve.api._get_controller()
|
||||
controller = client._controller
|
||||
|
||||
replica_config = ReplicaConfig(f)
|
||||
backend_config = BackendConfig(num_replicas=1)
|
||||
@@ -228,21 +227,21 @@ def test_create_backend_idempotent(serve_instance):
|
||||
replica_config))
|
||||
|
||||
assert len(ray.get(controller.get_all_backends.remote())) == 1
|
||||
serve.create_endpoint(
|
||||
client.create_endpoint(
|
||||
"my_endpoint", backend="my_backend", route="/my_route")
|
||||
|
||||
assert requests.get("http://127.0.0.1:8000/my_route").text == "hello"
|
||||
|
||||
|
||||
def test_create_endpoint_idempotent(serve_instance):
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
def f():
|
||||
return "hello"
|
||||
|
||||
serve.create_backend("my_backend", f)
|
||||
client.create_backend("my_backend", f)
|
||||
|
||||
controller = serve.api._get_controller()
|
||||
controller = client._controller
|
||||
|
||||
for i in range(10):
|
||||
ray.get(
|
||||
|
||||
@@ -5,7 +5,7 @@ import requests
|
||||
|
||||
|
||||
def test_handle_in_endpoint(serve_instance):
|
||||
serve.init()
|
||||
client = serve_instance
|
||||
|
||||
class Endpoint1:
|
||||
def __call__(self, flask_request):
|
||||
@@ -13,20 +13,21 @@ def test_handle_in_endpoint(serve_instance):
|
||||
|
||||
class Endpoint2:
|
||||
def __init__(self):
|
||||
self.handle = serve.get_handle("endpoint1", missing_ok=True)
|
||||
client = serve.connect()
|
||||
self.handle = client.get_handle("endpoint1")
|
||||
|
||||
def __call__(self):
|
||||
return ray.get(self.handle.remote())
|
||||
|
||||
serve.create_backend("endpoint1:v0", Endpoint1)
|
||||
serve.create_endpoint(
|
||||
client.create_backend("endpoint1:v0", Endpoint1)
|
||||
client.create_endpoint(
|
||||
"endpoint1",
|
||||
backend="endpoint1:v0",
|
||||
route="/endpoint1",
|
||||
methods=["GET", "POST"])
|
||||
|
||||
serve.create_backend("endpoint2:v0", Endpoint2)
|
||||
serve.create_endpoint(
|
||||
client.create_backend("endpoint2:v0", Endpoint2)
|
||||
client.create_endpoint(
|
||||
"endpoint2",
|
||||
backend="endpoint2:v0",
|
||||
route="/endpoint2",
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import requests
|
||||
import sys
|
||||
|
||||
from ray import serve
|
||||
|
||||
|
||||
def test_nonblocking():
|
||||
serve.init()
|
||||
|
||||
def function(flask_request):
|
||||
return {"method": flask_request.method}
|
||||
|
||||
serve.create_backend("nonblocking:v1", function)
|
||||
serve.create_endpoint(
|
||||
"nonblocking", backend="nonblocking:v1", route="/nonblocking")
|
||||
|
||||
resp = requests.get("http://127.0.0.1:8000/nonblocking").json()["method"]
|
||||
assert resp == "GET"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
@@ -1,25 +1,26 @@
|
||||
import ray
|
||||
import ray.test_utils
|
||||
from ray import serve
|
||||
|
||||
|
||||
def test_new_driver(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
script = """
|
||||
import ray
|
||||
ray.init(address="{}")
|
||||
|
||||
from ray import serve
|
||||
serve.init()
|
||||
client = serve.connect()
|
||||
|
||||
def driver(flask_request):
|
||||
return "OK!"
|
||||
|
||||
serve.create_backend("driver", driver)
|
||||
serve.create_endpoint("driver", backend="driver", route="/driver")
|
||||
client.create_backend("driver", driver)
|
||||
client.create_endpoint("driver", backend="driver", route="/driver")
|
||||
""".format(ray.worker._global_node._redis_address)
|
||||
ray.test_utils.run_string_as_driver(script)
|
||||
|
||||
handle = serve.get_handle("driver")
|
||||
handle = client.get_handle("driver")
|
||||
assert ray.get(handle.remote()) == "OK!"
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ from ray import serve
|
||||
|
||||
|
||||
def test_np_in_composed_model(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
# https://github.com/ray-project/ray/issues/9441
|
||||
# AttributeError: 'bytes' object has no attribute 'readonly'
|
||||
# in cloudpickle _from_numpy_buffer
|
||||
@@ -14,17 +16,18 @@ def test_np_in_composed_model(serve_instance):
|
||||
|
||||
class ComposedModel:
|
||||
def __init__(self):
|
||||
self.model = serve.get_handle("sum_model")
|
||||
client = serve.connect()
|
||||
self.model = client.get_handle("sum_model")
|
||||
|
||||
async def __call__(self, _request):
|
||||
data = np.ones((10, 10))
|
||||
result = await self.model.remote(data=data)
|
||||
return result
|
||||
|
||||
serve.create_backend("sum_model", sum_model)
|
||||
serve.create_endpoint("sum_model", backend="sum_model")
|
||||
serve.create_backend("model", ComposedModel)
|
||||
serve.create_endpoint(
|
||||
client.create_backend("sum_model", sum_model)
|
||||
client.create_endpoint("sum_model", backend="sum_model")
|
||||
client.create_backend("model", ComposedModel)
|
||||
client.create_endpoint(
|
||||
"model", backend="model", route="/model", methods=["GET"])
|
||||
|
||||
result = requests.get("http://127.0.0.1:8000/model")
|
||||
|
||||
@@ -49,7 +49,7 @@ def task_runner_mock_actor():
|
||||
|
||||
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("")
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
|
||||
q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0}))
|
||||
q.add_new_worker.remote("backend-single-prod", "replica-1",
|
||||
@@ -67,7 +67,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
|
||||
async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("")
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1}))
|
||||
await q.add_new_worker.remote("backend-alter", "replica-1",
|
||||
@@ -86,7 +86,7 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
|
||||
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("")
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
|
||||
await q.set_traffic.remote(
|
||||
"svc", TrafficPolicy({
|
||||
@@ -116,7 +116,7 @@ async def test_queue_remove_replicas(serve_instance):
|
||||
|
||||
temp_actor = mock_task_runner()
|
||||
q = ray.remote(TestRouter).remote()
|
||||
await q.setup.remote("")
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor)
|
||||
await q.remove_worker.remote("backend-remove", "replica-1")
|
||||
assert ray.get(q.worker_queue_size.remote("backend")) == 0
|
||||
@@ -124,7 +124,7 @@ async def test_queue_remove_replicas(serve_instance):
|
||||
|
||||
async def test_shard_key(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("")
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
|
||||
num_backends = 5
|
||||
traffic_dict = {}
|
||||
@@ -179,7 +179,7 @@ async def test_router_use_max_concurrency(serve_instance):
|
||||
|
||||
worker = MockWorker.remote()
|
||||
q = ray.remote(VisibleRouter).remote()
|
||||
await q.setup.remote("")
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
backend_name = "max-concurrent-test"
|
||||
config = BackendConfig(max_concurrent_queries=1)
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({backend_name: 1.0}))
|
||||
|
||||
@@ -12,7 +12,8 @@ import ray
|
||||
from ray import serve
|
||||
from ray.cluster_utils import Cluster
|
||||
from ray.serve.constants import SERVE_PROXY_NAME
|
||||
from ray.serve.utils import block_until_http_ready
|
||||
from ray.serve.utils import (block_until_http_ready, get_all_node_ids,
|
||||
format_actor_name)
|
||||
from ray.test_utils import wait_for_condition
|
||||
from ray.services import new_port
|
||||
|
||||
@@ -29,16 +30,24 @@ def test_multiple_routers():
|
||||
ray.init(head_node.address)
|
||||
node_ids = ray.state.node_ids()
|
||||
assert len(node_ids) == 2
|
||||
serve.init(http_port=8005)
|
||||
client = serve.start(http_port=8005) # noqa: F841
|
||||
|
||||
def actor_name(index):
|
||||
return SERVE_PROXY_NAME + "-{}-{}".format(node_ids[0], index)
|
||||
def get_proxy_names():
|
||||
proxy_names = []
|
||||
for node_id, _ in get_all_node_ids():
|
||||
proxy_names.append(
|
||||
format_actor_name(SERVE_PROXY_NAME, client._controller_name,
|
||||
node_id))
|
||||
return proxy_names
|
||||
|
||||
wait_for_condition(lambda: len(get_proxy_names()) == 2)
|
||||
proxy_names = get_proxy_names()
|
||||
|
||||
# Two actors should be started.
|
||||
def get_first_two_actors():
|
||||
try:
|
||||
ray.get_actor(actor_name(0))
|
||||
ray.get_actor(actor_name(1))
|
||||
ray.get_actor(proxy_names[0])
|
||||
ray.get_actor(proxy_names[1])
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
@@ -49,18 +58,22 @@ def test_multiple_routers():
|
||||
ray.get(block_until_http_ready.remote("http://127.0.0.1:8005/-/routes"))
|
||||
|
||||
# Kill one of the servers, the HTTP server should still function.
|
||||
ray.kill(ray.get_actor(actor_name(0)), no_restart=True)
|
||||
ray.kill(ray.get_actor(get_proxy_names()[0]), no_restart=True)
|
||||
ray.get(block_until_http_ready.remote("http://127.0.0.1:8005/-/routes"))
|
||||
|
||||
# Add a new node to the cluster. This should trigger a new router to get
|
||||
# started.
|
||||
new_node = cluster.add_node()
|
||||
|
||||
wait_for_condition(lambda: len(get_proxy_names()) == 3)
|
||||
third_proxy = get_proxy_names()[2]
|
||||
|
||||
def get_third_actor():
|
||||
try:
|
||||
ray.get_actor(actor_name(2))
|
||||
ray.get_actor(third_proxy)
|
||||
return True
|
||||
except ValueError:
|
||||
# IndexErrors covers when cluster resources aren't updated yet.
|
||||
except (IndexError, ValueError):
|
||||
return False
|
||||
|
||||
wait_for_condition(get_third_actor)
|
||||
@@ -71,7 +84,7 @@ def test_multiple_routers():
|
||||
|
||||
def third_actor_removed():
|
||||
try:
|
||||
ray.get_actor(actor_name(2))
|
||||
ray.get_actor(third_proxy)
|
||||
return False
|
||||
except ValueError:
|
||||
return True
|
||||
@@ -90,7 +103,7 @@ def test_middleware():
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
port = new_port()
|
||||
serve.init(
|
||||
serve.start(
|
||||
http_port=port,
|
||||
http_middlewares=[
|
||||
Middleware(
|
||||
@@ -112,6 +125,8 @@ def test_middleware():
|
||||
resp = requests.get(f"{root}/-/routes", headers=headers)
|
||||
assert resp.headers["access-control-allow-origin"] == "*"
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
||||
@@ -105,11 +105,11 @@ def get_random_letters(length=6):
|
||||
return "".join(random.choices(string.ascii_letters, k=length))
|
||||
|
||||
|
||||
def format_actor_name(actor_name, instance_name=None, *modifiers):
|
||||
if instance_name is None:
|
||||
def format_actor_name(actor_name, controller_name=None, *modifiers):
|
||||
if controller_name is None:
|
||||
name = actor_name
|
||||
else:
|
||||
name = "{}:{}".format(instance_name, actor_name)
|
||||
name = "{}:{}".format(controller_name, actor_name)
|
||||
|
||||
for modifier in modifiers:
|
||||
name += "-{}".format(modifier)
|
||||
|
||||
Reference in New Issue
Block a user