mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 09:44:50 +08:00
[Serve] Reimplement BackendConfig as pydantic model (#10389)
This commit is contained in:
@@ -3,20 +3,10 @@ from ray.serve.api import (init, create_backend, delete_backend,
|
||||
shadow_traffic, get_handle, update_backend_config,
|
||||
get_backend_config, accept_batch, list_backends,
|
||||
list_endpoints, shutdown) # noqa: E402
|
||||
|
||||
from ray.serve.config import BackendConfig
|
||||
__all__ = [
|
||||
"init",
|
||||
"create_backend",
|
||||
"delete_backend",
|
||||
"create_endpoint",
|
||||
"delete_endpoint",
|
||||
"set_traffic",
|
||||
"shadow_traffic",
|
||||
"get_handle",
|
||||
"update_backend_config",
|
||||
"get_backend_config",
|
||||
"accept_batch",
|
||||
"list_backends",
|
||||
"list_endpoints",
|
||||
"shutdown",
|
||||
"init", "create_backend", "delete_backend", "create_endpoint",
|
||||
"delete_endpoint", "set_traffic", "shadow_traffic", "get_handle",
|
||||
"update_backend_config", "get_backend_config", "accept_batch",
|
||||
"list_backends", "list_endpoints", "shutdown", "BackendConfig"
|
||||
]
|
||||
|
||||
+33
-21
@@ -7,7 +7,7 @@ from ray.serve.controller import ServeController
|
||||
from ray.serve.handle import RayServeHandle
|
||||
from ray.serve.utils import (block_until_http_ready, format_actor_name)
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
|
||||
from ray.actor import ActorHandle
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
@@ -201,16 +201,18 @@ def list_endpoints() -> Dict[str, Dict[str, Any]]:
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def update_backend_config(backend_tag: str,
|
||||
config_options: Dict[str, Any]) -> None:
|
||||
def update_backend_config(
|
||||
backend_tag: str,
|
||||
config_options: Union[BackendConfig, Dict[str, Any]]) -> None:
|
||||
"""Update a backend configuration for a backend tag.
|
||||
|
||||
Keys not specified in the passed will be left unchanged.
|
||||
|
||||
Args:
|
||||
backend_tag(str): A registered backend.
|
||||
config_options(dict): Backend config options to update.
|
||||
Supported options:
|
||||
config_options(dict, serve.BackendConfig): Backend config options to
|
||||
update. Either a BackendConfig object or a dict mapping strings to
|
||||
values for the following supported options:
|
||||
- "num_replicas": number of worker processes to start up that
|
||||
will handle requests to this backend.
|
||||
- "max_batch_size": the maximum number of requests that will
|
||||
@@ -222,14 +224,16 @@ def update_backend_config(backend_tag: str,
|
||||
that will be sent to a replica of this backend
|
||||
without receiving a response.
|
||||
"""
|
||||
if not isinstance(config_options, dict):
|
||||
raise ValueError("config_options must be a dictionary.")
|
||||
|
||||
if not isinstance(config_options, (BackendConfig, dict)):
|
||||
raise TypeError(
|
||||
"config_options must be a BackendConfig or dictionary.")
|
||||
ray.get(
|
||||
controller.update_backend_config.remote(backend_tag, config_options))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def get_backend_config(backend_tag: str):
|
||||
def get_backend_config(backend_tag: str) -> BackendConfig:
|
||||
"""Get the backend configuration for a backend tag.
|
||||
|
||||
Args:
|
||||
@@ -239,11 +243,12 @@ def get_backend_config(backend_tag: str):
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def create_backend(backend_tag: str,
|
||||
func_or_class: Union[Callable, Type[Callable]],
|
||||
*actor_init_args: Any,
|
||||
ray_actor_options: Optional[Dict] = None,
|
||||
config: Optional[Dict[str, Any]] = None) -> None:
|
||||
def create_backend(
|
||||
backend_tag: str,
|
||||
func_or_class: Union[Callable, Type[Callable]],
|
||||
*actor_init_args: Any,
|
||||
ray_actor_options: Optional[Dict] = None,
|
||||
config: Optional[Union[BackendConfig, Dict[str, Any]]] = None) -> None:
|
||||
"""Create a backend with the provided tag.
|
||||
|
||||
The backend will serve requests with func_or_class.
|
||||
@@ -256,8 +261,9 @@ def create_backend(backend_tag: str,
|
||||
initialization method.
|
||||
ray_actor_options (optional): options to be passed into the
|
||||
@ray.remote decorator for the backend actor.
|
||||
config (optional): configuration options for this backend.
|
||||
Supported options:
|
||||
config (dict, serve.BackendConfig, optional): configuration options
|
||||
for this backend. Either a BackendConfig, or a dictionary mapping
|
||||
strings to values for the following supported options:
|
||||
- "num_replicas": number of worker processes to start up that will
|
||||
handle requests to this backend.
|
||||
- "max_batch_size": the maximum number of requests that will
|
||||
@@ -276,14 +282,20 @@ def create_backend(backend_tag: str,
|
||||
|
||||
if config is None:
|
||||
config = {}
|
||||
if not isinstance(config, dict):
|
||||
raise TypeError("config must be a dictionary.")
|
||||
|
||||
replica_config = ReplicaConfig(
|
||||
func_or_class, *actor_init_args, ray_actor_options=ray_actor_options)
|
||||
backend_config = BackendConfig(config, replica_config.accepts_batches,
|
||||
replica_config.is_blocking)
|
||||
|
||||
metadata = BackendMetadata(
|
||||
accepts_batches=replica_config.accepts_batches,
|
||||
is_blocking=replica_config.is_blocking)
|
||||
if isinstance(config, dict):
|
||||
backend_config = BackendConfig.parse_obj({
|
||||
**config, "internal_metadata": metadata
|
||||
})
|
||||
elif isinstance(config, BackendConfig):
|
||||
backend_config = config.copy(update={"internal_metadata": metadata})
|
||||
else:
|
||||
raise TypeError("config must be a BackendConfig or a dictionary.")
|
||||
backend_config._validate_complete()
|
||||
ray.get(
|
||||
controller.create_backend.remote(backend_tag, backend_config,
|
||||
replica_config))
|
||||
|
||||
@@ -318,7 +318,7 @@ class RayServeWorker:
|
||||
|
||||
all_evaluated_futures = []
|
||||
|
||||
if not self.config.accepts_batches:
|
||||
if not self.config.internal_metadata.accepts_batches:
|
||||
query = batch[0]
|
||||
evaluated = asyncio.ensure_future(self.invoke_single(query))
|
||||
all_evaluated_futures = [evaluated]
|
||||
@@ -336,7 +336,7 @@ class RayServeWorker:
|
||||
chain_future(
|
||||
unpack_future(evaluated, len(group)), result_futures)
|
||||
|
||||
if self.config.is_blocking:
|
||||
if self.config.internal_metadata.is_blocking:
|
||||
# We use asyncio.wait here so if the result is exception,
|
||||
# it will not be raised.
|
||||
await asyncio.wait(all_evaluated_futures)
|
||||
|
||||
+72
-66
@@ -1,6 +1,9 @@
|
||||
import inspect
|
||||
|
||||
from pydantic import BaseModel, PositiveInt, validator
|
||||
from ray.serve.constants import ASYNC_CONCURRENCY
|
||||
from typing import Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def _callable_accepts_batch(func_or_class):
|
||||
@@ -17,86 +20,89 @@ def _callable_is_blocking(func_or_class):
|
||||
return not inspect.iscoroutinefunction(func_or_class.__call__)
|
||||
|
||||
|
||||
class BackendConfig:
|
||||
def __init__(self, config_dict, accepts_batches=False, is_blocking=True):
|
||||
assert isinstance(config_dict, dict)
|
||||
# Make a copy so that we don't modify the input dict.
|
||||
config_dict = config_dict.copy()
|
||||
@dataclass
|
||||
class BackendMetadata:
|
||||
accepts_batches: bool = False
|
||||
is_blocking: bool = True
|
||||
autoscaling_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
self.accepts_batches = accepts_batches
|
||||
self.is_blocking = is_blocking
|
||||
self.num_replicas = config_dict.pop("num_replicas", 1)
|
||||
self.max_batch_size = config_dict.pop("max_batch_size", None)
|
||||
self.batch_wait_timeout = config_dict.pop("batch_wait_timeout", 0)
|
||||
self.max_concurrent_queries = config_dict.pop("max_concurrent_queries",
|
||||
None)
|
||||
self.autoscaling_config = config_dict.pop("autoscaling", None)
|
||||
|
||||
if self.max_concurrent_queries is None:
|
||||
class BackendConfig(BaseModel):
|
||||
"""Configuration options for a backend, to be set by the user.
|
||||
|
||||
:param num_replicas: The number of worker processes to start up that will
|
||||
handle requests to this backend. Defaults to 0.
|
||||
:type num_replicas: int, optional
|
||||
:param max_batch_size: The maximum number of requests that will be
|
||||
processed in one batch by this backend. Defaults to None (no
|
||||
maximium).
|
||||
:type max_batch_size: int, optional
|
||||
:param batch_wait_timeout: The time in seconds that backend replicas will
|
||||
wait for a full batch of requests before processing a partial batch.
|
||||
Defaults to 0.
|
||||
:type batch_wait_timeout: float, optional
|
||||
:param max_concurrent_queries: The maximum number of queries that will be
|
||||
sent to a replica of this backend without receiving a response.
|
||||
Defaults to None (no maximum).
|
||||
:type max_concurrent_queries: int, optional
|
||||
"""
|
||||
|
||||
internal_metadata: BackendMetadata = BackendMetadata()
|
||||
num_replicas: PositiveInt = 1
|
||||
max_batch_size: Optional[PositiveInt] = None
|
||||
batch_wait_timeout: float = 0
|
||||
max_concurrent_queries: Optional[int] = None
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _validate_batch_size(self):
|
||||
if (self.max_batch_size is not None
|
||||
and not self.internal_metadata.accepts_batches
|
||||
and self.max_batch_size > 1):
|
||||
raise ValueError(
|
||||
"max_batch_size is set in config but the function or "
|
||||
"method does not accept batching. Please use "
|
||||
"@serve.accept_batch to explicitly mark that the function or "
|
||||
"method accepts a list of requests as an argument.")
|
||||
|
||||
# This is not a pydantic validator, so that we may skip this method when
|
||||
# creating partially filled BackendConfig objects to pass as updates--for
|
||||
# example, BackendConfig(max_batch_size=5).
|
||||
def _validate_complete(self):
|
||||
self._validate_batch_size()
|
||||
|
||||
# Dynamic default for max_concurrent_queries
|
||||
@validator("max_concurrent_queries", always=True)
|
||||
def set_max_queries_by_mode(cls, v, values):
|
||||
if v is None:
|
||||
# Model serving mode: if the servable is blocking and the wait
|
||||
# timeout is default zero seconds, then we keep the existing
|
||||
# behavior to allow at most max batch size queries.
|
||||
if self.is_blocking and self.batch_wait_timeout == 0:
|
||||
if self.max_batch_size:
|
||||
self.max_concurrent_queries = 2 * self.max_batch_size
|
||||
if (values["internal_metadata"].is_blocking
|
||||
and values["batch_wait_timeout"] == 0):
|
||||
if ("max_batch_size" in values
|
||||
and values["max_batch_size"] is not None):
|
||||
v = 2 * values["max_batch_size"]
|
||||
else:
|
||||
self.max_concurrent_queries = 8
|
||||
v = 8
|
||||
|
||||
# Pipeline/async mode: if the servable is not blocking,
|
||||
# router should just keep pushing queries to the worker
|
||||
# replicas until a high limit.
|
||||
if not self.is_blocking:
|
||||
self.max_concurrent_queries = ASYNC_CONCURRENCY
|
||||
if not values["internal_metadata"].is_blocking:
|
||||
v = ASYNC_CONCURRENCY
|
||||
|
||||
# Batch inference mode: user specifies non zero timeout to wait for
|
||||
# full batch. We will use 2*max_batch_size to perform double
|
||||
# buffering to keep the replica busy.
|
||||
if self.max_batch_size is not None and self.batch_wait_timeout > 0:
|
||||
self.max_concurrent_queries = 2 * self.max_batch_size
|
||||
|
||||
if len(config_dict) != 0:
|
||||
raise ValueError("Unknown options in backend config: {}".format(
|
||||
list(config_dict.keys())))
|
||||
|
||||
self._validate()
|
||||
|
||||
def update(self, config_dict):
|
||||
"""Updates this BackendConfig with options set in the passed config.
|
||||
|
||||
Unspecified keys will remain the same.
|
||||
"""
|
||||
if "num_replicas" in config_dict:
|
||||
self.num_replicas = config_dict.pop("num_replicas")
|
||||
if "max_batch_size" in config_dict:
|
||||
self.max_batch_size = config_dict.pop("max_batch_size")
|
||||
if "max_concurrent_queries" in config_dict:
|
||||
self.max_concurrent_queries = config_dict.pop(
|
||||
"max_concurrent_queries")
|
||||
|
||||
if len(config_dict) != 0:
|
||||
raise ValueError("Unknown options in backend config: {}".format(
|
||||
list(config_dict.keys())))
|
||||
|
||||
self._validate()
|
||||
|
||||
def _validate(self):
|
||||
if not isinstance(self.num_replicas, int):
|
||||
raise TypeError("num_replicas must be an int.")
|
||||
elif self.num_replicas < 1:
|
||||
raise ValueError("num_replicas must be >= 1.")
|
||||
|
||||
if self.max_batch_size is not None:
|
||||
if not isinstance(self.max_batch_size, int):
|
||||
raise TypeError("max_batch_size must be an integer.")
|
||||
elif self.max_batch_size < 1:
|
||||
raise ValueError("max_batch_size must be >= 1.")
|
||||
|
||||
if not self.accepts_batches and self.max_batch_size > 1:
|
||||
raise ValueError(
|
||||
"max_batch_size is set in config but the function or "
|
||||
"method does not accept batching. Please use "
|
||||
"@serve.accept_batch to explicitly mark the function or "
|
||||
"method as batchable and takes in list as arguments.")
|
||||
if ("max_batch_size" in values
|
||||
and values["max_batch_size"] is not None
|
||||
and values["batch_wait_timeout"] > 0):
|
||||
v = 2 * values["max_batch_size"]
|
||||
return v
|
||||
|
||||
|
||||
class ReplicaConfig:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
from collections import defaultdict, namedtuple
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Union, Dict, Any, List, Tuple
|
||||
from pydantic import BaseModel
|
||||
|
||||
import ray
|
||||
import ray.cloudpickle as pickle
|
||||
@@ -16,7 +18,6 @@ from ray.serve.utils import (format_actor_name, get_random_letters, logger,
|
||||
try_schedule_resources_on_nodes, get_all_node_ids)
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig
|
||||
from ray.actor import ActorHandle
|
||||
from typing import Dict, List, Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -62,8 +63,17 @@ class TrafficPolicy:
|
||||
self.shadow_dict[backend] = proportion
|
||||
|
||||
|
||||
BackendInfo = namedtuple("BackendInfo",
|
||||
["worker_class", "backend_config", "replica_config"])
|
||||
class BackendInfo(BaseModel):
|
||||
# TODO(architkulkarni): Add type hint for worker_class after upgrading
|
||||
# cloudpickle and adding types to RayServeWrappedWorker
|
||||
worker_class: Any
|
||||
backend_config: BackendConfig
|
||||
replica_config: ReplicaConfig
|
||||
|
||||
class Config:
|
||||
# TODO(architkulkarni): Remove once ReplicaConfig is a pydantic
|
||||
# model
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -313,9 +323,10 @@ class ServeController:
|
||||
for router in self.routers.values()
|
||||
])
|
||||
await self.broadcast_backend_config(backend)
|
||||
if info.backend_config.autoscaling_config is not None:
|
||||
metadata = info.backend_config.internal_metadata
|
||||
if metadata.autoscaling_config is not None:
|
||||
self.autoscaling_policies[backend] = BasicAutoscalingPolicy(
|
||||
backend, info.backend_config.autoscaling_config)
|
||||
backend, metadata.autoscaling_config)
|
||||
|
||||
# Push configuration state to the routers.
|
||||
await asyncio.gather(*[
|
||||
@@ -753,11 +764,14 @@ class ServeController:
|
||||
# Save creator that starts replicas, the arguments to be passed in,
|
||||
# and the configuration for the backends.
|
||||
self.backends[backend_tag] = BackendInfo(
|
||||
backend_worker, backend_config, replica_config)
|
||||
if backend_config.autoscaling_config is not None:
|
||||
worker_class=backend_worker,
|
||||
backend_config=backend_config,
|
||||
replica_config=replica_config)
|
||||
metadata = backend_config.internal_metadata
|
||||
if metadata.autoscaling_config is not None:
|
||||
self.autoscaling_policies[
|
||||
backend_tag] = BasicAutoscalingPolicy(
|
||||
backend_tag, backend_config.autoscaling_config)
|
||||
backend_tag, metadata.autoscaling_config)
|
||||
|
||||
try:
|
||||
self._scale_replicas(backend_tag, backend_config.num_replicas)
|
||||
@@ -814,16 +828,25 @@ class ServeController:
|
||||
await self._stop_pending_replicas()
|
||||
await self._remove_pending_backends()
|
||||
|
||||
async def update_backend_config(self, backend_tag: str,
|
||||
config_options: Dict[str, Any]) -> None:
|
||||
async def update_backend_config(
|
||||
self, backend_tag: str,
|
||||
config_options: "Union[BackendConfig, Dict[str, Any]]") -> None:
|
||||
"""Set the config for the specified backend."""
|
||||
async with self.write_lock:
|
||||
assert (backend_tag in self.backends
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert isinstance(config_options, dict)
|
||||
assert isinstance(config_options, BackendConfig) or isinstance(
|
||||
config_options, dict)
|
||||
|
||||
self.backends[backend_tag].backend_config.update(config_options)
|
||||
backend_config = self.backends[backend_tag].backend_config
|
||||
if isinstance(config_options, BackendConfig):
|
||||
update_data = config_options.dict(exclude_unset=True)
|
||||
elif isinstance(config_options, dict):
|
||||
update_data = config_options
|
||||
|
||||
stored_backend_config = self.backends[backend_tag].backend_config
|
||||
backend_config = stored_backend_config.copy(update=update_data)
|
||||
backend_config._validate_complete()
|
||||
self.backends[backend_tag].backend_config = backend_config
|
||||
|
||||
# Scale the replicas with the new configuration.
|
||||
self._scale_replicas(backend_tag, backend_config.num_replicas)
|
||||
|
||||
@@ -11,6 +11,7 @@ from ray.test_utils import wait_for_condition
|
||||
from ray.serve import constants
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.utils import format_actor_name, get_random_letters
|
||||
from ray.serve.config import BackendConfig
|
||||
|
||||
|
||||
def test_e2e(serve_instance):
|
||||
@@ -150,6 +151,43 @@ def test_scaling_replicas(serve_instance):
|
||||
self.count += 1
|
||||
return self.count
|
||||
|
||||
serve.create_backend(
|
||||
"counter:v1", Counter, config=BackendConfig(num_replicas=2))
|
||||
serve.create_endpoint("counter", backend="counter:v1", route="/increment")
|
||||
|
||||
# Keep checking the routing table until /increment is populated
|
||||
while "/increment" not in requests.get(
|
||||
"http://127.0.0.1:8000/-/routes").json():
|
||||
time.sleep(0.2)
|
||||
|
||||
counter_result = []
|
||||
for _ in range(10):
|
||||
resp = requests.get("http://127.0.0.1:8000/increment").json()
|
||||
counter_result.append(resp)
|
||||
|
||||
# If the load is shared among two replicas. The max result cannot be 10.
|
||||
assert max(counter_result) < 10
|
||||
|
||||
serve.update_backend_config("counter:v1", {"num_replicas": 1})
|
||||
|
||||
counter_result = []
|
||||
for _ in range(10):
|
||||
resp = requests.get("http://127.0.0.1:8000/increment").json()
|
||||
counter_result.append(resp)
|
||||
# Give some time for a replica to spin down. But majority of the request
|
||||
# should be served by the only remaining replica.
|
||||
assert max(counter_result) - min(counter_result) > 6
|
||||
|
||||
|
||||
def test_scaling_replicas_legacy(serve_instance):
|
||||
class Counter:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
def __call__(self, _):
|
||||
self.count += 1
|
||||
return self.count
|
||||
|
||||
serve.create_backend("counter:v1", Counter, config={"num_replicas": 2})
|
||||
serve.create_endpoint("counter", backend="counter:v1", route="/increment")
|
||||
|
||||
@@ -188,6 +226,43 @@ def test_batching(serve_instance):
|
||||
batch_size = serve.context.batch_size
|
||||
return [self.count] * batch_size
|
||||
|
||||
# set the max batch size
|
||||
serve.create_backend(
|
||||
"counter:v11",
|
||||
BatchingExample,
|
||||
config=BackendConfig(max_batch_size=5, batch_wait_timeout=1))
|
||||
serve.create_endpoint(
|
||||
"counter1", backend="counter:v11", route="/increment2")
|
||||
|
||||
# Keep checking the routing table until /increment is populated
|
||||
while "/increment2" not in requests.get(
|
||||
"http://127.0.0.1:8000/-/routes").json():
|
||||
time.sleep(0.2)
|
||||
|
||||
future_list = []
|
||||
handle = serve.get_handle("counter1")
|
||||
for _ in range(20):
|
||||
f = handle.remote(temp=1)
|
||||
future_list.append(f)
|
||||
|
||||
counter_result = ray.get(future_list)
|
||||
# since count is only updated per batch of queries
|
||||
# If there atleast one __call__ fn call with batch size greater than 1
|
||||
# counter result will always be less than 20
|
||||
assert max(counter_result) < 20
|
||||
|
||||
|
||||
def test_batching_legacy(serve_instance):
|
||||
class BatchingExample:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
@serve.accept_batch
|
||||
def __call__(self, flask_request, temp=None):
|
||||
self.count += 1
|
||||
batch_size = serve.context.batch_size
|
||||
return [self.count] * batch_size
|
||||
|
||||
# set the max batch size
|
||||
serve.create_backend(
|
||||
"counter:v11",
|
||||
@@ -227,6 +302,27 @@ def test_batching_exception(serve_instance):
|
||||
batch_size = serve.context.batch_size
|
||||
return batch_size
|
||||
|
||||
# set the max batch size
|
||||
serve.create_backend(
|
||||
"exception:v1", NoListReturned, config=BackendConfig(max_batch_size=5))
|
||||
serve.create_endpoint(
|
||||
"exception-test", backend="exception:v1", route="/noListReturned")
|
||||
|
||||
handle = serve.get_handle("exception-test")
|
||||
with pytest.raises(ray.exceptions.RayTaskError):
|
||||
assert ray.get(handle.remote(temp=1))
|
||||
|
||||
|
||||
def test_batching_exception_legacy(serve_instance):
|
||||
class NoListReturned:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
@serve.accept_batch
|
||||
def __call__(self, flask_request, temp=None):
|
||||
batch_size = serve.context.batch_size
|
||||
return batch_size
|
||||
|
||||
# set the max batch size
|
||||
serve.create_backend(
|
||||
"exception:v1", NoListReturned, config={"max_batch_size": 5})
|
||||
@@ -248,6 +344,40 @@ def test_updating_config(serve_instance):
|
||||
batch_size = serve.context.batch_size
|
||||
return [1] * batch_size
|
||||
|
||||
serve.create_backend(
|
||||
"bsimple:v1",
|
||||
BatchSimple,
|
||||
config=BackendConfig(max_batch_size=2, num_replicas=3))
|
||||
serve.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")
|
||||
|
||||
controller = serve.api._get_controller()
|
||||
old_replica_tag_list = ray.get(
|
||||
controller._list_replicas.remote("bsimple:v1"))
|
||||
|
||||
serve.update_backend_config("bsimple:v1", BackendConfig(max_batch_size=5))
|
||||
new_replica_tag_list = ray.get(
|
||||
controller._list_replicas.remote("bsimple:v1"))
|
||||
new_all_tag_list = []
|
||||
for worker_dict in ray.get(
|
||||
controller.get_all_worker_handles.remote()).values():
|
||||
new_all_tag_list.extend(list(worker_dict.keys()))
|
||||
|
||||
# the old and new replica tag list should be identical
|
||||
# and should be subset of all_tag_list
|
||||
assert set(old_replica_tag_list) <= set(new_all_tag_list)
|
||||
assert set(old_replica_tag_list) == set(new_replica_tag_list)
|
||||
|
||||
|
||||
def test_updating_config_legacy(serve_instance):
|
||||
class BatchSimple:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
@serve.accept_batch
|
||||
def __call__(self, flask_request, temp=None):
|
||||
batch_size = serve.context.batch_size
|
||||
return [1] * batch_size
|
||||
|
||||
serve.create_backend(
|
||||
"bsimple:v1",
|
||||
BatchSimple,
|
||||
@@ -450,6 +580,43 @@ def test_parallel_start(serve_instance):
|
||||
|
||||
barrier = Barrier.remote(release_on=2)
|
||||
|
||||
class LongStartingServable:
|
||||
def __init__(self):
|
||||
ray.get(barrier.wait.remote(), timeout=10)
|
||||
|
||||
def __call__(self, _):
|
||||
return "Ready"
|
||||
|
||||
serve.create_backend(
|
||||
"p:v0", LongStartingServable, config=BackendConfig(num_replicas=2))
|
||||
serve.create_endpoint("test-parallel", backend="p:v0")
|
||||
handle = serve.get_handle("test-parallel")
|
||||
|
||||
ray.get(handle.remote(), timeout=10)
|
||||
|
||||
|
||||
def test_parallel_start_legacy(serve_instance):
|
||||
# Test the ability to start multiple replicas in parallel.
|
||||
# In the past, when Serve scale up a backend, it does so one by one and
|
||||
# wait for each replica to initialize. This test avoid this by preventing
|
||||
# the first replica to finish initialization unless the second replica is
|
||||
# also started.
|
||||
@ray.remote
|
||||
class Barrier:
|
||||
def __init__(self, release_on):
|
||||
self.release_on = release_on
|
||||
self.current_waiters = 0
|
||||
self.event = asyncio.Event()
|
||||
|
||||
async def wait(self):
|
||||
self.current_waiters += 1
|
||||
if self.current_waiters == self.release_on:
|
||||
self.event.set()
|
||||
else:
|
||||
await self.event.wait()
|
||||
|
||||
barrier = Barrier.remote(release_on=2)
|
||||
|
||||
class LongStartingServable:
|
||||
def __init__(self):
|
||||
ray.get(barrier.wait.remote(), timeout=10)
|
||||
@@ -512,6 +679,33 @@ def test_list_endpoints(serve_instance):
|
||||
def test_list_backends(serve_instance):
|
||||
serve.init()
|
||||
|
||||
@serve.accept_batch
|
||||
def f():
|
||||
pass
|
||||
|
||||
serve.create_backend("backend", f, config=BackendConfig(max_batch_size=10))
|
||||
backends = serve.list_backends()
|
||||
assert len(backends) == 1
|
||||
assert "backend" in backends
|
||||
assert backends["backend"]["max_batch_size"] == 10
|
||||
|
||||
serve.create_backend("backend2", f, config=BackendConfig(num_replicas=10))
|
||||
backends = serve.list_backends()
|
||||
assert len(backends) == 2
|
||||
assert backends["backend2"]["num_replicas"] == 10
|
||||
|
||||
serve.delete_backend("backend")
|
||||
backends = serve.list_backends()
|
||||
assert len(backends) == 1
|
||||
assert "backend2" in backends
|
||||
|
||||
serve.delete_backend("backend2")
|
||||
assert len(serve.list_backends()) == 0
|
||||
|
||||
|
||||
def test_list_backends_legacy(serve_instance):
|
||||
serve.init()
|
||||
|
||||
@serve.accept_batch
|
||||
def f():
|
||||
pass
|
||||
@@ -555,6 +749,37 @@ def test_endpoint_input_validation(serve_instance):
|
||||
def test_create_infeasible_error(serve_instance):
|
||||
serve.init()
|
||||
|
||||
def f():
|
||||
pass
|
||||
|
||||
# Non existent resource should be infeasible.
|
||||
with pytest.raises(RayServeException, match="Cannot scale backend"):
|
||||
serve.create_backend(
|
||||
"f:1",
|
||||
f,
|
||||
ray_actor_options={"resources": {
|
||||
"MagicMLResource": 100
|
||||
}})
|
||||
|
||||
# Even each replica might be feasible, the total might not be.
|
||||
current_cpus = int(ray.nodes()[0]["Resources"]["CPU"])
|
||||
with pytest.raises(RayServeException, match="Cannot scale backend"):
|
||||
serve.create_backend(
|
||||
"f:1",
|
||||
f,
|
||||
ray_actor_options={"resources": {
|
||||
"CPU": 1,
|
||||
}},
|
||||
config=BackendConfig(num_replicas=(current_cpus + 20)))
|
||||
|
||||
# No replica should be created!
|
||||
replicas = ray.get(serve.api.controller._list_replicas.remote("f1"))
|
||||
assert len(replicas) == 0
|
||||
|
||||
|
||||
def test_create_infeasible_error_legacy(serve_instance):
|
||||
serve.init()
|
||||
|
||||
def f():
|
||||
pass
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error
|
||||
from ray.serve.controller import TrafficPolicy
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.router import Router
|
||||
from ray.serve.config import BackendConfig
|
||||
from ray.serve.config import BackendConfig, BackendMetadata
|
||||
from ray.serve.exceptions import RayServeException
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
@@ -19,7 +19,7 @@ pytestmark = pytest.mark.asyncio
|
||||
def setup_worker(name,
|
||||
func_or_class,
|
||||
init_args=None,
|
||||
backend_config=BackendConfig({})):
|
||||
backend_config=BackendConfig()):
|
||||
if init_args is None:
|
||||
init_args = ()
|
||||
|
||||
@@ -178,10 +178,9 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||
PRODUCER_NAME = "producer"
|
||||
|
||||
backend_config = BackendConfig(
|
||||
{
|
||||
"max_batch_size": 4,
|
||||
"batch_wait_timeout": 2
|
||||
}, accepts_batches=True)
|
||||
max_batch_size=4,
|
||||
batch_wait_timeout=2,
|
||||
internal_metadata=BackendMetadata(accepts_batches=True))
|
||||
worker = setup_worker(
|
||||
CONSUMER_NAME, Batcher, backend_config=backend_config)
|
||||
|
||||
@@ -230,10 +229,9 @@ async def test_task_runner_perform_batch(serve_instance):
|
||||
PRODUCER_NAME = "producer"
|
||||
|
||||
config = BackendConfig(
|
||||
{
|
||||
"max_batch_size": 2,
|
||||
"batch_wait_timeout": 10
|
||||
}, accepts_batches=True)
|
||||
max_batch_size=2,
|
||||
batch_wait_timeout=10,
|
||||
internal_metadata=BackendMetadata(accepts_batches=True))
|
||||
|
||||
worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
@@ -277,7 +275,9 @@ async def test_task_runner_perform_async(serve_instance):
|
||||
CONSUMER_NAME = "runner"
|
||||
PRODUCER_NAME = "producer"
|
||||
|
||||
config = BackendConfig({"max_concurrent_queries": 10}, is_blocking=False)
|
||||
config = BackendConfig(
|
||||
max_concurrent_queries=10,
|
||||
internal_metadata=BackendMetadata(is_blocking=False))
|
||||
|
||||
worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
|
||||
@@ -1,67 +1,75 @@
|
||||
import pytest
|
||||
|
||||
from ray import serve
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
|
||||
from ray.serve.constants import ASYNC_CONCURRENCY
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
def test_backend_config_validation():
|
||||
# Test unknown key.
|
||||
with pytest.raises(ValueError, match="unknown_key"):
|
||||
BackendConfig({"unknown_key": -1})
|
||||
|
||||
# Test that the input dict isn't modified.
|
||||
config = {"num_replicas": 2}
|
||||
BackendConfig(config)
|
||||
assert len(config) == 1 and config["num_replicas"] == 2
|
||||
with pytest.raises(ValidationError):
|
||||
BackendConfig(unknown_key=-1)
|
||||
|
||||
# Test num_replicas validation.
|
||||
BackendConfig({"num_replicas": 1})
|
||||
with pytest.raises(TypeError):
|
||||
BackendConfig({"num_replicas": "hello"})
|
||||
with pytest.raises(ValueError):
|
||||
BackendConfig({"num_replicas": -1})
|
||||
BackendConfig(num_replicas=1)
|
||||
with pytest.raises(ValidationError, match="type_error"):
|
||||
BackendConfig(num_replicas="hello")
|
||||
with pytest.raises(ValidationError, match="value_error"):
|
||||
BackendConfig(num_replicas=-1)
|
||||
|
||||
# Test max_batch_size validation.
|
||||
BackendConfig({"max_batch_size": 10}, accepts_batches=True)
|
||||
BackendConfig(
|
||||
max_batch_size=10,
|
||||
internal_metadata=BackendMetadata(
|
||||
accepts_batches=True))._validate_complete()
|
||||
with pytest.raises(ValueError):
|
||||
BackendConfig({"max_batch_size": 10}, accepts_batches=False)
|
||||
with pytest.raises(TypeError):
|
||||
BackendConfig({"max_batch_size": 1.0})
|
||||
with pytest.raises(TypeError):
|
||||
BackendConfig({"max_batch_size": "hello"})
|
||||
with pytest.raises(ValueError):
|
||||
BackendConfig({"max_batch_size": 0})
|
||||
with pytest.raises(ValueError):
|
||||
BackendConfig({"max_batch_size": -1})
|
||||
BackendConfig(
|
||||
max_batch_size=10,
|
||||
internal_metadata=BackendMetadata(
|
||||
accepts_batches=False))._validate_complete()
|
||||
with pytest.raises(ValidationError, match="type_error"):
|
||||
BackendConfig(max_batch_size="hello")
|
||||
with pytest.raises(ValidationError, match="value_error"):
|
||||
BackendConfig(max_batch_size=0)
|
||||
with pytest.raises(ValidationError, match="value_error"):
|
||||
BackendConfig(max_batch_size=-1)
|
||||
|
||||
# Test dynamic default for max_concurrent_queries.
|
||||
assert BackendConfig().max_concurrent_queries == 8
|
||||
assert BackendConfig(max_batch_size=7).max_concurrent_queries == 14
|
||||
assert BackendConfig(
|
||||
max_batch_size=10,
|
||||
internal_metadata=BackendMetadata(
|
||||
is_blocking=False)).max_concurrent_queries == ASYNC_CONCURRENCY
|
||||
assert BackendConfig(
|
||||
max_batch_size=7, batch_wait_timeout=1.0).max_concurrent_queries == 14
|
||||
|
||||
|
||||
def test_backend_config_update():
|
||||
b = BackendConfig({"num_replicas": 1, "max_batch_size": 1})
|
||||
b = BackendConfig(num_replicas=1, max_batch_size=1)
|
||||
|
||||
# Test updating a key works.
|
||||
b.update({"num_replicas": 2})
|
||||
b.num_replicas = 2
|
||||
assert b.num_replicas == 2
|
||||
# Check that not specifying a key doesn't update it.
|
||||
assert b.max_batch_size == 1
|
||||
|
||||
# Check that passing an invalid key fails.
|
||||
with pytest.raises(ValueError):
|
||||
b.update({"unknown": 1})
|
||||
|
||||
# Check that input is validated.
|
||||
with pytest.raises(TypeError):
|
||||
b.update({"num_replicas": "hello"})
|
||||
with pytest.raises(ValueError):
|
||||
b.update({"num_replicas": -1})
|
||||
with pytest.raises(ValidationError):
|
||||
b.num_replicas = "Hello"
|
||||
with pytest.raises(ValidationError):
|
||||
b.num_replicas = -1
|
||||
|
||||
# Test batch validation.
|
||||
b = BackendConfig({}, accepts_batches=False)
|
||||
b.update({"max_batch_size": 1})
|
||||
b = BackendConfig(internal_metadata=BackendMetadata(accepts_batches=False))
|
||||
b.max_batch_size = 1
|
||||
with pytest.raises(ValueError):
|
||||
b.update({"max_batch_size": 2})
|
||||
b.max_batch_size = 2
|
||||
b._validate_complete()
|
||||
|
||||
b = BackendConfig({}, accepts_batches=True)
|
||||
b.update({"max_batch_size": 2})
|
||||
b = BackendConfig(internal_metadata=BackendMetadata(accepts_batches=True))
|
||||
b.max_batch_size = 2
|
||||
|
||||
|
||||
def test_replica_config_validation():
|
||||
|
||||
@@ -183,7 +183,8 @@ def test_worker_replica_failure(serve_instance):
|
||||
temp_path = os.path.join(tempfile.gettempdir(),
|
||||
serve.utils.get_random_letters())
|
||||
serve.create_backend("replica_failure", Worker, temp_path)
|
||||
serve.update_backend_config("replica_failure", {"num_replicas": 2})
|
||||
serve.update_backend_config(
|
||||
"replica_failure", BackendConfig(num_replicas=2))
|
||||
serve.create_endpoint(
|
||||
"replica_failure", backend="replica_failure", route="/replica_failure")
|
||||
|
||||
@@ -219,7 +220,7 @@ def test_create_backend_idempotent(serve_instance):
|
||||
controller = serve.api._get_controller()
|
||||
|
||||
replica_config = ReplicaConfig(f)
|
||||
backend_config = BackendConfig({"num_replicas": 1})
|
||||
backend_config = BackendConfig(num_replicas=1)
|
||||
|
||||
for i in range(10):
|
||||
ray.get(
|
||||
|
||||
@@ -181,7 +181,7 @@ async def test_router_use_max_concurrency(serve_instance):
|
||||
q = ray.remote(VisibleRouter).remote()
|
||||
await q.setup.remote("")
|
||||
backend_name = "max-concurrent-test"
|
||||
config = BackendConfig({"max_concurrent_queries": 1})
|
||||
config = BackendConfig(max_concurrent_queries=1)
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({backend_name: 1.0}))
|
||||
await q.add_new_worker.remote(backend_name, "replica-tag", worker)
|
||||
await q.set_backend_config.remote(backend_name, config)
|
||||
|
||||
@@ -37,6 +37,7 @@ scipy==1.4.1
|
||||
tabulate
|
||||
tensorboardX
|
||||
uvicorn
|
||||
pydantic
|
||||
dataclasses
|
||||
|
||||
# Requirements for running tests
|
||||
|
||||
+1
-1
@@ -110,7 +110,7 @@ if os.getenv("RAY_USE_NEW_GCS") == "on":
|
||||
# in this directory
|
||||
extras = {
|
||||
"debug": [],
|
||||
"serve": ["uvicorn", "flask", "requests", "dataclasses"],
|
||||
"serve": ["uvicorn", "flask", "requests", "pydantic", "dataclasses"],
|
||||
"tune": ["tabulate", "tensorboardX", "pandas"]
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user