[serve] Serve client refactor (#10409)

This commit is contained in:
Edward Oakes
2020-09-04 12:02:23 -05:00
committed by GitHub
parent 2e49e22f21
commit 786f12edfd
45 changed files with 1017 additions and 929 deletions
+2 -12
View File
@@ -6,8 +6,7 @@ py_library(
)
serve_tests_srcs = glob(["tests/*.py"],
exclude=["tests/test_nonblocking.py",
"tests/test_controller_crashes.py",
exclude=["tests/test_controller_crashes.py",
"tests/test_serve.py",
])
@@ -115,8 +114,7 @@ py_test(
# srcs = glob(["tests/test_controller_crashes.py",
# "tests/test_api.py",
# "tests/test_failure.py"],
# exclude=["tests/test_nonblocking.py",
# "tests/test_serve.py"]),
# exclude=["tests/test_serve.py"]),
# )
py_test(
@@ -127,14 +125,6 @@ py_test(
deps = [":serve_lib"]
)
py_test(
name = "test_nonblocking",
size = "small",
srcs = glob(["tests/test_nonblocking.py"]),
tags = ["exclusive"],
deps = [":serve_lib"],
)
# Make sure the example showing in doc is tested
py_test(
name = "quickstart_class",
+7 -9
View File
@@ -1,12 +1,10 @@
from ray.serve.api import (init, create_backend, delete_backend,
create_endpoint, delete_endpoint, set_traffic,
shadow_traffic, get_handle, update_backend_config,
get_backend_config, accept_batch, list_backends,
list_endpoints, shutdown) # noqa: E402
from ray.serve.api import (accept_batch, Client, connect, start) # noqa: F401
from ray.serve.config import BackendConfig
__all__ = [
"init", "create_backend", "delete_backend", "create_endpoint",
"delete_endpoint", "set_traffic", "shadow_traffic", "get_handle",
"update_backend_config", "get_backend_config", "accept_batch",
"list_backends", "list_endpoints", "shutdown", "BackendConfig"
"accept_batch",
"BackendConfig",
"connect"
"Client",
"start",
]
+386 -319
View File
@@ -1,3 +1,4 @@
import atexit
from functools import wraps
import ray
@@ -5,99 +6,383 @@ from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT)
from ray.serve.controller import ServeController
from ray.serve.handle import RayServeHandle
from ray.serve.utils import (block_until_http_ready, format_actor_name)
from ray.serve.utils import (block_until_http_ready, format_actor_name,
get_random_letters)
from ray.serve.exceptions import RayServeException
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
from ray.actor import ActorHandle
from typing import Any, Callable, Dict, List, Optional, Type, Union
controller = None
_INTERNAL_CONTROLLER_NAME = None
def _get_controller() -> ActorHandle:
"""Used for internal purpose because using just import serve.global_state
will always reference the original None object.
"""
global controller
if controller is None:
raise RayServeException("Please run serve.init to initialize or "
"connect to existing ray serve cluster.")
return controller
def _set_internal_controller_name(name):
global _INTERNAL_CONTROLLER_NAME
_INTERNAL_CONTROLLER_NAME = name
def _ensure_connected(f: Callable) -> Callable:
@wraps(f)
def check(*args, **kwargs):
_get_controller()
return f(*args, **kwargs)
def check(self, *args, **kwargs):
if self._shutdown:
raise RayServeException("Client has already been shut down.")
return f(self, *args, **kwargs)
return check
def accept_batch(f: Callable) -> Callable:
"""Annotation to mark a serving function that batch is accepted.
class Client:
def __init__(self,
controller: ActorHandle,
controller_name: str,
detached: bool = False):
self._controller = controller
self._controller_name = controller_name
self._detached = detached
self._shutdown = False
This annotation need to be used to mark a function expect all arguments
to be passed into a list.
# NOTE(edoakes): Need this because the shutdown order isn't guaranteed
# when the interpreter is exiting so we can't rely on __del__ (it
# throws a nasty stacktrace).
if not self._detached:
Example:
def shutdown_serve_client():
self.shutdown()
>>> @serve.accept_batch
def serving_func(flask_request):
assert isinstance(flask_request, list)
...
atexit.register(shutdown_serve_client)
>>> class ServingActor:
@serve.accept_batch
def __call__(self, *, python_arg=None):
assert isinstance(python_arg, list)
"""
f._serve_accept_batch = True
return f
def __del__(self):
if not self._detached:
self.shutdown()
def shutdown(self) -> None:
"""Completely shut down the connected Serve instance.
Shuts down all processes and deletes all state associated with the
instance.
"""
if not self._shutdown:
ray.get(self._controller.shutdown.remote())
ray.kill(self._controller, no_restart=True)
self._shutdown = True
@_ensure_connected
def create_endpoint(self,
endpoint_name: str,
*,
backend: str = None,
route: Optional[str] = None,
methods: List[str] = ["GET"]) -> None:
"""Create a service endpoint given route_expression.
Args:
endpoint_name (str): A name to associate to with the endpoint.
backend (str, required): The backend that will serve requests to
this endpoint. To change this or split traffic among backends,
use `serve.set_traffic`.
route (str, optional): A string begin with "/". HTTP server will
use the string to match the path.
methods(List[str], optional): The HTTP methods that are valid for
this endpoint.
"""
if backend is None:
raise TypeError("backend must be specified when creating "
"an endpoint.")
elif not isinstance(backend, str):
raise TypeError("backend must be a string, got {}.".format(
type(backend)))
if route is not None:
if not isinstance(route, str) or not route.startswith("/"):
raise TypeError("route must be a string starting with '/'.")
if not isinstance(methods, list):
raise TypeError(
"methods must be a list of strings, but got type {}".format(
type(methods)))
endpoints = self.list_endpoints()
if endpoint_name in endpoints:
methods_old = endpoints[endpoint_name]["methods"]
route_old = endpoints[endpoint_name]["route"]
if methods_old.sort() == methods.sort() and route_old == route:
raise ValueError(
"Route '{}' is already registered to endpoint '{}' "
"with methods '{}'. To set the backend for this "
"endpoint, please use serve.set_traffic().".format(
route, endpoint_name, methods))
upper_methods = []
for method in methods:
if not isinstance(method, str):
raise TypeError(
"methods must be a list of strings, but contained "
"an element of type {}".format(type(method)))
upper_methods.append(method.upper())
ray.get(
self._controller.create_endpoint.remote(
endpoint_name, {backend: 1.0}, route, upper_methods))
@_ensure_connected
def delete_endpoint(self, endpoint: str) -> None:
"""Delete the given endpoint.
Does not delete any associated backends.
"""
ray.get(self._controller.delete_endpoint.remote(endpoint))
@_ensure_connected
def list_endpoints(self) -> Dict[str, Dict[str, Any]]:
"""Returns a dictionary of all registered endpoints.
The dictionary keys are endpoint names and values are dictionaries
of the form: {"methods": List[str], "traffic": Dict[str, float]}.
"""
return ray.get(self._controller.get_all_endpoints.remote())
@_ensure_connected
def update_backend_config(
self, backend_tag: str,
config_options: Union[BackendConfig, Dict[str, Any]]) -> None:
"""Update a backend configuration for a backend tag.
Keys not specified in the passed will be left unchanged.
Args:
backend_tag(str): A registered backend.
config_options(dict, serve.BackendConfig): Backend config options
to update. Either a BackendConfig object or a dict mapping
strings to values for the following supported options:
- "num_replicas": number of worker processes to start up that
will handle requests to this backend.
- "max_batch_size": the maximum number of requests that will
be processed in one batch by this backend.
- "batch_wait_timeout": time in seconds that backend replicas
will wait for a full batch of requests before
processing a partial batch.
- "max_concurrent_queries": the maximum number of queries
that will be sent to a replica of this backend
without receiving a response.
"""
if not isinstance(config_options, (BackendConfig, dict)):
raise TypeError(
"config_options must be a BackendConfig or dictionary.")
ray.get(
self._controller.update_backend_config.remote(
backend_tag, config_options))
@_ensure_connected
def get_backend_config(self, backend_tag: str) -> BackendConfig:
"""Get the backend configuration for a backend tag.
Args:
backend_tag(str): A registered backend.
"""
return ray.get(self._controller.get_backend_config.remote(backend_tag))
@_ensure_connected
def create_backend(
self,
backend_tag: str,
func_or_class: Union[Callable, Type[Callable]],
*actor_init_args: Any,
ray_actor_options: Optional[Dict] = None,
config: Optional[Union[BackendConfig, Dict[str, Any]]] = None
) -> None:
"""Create a backend with the provided tag.
The backend will serve requests with func_or_class.
Args:
backend_tag (str): a unique tag assign to identify this backend.
func_or_class (callable, class): a function or a class implementing
__call__.
actor_init_args (optional): the arguments to pass to the class.
initialization method.
ray_actor_options (optional): options to be passed into the
@ray.remote decorator for the backend actor.
config (dict, serve.BackendConfig, optional): configuration options
for this backend. Either a BackendConfig, or a dictionary
mapping strings to values for the following supported options:
- "num_replicas": number of worker processes to start up that
will handle requests to this backend.
- "max_batch_size": the maximum number of requests that will
be processed in one batch by this backend.
- "batch_wait_timeout": time in seconds that backend replicas
will wait for a full batch of requests before processing a
partial batch.
- "max_concurrent_queries": the maximum number of queries that
will be sent to a replica of this backend without receiving a
response.
"""
if backend_tag in self.list_backends():
raise ValueError(
"Cannot create backend. "
"Backend '{}' is already registered.".format(backend_tag))
if config is None:
config = {}
replica_config = ReplicaConfig(
func_or_class,
*actor_init_args,
ray_actor_options=ray_actor_options)
metadata = BackendMetadata(
accepts_batches=replica_config.accepts_batches,
is_blocking=replica_config.is_blocking)
if isinstance(config, dict):
backend_config = BackendConfig.parse_obj({
**config, "internal_metadata": metadata
})
elif isinstance(config, BackendConfig):
backend_config = config.copy(
update={"internal_metadata": metadata})
else:
raise TypeError("config must be a BackendConfig or a dictionary.")
backend_config._validate_complete()
ray.get(
self._controller.create_backend.remote(backend_tag, backend_config,
replica_config))
@_ensure_connected
def list_backends(self) -> Dict[str, Dict[str, Any]]:
"""Returns a dictionary of all registered backends.
Dictionary maps backend tags to backend configs.
"""
return ray.get(self._controller.get_all_backends.remote())
@_ensure_connected
def delete_backend(self, backend_tag: str) -> None:
"""Delete the given backend.
The backend must not currently be used by any endpoints.
"""
ray.get(self._controller.delete_backend.remote(backend_tag))
@_ensure_connected
def set_traffic(self, endpoint_name: str,
traffic_policy_dictionary: Dict[str, float]) -> None:
"""Associate a service endpoint with traffic policy.
Example:
>>> serve.set_traffic("service-name", {
"backend:v1": 0.5,
"backend:v2": 0.5
})
Args:
endpoint_name (str): A registered service endpoint.
traffic_policy_dictionary (dict): a dictionary maps backend names
to their traffic weights. The weights must sum to 1.
"""
ray.get(
self._controller.set_traffic.remote(endpoint_name,
traffic_policy_dictionary))
@_ensure_connected
def shadow_traffic(self, endpoint_name: str, backend_tag: str,
proportion: float) -> None:
"""Shadow traffic from an endpoint to a backend.
The specified proportion of requests will be duplicated and sent to the
backend. Responses of the duplicated traffic will be ignored.
The backend must not already be in use.
To stop shadowing traffic to a backend, call `shadow_traffic` with
proportion equal to 0.
Args:
endpoint_name (str): A registered service endpoint.
backend_tag (str): A registered backend.
proportion (float): The proportion of traffic from 0 to 1.
"""
if not isinstance(proportion,
(float, int)) or not 0 <= proportion <= 1:
raise TypeError("proportion must be a float from 0 to 1.")
ray.get(
self._controller.shadow_traffic.remote(endpoint_name, backend_tag,
proportion))
@_ensure_connected
def get_handle(self, endpoint_name: str) -> RayServeHandle:
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
Args:
endpoint_name (str): A registered service endpoint.
Returns:
RayServeHandle
"""
if endpoint_name not in ray.get(
self._controller.get_all_endpoints.remote()):
raise KeyError(f"Endpoint '{endpoint_name}' does not exist.")
# TODO(edoakes): we should choose the router on the same node.
routers = ray.get(self._controller.get_routers.remote())
return RayServeHandle(
list(routers.values())[0],
endpoint_name,
)
def init(name: Optional[str] = None,
http_host: str = DEFAULT_HTTP_HOST,
http_port: int = DEFAULT_HTTP_PORT,
http_middlewares: List[Any] = []) -> None:
"""Initialize or connect to a serve cluster.
def start(detached: bool = False,
http_host: str = DEFAULT_HTTP_HOST,
http_port: int = DEFAULT_HTTP_PORT,
http_middlewares: List[Any] = []) -> Client:
"""Initialize a serve instance.
If serve cluster is already initialized, this function will just return.
If `ray.init` has not been called in this process, it will be called with
no arguments. To specify kwargs to `ray.init`, it should be called
separately before calling `serve.init`.
By default, the instance will be scoped to the lifetime of the returned
Client object (or when the script exits). If detached is set to True, the
instance will instead persist until client.shutdown() is called and clients
to it can be connected using serve.connect(). This is only relevant if
connecting to a long-running Ray cluster (e.g., with address="auto").
Args:
name (str): A unique name for this serve instance. This allows
multiple serve instances to run on the same ray cluster. Must be
specified in all subsequent serve.init() calls.
http_host (str): Host for HTTP servers. Default to "0.0.0.0". Serve
starts one HTTP server per node in the Ray cluster.
http_port (int, List[int]): Port for HTTP server. Default to 8000.
detached (bool): Whether not the instance should be detached from this
script.
http_host (str): Host for HTTP servers to listen on. Defaults to
"127.0.0.1". To expose Serve publicly, you probably want to set
this to "0.0.0.0". One HTTP server will be started on each node in
the Ray cluster.
http_port (int): Port for HTTP server. Defaults to 8000.
http_middleswares (list): A list of Starlette middlewares that will be
applied to the HTTP servers in the cluster.
"""
if name is not None and not isinstance(name, str):
raise TypeError("name must be a string.")
# Initialize ray if needed.
if not ray.is_initialized():
ray.init()
# Try to get serve controller if it exists
global controller
controller_name = format_actor_name(SERVE_CONTROLLER_NAME, name)
try:
controller = ray.get_actor(controller_name)
return
except ValueError:
pass
if detached:
controller_name = SERVE_CONTROLLER_NAME
try:
ray.get_actor(controller_name)
raise RayServeException("Called serve.start(detached=True) but a "
"detached instance is already running. "
"Please use serve.connect() to connect to "
"the running instance instead.")
except ValueError:
pass
else:
controller_name = format_actor_name(SERVE_CONTROLLER_NAME,
get_random_letters())
controller = ServeController.options(
name=controller_name,
lifetime="detached",
lifetime="detached" if detached else None,
max_restarts=-1,
max_task_retries=-1,
).remote(name, http_host, http_port, http_middlewares)
).remote(
controller_name,
http_host,
http_port,
http_middlewares,
detached=detached)
futures = []
for node_id in ray.state.node_ids():
@@ -110,281 +395,63 @@ def init(name: Optional[str] = None,
futures.append(future)
ray.get(futures)
@_ensure_connected
def shutdown() -> None:
"""Completely shut down the connected Serve instance.
Shuts down all processes and deletes all state associated with the Serve
instance that's currently connected to (via serve.init).
"""
global controller
ray.get(controller.shutdown.remote())
ray.kill(controller, no_restart=True)
controller = None
return Client(controller, controller_name, detached=detached)
@_ensure_connected
def create_endpoint(endpoint_name: str,
*,
backend: str = None,
route: Optional[str] = None,
methods: List[str] = ["GET"]) -> None:
"""Create a service endpoint given route_expression.
def connect() -> Client:
"""Connect to an existing Serve instance on this Ray cluster.
Args:
endpoint_name (str): A name to associate to with the endpoint.
backend (str, required): The backend that will serve requests to
this endpoint. To change this or split traffic among backends, use
`serve.set_traffic`.
route (str, optional): A string begin with "/". HTTP server will use
the string to match the path.
methods(List[str], optional): The HTTP methods that are valid for this
endpoint.
"""
if backend is None:
raise TypeError("backend must be specified when creating "
"an endpoint.")
elif not isinstance(backend, str):
raise TypeError("backend must be a string, got {}.".format(
type(backend)))
If calling from the driver program, the Serve instance on this Ray cluster
must first have been initialized using `serve.start(detached=True)`.
if route is not None:
if not isinstance(route, str) or not route.startswith("/"):
raise TypeError("route must be a string starting with '/'.")
if not isinstance(methods, list):
raise TypeError(
"methods must be a list of strings, but got type {}".format(
type(methods)))
endpoints = list_endpoints()
if endpoint_name in endpoints:
methods_old = endpoints[endpoint_name]["methods"]
route_old = endpoints[endpoint_name]["route"]
if methods_old.sort() == methods.sort() and route_old == route:
raise ValueError(
"Route '{}' is already registered to endpoint '{}' "
"with methods '{}'. To set the backend for this "
"endpoint, please use serve.set_traffic().".format(
route, endpoint_name, methods))
upper_methods = []
for method in methods:
if not isinstance(method, str):
raise TypeError("methods must be a list of strings, but contained "
"an element of type {}".format(type(method)))
upper_methods.append(method.upper())
ray.get(
controller.create_endpoint.remote(endpoint_name, {backend: 1.0}, route,
upper_methods))
@_ensure_connected
def delete_endpoint(endpoint: str) -> None:
"""Delete the given endpoint.
Does not delete any associated backends.
"""
ray.get(controller.delete_endpoint.remote(endpoint))
@_ensure_connected
def list_endpoints() -> Dict[str, Dict[str, Any]]:
"""Returns a dictionary of all registered endpoints.
The dictionary keys are endpoint names and values are dictionaries
of the form: {"methods": List[str], "traffic": Dict[str, float]}.
"""
return ray.get(controller.get_all_endpoints.remote())
@_ensure_connected
def update_backend_config(
backend_tag: str,
config_options: Union[BackendConfig, Dict[str, Any]]) -> None:
"""Update a backend configuration for a backend tag.
Keys not specified in the passed will be left unchanged.
Args:
backend_tag(str): A registered backend.
config_options(dict, serve.BackendConfig): Backend config options to
update. Either a BackendConfig object or a dict mapping strings to
values for the following supported options:
- "num_replicas": number of worker processes to start up that
will handle requests to this backend.
- "max_batch_size": the maximum number of requests that will
be processed in one batch by this backend.
- "batch_wait_timeout": time in seconds that backend replicas
will wait for a full batch of requests before
processing a partial batch.
- "max_concurrent_queries": the maximum number of queries
that will be sent to a replica of this backend
without receiving a response.
If called from within a backend, will connect to the same Serve instance
that the backend is running in.
"""
if not isinstance(config_options, (BackendConfig, dict)):
raise TypeError(
"config_options must be a BackendConfig or dictionary.")
ray.get(
controller.update_backend_config.remote(backend_tag, config_options))
# Initialize ray if needed.
if not ray.is_initialized():
ray.init()
@_ensure_connected
def get_backend_config(backend_tag: str) -> BackendConfig:
"""Get the backend configuration for a backend tag.
Args:
backend_tag(str): A registered backend.
"""
return ray.get(controller.get_backend_config.remote(backend_tag))
@_ensure_connected
def create_backend(
backend_tag: str,
func_or_class: Union[Callable, Type[Callable]],
*actor_init_args: Any,
ray_actor_options: Optional[Dict] = None,
config: Optional[Union[BackendConfig, Dict[str, Any]]] = None) -> None:
"""Create a backend with the provided tag.
The backend will serve requests with func_or_class.
Args:
backend_tag (str): a unique tag assign to identify this backend.
func_or_class (callable, class): a function or a class implementing
__call__.
actor_init_args (optional): the arguments to pass to the class.
initialization method.
ray_actor_options (optional): options to be passed into the
@ray.remote decorator for the backend actor.
config (dict, serve.BackendConfig, optional): configuration options
for this backend. Either a BackendConfig, or a dictionary mapping
strings to values for the following supported options:
- "num_replicas": number of worker processes to start up that will
handle requests to this backend.
- "max_batch_size": the maximum number of requests that will
be processed in one batch by this backend.
- "batch_wait_timeout": time in seconds that backend replicas
will wait for a full batch of requests before processing a
partial batch.
- "max_concurrent_queries": the maximum number of queries that will
be sent to a replica of this backend without receiving a
response.
"""
if backend_tag in list_backends():
raise ValueError(
"Cannot create backend. "
"Backend '{}' is already registered.".format(backend_tag))
if config is None:
config = {}
replica_config = ReplicaConfig(
func_or_class, *actor_init_args, ray_actor_options=ray_actor_options)
metadata = BackendMetadata(
accepts_batches=replica_config.accepts_batches,
is_blocking=replica_config.is_blocking)
if isinstance(config, dict):
backend_config = BackendConfig.parse_obj({
**config, "internal_metadata": metadata
})
elif isinstance(config, BackendConfig):
backend_config = config.copy(update={"internal_metadata": metadata})
# When running inside of a backend, _INTERNAL_CONTROLLER_NAME is set to
# ensure that the correct instance is connected to.
if _INTERNAL_CONTROLLER_NAME is None:
controller_name = SERVE_CONTROLLER_NAME
else:
raise TypeError("config must be a BackendConfig or a dictionary.")
backend_config._validate_complete()
ray.get(
controller.create_backend.remote(backend_tag, backend_config,
replica_config))
controller_name = _INTERNAL_CONTROLLER_NAME
# Try to get serve controller if it exists
try:
controller = ray.get_actor(controller_name)
except ValueError:
raise RayServeException("Called `serve.connect()` but there is no "
"instance running on this Ray cluster. Please "
"call `serve.start(detached=True) to start "
"one.")
return Client(controller, controller_name, detached=True)
@_ensure_connected
def list_backends() -> Dict[str, Dict[str, Any]]:
"""Returns a dictionary of all registered backends.
def accept_batch(f: Callable) -> Callable:
"""Annotation to mark that a serving function accepts batches of requests.
Dictionary maps backend tags to backend configs.
"""
return ray.get(controller.get_all_backends.remote())
In order to accept batches of requests as input, the implementation must
handle a list of requests being passed in rather than just a single
request.
@_ensure_connected
def delete_backend(backend_tag: str) -> None:
"""Delete the given backend.
The backend must not currently be used by any endpoints.
"""
ray.get(controller.delete_backend.remote(backend_tag))
@_ensure_connected
def set_traffic(endpoint_name: str,
traffic_policy_dictionary: Dict[str, float]) -> None:
"""Associate a service endpoint with traffic policy.
This must be set on any backend implementation that will have
max_batch_size set to greater than 1.
Example:
>>> serve.set_traffic("service-name", {
"backend:v1": 0.5,
"backend:v2": 0.5
})
>>> @serve.accept_batch
def serving_func(requests):
assert isinstance(requests, list)
...
Args:
endpoint_name (str): A registered service endpoint.
traffic_policy_dictionary (dict): a dictionary maps backend names
to their traffic weights. The weights must sum to 1.
>>> class ServingActor:
@serve.accept_batch
def __call__(self, requests):
assert isinstance(requests, list)
"""
ray.get(
controller.set_traffic.remote(endpoint_name,
traffic_policy_dictionary))
@_ensure_connected
def shadow_traffic(endpoint_name: str, backend_tag: str,
proportion: float) -> None:
"""Shadow traffic from an endpoint to a backend.
The specified proportion of requests will be duplicated and sent to the
backend. Responses of the duplicated traffic will be ignored.
The backend must not already be in use.
To stop shadowing traffic to a backend, call `shadow_traffic` with
proportion equal to 0.
Args:
endpoint_name (str): A registered service endpoint.
backend_tag (str): A registered backend.
proportion (float): The proportion of traffic from 0 to 1.
"""
if not isinstance(proportion, (float, int)) or not 0 <= proportion <= 1:
raise TypeError("proportion must be a float from 0 to 1.")
ray.get(
controller.shadow_traffic.remote(endpoint_name, backend_tag,
proportion))
@_ensure_connected
def get_handle(endpoint_name: str, missing_ok: bool = False) -> RayServeHandle:
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
Args:
endpoint_name (str): A registered service endpoint.
missing_ok (bool): If true, skip the check for the endpoint existence.
It can be useful when the endpoint has not been registered.
Returns:
RayServeHandle
"""
if not missing_ok:
assert endpoint_name in ray.get(controller.get_all_endpoints.remote())
# TODO(edoakes): we should choose the router on the same node.
routers = ray.get(controller.get_routers.remote())
return RayServeHandle(
list(routers.values())[0],
endpoint_name,
)
f._serve_accept_batch = True
return f
+5 -9
View File
@@ -10,7 +10,6 @@ import time
import ray
from ray.async_compat import sync_to_async
from ray import serve
from ray.serve import context as serve_context
from ray.serve.context import FakeFlaskRequest
from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
@@ -102,14 +101,11 @@ def create_backend_worker(func_or_class: Union[Callable, Type[Callable]]):
# TODO(architkulkarni): Add type hints after upgrading cloudpickle
class RayServeWrappedWorker(object):
def __init__(self,
backend_tag,
replica_tag,
init_args,
backend_config: BackendConfig,
instance_name=None):
serve.init(name=instance_name)
def __init__(self, backend_tag, replica_tag, init_args,
backend_config: BackendConfig, controller_name: str):
# Set the controller name so that serve.connect() will connect to
# the instance that this backend is running in.
ray.serve.api._set_internal_controller_name(controller_name)
if is_function:
_callable = func_or_class
else:
@@ -78,10 +78,10 @@ async def trial(actors, session, data_size):
async def main():
ray.init(log_to_driver=False)
serve.init()
client = serve.start()
serve.create_backend("backend", backend)
serve.create_endpoint("endpoint", backend="backend", route="/api")
client.create_backend("backend", backend)
client.create_endpoint("endpoint", backend="backend", route="/api")
actors = [Client.remote() for _ in range(NUM_CLIENTS)]
for num_replicas in [1, 8]:
@@ -100,7 +100,7 @@ async def main():
},
]:
backend_config["num_replicas"] = num_replicas
serve.update_backend_config("backend", backend_config)
client.update_backend_config("backend", backend_config)
print(repr(backend_config) + ":")
async with aiohttp.ClientSession() as session:
# TODO(edoakes): large data causes broken pipe errors.
+3 -3
View File
@@ -42,7 +42,7 @@ def run_http_benchmark(url, num_queries):
@click.option("--max-concurrent-queries", type=int, required=False)
def main(num_replicas: int, num_queries: Optional[int],
max_concurrent_queries: Optional[int], blocking: bool):
serve.init()
client = serve.start()
def noop(_):
return "hello world"
@@ -52,8 +52,8 @@ def main(num_replicas: int, num_queries: Optional[int],
"max_concurrent_queries": max_concurrent_queries
}
print("Using config", config)
serve.create_backend("noop", noop, config=config)
serve.create_endpoint("noop", backend="noop", route="/noop")
client.create_backend("noop", noop, config=config)
client.create_endpoint("noop", backend="noop", route="/noop")
url = "{}/noop".format(DEFAULT_HTTP_ADDRESS)
block_until_ready(url)
+2 -2
View File
@@ -16,10 +16,10 @@ DEFAULT_HTTP_PORT = 8000
#: Max concurrency
ASYNC_CONCURRENCY = int(1e6)
#: Time to wait for HTTP proxy in `serve.init()`
#: Max time to wait for HTTP proxy in `serve.start()`.
HTTP_PROXY_TIMEOUT = 60
#: Default histogram buckets for latency tracker
#: Default histogram buckets for latency tracker.
DEFAULT_LATENCY_BUCKET_MS = [
1,
2,
+27 -23
View File
@@ -102,13 +102,17 @@ class ServeController:
requires all implementations here to be idempotent.
"""
async def __init__(self, instance_name: str, http_host: str,
http_port: str, http_middlewares: List[Any]) -> None:
# Unique name of the serve instance managed by this actor. Used to
# namespace child actors and checkpoints.
self.instance_name = instance_name
async def __init__(self,
controller_name: str,
http_host: str,
http_port: str,
http_middlewares: List[Any],
detached: bool = False):
self.detached = detached
# Name of this controller actor.
self.controller_name = controller_name
# Used to read/write checkpoints.
self.kv_store = RayInternalKVStore(namespace=instance_name)
self.kv_store = RayInternalKVStore(namespace=controller_name)
# path -> (endpoint, methods).
self.routes = dict()
# backend -> BackendInfo.
@@ -180,7 +184,7 @@ class ServeController:
continue
router_name = format_actor_name(SERVE_PROXY_NAME,
self.instance_name, node_id)
self.controller_name, node_id)
try:
router = ray.get_actor(router_name)
except ValueError:
@@ -190,7 +194,7 @@ class ServeController:
self.http_port))
router = HTTPProxyActor.options(
name=router_name,
lifetime="detached",
lifetime="detached" if self.detached else None,
max_concurrency=ASYNC_CONCURRENCY,
max_restarts=-1,
max_task_retries=-1,
@@ -201,7 +205,7 @@ class ServeController:
node_id,
self.http_host,
self.http_port,
instance_name=self.instance_name,
controller_name=self.controller_name,
http_middlewares=self.http_middlewares)
self.routers[node_id] = router
@@ -287,7 +291,7 @@ class ServeController:
for node_id in router_node_ids:
router_name = format_actor_name(SERVE_PROXY_NAME,
self.instance_name, node_id)
self.controller_name, node_id)
self.routers[node_id] = ray.get_actor(router_name)
# Fetch actor handles for all of the backend replicas in the system.
@@ -297,7 +301,7 @@ class ServeController:
for backend_tag, replica_tags in self.replicas.items():
for replica_tag in replica_tags:
replica_name = format_actor_name(replica_tag,
self.instance_name)
self.controller_name)
self.workers[backend_tag][replica_tag] = ray.get_actor(
replica_name)
@@ -389,8 +393,8 @@ class ServeController:
"""Fetched by serve handles."""
return self.traffic_policies[endpoint]
async def _start_backend_worker(self, backend_tag: str,
replica_tag: str) -> ActorHandle:
async def _start_backend_worker(self, backend_tag: str, replica_tag: str,
replica_name: str) -> ActorHandle:
"""Creates a backend worker and waits for it to start up.
Assumes that the backend configuration has already been registered
@@ -400,18 +404,15 @@ class ServeController:
replica_tag, backend_tag))
backend_info = self.backends[backend_tag]
replica_name = format_actor_name(replica_tag, self.instance_name)
worker_handle = ray.remote(backend_info.worker_class).options(
name=replica_name,
lifetime="detached",
lifetime="detached" if self.detached else None,
max_restarts=-1,
max_task_retries=-1,
**backend_info.replica_config.ray_actor_options).remote(
backend_tag,
replica_tag,
backend_tag, replica_tag,
backend_info.replica_config.actor_init_args,
backend_info.backend_config,
instance_name=self.instance_name)
backend_info.backend_config, self.controller_name)
# TODO(edoakes): we should probably have a timeout here.
await worker_handle.ready.remote()
return worker_handle
@@ -420,11 +421,12 @@ class ServeController:
# NOTE(edoakes): the replicas may already be created if we
# failed after creating them but before writing a
# checkpoint.
replica_name = format_actor_name(replica_tag, self.controller_name)
try:
worker_handle = ray.get_actor(replica_tag)
worker_handle = ray.get_actor(replica_name)
except ValueError:
worker_handle = await self._start_backend_worker(
backend_tag, replica_tag)
backend_tag, replica_tag, replica_name)
self.replicas[backend_tag].append(replica_tag)
self.workers[backend_tag][replica_tag] = worker_handle
@@ -466,8 +468,10 @@ class ServeController:
for replica_tag in replicas_to_stop:
# NOTE(edoakes): the replicas may already be stopped if we
# failed after stopping them but before writing a checkpoint.
replica_name = format_actor_name(replica_tag,
self.controller_name)
try:
replica = ray.get_actor(replica_tag)
replica = ray.get_actor(replica_name)
except ValueError:
continue
@@ -556,7 +560,7 @@ class ServeController:
self.replicas_to_start[backend_tag].append(replica_tag)
elif delta_num_replicas < 0:
logger.debug("Removing {} replicas from backend {}".format(
logger.debug("Removing {} replicas from backend '{}'".format(
-delta_num_replicas, backend_tag))
assert len(self.replicas[backend_tag]) >= delta_num_replicas
for _ in range(-delta_num_replicas):
@@ -1,7 +1,7 @@
from ray import serve
import requests
serve.init()
client = serve.start()
class Counter:
@@ -12,8 +12,8 @@ class Counter:
return {"current_counter": self.count}
serve.create_backend("counter", Counter)
serve.create_endpoint("counter", backend="counter", route="/counter")
client.create_backend("counter", Counter)
client.create_endpoint("counter", backend="counter", route="/counter")
requests.get("http://127.0.0.1:8000/counter").json()
# > {"current_counter": self.count}
@@ -1,15 +1,15 @@
from ray import serve
import requests
serve.init()
client = serve.start()
def echo(flask_request):
return "hello " + flask_request.args.get("name", "serve!")
serve.create_backend("hello", echo)
serve.create_endpoint("hello", backend="hello", route="/hello")
client.create_backend("hello", echo)
client.create_endpoint("hello", backend="hello", route="/hello")
requests.get("http://127.0.0.1:8000/hello").text
# > "hello serve!"
@@ -4,7 +4,7 @@ import ray
from ray import serve
ray.init(num_cpus=10)
serve.init()
client = serve.start()
# Our pipeline will be structured as follows:
# - Input comes in, the composed model sends it to model_one
@@ -27,8 +27,9 @@ def model_two(_unused_flask_request, data=None):
class ComposedModel:
def __init__(self):
self.model_one = serve.get_handle("model_one")
self.model_two = serve.get_handle("model_two")
client = serve.connect()
self.model_one = client.get_handle("model_one")
self.model_two = client.get_handle("model_two")
# This method can be called concurrently!
async def __call__(self, flask_request):
@@ -44,17 +45,17 @@ class ComposedModel:
return result
serve.create_backend("model_one", model_one)
serve.create_endpoint("model_one", backend="model_one")
client.create_backend("model_one", model_one)
client.create_endpoint("model_one", backend="model_one")
serve.create_backend("model_two", model_two)
serve.create_endpoint("model_two", backend="model_two")
client.create_backend("model_two", model_two)
client.create_endpoint("model_two", backend="model_two")
# max_concurrent_queries is optional. By default, if you pass in an async
# function, Ray Serve sets the limit to a high number.
serve.create_backend(
client.create_backend(
"composed_backend", ComposedModel, config={"max_concurrent_queries": 10})
serve.create_endpoint(
client.create_endpoint(
"composed", backend="composed_backend", route="/composed")
for _ in range(5):
@@ -30,9 +30,9 @@ def batch_adder_v0(flask_requests: List):
# __doc_deploy_begin__
ray.init(num_cpus=10)
serve.init()
serve.create_backend("adder:v0", batch_adder_v0, config={"max_batch_size": 4})
serve.create_endpoint(
client = serve.start()
client.create_backend("adder:v0", batch_adder_v0, config={"max_batch_size": 4})
client.create_endpoint(
"adder", backend="adder:v0", route="/adder", methods=["GET"])
# __doc_deploy_end__
@@ -80,12 +80,12 @@ def batch_adder_v1(flask_requests: List, *, numbers: List = []):
# __doc_define_servable_v1_end__
# __doc_deploy_v1_begin__
serve.create_backend("adder:v1", batch_adder_v1, config={"max_batch_size": 4})
serve.set_traffic("adder", {"adder:v1": 1})
client.create_backend("adder:v1", batch_adder_v1, config={"max_batch_size": 4})
client.set_traffic("adder", {"adder:v1": 1})
# __doc_deploy_v1_end__
# __doc_query_handle_begin__
handle = serve.get_handle("adder")
handle = client.get_handle("adder")
print(handle)
# Output
# RayServeHandle(
@@ -69,9 +69,9 @@ ray.init(address="auto")
# now we initialize /connect to the Ray service
# listen on 0.0.0.0 to make the HTTP server accessible from other machines.
serve.init(http_host="0.0.0.0")
serve.create_backend("lr:v1", BoostingModel)
serve.create_endpoint("iris_classifier", backend="lr:v1", route="/regressor")
client = serve.start(http_host="0.0.0.0")
client.create_backend("lr:v1", BoostingModel)
client.create_endpoint("iris_classifier", backend="lr:v1", route="/regressor")
# __doc_create_deploy_end__
# __doc_query_begin__
@@ -163,7 +163,7 @@ class BoostingModelv2:
# now we initialize /connect to the Ray service
serve.init()
serve.create_backend("lr:v2", BoostingModelv2)
serve.set_traffic("iris_classifier", {"lr:v2": 0.25, "lr:v1": 0.75})
client = serve.connect()
client.create_backend("lr:v2", BoostingModelv2)
client.set_traffic("iris_classifier", {"lr:v2": 0.25, "lr:v1": 0.75})
# __doc_create_deploy_2_end__
@@ -46,9 +46,9 @@ class ImageModel:
# __doc_define_servable_end__
# __doc_deploy_begin__
serve.init()
serve.create_backend("resnet18:v0", ImageModel)
serve.create_endpoint(
client = serve.start()
client.create_backend("resnet18:v0", ImageModel)
client.create_endpoint(
"predictor",
backend="resnet18:v0",
route="/image_predict",
@@ -65,9 +65,9 @@ class BoostingModel:
# __doc_define_servable_end__
# __doc_deploy_begin__
serve.init()
serve.create_backend("lr:v1", BoostingModel)
serve.create_endpoint("iris_classifier", backend="lr:v1", route="/regressor")
client = serve.start()
client.create_backend("lr:v1", BoostingModel)
client.create_endpoint("iris_classifier", backend="lr:v1", route="/regressor")
# __doc_deploy_end__
# __doc_query_begin__
@@ -68,9 +68,9 @@ class TFMnistModel:
# __doc_define_servable_end__
# __doc_deploy_begin__
serve.init()
serve.create_backend("tf:v1", TFMnistModel, "/tmp/mnist_model.h5")
serve.create_endpoint("tf_classifier", backend="tf:v1", route="/mnist")
client = serve.start()
client.create_backend("tf:v1", TFMnistModel, "/tmp/mnist_model.h5")
client.create_endpoint("tf_classifier", backend="tf:v1", route="/mnist")
# __doc_deploy_end__
# __doc_query_begin__
+3 -4
View File
@@ -13,10 +13,9 @@ def echo(flask_request):
return ["hello " + flask_request.args.get("name", "serve!")]
serve.init()
serve.create_backend("echo:v1", echo)
serve.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
client = serve.start()
client.create_backend("echo:v1", echo)
client.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
while True:
resp = requests.get("http://127.0.0.1:8000/echo").json()
+4 -4
View File
@@ -36,9 +36,9 @@ class MagicCounter:
return base_number + self.increment
serve.init()
serve.create_backend("counter:v1", MagicCounter, 42) # increment=42
serve.create_endpoint("magic_counter", backend="counter:v1", route="/counter")
client = serve.start()
client.create_backend("counter:v1", MagicCounter, 42) # increment=42
client.create_endpoint("magic_counter", backend="counter:v1", route="/counter")
print("Sending ten queries via HTTP")
for i in range(10):
@@ -50,7 +50,7 @@ for i in range(10):
time.sleep(0.2)
print("Sending ten queries via Python")
handle = serve.get_handle("magic_counter")
handle = client.get_handle("magic_counter")
for i in range(10):
print("> Pinging handle.remote(base_number={})".format(i))
result = ray.get(handle.remote(base_number=i))
@@ -47,11 +47,11 @@ class MagicCounter:
return result
serve.init()
serve.create_backend(
client = serve.start()
client.create_backend(
"counter:v1", MagicCounter, 42,
config={"max_batch_size": 5}) # increment=42
serve.create_endpoint("magic_counter", backend="counter:v1", route="/counter")
client.create_endpoint("magic_counter", backend="counter:v1", route="/counter")
print("Sending ten queries via HTTP")
for i in range(10):
@@ -63,7 +63,7 @@ for i in range(10):
time.sleep(0.2)
print("Sending ten queries via Python")
handle = serve.get_handle("magic_counter")
handle = client.get_handle("magic_counter")
for i in range(10):
print("> Pinging handle.remote(base_number={})".format(i))
result = ray.get(handle.remote(base_number=i))
+4 -4
View File
@@ -26,16 +26,16 @@ class MagicCounter:
return ""
serve.init()
client = serve.start()
# specify max_batch_size in BackendConfig
backend_config = {"max_batch_size": 5}
serve.create_backend(
client.create_backend(
"counter:v1", MagicCounter, 42, config=backend_config) # increment=42
print("Backend Config for backend: 'counter:v1'")
print(backend_config)
serve.create_endpoint("magic_counter", backend="counter:v1", route="/counter")
client.create_endpoint("magic_counter", backend="counter:v1", route="/counter")
handle = serve.get_handle("magic_counter")
handle = client.get_handle("magic_counter")
future_list = []
# fire 30 requests
+3 -3
View File
@@ -38,9 +38,9 @@ def echo(_):
raise Exception("Something went wrong...")
serve.init()
client = serve.start()
serve.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
client.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
for _ in range(2):
resp = requests.get("http://127.0.0.1:8000/echo").json()
@@ -49,6 +49,6 @@ for _ in range(2):
print("...Sleeping for 2 seconds...")
time.sleep(2)
handle = serve.get_handle("my_endpoint")
handle = client.get_handle("my_endpoint")
print("Invoke from python will raise exception with traceback:")
ray.get(handle.remote())
+8 -8
View File
@@ -7,7 +7,7 @@ import ray.serve as serve
# initialize ray serve system.
ray.init(num_cpus=10)
serve.init()
client = serve.start()
# a backend can be a function or class.
@@ -18,16 +18,16 @@ def echo_v1(flask_request, response="hello from python!"):
return response
serve.create_backend("echo:v1", echo_v1)
client.create_backend("echo:v1", echo_v1)
# An endpoint is associated with an HTTP path and traffic to the endpoint
# will be serviced by the echo:v1 backend.
serve.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
client.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
print(requests.get("http://127.0.0.1:8000/echo", timeout=0.5).text)
# The service will be reachable from http
print(ray.get(serve.get_handle("my_endpoint").remote(response="hello")))
print(ray.get(client.get_handle("my_endpoint").remote(response="hello")))
# as well as within the ray system.
@@ -38,10 +38,10 @@ def echo_v2(flask_request):
return "something new"
serve.create_backend("echo:v2", echo_v2)
client.create_backend("echo:v2", echo_v2)
# The two backend will now split the traffic 50%-50%.
serve.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
client.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
# Observe requests are now split between two backends.
for _ in range(10):
@@ -49,5 +49,5 @@ for _ in range(10):
time.sleep(0.5)
# You can also change number of replicas for each backend independently.
serve.update_backend_config("echo:v1", {"num_replicas": 2})
serve.update_backend_config("echo:v2", {"num_replicas": 2})
client.update_backend_config("echo:v1", {"num_replicas": 2})
client.update_backend_config("echo:v2", {"num_replicas": 2})
+16 -16
View File
@@ -5,42 +5,42 @@ import ray
import ray.serve as serve
import time
# initialize ray serve system.
serve.init()
# Initialize ray serve instance.
client = serve.start()
# a backend can be a function or class.
# it can be made to be invoked from web as well as python.
# A backend can be a function or class.
# It can be made to be invoked via HTTP as well as python.
def echo_v1(_, response="hello from python!"):
return f"echo_v1({response})"
serve.create_backend("echo_v1", echo_v1)
serve.create_endpoint("echo_v1", backend="echo_v1", route="/echo_v1")
client.create_backend("echo_v1", echo_v1)
client.create_endpoint("echo_v1", backend="echo_v1", route="/echo_v1")
def echo_v2(_, relay=""):
return f"echo_v2({relay})"
serve.create_backend("echo_v2", echo_v2)
serve.create_endpoint("echo_v2", backend="echo_v2", route="/echo_v2")
client.create_backend("echo_v2", echo_v2)
client.create_endpoint("echo_v2", backend="echo_v2", route="/echo_v2")
def echo_v3(_, relay=""):
return f"echo_v3({relay})"
serve.create_backend("echo_v3", echo_v3)
serve.create_endpoint("echo_v3", backend="echo_v3", route="/echo_v3")
client.create_backend("echo_v3", echo_v3)
client.create_endpoint("echo_v3", backend="echo_v3", route="/echo_v3")
def echo_v4(_, relay1="", relay2=""):
return f"echo_v4({relay1} , {relay2})"
serve.create_backend("echo_v4", echo_v4)
serve.create_endpoint("echo_v4", backend="echo_v4", route="/echo_v4")
client.create_backend("echo_v4", echo_v4)
client.create_endpoint("echo_v4", backend="echo_v4", route="/echo_v4")
"""
The pipeline created is as follows -
"my_endpoint1"
@@ -62,10 +62,10 @@ The pipeline created is as follows -
"""
# get the handle of the endpoints
handle1 = serve.get_handle("echo_v1")
handle2 = serve.get_handle("echo_v2")
handle3 = serve.get_handle("echo_v3")
handle4 = serve.get_handle("echo_v4")
handle1 = client.get_handle("echo_v1")
handle2 = client.get_handle("echo_v2")
handle3 = client.get_handle("echo_v3")
handle4 = client.get_handle("echo_v4")
start = time.time()
print("Start firing to the pipeline: {} s".format(time.time()))
+5 -5
View File
@@ -30,10 +30,10 @@ def echo_v2(_):
return "v2"
serve.init()
client = serve.start()
serve.create_backend("echo:v1", echo_v1)
serve.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
client.create_backend("echo:v1", echo_v1)
client.create_endpoint("my_endpoint", backend="echo:v1", route="/echo")
for _ in range(3):
resp = requests.get("http://127.0.0.1:8000/echo").json()
@@ -42,8 +42,8 @@ for _ in range(3):
print("...Sleeping for 2 seconds...")
time.sleep(2)
serve.create_backend("echo:v2", echo_v2)
serve.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
client.create_backend("echo:v2", echo_v2)
client.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
while True:
resp = requests.get("http://127.0.0.1:8000/echo").json()
print(pformat_color_json(resp))
+17 -9
View File
@@ -1,7 +1,5 @@
from typing import Optional
import ray
from ray import serve
from ray.serve.context import TaskContext
from ray.serve.request_params import RequestMetadata
@@ -42,6 +40,14 @@ class RayServeHandle:
self.shard_key = shard_key
def remote(self, *args, **kwargs):
"""Invoke a request on the endpoint.
Returns a Ray ObjectRef whose result can be waited for or retrieved
using `ray.wait` or `ray.get`, respectively.
Returns:
ray.ObjectRef
"""
if len(args) > 0:
raise ValueError(
"handle.remote must be invoked with keyword arguments.")
@@ -59,6 +65,14 @@ class RayServeHandle:
method_name: Optional[str] = None,
http_method: Optional[str] = None,
shard_key: Optional[str] = None):
"""Set options for this handle.
Args:
method_name(str): The method to invoke on the backend.
http_method(str): The HTTP method to use for the request.
shard_key(str): A string to use to deterministically map this
request to a backend if there are multiple for this endpoint.
"""
return RayServeHandle(
self.router_handle,
self.endpoint_name,
@@ -68,11 +82,5 @@ class RayServeHandle:
shard_key=self.shard_key or shard_key,
)
def _get_traffic_policy(self):
controller = serve.api._get_controller()
return ray.get(
controller.get_traffic_policy.remote(self.endpoint_name))
def __repr__(self):
return (f"RayServeHandle(Endpoint='{self.endpoint_name}', "
f"Traffic={self._get_traffic_policy()})")
return f"RayServeHandle(endpoint='{self.endpoint_name}')"
+5 -7
View File
@@ -6,7 +6,6 @@ import uvicorn
import ray
from ray.exceptions import RayTaskError
from ray import serve
from ray.serve.context import TaskContext
from ray.experimental import metrics
from ray.serve.request_params import RequestMetadata
@@ -27,9 +26,9 @@ class HTTPProxy:
# blocks forever
"""
async def fetch_config_from_controller(self, name, instance_name=None):
async def fetch_config_from_controller(self, name, controller_name):
assert ray.is_initialized()
controller = serve.api._get_controller()
controller = ray.get_actor(controller_name)
self.route_table = await controller.get_router_config.remote()
@@ -38,7 +37,7 @@ class HTTPProxy:
"requests", ["route"])
self.router = Router()
await self.router.setup(name, instance_name)
await self.router.setup(name, controller_name)
def set_route_table(self, route_table):
self.route_table = route_table
@@ -133,15 +132,14 @@ class HTTPProxyActor:
name,
host,
port,
instance_name=None,
controller_name,
http_middlewares: List["starlette.middleware.Middleware"] = []):
serve.init(name=instance_name)
self.app = HTTPProxy()
self.host = host
self.port = port
self.app = HTTPProxy()
await self.app.fetch_config_from_controller(name, instance_name)
await self.app.fetch_config_from_controller(name, controller_name)
self.wrapped_app = self.app
for middleware in http_middlewares:
+2 -4
View File
@@ -9,7 +9,6 @@ from dataclasses import dataclass
from ray.exceptions import RayTaskError
import ray
from ray import serve
from ray.experimental import metrics
from ray.serve.context import TaskContext
from ray.serve.endpoint_policy import RandomEndpointPolicy
@@ -53,7 +52,7 @@ class Query:
class Router:
"""A router that routes request to available workers."""
async def setup(self, name, instance_name=None):
async def setup(self, name, controller_name):
# Note: Several queues are used in the router
# - When a request come in, it's placed inside its corresponding
# endpoint_queue.
@@ -104,8 +103,7 @@ class Router:
# the controller. We use a "pull-based" approach instead of pushing
# them from the controller so that the router can transparently recover
# from failure.
serve.init(name=instance_name)
self.controller = serve.api._get_controller()
self.controller = ray.get_actor(controller_name)
traffic_policies = ray.get(
self.controller.get_traffic_policies.remote())
+5 -9
View File
@@ -15,19 +15,15 @@ def _shared_serve_instance():
num_cpus=36,
_metrics_export_port=9999,
_system_config={"metrics_report_interval_ms": 1000})
serve.init()
yield
yield serve.start(detached=True)
@pytest.fixture
def serve_instance(_shared_serve_instance):
serve.init()
yield
# Re-init if necessary.
serve.init()
controller = serve.api._get_controller()
yield _shared_serve_instance
controller = _shared_serve_instance._controller
# Clear all state between tests to avoid naming collisions.
for endpoint in ray.get(controller.get_all_endpoints.remote()):
serve.delete_endpoint(endpoint)
_shared_serve_instance.delete_endpoint(endpoint)
for backend in ray.get(controller.get_all_backends.remote()):
serve.delete_backend(backend)
_shared_serve_instance.delete_backend(backend)
+234 -161
View File
@@ -8,7 +8,7 @@ import requests
import ray
from ray import serve
from ray.test_utils import wait_for_condition
from ray.serve import constants
from ray.serve.constants import SERVE_PROXY_NAME
from ray.serve.exceptions import RayServeException
from ray.serve.config import BackendConfig
from ray.serve.utils import (block_until_http_ready, format_actor_name,
@@ -16,13 +16,13 @@ from ray.serve.utils import (block_until_http_ready, format_actor_name,
def test_e2e(serve_instance):
serve.init()
client = serve_instance
def function(flask_request):
return {"method": flask_request.method}
serve.create_backend("echo:v1", function)
serve.create_endpoint(
client.create_backend("echo:v1", function)
client.create_endpoint(
"endpoint", backend="echo:v1", route="/api", methods=["GET", "POST"])
retry_count = 5
@@ -49,12 +49,14 @@ def test_e2e(serve_instance):
def test_call_method(serve_instance):
client = serve_instance
class CallMethod:
def method(self, request):
return "hello"
serve.create_backend("backend", CallMethod)
serve.create_endpoint("endpoint", backend="backend", route="/api")
client.create_backend("backend", CallMethod)
client.create_endpoint("endpoint", backend="backend", route="/api")
# Test HTTP path.
resp = requests.get(
@@ -64,59 +66,69 @@ def test_call_method(serve_instance):
assert resp.text == "hello"
# Test serve handle path.
handle = serve.get_handle("endpoint")
handle = client.get_handle("endpoint")
assert ray.get(handle.options("method").remote()) == "hello"
def test_no_route(serve_instance):
client = serve_instance
def func(_, i=1):
return 1
serve.create_backend("backend:1", func)
serve.create_endpoint("noroute-endpoint", backend="backend:1")
service_handle = serve.get_handle("noroute-endpoint")
client.create_backend("backend:1", func)
client.create_endpoint("noroute-endpoint", backend="backend:1")
service_handle = client.get_handle("noroute-endpoint")
result = ray.get(service_handle.remote(i=1))
assert result == 1
def test_reject_duplicate_backend(serve_instance):
client = serve_instance
def f():
pass
def g():
pass
serve.create_backend("backend", f)
client.create_backend("backend", f)
with pytest.raises(ValueError):
serve.create_backend("backend", g)
client.create_backend("backend", g)
def test_reject_duplicate_route(serve_instance):
client = serve_instance
def f():
pass
serve.create_backend("backend", f)
client.create_backend("backend", f)
route = "/foo"
serve.create_endpoint("bar", backend="backend", route=route)
client.create_endpoint("bar", backend="backend", route=route)
with pytest.raises(ValueError):
serve.create_endpoint("foo", backend="backend", route=route)
client.create_endpoint("foo", backend="backend", route=route)
def test_reject_duplicate_endpoint(serve_instance):
client = serve_instance
def f():
pass
serve.create_backend("backend", f)
client.create_backend("backend", f)
endpoint_name = "foo"
serve.create_endpoint(endpoint_name, backend="backend", route="/ok")
client.create_endpoint(endpoint_name, backend="backend", route="/ok")
with pytest.raises(ValueError):
serve.create_endpoint(
client.create_endpoint(
endpoint_name, backend="backend", route="/different")
def test_reject_duplicate_endpoint_and_route(serve_instance):
client = serve_instance
class SimpleBackend(object):
def __init__(self, message):
self.message = message
@@ -124,26 +136,30 @@ def test_reject_duplicate_endpoint_and_route(serve_instance):
def __call__(self, *args, **kwargs):
return {"message": self.message}
serve.create_backend("backend1", SimpleBackend, "First")
serve.create_backend("backend2", SimpleBackend, "Second")
client.create_backend("backend1", SimpleBackend, "First")
client.create_backend("backend2", SimpleBackend, "Second")
serve.create_endpoint("test", backend="backend1", route="/test")
client.create_endpoint("test", backend="backend1", route="/test")
with pytest.raises(ValueError):
serve.create_endpoint("test", backend="backend2", route="/test")
client.create_endpoint("test", backend="backend2", route="/test")
def test_set_traffic_missing_data(serve_instance):
client = serve_instance
endpoint_name = "foobar"
backend_name = "foo_backend"
serve.create_backend(backend_name, lambda: 5)
serve.create_endpoint(endpoint_name, backend=backend_name)
client.create_backend(backend_name, lambda: 5)
client.create_endpoint(endpoint_name, backend=backend_name)
with pytest.raises(ValueError):
serve.set_traffic(endpoint_name, {"nonexistent_backend": 1.0})
client.set_traffic(endpoint_name, {"nonexistent_backend": 1.0})
with pytest.raises(ValueError):
serve.set_traffic("nonexistent_endpoint_name", {backend_name: 1.0})
client.set_traffic("nonexistent_endpoint_name", {backend_name: 1.0})
def test_scaling_replicas(serve_instance):
client = serve_instance
class Counter:
def __init__(self):
self.count = 0
@@ -152,9 +168,9 @@ def test_scaling_replicas(serve_instance):
self.count += 1
return self.count
serve.create_backend(
client.create_backend(
"counter:v1", Counter, config=BackendConfig(num_replicas=2))
serve.create_endpoint("counter", backend="counter:v1", route="/increment")
client.create_endpoint("counter", backend="counter:v1", route="/increment")
# Keep checking the routing table until /increment is populated
while "/increment" not in requests.get(
@@ -169,7 +185,7 @@ def test_scaling_replicas(serve_instance):
# If the load is shared among two replicas. The max result cannot be 10.
assert max(counter_result) < 10
serve.update_backend_config("counter:v1", {"num_replicas": 1})
client.update_backend_config("counter:v1", {"num_replicas": 1})
counter_result = []
for _ in range(10):
@@ -181,6 +197,8 @@ def test_scaling_replicas(serve_instance):
def test_scaling_replicas_legacy(serve_instance):
client = serve_instance
class Counter:
def __init__(self):
self.count = 0
@@ -189,8 +207,8 @@ def test_scaling_replicas_legacy(serve_instance):
self.count += 1
return self.count
serve.create_backend("counter:v1", Counter, config={"num_replicas": 2})
serve.create_endpoint("counter", backend="counter:v1", route="/increment")
client.create_backend("counter:v1", Counter, config={"num_replicas": 2})
client.create_endpoint("counter", backend="counter:v1", route="/increment")
# Keep checking the routing table until /increment is populated
while "/increment" not in requests.get(
@@ -205,7 +223,7 @@ def test_scaling_replicas_legacy(serve_instance):
# If the load is shared among two replicas. The max result cannot be 10.
assert max(counter_result) < 10
serve.update_backend_config("counter:v1", {"num_replicas": 1})
client.update_backend_config("counter:v1", {"num_replicas": 1})
counter_result = []
for _ in range(10):
@@ -217,6 +235,8 @@ def test_scaling_replicas_legacy(serve_instance):
def test_batching(serve_instance):
client = serve_instance
class BatchingExample:
def __init__(self):
self.count = 0
@@ -228,11 +248,11 @@ def test_batching(serve_instance):
return [self.count] * batch_size
# set the max batch size
serve.create_backend(
client.create_backend(
"counter:v11",
BatchingExample,
config=BackendConfig(max_batch_size=5, batch_wait_timeout=1))
serve.create_endpoint(
client.create_endpoint(
"counter1", backend="counter:v11", route="/increment2")
# Keep checking the routing table until /increment is populated
@@ -241,7 +261,7 @@ def test_batching(serve_instance):
time.sleep(0.2)
future_list = []
handle = serve.get_handle("counter1")
handle = client.get_handle("counter1")
for _ in range(20):
f = handle.remote(temp=1)
future_list.append(f)
@@ -254,6 +274,8 @@ def test_batching(serve_instance):
def test_batching_legacy(serve_instance):
client = serve_instance
class BatchingExample:
def __init__(self):
self.count = 0
@@ -265,14 +287,14 @@ def test_batching_legacy(serve_instance):
return [self.count] * batch_size
# set the max batch size
serve.create_backend(
client.create_backend(
"counter:v11",
BatchingExample,
config={
"max_batch_size": 5,
"batch_wait_timeout": 1
})
serve.create_endpoint(
client.create_endpoint(
"counter1", backend="counter:v11", route="/increment2")
# Keep checking the routing table until /increment is populated
@@ -281,7 +303,7 @@ def test_batching_legacy(serve_instance):
time.sleep(0.2)
future_list = []
handle = serve.get_handle("counter1")
handle = client.get_handle("counter1")
for _ in range(20):
f = handle.remote(temp=1)
future_list.append(f)
@@ -294,6 +316,8 @@ def test_batching_legacy(serve_instance):
def test_batching_exception(serve_instance):
client = serve_instance
class NoListReturned:
def __init__(self):
self.count = 0
@@ -304,17 +328,19 @@ def test_batching_exception(serve_instance):
return batch_size
# set the max batch size
serve.create_backend(
client.create_backend(
"exception:v1", NoListReturned, config=BackendConfig(max_batch_size=5))
serve.create_endpoint(
client.create_endpoint(
"exception-test", backend="exception:v1", route="/noListReturned")
handle = serve.get_handle("exception-test")
handle = client.get_handle("exception-test")
with pytest.raises(ray.exceptions.RayTaskError):
assert ray.get(handle.remote(temp=1))
def test_batching_exception_legacy(serve_instance):
client = serve_instance
class NoListReturned:
def __init__(self):
self.count = 0
@@ -325,17 +351,19 @@ def test_batching_exception_legacy(serve_instance):
return batch_size
# set the max batch size
serve.create_backend(
client.create_backend(
"exception:v1", NoListReturned, config={"max_batch_size": 5})
serve.create_endpoint(
client.create_endpoint(
"exception-test", backend="exception:v1", route="/noListReturned")
handle = serve.get_handle("exception-test")
handle = client.get_handle("exception-test")
with pytest.raises(ray.exceptions.RayTaskError):
assert ray.get(handle.remote(temp=1))
def test_updating_config(serve_instance):
client = serve_instance
class BatchSimple:
def __init__(self):
self.count = 0
@@ -345,17 +373,17 @@ def test_updating_config(serve_instance):
batch_size = serve.context.batch_size
return [1] * batch_size
serve.create_backend(
client.create_backend(
"bsimple:v1",
BatchSimple,
config=BackendConfig(max_batch_size=2, num_replicas=3))
serve.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")
client.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")
controller = serve.api._get_controller()
controller = client._controller
old_replica_tag_list = ray.get(
controller._list_replicas.remote("bsimple:v1"))
serve.update_backend_config("bsimple:v1", BackendConfig(max_batch_size=5))
client.update_backend_config("bsimple:v1", BackendConfig(max_batch_size=5))
new_replica_tag_list = ray.get(
controller._list_replicas.remote("bsimple:v1"))
new_all_tag_list = []
@@ -370,6 +398,8 @@ def test_updating_config(serve_instance):
def test_updating_config_legacy(serve_instance):
client = serve_instance
class BatchSimple:
def __init__(self):
self.count = 0
@@ -379,20 +409,20 @@ def test_updating_config_legacy(serve_instance):
batch_size = serve.context.batch_size
return [1] * batch_size
serve.create_backend(
client.create_backend(
"bsimple:v1",
BatchSimple,
config={
"max_batch_size": 2,
"num_replicas": 3
})
serve.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")
client.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")
controller = serve.api._get_controller()
controller = client._controller
old_replica_tag_list = ray.get(
controller._list_replicas.remote("bsimple:v1"))
serve.update_backend_config("bsimple:v1", {"max_batch_size": 5})
client.update_backend_config("bsimple:v1", {"max_batch_size": 5})
new_replica_tag_list = ray.get(
controller._list_replicas.remote("bsimple:v1"))
new_all_tag_list = []
@@ -407,79 +437,85 @@ def test_updating_config_legacy(serve_instance):
def test_delete_backend(serve_instance):
client = serve_instance
def function():
return "hello"
serve.create_backend("delete:v1", function)
serve.create_endpoint(
client.create_backend("delete:v1", function)
client.create_endpoint(
"delete_backend", backend="delete:v1", route="/delete-backend")
assert requests.get("http://127.0.0.1:8000/delete-backend").text == "hello"
# Check that we can't delete the backend while it's in use.
with pytest.raises(ValueError):
serve.delete_backend("delete:v1")
client.delete_backend("delete:v1")
serve.create_backend("delete:v2", function)
serve.set_traffic("delete_backend", {"delete:v1": 0.5, "delete:v2": 0.5})
client.create_backend("delete:v2", function)
client.set_traffic("delete_backend", {"delete:v1": 0.5, "delete:v2": 0.5})
with pytest.raises(ValueError):
serve.delete_backend("delete:v1")
client.delete_backend("delete:v1")
# Check that the backend can be deleted once it's no longer in use.
serve.set_traffic("delete_backend", {"delete:v2": 1.0})
serve.delete_backend("delete:v1")
client.set_traffic("delete_backend", {"delete:v2": 1.0})
client.delete_backend("delete:v1")
# Check that we can no longer use the previously deleted backend.
with pytest.raises(ValueError):
serve.set_traffic("delete_backend", {"delete:v1": 1.0})
client.set_traffic("delete_backend", {"delete:v1": 1.0})
def function2():
return "olleh"
# Check that we can now reuse the previously delete backend's tag.
serve.create_backend("delete:v1", function2)
serve.set_traffic("delete_backend", {"delete:v1": 1.0})
client.create_backend("delete:v1", function2)
client.set_traffic("delete_backend", {"delete:v1": 1.0})
assert requests.get("http://127.0.0.1:8000/delete-backend").text == "olleh"
@pytest.mark.parametrize("route", [None, "/delete-endpoint"])
def test_delete_endpoint(serve_instance, route):
client = serve_instance
def function():
return "hello"
backend_name = "delete-endpoint:v1"
serve.create_backend(backend_name, function)
client.create_backend(backend_name, function)
endpoint_name = "delete_endpoint" + str(route)
serve.create_endpoint(endpoint_name, backend=backend_name, route=route)
serve.delete_endpoint(endpoint_name)
client.create_endpoint(endpoint_name, backend=backend_name, route=route)
client.delete_endpoint(endpoint_name)
# Check that we can reuse a deleted endpoint name and route.
serve.create_endpoint(endpoint_name, backend=backend_name, route=route)
client.create_endpoint(endpoint_name, backend=backend_name, route=route)
if route is not None:
assert requests.get(
"http://127.0.0.1:8000/delete-endpoint").text == "hello"
else:
handle = serve.get_handle(endpoint_name)
handle = client.get_handle(endpoint_name)
assert ray.get(handle.remote()) == "hello"
# Check that deleting the endpoint doesn't delete the backend.
serve.delete_endpoint(endpoint_name)
serve.create_endpoint(endpoint_name, backend=backend_name, route=route)
client.delete_endpoint(endpoint_name)
client.create_endpoint(endpoint_name, backend=backend_name, route=route)
if route is not None:
assert requests.get(
"http://127.0.0.1:8000/delete-endpoint").text == "hello"
else:
handle = serve.get_handle(endpoint_name)
handle = client.get_handle(endpoint_name)
assert ray.get(handle.remote()) == "hello"
@pytest.mark.parametrize("route", [None, "/shard"])
def test_shard_key(serve_instance, route):
client = serve_instance
# Create five backends that return different integers.
num_backends = 5
traffic_dict = {}
@@ -490,11 +526,11 @@ def test_shard_key(serve_instance, route):
backend_name = "backend-split-" + str(i)
traffic_dict[backend_name] = 1.0 / num_backends
serve.create_backend(backend_name, function)
client.create_backend(backend_name, function)
serve.create_endpoint(
client.create_endpoint(
"endpoint", backend=list(traffic_dict.keys())[0], route=route)
serve.set_traffic("endpoint", traffic_dict)
client.set_traffic("endpoint", traffic_dict)
def do_request(shard_key):
if route is not None:
@@ -502,7 +538,7 @@ def test_shard_key(serve_instance, route):
headers = {"X-SERVE-SHARD-KEY": shard_key}
result = requests.get(url, headers=headers).text
else:
handle = serve.get_handle("endpoint").options(shard_key=shard_key)
handle = client.get_handle("endpoint").options(shard_key=shard_key)
result = ray.get(handle.options(shard_key=shard_key).remote())
return result
@@ -517,49 +553,47 @@ def test_shard_key(serve_instance, route):
assert do_request(shard_key) == results[shard_key]
def test_name():
with pytest.raises(TypeError):
serve.init(name=1)
def test_multiple_instances():
route = "/api"
backend = "backend"
endpoint = "endpoint"
serve.init(name="cluster1", http_port=8001)
client1 = serve.start(http_port=8001)
def function():
return "hello1"
serve.create_backend(backend, function)
serve.create_endpoint(endpoint, backend=backend, route=route)
client1.create_backend(backend, function)
client1.create_endpoint(endpoint, backend=backend, route=route)
assert requests.get("http://127.0.0.1:8001" + route).text == "hello1"
# Create a second cluster on port 8002. Create an endpoint and backend with
# the same names and check that they don't collide.
serve.init(name="cluster2", http_port=8002)
client2 = serve.start(http_port=8002)
def function():
return "hello2"
serve.create_backend(backend, function)
serve.create_endpoint(endpoint, backend=backend, route=route)
client2.create_backend(backend, function)
client2.create_endpoint(endpoint, backend=backend, route=route)
assert requests.get("http://127.0.0.1:8001" + route).text == "hello1"
assert requests.get("http://127.0.0.1:8002" + route).text == "hello2"
# Check that deleting the backend in the current cluster doesn't.
serve.delete_endpoint(endpoint)
serve.delete_backend(backend)
client2.delete_endpoint(endpoint)
client2.delete_backend(backend)
assert requests.get("http://127.0.0.1:8001" + route).text == "hello1"
# Check that we can re-connect to the first cluster.
serve.init(name="cluster1")
serve.delete_endpoint(endpoint)
serve.delete_backend(backend)
# Check that the first client still works.
client1.delete_endpoint(endpoint)
client1.delete_backend(backend)
def test_parallel_start(serve_instance):
client = serve_instance
# Test the ability to start multiple replicas in parallel.
# In the past, when Serve scale up a backend, it does so one by one and
# wait for each replica to initialize. This test avoid this by preventing
@@ -588,15 +622,17 @@ def test_parallel_start(serve_instance):
def __call__(self, _):
return "Ready"
serve.create_backend(
client.create_backend(
"p:v0", LongStartingServable, config=BackendConfig(num_replicas=2))
serve.create_endpoint("test-parallel", backend="p:v0")
handle = serve.get_handle("test-parallel")
client.create_endpoint("test-parallel", backend="p:v0")
handle = client.get_handle("test-parallel")
ray.get(handle.remote(), timeout=10)
def test_parallel_start_legacy(serve_instance):
client = serve_instance
# Test the ability to start multiple replicas in parallel.
# In the past, when Serve scale up a backend, it does so one by one and
# wait for each replica to initialize. This test avoid this by preventing
@@ -625,29 +661,29 @@ def test_parallel_start_legacy(serve_instance):
def __call__(self, _):
return "Ready"
serve.create_backend(
client.create_backend(
"p:v0", LongStartingServable, config={"num_replicas": 2})
serve.create_endpoint("test-parallel", backend="p:v0")
handle = serve.get_handle("test-parallel")
client.create_endpoint("test-parallel", backend="p:v0")
handle = client.get_handle("test-parallel")
ray.get(handle.remote(), timeout=10)
def test_list_endpoints(serve_instance):
serve.init()
client = serve_instance
def f():
pass
serve.create_backend("backend", f)
serve.create_backend("backend2", f)
serve.create_backend("backend3", f)
serve.create_endpoint(
client.create_backend("backend", f)
client.create_backend("backend2", f)
client.create_backend("backend3", f)
client.create_endpoint(
"endpoint", backend="backend", route="/api", methods=["GET", "POST"])
serve.create_endpoint("endpoint2", backend="backend2", methods=["POST"])
serve.shadow_traffic("endpoint", "backend3", 0.5)
client.create_endpoint("endpoint2", backend="backend2", methods=["POST"])
client.shadow_traffic("endpoint", "backend3", 0.5)
endpoints = serve.list_endpoints()
endpoints = client.list_endpoints()
assert "endpoint" in endpoints
assert endpoints["endpoint"] == {
"route": "/api",
@@ -670,92 +706,93 @@ def test_list_endpoints(serve_instance):
"shadows": {}
}
serve.delete_endpoint("endpoint")
assert "endpoint2" in serve.list_endpoints()
client.delete_endpoint("endpoint")
assert "endpoint2" in client.list_endpoints()
serve.delete_endpoint("endpoint2")
assert len(serve.list_endpoints()) == 0
client.delete_endpoint("endpoint2")
assert len(client.list_endpoints()) == 0
def test_list_backends(serve_instance):
serve.init()
client = serve_instance
@serve.accept_batch
def f():
pass
serve.create_backend("backend", f, config=BackendConfig(max_batch_size=10))
backends = serve.list_backends()
client.create_backend(
"backend", f, config=BackendConfig(max_batch_size=10))
backends = client.list_backends()
assert len(backends) == 1
assert "backend" in backends
assert backends["backend"]["max_batch_size"] == 10
serve.create_backend("backend2", f, config=BackendConfig(num_replicas=10))
backends = serve.list_backends()
client.create_backend("backend2", f, config=BackendConfig(num_replicas=10))
backends = client.list_backends()
assert len(backends) == 2
assert backends["backend2"]["num_replicas"] == 10
serve.delete_backend("backend")
backends = serve.list_backends()
client.delete_backend("backend")
backends = client.list_backends()
assert len(backends) == 1
assert "backend2" in backends
serve.delete_backend("backend2")
assert len(serve.list_backends()) == 0
client.delete_backend("backend2")
assert len(client.list_backends()) == 0
def test_list_backends_legacy(serve_instance):
serve.init()
client = serve_instance
@serve.accept_batch
def f():
pass
serve.create_backend("backend", f, config={"max_batch_size": 10})
backends = serve.list_backends()
client.create_backend("backend", f, config={"max_batch_size": 10})
backends = client.list_backends()
assert len(backends) == 1
assert "backend" in backends
assert backends["backend"]["max_batch_size"] == 10
serve.create_backend("backend2", f, config={"num_replicas": 10})
backends = serve.list_backends()
client.create_backend("backend2", f, config={"num_replicas": 10})
backends = client.list_backends()
assert len(backends) == 2
assert backends["backend2"]["num_replicas"] == 10
serve.delete_backend("backend")
backends = serve.list_backends()
client.delete_backend("backend")
backends = client.list_backends()
assert len(backends) == 1
assert "backend2" in backends
serve.delete_backend("backend2")
assert len(serve.list_backends()) == 0
client.delete_backend("backend2")
assert len(client.list_backends()) == 0
def test_endpoint_input_validation(serve_instance):
serve.init()
client = serve_instance
def f():
pass
serve.create_backend("backend", f)
client.create_backend("backend", f)
with pytest.raises(TypeError):
serve.create_endpoint("endpoint")
client.create_endpoint("endpoint")
with pytest.raises(TypeError):
serve.create_endpoint("endpoint", route="/hello")
client.create_endpoint("endpoint", route="/hello")
with pytest.raises(TypeError):
serve.create_endpoint("endpoint", backend=2)
serve.create_endpoint("endpoint", backend="backend")
client.create_endpoint("endpoint", backend=2)
client.create_endpoint("endpoint", backend="backend")
def test_create_infeasible_error(serve_instance):
serve.init()
client = serve_instance
def f():
pass
# Non existent resource should be infeasible.
with pytest.raises(RayServeException, match="Cannot scale backend"):
serve.create_backend(
client.create_backend(
"f:1",
f,
ray_actor_options={"resources": {
@@ -765,7 +802,7 @@ def test_create_infeasible_error(serve_instance):
# Even each replica might be feasible, the total might not be.
current_cpus = int(ray.nodes()[0]["Resources"]["CPU"])
with pytest.raises(RayServeException, match="Cannot scale backend"):
serve.create_backend(
client.create_backend(
"f:1",
f,
ray_actor_options={"resources": {
@@ -774,19 +811,19 @@ def test_create_infeasible_error(serve_instance):
config=BackendConfig(num_replicas=(current_cpus + 20)))
# No replica should be created!
replicas = ray.get(serve.api.controller._list_replicas.remote("f1"))
replicas = ray.get(client._controller._list_replicas.remote("f1"))
assert len(replicas) == 0
def test_create_infeasible_error_legacy(serve_instance):
serve.init()
client = serve_instance
def f():
pass
# Non existent resource should be infeasible.
with pytest.raises(RayServeException, match="Cannot scale backend"):
serve.create_backend(
client.create_backend(
"f:1",
f,
ray_actor_options={"resources": {
@@ -796,7 +833,7 @@ def test_create_infeasible_error_legacy(serve_instance):
# Even each replica might be feasible, the total might not be.
current_cpus = int(ray.nodes()[0]["Resources"]["CPU"])
with pytest.raises(RayServeException, match="Cannot scale backend"):
serve.create_backend(
client.create_backend(
"f:1",
f,
ray_actor_options={"resources": {
@@ -805,29 +842,29 @@ def test_create_infeasible_error_legacy(serve_instance):
config={"num_replicas": current_cpus + 20})
# No replica should be created!
replicas = ray.get(serve.api.controller._list_replicas.remote("f1"))
replicas = ray.get(client._controller._list_replicas.remote("f1"))
assert len(replicas) == 0
def test_shutdown(serve_instance):
def test_shutdown():
def f():
pass
instance_name = "shutdown"
serve.init(name=instance_name, http_port=8003)
serve.create_backend("backend", f)
serve.create_endpoint("endpoint", backend="backend")
client = serve.start(http_port=8003)
client.create_backend("backend", f)
client.create_endpoint("endpoint", backend="backend")
serve.shutdown()
with pytest.raises(RayServeException, match="Please run serve.init"):
serve.list_backends()
client.shutdown()
with pytest.raises(RayServeException):
client.list_backends()
def check_dead():
for actor_name in [
constants.SERVE_CONTROLLER_NAME, constants.SERVE_PROXY_NAME
client._controller_name,
format_actor_name(SERVE_PROXY_NAME, client._controller_name)
]:
try:
ray.get_actor(format_actor_name(actor_name, instance_name))
ray.get_actor(actor_name)
return False
except ValueError:
pass
@@ -837,6 +874,8 @@ def test_shutdown(serve_instance):
def test_shadow_traffic(serve_instance):
client = serve_instance
@ray.remote
class RequestCounter:
def __init__(self):
@@ -866,15 +905,15 @@ def test_shadow_traffic(serve_instance):
ray.get(counter.record.remote("backend4"))
return "oops"
serve.create_backend("backend1", f)
serve.create_backend("backend2", f_shadow_1)
serve.create_backend("backend3", f_shadow_2)
serve.create_backend("backend4", f_shadow_3)
client.create_backend("backend1", f)
client.create_backend("backend2", f_shadow_1)
client.create_backend("backend3", f_shadow_2)
client.create_backend("backend4", f_shadow_3)
serve.create_endpoint("endpoint", backend="backend1", route="/api")
serve.shadow_traffic("endpoint", "backend2", 1.0)
serve.shadow_traffic("endpoint", "backend3", 0.5)
serve.shadow_traffic("endpoint", "backend4", 0.1)
client.create_endpoint("endpoint", backend="backend1", route="/api")
client.shadow_traffic("endpoint", "backend2", 1.0)
client.shadow_traffic("endpoint", "backend3", 0.5)
client.shadow_traffic("endpoint", "backend4", 0.1)
start = time.time()
num_requests = 100
@@ -897,13 +936,47 @@ def test_shadow_traffic(serve_instance):
wait_for_condition(check_requests)
def test_connect(serve_instance):
client = serve_instance
# Check that you can have multiple clients to the same detached instance.
client2 = serve.connect()
assert client._controller_name == client2._controller_name
# Check that you can have detached and non-detached instances.
client3 = serve.start(http_port=8004)
assert client3._controller_name != client._controller_name
# Check that you can call serve.connect() from within a backend for both
# detached and non-detached instances.
def connect_in_backend():
client = serve.connect()
client.create_backend("backend-ception", connect_in_backend)
return client._controller_name
client.create_backend("connect_in_backend", connect_in_backend)
client.create_endpoint("endpoint", backend="connect_in_backend")
handle = client.get_handle("endpoint")
assert ray.get(handle.remote()) == client._controller_name
assert "backend-ception" in client.list_backends()
client3.create_backend("connect_in_backend", connect_in_backend)
client3.create_endpoint("endpoint", backend="connect_in_backend")
handle = client3.get_handle("endpoint")
assert ray.get(handle.remote()) == client3._controller_name
assert "backend-ception" in client3.list_backends()
def test_serve_metrics(serve_instance):
client = serve_instance
@serve.accept_batch
def batcher(flask_requests):
return ["hello"] * len(flask_requests)
serve.create_backend("metrics", batcher)
serve.create_endpoint("metrics", backend="metrics", route="/metrics")
client.create_backend("metrics", batcher)
client.create_endpoint("metrics", backend="metrics", route="/metrics")
# send 10 concurrent requests
url = "http://127.0.0.1:8000/metrics"
ray.get([block_until_http_ready.remote(url) for _ in range(10)])
+11 -9
View File
@@ -19,7 +19,8 @@ pytestmark = pytest.mark.asyncio
def setup_worker(name,
func_or_class,
init_args=None,
backend_config=BackendConfig()):
backend_config=BackendConfig(),
controller_name=""):
if init_args is None:
init_args = ()
@@ -27,7 +28,8 @@ def setup_worker(name,
class WorkerActor:
def __init__(self):
self.worker = create_backend_worker(func_or_class)(
name, name + ":tag", init_args, backend_config)
name, name + ":tag", init_args, backend_config,
controller_name)
def ready(self):
pass
@@ -50,7 +52,7 @@ async def test_runner_wraps_error():
async def test_runner_actor(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("")
await q.setup.remote("", serve_instance._controller_name)
def echo(flask_request, i=None):
return i
@@ -72,7 +74,7 @@ async def test_runner_actor(serve_instance):
async def test_ray_serve_mixin(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("")
await q.setup.remote("", serve_instance._controller_name)
CONSUMER_NAME = "runner-cls"
PRODUCER_NAME = "prod-cls"
@@ -98,7 +100,7 @@ async def test_ray_serve_mixin(serve_instance):
async def test_task_runner_check_context(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("")
await q.setup.remote("", serve_instance._controller_name)
def echo(flask_request, i=None):
# Accessing the flask_request without web context should throw.
@@ -120,7 +122,7 @@ async def test_task_runner_check_context(serve_instance):
async def test_task_runner_custom_method_single(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("")
await q.setup.remote("", serve_instance._controller_name)
class NonBatcher:
def a(self, _):
@@ -155,7 +157,7 @@ async def test_task_runner_custom_method_single(serve_instance):
async def test_task_runner_custom_method_batch(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("")
await q.setup.remote("", serve_instance._controller_name)
@serve.accept_batch
class Batcher:
@@ -220,7 +222,7 @@ async def test_task_runner_custom_method_batch(serve_instance):
async def test_task_runner_perform_batch(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("")
await q.setup.remote("", serve_instance._controller_name)
def batcher(*args, **kwargs):
return [serve.context.batch_size] * serve.context.batch_size
@@ -250,7 +252,7 @@ async def test_task_runner_perform_batch(serve_instance):
async def test_task_runner_perform_async(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("")
await q.setup.remote("", serve_instance._controller_name)
@ray.remote
class Barrier:
+37 -38
View File
@@ -21,13 +21,13 @@ def request_with_retries(endpoint, timeout=30):
def test_controller_failure(serve_instance):
serve.init()
client = serve_instance
def function():
return "hello1"
serve.create_backend("controller_failure:v1", function)
serve.create_endpoint(
client.create_backend("controller_failure:v1", function)
client.create_endpoint(
"controller_failure",
backend="controller_failure:v1",
route="/controller_failure")
@@ -39,7 +39,7 @@ def test_controller_failure(serve_instance):
response = request_with_retries("/controller_failure", timeout=30)
assert response.text == "hello1"
ray.kill(serve.api._get_controller(), no_restart=False)
ray.kill(client._controller, no_restart=False)
for _ in range(10):
response = request_with_retries("/controller_failure", timeout=30)
@@ -48,10 +48,10 @@ def test_controller_failure(serve_instance):
def function():
return "hello2"
ray.kill(serve.api._get_controller(), no_restart=False)
ray.kill(client._controller, no_restart=False)
serve.create_backend("controller_failure:v2", function)
serve.set_traffic("controller_failure", {"controller_failure:v2": 1.0})
client.create_backend("controller_failure:v2", function)
client.set_traffic("controller_failure", {"controller_failure:v2": 1.0})
for _ in range(10):
response = request_with_retries("/controller_failure", timeout=30)
@@ -60,14 +60,14 @@ def test_controller_failure(serve_instance):
def function():
return "hello3"
ray.kill(serve.api._get_controller(), no_restart=False)
serve.create_backend("controller_failure_2", function)
ray.kill(serve.api._get_controller(), no_restart=False)
serve.create_endpoint(
ray.kill(client._controller, no_restart=False)
client.create_backend("controller_failure_2", function)
ray.kill(client._controller, no_restart=False)
client.create_endpoint(
"controller_failure_2",
backend="controller_failure_2",
route="/controller_failure_2")
ray.kill(serve.api._get_controller(), no_restart=False)
ray.kill(client._controller, no_restart=False)
for _ in range(10):
response = request_with_retries("/controller_failure", timeout=30)
@@ -76,20 +76,20 @@ def test_controller_failure(serve_instance):
assert response.text == "hello3"
def _kill_routers():
routers = ray.get(serve.api._get_controller().get_routers.remote())
def _kill_routers(client):
routers = ray.get(client._controller.get_routers.remote())
for router in routers.values():
ray.kill(router, no_restart=False)
def test_http_proxy_failure(serve_instance):
serve.init()
client = serve_instance
def function():
return "hello1"
serve.create_backend("proxy_failure:v1", function)
serve.create_endpoint(
client.create_backend("proxy_failure:v1", function)
client.create_endpoint(
"proxy_failure", backend="proxy_failure:v1", route="/proxy_failure")
assert request_with_retries("/proxy_failure", timeout=1.0).text == "hello1"
@@ -98,21 +98,21 @@ def test_http_proxy_failure(serve_instance):
response = request_with_retries("/proxy_failure", timeout=30)
assert response.text == "hello1"
_kill_routers()
_kill_routers(client)
def function():
return "hello2"
serve.create_backend("proxy_failure:v2", function)
serve.set_traffic("proxy_failure", {"proxy_failure:v2": 1.0})
client.create_backend("proxy_failure:v2", function)
client.set_traffic("proxy_failure", {"proxy_failure:v2": 1.0})
for _ in range(10):
response = request_with_retries("/proxy_failure", timeout=30)
assert response.text == "hello2"
def _get_worker_handles(backend):
controller = serve.api._get_controller()
def _get_worker_handles(client, backend):
controller = client._controller
backend_dict = ray.get(controller.get_all_worker_handles.remote())
return list(backend_dict[backend].values())
@@ -121,21 +121,21 @@ def _get_worker_handles(backend):
# Test that a worker dying unexpectedly causes it to restart and continue
# serving requests.
def test_worker_restart(serve_instance):
serve.init()
client = serve_instance
class Worker1:
def __call__(self):
return os.getpid()
serve.create_backend("worker_failure:v1", Worker1)
serve.create_endpoint(
client.create_backend("worker_failure:v1", Worker1)
client.create_endpoint(
"worker_failure", backend="worker_failure:v1", route="/worker_failure")
# Get the PID of the worker.
old_pid = request_with_retries("/worker_failure", timeout=1).text
# Kill the worker.
handles = _get_worker_handles("worker_failure:v1")
handles = _get_worker_handles(client, "worker_failure:v1")
assert len(handles) == 1
ray.kill(handles[0], no_restart=False)
@@ -152,8 +152,7 @@ def test_worker_restart(serve_instance):
# Test that if there are multiple replicas for a worker and one dies
# unexpectedly, the others continue to serve requests.
def test_worker_replica_failure(serve_instance):
serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0
serve.init()
client = serve_instance
class Worker:
# Assumes that two replicas are started. Will hang forever in the
@@ -182,10 +181,10 @@ def test_worker_replica_failure(serve_instance):
temp_path = os.path.join(tempfile.gettempdir(),
serve.utils.get_random_letters())
serve.create_backend("replica_failure", Worker, temp_path)
serve.update_backend_config(
client.create_backend("replica_failure", Worker, temp_path)
client.update_backend_config(
"replica_failure", BackendConfig(num_replicas=2))
serve.create_endpoint(
client.create_endpoint(
"replica_failure", backend="replica_failure", route="/replica_failure")
# Wait until both replicas have been started.
@@ -195,7 +194,7 @@ def test_worker_replica_failure(serve_instance):
time.sleep(0.1)
# Kill one of the replicas.
handles = _get_worker_handles("replica_failure")
handles = _get_worker_handles(client, "replica_failure")
assert len(handles) == 2
ray.kill(handles[0], no_restart=False)
@@ -212,12 +211,12 @@ def test_worker_replica_failure(serve_instance):
def test_create_backend_idempotent(serve_instance):
serve.init()
client = serve_instance
def f():
return "hello"
controller = serve.api._get_controller()
controller = client._controller
replica_config = ReplicaConfig(f)
backend_config = BackendConfig(num_replicas=1)
@@ -228,21 +227,21 @@ def test_create_backend_idempotent(serve_instance):
replica_config))
assert len(ray.get(controller.get_all_backends.remote())) == 1
serve.create_endpoint(
client.create_endpoint(
"my_endpoint", backend="my_backend", route="/my_route")
assert requests.get("http://127.0.0.1:8000/my_route").text == "hello"
def test_create_endpoint_idempotent(serve_instance):
serve.init()
client = serve_instance
def f():
return "hello"
serve.create_backend("my_backend", f)
client.create_backend("my_backend", f)
controller = serve.api._get_controller()
controller = client._controller
for i in range(10):
ray.get(
+7 -6
View File
@@ -5,7 +5,7 @@ import requests
def test_handle_in_endpoint(serve_instance):
serve.init()
client = serve_instance
class Endpoint1:
def __call__(self, flask_request):
@@ -13,20 +13,21 @@ def test_handle_in_endpoint(serve_instance):
class Endpoint2:
def __init__(self):
self.handle = serve.get_handle("endpoint1", missing_ok=True)
client = serve.connect()
self.handle = client.get_handle("endpoint1")
def __call__(self):
return ray.get(self.handle.remote())
serve.create_backend("endpoint1:v0", Endpoint1)
serve.create_endpoint(
client.create_backend("endpoint1:v0", Endpoint1)
client.create_endpoint(
"endpoint1",
backend="endpoint1:v0",
route="/endpoint1",
methods=["GET", "POST"])
serve.create_backend("endpoint2:v0", Endpoint2)
serve.create_endpoint(
client.create_backend("endpoint2:v0", Endpoint2)
client.create_endpoint(
"endpoint2",
backend="endpoint2:v0",
route="/endpoint2",
@@ -1,23 +0,0 @@
import requests
import sys
from ray import serve
def test_nonblocking():
serve.init()
def function(flask_request):
return {"method": flask_request.method}
serve.create_backend("nonblocking:v1", function)
serve.create_endpoint(
"nonblocking", backend="nonblocking:v1", route="/nonblocking")
resp = requests.get("http://127.0.0.1:8000/nonblocking").json()["method"]
assert resp == "GET"
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", "-s", __file__]))
+6 -5
View File
@@ -1,25 +1,26 @@
import ray
import ray.test_utils
from ray import serve
def test_new_driver(serve_instance):
client = serve_instance
script = """
import ray
ray.init(address="{}")
from ray import serve
serve.init()
client = serve.connect()
def driver(flask_request):
return "OK!"
serve.create_backend("driver", driver)
serve.create_endpoint("driver", backend="driver", route="/driver")
client.create_backend("driver", driver)
client.create_endpoint("driver", backend="driver", route="/driver")
""".format(ray.worker._global_node._redis_address)
ray.test_utils.run_string_as_driver(script)
handle = serve.get_handle("driver")
handle = client.get_handle("driver")
assert ray.get(handle.remote()) == "OK!"
+8 -5
View File
@@ -5,6 +5,8 @@ from ray import serve
def test_np_in_composed_model(serve_instance):
client = serve_instance
# https://github.com/ray-project/ray/issues/9441
# AttributeError: 'bytes' object has no attribute 'readonly'
# in cloudpickle _from_numpy_buffer
@@ -14,17 +16,18 @@ def test_np_in_composed_model(serve_instance):
class ComposedModel:
def __init__(self):
self.model = serve.get_handle("sum_model")
client = serve.connect()
self.model = client.get_handle("sum_model")
async def __call__(self, _request):
data = np.ones((10, 10))
result = await self.model.remote(data=data)
return result
serve.create_backend("sum_model", sum_model)
serve.create_endpoint("sum_model", backend="sum_model")
serve.create_backend("model", ComposedModel)
serve.create_endpoint(
client.create_backend("sum_model", sum_model)
client.create_endpoint("sum_model", backend="sum_model")
client.create_backend("model", ComposedModel)
client.create_endpoint(
"model", backend="model", route="/model", methods=["GET"])
result = requests.get("http://127.0.0.1:8000/model")
+6 -6
View File
@@ -49,7 +49,7 @@ def task_runner_mock_actor():
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote("")
await q.setup.remote("", serve_instance._controller_name)
q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0}))
q.add_new_worker.remote("backend-single-prod", "replica-1",
@@ -67,7 +67,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
async def test_alter_backend(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote("")
await q.setup.remote("", serve_instance._controller_name)
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1}))
await q.add_new_worker.remote("backend-alter", "replica-1",
@@ -86,7 +86,7 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote("")
await q.setup.remote("", serve_instance._controller_name)
await q.set_traffic.remote(
"svc", TrafficPolicy({
@@ -116,7 +116,7 @@ async def test_queue_remove_replicas(serve_instance):
temp_actor = mock_task_runner()
q = ray.remote(TestRouter).remote()
await q.setup.remote("")
await q.setup.remote("", serve_instance._controller_name)
await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor)
await q.remove_worker.remote("backend-remove", "replica-1")
assert ray.get(q.worker_queue_size.remote("backend")) == 0
@@ -124,7 +124,7 @@ async def test_queue_remove_replicas(serve_instance):
async def test_shard_key(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote("")
await q.setup.remote("", serve_instance._controller_name)
num_backends = 5
traffic_dict = {}
@@ -179,7 +179,7 @@ async def test_router_use_max_concurrency(serve_instance):
worker = MockWorker.remote()
q = ray.remote(VisibleRouter).remote()
await q.setup.remote("")
await q.setup.remote("", serve_instance._controller_name)
backend_name = "max-concurrent-test"
config = BackendConfig(max_concurrent_queries=1)
await q.set_traffic.remote("svc", TrafficPolicy({backend_name: 1.0}))
+26 -11
View File
@@ -12,7 +12,8 @@ import ray
from ray import serve
from ray.cluster_utils import Cluster
from ray.serve.constants import SERVE_PROXY_NAME
from ray.serve.utils import block_until_http_ready
from ray.serve.utils import (block_until_http_ready, get_all_node_ids,
format_actor_name)
from ray.test_utils import wait_for_condition
from ray.services import new_port
@@ -29,16 +30,24 @@ def test_multiple_routers():
ray.init(head_node.address)
node_ids = ray.state.node_ids()
assert len(node_ids) == 2
serve.init(http_port=8005)
client = serve.start(http_port=8005) # noqa: F841
def actor_name(index):
return SERVE_PROXY_NAME + "-{}-{}".format(node_ids[0], index)
def get_proxy_names():
proxy_names = []
for node_id, _ in get_all_node_ids():
proxy_names.append(
format_actor_name(SERVE_PROXY_NAME, client._controller_name,
node_id))
return proxy_names
wait_for_condition(lambda: len(get_proxy_names()) == 2)
proxy_names = get_proxy_names()
# Two actors should be started.
def get_first_two_actors():
try:
ray.get_actor(actor_name(0))
ray.get_actor(actor_name(1))
ray.get_actor(proxy_names[0])
ray.get_actor(proxy_names[1])
return True
except ValueError:
return False
@@ -49,18 +58,22 @@ def test_multiple_routers():
ray.get(block_until_http_ready.remote("http://127.0.0.1:8005/-/routes"))
# Kill one of the servers, the HTTP server should still function.
ray.kill(ray.get_actor(actor_name(0)), no_restart=True)
ray.kill(ray.get_actor(get_proxy_names()[0]), no_restart=True)
ray.get(block_until_http_ready.remote("http://127.0.0.1:8005/-/routes"))
# Add a new node to the cluster. This should trigger a new router to get
# started.
new_node = cluster.add_node()
wait_for_condition(lambda: len(get_proxy_names()) == 3)
third_proxy = get_proxy_names()[2]
def get_third_actor():
try:
ray.get_actor(actor_name(2))
ray.get_actor(third_proxy)
return True
except ValueError:
# IndexErrors covers when cluster resources aren't updated yet.
except (IndexError, ValueError):
return False
wait_for_condition(get_third_actor)
@@ -71,7 +84,7 @@ def test_multiple_routers():
def third_actor_removed():
try:
ray.get_actor(actor_name(2))
ray.get_actor(third_proxy)
return False
except ValueError:
return True
@@ -90,7 +103,7 @@ def test_middleware():
from starlette.middleware.cors import CORSMiddleware
port = new_port()
serve.init(
serve.start(
http_port=port,
http_middlewares=[
Middleware(
@@ -112,6 +125,8 @@ def test_middleware():
resp = requests.get(f"{root}/-/routes", headers=headers)
assert resp.headers["access-control-allow-origin"] == "*"
ray.shutdown()
if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))
+3 -3
View File
@@ -105,11 +105,11 @@ def get_random_letters(length=6):
return "".join(random.choices(string.ascii_letters, k=length))
def format_actor_name(actor_name, instance_name=None, *modifiers):
if instance_name is None:
def format_actor_name(actor_name, controller_name=None, *modifiers):
if controller_name is None:
name = actor_name
else:
name = "{}:{}".format(instance_name, actor_name)
name = "{}:{}".format(controller_name, actor_name)
for modifier in modifiers:
name += "-{}".format(modifier)