[Serve] Reimplement BackendConfig as pydantic model (#10389)

This commit is contained in:
architkulkarni
2020-09-03 17:16:17 -07:00
committed by GitHub
parent 43a7a64b30
commit 0d93e92720
14 changed files with 439 additions and 171 deletions
+5 -15
View File
@@ -3,20 +3,10 @@ from ray.serve.api import (init, create_backend, delete_backend,
shadow_traffic, get_handle, update_backend_config,
get_backend_config, accept_batch, list_backends,
list_endpoints, shutdown) # noqa: E402
from ray.serve.config import BackendConfig
__all__ = [
"init",
"create_backend",
"delete_backend",
"create_endpoint",
"delete_endpoint",
"set_traffic",
"shadow_traffic",
"get_handle",
"update_backend_config",
"get_backend_config",
"accept_batch",
"list_backends",
"list_endpoints",
"shutdown",
"init", "create_backend", "delete_backend", "create_endpoint",
"delete_endpoint", "set_traffic", "shadow_traffic", "get_handle",
"update_backend_config", "get_backend_config", "accept_batch",
"list_backends", "list_endpoints", "shutdown", "BackendConfig"
]
+33 -21
View File
@@ -7,7 +7,7 @@ from ray.serve.controller import ServeController
from ray.serve.handle import RayServeHandle
from ray.serve.utils import (block_until_http_ready, format_actor_name)
from ray.serve.exceptions import RayServeException
from ray.serve.config import BackendConfig, ReplicaConfig
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
from ray.actor import ActorHandle
from typing import Any, Callable, Dict, List, Optional, Type, Union
@@ -201,16 +201,18 @@ def list_endpoints() -> Dict[str, Dict[str, Any]]:
@_ensure_connected
def update_backend_config(backend_tag: str,
config_options: Dict[str, Any]) -> None:
def update_backend_config(
backend_tag: str,
config_options: Union[BackendConfig, Dict[str, Any]]) -> None:
"""Update a backend configuration for a backend tag.
Keys not specified in the passed will be left unchanged.
Args:
backend_tag(str): A registered backend.
config_options(dict): Backend config options to update.
Supported options:
config_options(dict, serve.BackendConfig): Backend config options to
update. Either a BackendConfig object or a dict mapping strings to
values for the following supported options:
- "num_replicas": number of worker processes to start up that
will handle requests to this backend.
- "max_batch_size": the maximum number of requests that will
@@ -222,14 +224,16 @@ def update_backend_config(backend_tag: str,
that will be sent to a replica of this backend
without receiving a response.
"""
if not isinstance(config_options, dict):
raise ValueError("config_options must be a dictionary.")
if not isinstance(config_options, (BackendConfig, dict)):
raise TypeError(
"config_options must be a BackendConfig or dictionary.")
ray.get(
controller.update_backend_config.remote(backend_tag, config_options))
@_ensure_connected
def get_backend_config(backend_tag: str):
def get_backend_config(backend_tag: str) -> BackendConfig:
"""Get the backend configuration for a backend tag.
Args:
@@ -239,11 +243,12 @@ def get_backend_config(backend_tag: str):
@_ensure_connected
def create_backend(backend_tag: str,
func_or_class: Union[Callable, Type[Callable]],
*actor_init_args: Any,
ray_actor_options: Optional[Dict] = None,
config: Optional[Dict[str, Any]] = None) -> None:
def create_backend(
backend_tag: str,
func_or_class: Union[Callable, Type[Callable]],
*actor_init_args: Any,
ray_actor_options: Optional[Dict] = None,
config: Optional[Union[BackendConfig, Dict[str, Any]]] = None) -> None:
"""Create a backend with the provided tag.
The backend will serve requests with func_or_class.
@@ -256,8 +261,9 @@ def create_backend(backend_tag: str,
initialization method.
ray_actor_options (optional): options to be passed into the
@ray.remote decorator for the backend actor.
config (optional): configuration options for this backend.
Supported options:
config (dict, serve.BackendConfig, optional): configuration options
for this backend. Either a BackendConfig, or a dictionary mapping
strings to values for the following supported options:
- "num_replicas": number of worker processes to start up that will
handle requests to this backend.
- "max_batch_size": the maximum number of requests that will
@@ -276,14 +282,20 @@ def create_backend(backend_tag: str,
if config is None:
config = {}
if not isinstance(config, dict):
raise TypeError("config must be a dictionary.")
replica_config = ReplicaConfig(
func_or_class, *actor_init_args, ray_actor_options=ray_actor_options)
backend_config = BackendConfig(config, replica_config.accepts_batches,
replica_config.is_blocking)
metadata = BackendMetadata(
accepts_batches=replica_config.accepts_batches,
is_blocking=replica_config.is_blocking)
if isinstance(config, dict):
backend_config = BackendConfig.parse_obj({
**config, "internal_metadata": metadata
})
elif isinstance(config, BackendConfig):
backend_config = config.copy(update={"internal_metadata": metadata})
else:
raise TypeError("config must be a BackendConfig or a dictionary.")
backend_config._validate_complete()
ray.get(
controller.create_backend.remote(backend_tag, backend_config,
replica_config))
+2 -2
View File
@@ -318,7 +318,7 @@ class RayServeWorker:
all_evaluated_futures = []
if not self.config.accepts_batches:
if not self.config.internal_metadata.accepts_batches:
query = batch[0]
evaluated = asyncio.ensure_future(self.invoke_single(query))
all_evaluated_futures = [evaluated]
@@ -336,7 +336,7 @@ class RayServeWorker:
chain_future(
unpack_future(evaluated, len(group)), result_futures)
if self.config.is_blocking:
if self.config.internal_metadata.is_blocking:
# We use asyncio.wait here so if the result is exception,
# it will not be raised.
await asyncio.wait(all_evaluated_futures)
+72 -66
View File
@@ -1,6 +1,9 @@
import inspect
from pydantic import BaseModel, PositiveInt, validator
from ray.serve.constants import ASYNC_CONCURRENCY
from typing import Optional, Dict, Any
from dataclasses import dataclass
def _callable_accepts_batch(func_or_class):
@@ -17,86 +20,89 @@ def _callable_is_blocking(func_or_class):
return not inspect.iscoroutinefunction(func_or_class.__call__)
class BackendConfig:
def __init__(self, config_dict, accepts_batches=False, is_blocking=True):
assert isinstance(config_dict, dict)
# Make a copy so that we don't modify the input dict.
config_dict = config_dict.copy()
@dataclass
class BackendMetadata:
accepts_batches: bool = False
is_blocking: bool = True
autoscaling_config: Optional[Dict[str, Any]] = None
self.accepts_batches = accepts_batches
self.is_blocking = is_blocking
self.num_replicas = config_dict.pop("num_replicas", 1)
self.max_batch_size = config_dict.pop("max_batch_size", None)
self.batch_wait_timeout = config_dict.pop("batch_wait_timeout", 0)
self.max_concurrent_queries = config_dict.pop("max_concurrent_queries",
None)
self.autoscaling_config = config_dict.pop("autoscaling", None)
if self.max_concurrent_queries is None:
class BackendConfig(BaseModel):
"""Configuration options for a backend, to be set by the user.
:param num_replicas: The number of worker processes to start up that will
handle requests to this backend. Defaults to 0.
:type num_replicas: int, optional
:param max_batch_size: The maximum number of requests that will be
processed in one batch by this backend. Defaults to None (no
maximium).
:type max_batch_size: int, optional
:param batch_wait_timeout: The time in seconds that backend replicas will
wait for a full batch of requests before processing a partial batch.
Defaults to 0.
:type batch_wait_timeout: float, optional
:param max_concurrent_queries: The maximum number of queries that will be
sent to a replica of this backend without receiving a response.
Defaults to None (no maximum).
:type max_concurrent_queries: int, optional
"""
internal_metadata: BackendMetadata = BackendMetadata()
num_replicas: PositiveInt = 1
max_batch_size: Optional[PositiveInt] = None
batch_wait_timeout: float = 0
max_concurrent_queries: Optional[int] = None
class Config:
validate_assignment = True
extra = "forbid"
arbitrary_types_allowed = True
def _validate_batch_size(self):
if (self.max_batch_size is not None
and not self.internal_metadata.accepts_batches
and self.max_batch_size > 1):
raise ValueError(
"max_batch_size is set in config but the function or "
"method does not accept batching. Please use "
"@serve.accept_batch to explicitly mark that the function or "
"method accepts a list of requests as an argument.")
# This is not a pydantic validator, so that we may skip this method when
# creating partially filled BackendConfig objects to pass as updates--for
# example, BackendConfig(max_batch_size=5).
def _validate_complete(self):
self._validate_batch_size()
# Dynamic default for max_concurrent_queries
@validator("max_concurrent_queries", always=True)
def set_max_queries_by_mode(cls, v, values):
if v is None:
# Model serving mode: if the servable is blocking and the wait
# timeout is default zero seconds, then we keep the existing
# behavior to allow at most max batch size queries.
if self.is_blocking and self.batch_wait_timeout == 0:
if self.max_batch_size:
self.max_concurrent_queries = 2 * self.max_batch_size
if (values["internal_metadata"].is_blocking
and values["batch_wait_timeout"] == 0):
if ("max_batch_size" in values
and values["max_batch_size"] is not None):
v = 2 * values["max_batch_size"]
else:
self.max_concurrent_queries = 8
v = 8
# Pipeline/async mode: if the servable is not blocking,
# router should just keep pushing queries to the worker
# replicas until a high limit.
if not self.is_blocking:
self.max_concurrent_queries = ASYNC_CONCURRENCY
if not values["internal_metadata"].is_blocking:
v = ASYNC_CONCURRENCY
# Batch inference mode: user specifies non zero timeout to wait for
# full batch. We will use 2*max_batch_size to perform double
# buffering to keep the replica busy.
if self.max_batch_size is not None and self.batch_wait_timeout > 0:
self.max_concurrent_queries = 2 * self.max_batch_size
if len(config_dict) != 0:
raise ValueError("Unknown options in backend config: {}".format(
list(config_dict.keys())))
self._validate()
def update(self, config_dict):
"""Updates this BackendConfig with options set in the passed config.
Unspecified keys will remain the same.
"""
if "num_replicas" in config_dict:
self.num_replicas = config_dict.pop("num_replicas")
if "max_batch_size" in config_dict:
self.max_batch_size = config_dict.pop("max_batch_size")
if "max_concurrent_queries" in config_dict:
self.max_concurrent_queries = config_dict.pop(
"max_concurrent_queries")
if len(config_dict) != 0:
raise ValueError("Unknown options in backend config: {}".format(
list(config_dict.keys())))
self._validate()
def _validate(self):
if not isinstance(self.num_replicas, int):
raise TypeError("num_replicas must be an int.")
elif self.num_replicas < 1:
raise ValueError("num_replicas must be >= 1.")
if self.max_batch_size is not None:
if not isinstance(self.max_batch_size, int):
raise TypeError("max_batch_size must be an integer.")
elif self.max_batch_size < 1:
raise ValueError("max_batch_size must be >= 1.")
if not self.accepts_batches and self.max_batch_size > 1:
raise ValueError(
"max_batch_size is set in config but the function or "
"method does not accept batching. Please use "
"@serve.accept_batch to explicitly mark the function or "
"method as batchable and takes in list as arguments.")
if ("max_batch_size" in values
and values["max_batch_size"] is not None
and values["batch_wait_timeout"] > 0):
v = 2 * values["max_batch_size"]
return v
class ReplicaConfig:
+37 -14
View File
@@ -1,8 +1,10 @@
import asyncio
from collections import defaultdict, namedtuple
from collections import defaultdict
import os
import random
import time
from typing import Union, Dict, Any, List, Tuple
from pydantic import BaseModel
import ray
import ray.cloudpickle as pickle
@@ -16,7 +18,6 @@ from ray.serve.utils import (format_actor_name, get_random_letters, logger,
try_schedule_resources_on_nodes, get_all_node_ids)
from ray.serve.config import BackendConfig, ReplicaConfig
from ray.actor import ActorHandle
from typing import Dict, List, Any, Tuple
import numpy as np
@@ -62,8 +63,17 @@ class TrafficPolicy:
self.shadow_dict[backend] = proportion
BackendInfo = namedtuple("BackendInfo",
["worker_class", "backend_config", "replica_config"])
class BackendInfo(BaseModel):
# TODO(architkulkarni): Add type hint for worker_class after upgrading
# cloudpickle and adding types to RayServeWrappedWorker
worker_class: Any
backend_config: BackendConfig
replica_config: ReplicaConfig
class Config:
# TODO(architkulkarni): Remove once ReplicaConfig is a pydantic
# model
arbitrary_types_allowed = True
@ray.remote
@@ -313,9 +323,10 @@ class ServeController:
for router in self.routers.values()
])
await self.broadcast_backend_config(backend)
if info.backend_config.autoscaling_config is not None:
metadata = info.backend_config.internal_metadata
if metadata.autoscaling_config is not None:
self.autoscaling_policies[backend] = BasicAutoscalingPolicy(
backend, info.backend_config.autoscaling_config)
backend, metadata.autoscaling_config)
# Push configuration state to the routers.
await asyncio.gather(*[
@@ -753,11 +764,14 @@ class ServeController:
# Save creator that starts replicas, the arguments to be passed in,
# and the configuration for the backends.
self.backends[backend_tag] = BackendInfo(
backend_worker, backend_config, replica_config)
if backend_config.autoscaling_config is not None:
worker_class=backend_worker,
backend_config=backend_config,
replica_config=replica_config)
metadata = backend_config.internal_metadata
if metadata.autoscaling_config is not None:
self.autoscaling_policies[
backend_tag] = BasicAutoscalingPolicy(
backend_tag, backend_config.autoscaling_config)
backend_tag, metadata.autoscaling_config)
try:
self._scale_replicas(backend_tag, backend_config.num_replicas)
@@ -814,16 +828,25 @@ class ServeController:
await self._stop_pending_replicas()
await self._remove_pending_backends()
async def update_backend_config(self, backend_tag: str,
config_options: Dict[str, Any]) -> None:
async def update_backend_config(
self, backend_tag: str,
config_options: "Union[BackendConfig, Dict[str, Any]]") -> None:
"""Set the config for the specified backend."""
async with self.write_lock:
assert (backend_tag in self.backends
), "Backend {} is not registered.".format(backend_tag)
assert isinstance(config_options, dict)
assert isinstance(config_options, BackendConfig) or isinstance(
config_options, dict)
self.backends[backend_tag].backend_config.update(config_options)
backend_config = self.backends[backend_tag].backend_config
if isinstance(config_options, BackendConfig):
update_data = config_options.dict(exclude_unset=True)
elif isinstance(config_options, dict):
update_data = config_options
stored_backend_config = self.backends[backend_tag].backend_config
backend_config = stored_backend_config.copy(update=update_data)
backend_config._validate_complete()
self.backends[backend_tag].backend_config = backend_config
# Scale the replicas with the new configuration.
self._scale_replicas(backend_tag, backend_config.num_replicas)
+225
View File
@@ -11,6 +11,7 @@ from ray.test_utils import wait_for_condition
from ray.serve import constants
from ray.serve.exceptions import RayServeException
from ray.serve.utils import format_actor_name, get_random_letters
from ray.serve.config import BackendConfig
def test_e2e(serve_instance):
@@ -150,6 +151,43 @@ def test_scaling_replicas(serve_instance):
self.count += 1
return self.count
serve.create_backend(
"counter:v1", Counter, config=BackendConfig(num_replicas=2))
serve.create_endpoint("counter", backend="counter:v1", route="/increment")
# Keep checking the routing table until /increment is populated
while "/increment" not in requests.get(
"http://127.0.0.1:8000/-/routes").json():
time.sleep(0.2)
counter_result = []
for _ in range(10):
resp = requests.get("http://127.0.0.1:8000/increment").json()
counter_result.append(resp)
# If the load is shared among two replicas. The max result cannot be 10.
assert max(counter_result) < 10
serve.update_backend_config("counter:v1", {"num_replicas": 1})
counter_result = []
for _ in range(10):
resp = requests.get("http://127.0.0.1:8000/increment").json()
counter_result.append(resp)
# Give some time for a replica to spin down. But majority of the request
# should be served by the only remaining replica.
assert max(counter_result) - min(counter_result) > 6
def test_scaling_replicas_legacy(serve_instance):
class Counter:
def __init__(self):
self.count = 0
def __call__(self, _):
self.count += 1
return self.count
serve.create_backend("counter:v1", Counter, config={"num_replicas": 2})
serve.create_endpoint("counter", backend="counter:v1", route="/increment")
@@ -188,6 +226,43 @@ def test_batching(serve_instance):
batch_size = serve.context.batch_size
return [self.count] * batch_size
# set the max batch size
serve.create_backend(
"counter:v11",
BatchingExample,
config=BackendConfig(max_batch_size=5, batch_wait_timeout=1))
serve.create_endpoint(
"counter1", backend="counter:v11", route="/increment2")
# Keep checking the routing table until /increment is populated
while "/increment2" not in requests.get(
"http://127.0.0.1:8000/-/routes").json():
time.sleep(0.2)
future_list = []
handle = serve.get_handle("counter1")
for _ in range(20):
f = handle.remote(temp=1)
future_list.append(f)
counter_result = ray.get(future_list)
# since count is only updated per batch of queries
# If there atleast one __call__ fn call with batch size greater than 1
# counter result will always be less than 20
assert max(counter_result) < 20
def test_batching_legacy(serve_instance):
class BatchingExample:
def __init__(self):
self.count = 0
@serve.accept_batch
def __call__(self, flask_request, temp=None):
self.count += 1
batch_size = serve.context.batch_size
return [self.count] * batch_size
# set the max batch size
serve.create_backend(
"counter:v11",
@@ -227,6 +302,27 @@ def test_batching_exception(serve_instance):
batch_size = serve.context.batch_size
return batch_size
# set the max batch size
serve.create_backend(
"exception:v1", NoListReturned, config=BackendConfig(max_batch_size=5))
serve.create_endpoint(
"exception-test", backend="exception:v1", route="/noListReturned")
handle = serve.get_handle("exception-test")
with pytest.raises(ray.exceptions.RayTaskError):
assert ray.get(handle.remote(temp=1))
def test_batching_exception_legacy(serve_instance):
class NoListReturned:
def __init__(self):
self.count = 0
@serve.accept_batch
def __call__(self, flask_request, temp=None):
batch_size = serve.context.batch_size
return batch_size
# set the max batch size
serve.create_backend(
"exception:v1", NoListReturned, config={"max_batch_size": 5})
@@ -248,6 +344,40 @@ def test_updating_config(serve_instance):
batch_size = serve.context.batch_size
return [1] * batch_size
serve.create_backend(
"bsimple:v1",
BatchSimple,
config=BackendConfig(max_batch_size=2, num_replicas=3))
serve.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")
controller = serve.api._get_controller()
old_replica_tag_list = ray.get(
controller._list_replicas.remote("bsimple:v1"))
serve.update_backend_config("bsimple:v1", BackendConfig(max_batch_size=5))
new_replica_tag_list = ray.get(
controller._list_replicas.remote("bsimple:v1"))
new_all_tag_list = []
for worker_dict in ray.get(
controller.get_all_worker_handles.remote()).values():
new_all_tag_list.extend(list(worker_dict.keys()))
# the old and new replica tag list should be identical
# and should be subset of all_tag_list
assert set(old_replica_tag_list) <= set(new_all_tag_list)
assert set(old_replica_tag_list) == set(new_replica_tag_list)
def test_updating_config_legacy(serve_instance):
class BatchSimple:
def __init__(self):
self.count = 0
@serve.accept_batch
def __call__(self, flask_request, temp=None):
batch_size = serve.context.batch_size
return [1] * batch_size
serve.create_backend(
"bsimple:v1",
BatchSimple,
@@ -450,6 +580,43 @@ def test_parallel_start(serve_instance):
barrier = Barrier.remote(release_on=2)
class LongStartingServable:
def __init__(self):
ray.get(barrier.wait.remote(), timeout=10)
def __call__(self, _):
return "Ready"
serve.create_backend(
"p:v0", LongStartingServable, config=BackendConfig(num_replicas=2))
serve.create_endpoint("test-parallel", backend="p:v0")
handle = serve.get_handle("test-parallel")
ray.get(handle.remote(), timeout=10)
def test_parallel_start_legacy(serve_instance):
# Test the ability to start multiple replicas in parallel.
# In the past, when Serve scale up a backend, it does so one by one and
# wait for each replica to initialize. This test avoid this by preventing
# the first replica to finish initialization unless the second replica is
# also started.
@ray.remote
class Barrier:
def __init__(self, release_on):
self.release_on = release_on
self.current_waiters = 0
self.event = asyncio.Event()
async def wait(self):
self.current_waiters += 1
if self.current_waiters == self.release_on:
self.event.set()
else:
await self.event.wait()
barrier = Barrier.remote(release_on=2)
class LongStartingServable:
def __init__(self):
ray.get(barrier.wait.remote(), timeout=10)
@@ -512,6 +679,33 @@ def test_list_endpoints(serve_instance):
def test_list_backends(serve_instance):
serve.init()
@serve.accept_batch
def f():
pass
serve.create_backend("backend", f, config=BackendConfig(max_batch_size=10))
backends = serve.list_backends()
assert len(backends) == 1
assert "backend" in backends
assert backends["backend"]["max_batch_size"] == 10
serve.create_backend("backend2", f, config=BackendConfig(num_replicas=10))
backends = serve.list_backends()
assert len(backends) == 2
assert backends["backend2"]["num_replicas"] == 10
serve.delete_backend("backend")
backends = serve.list_backends()
assert len(backends) == 1
assert "backend2" in backends
serve.delete_backend("backend2")
assert len(serve.list_backends()) == 0
def test_list_backends_legacy(serve_instance):
serve.init()
@serve.accept_batch
def f():
pass
@@ -555,6 +749,37 @@ def test_endpoint_input_validation(serve_instance):
def test_create_infeasible_error(serve_instance):
serve.init()
def f():
pass
# Non existent resource should be infeasible.
with pytest.raises(RayServeException, match="Cannot scale backend"):
serve.create_backend(
"f:1",
f,
ray_actor_options={"resources": {
"MagicMLResource": 100
}})
# Even each replica might be feasible, the total might not be.
current_cpus = int(ray.nodes()[0]["Resources"]["CPU"])
with pytest.raises(RayServeException, match="Cannot scale backend"):
serve.create_backend(
"f:1",
f,
ray_actor_options={"resources": {
"CPU": 1,
}},
config=BackendConfig(num_replicas=(current_cpus + 20)))
# No replica should be created!
replicas = ray.get(serve.api.controller._list_replicas.remote("f1"))
assert len(replicas) == 0
def test_create_infeasible_error_legacy(serve_instance):
serve.init()
def f():
pass
+11 -11
View File
@@ -10,7 +10,7 @@ from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error
from ray.serve.controller import TrafficPolicy
from ray.serve.request_params import RequestMetadata
from ray.serve.router import Router
from ray.serve.config import BackendConfig
from ray.serve.config import BackendConfig, BackendMetadata
from ray.serve.exceptions import RayServeException
pytestmark = pytest.mark.asyncio
@@ -19,7 +19,7 @@ pytestmark = pytest.mark.asyncio
def setup_worker(name,
func_or_class,
init_args=None,
backend_config=BackendConfig({})):
backend_config=BackendConfig()):
if init_args is None:
init_args = ()
@@ -178,10 +178,9 @@ async def test_task_runner_custom_method_batch(serve_instance):
PRODUCER_NAME = "producer"
backend_config = BackendConfig(
{
"max_batch_size": 4,
"batch_wait_timeout": 2
}, accepts_batches=True)
max_batch_size=4,
batch_wait_timeout=2,
internal_metadata=BackendMetadata(accepts_batches=True))
worker = setup_worker(
CONSUMER_NAME, Batcher, backend_config=backend_config)
@@ -230,10 +229,9 @@ async def test_task_runner_perform_batch(serve_instance):
PRODUCER_NAME = "producer"
config = BackendConfig(
{
"max_batch_size": 2,
"batch_wait_timeout": 10
}, accepts_batches=True)
max_batch_size=2,
batch_wait_timeout=10,
internal_metadata=BackendMetadata(accepts_batches=True))
worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
@@ -277,7 +275,9 @@ async def test_task_runner_perform_async(serve_instance):
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
config = BackendConfig({"max_concurrent_queries": 10}, is_blocking=False)
config = BackendConfig(
max_concurrent_queries=10,
internal_metadata=BackendMetadata(is_blocking=False))
worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
+46 -38
View File
@@ -1,67 +1,75 @@
import pytest
from ray import serve
from ray.serve.config import BackendConfig, ReplicaConfig
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
from ray.serve.constants import ASYNC_CONCURRENCY
from pydantic import ValidationError
def test_backend_config_validation():
# Test unknown key.
with pytest.raises(ValueError, match="unknown_key"):
BackendConfig({"unknown_key": -1})
# Test that the input dict isn't modified.
config = {"num_replicas": 2}
BackendConfig(config)
assert len(config) == 1 and config["num_replicas"] == 2
with pytest.raises(ValidationError):
BackendConfig(unknown_key=-1)
# Test num_replicas validation.
BackendConfig({"num_replicas": 1})
with pytest.raises(TypeError):
BackendConfig({"num_replicas": "hello"})
with pytest.raises(ValueError):
BackendConfig({"num_replicas": -1})
BackendConfig(num_replicas=1)
with pytest.raises(ValidationError, match="type_error"):
BackendConfig(num_replicas="hello")
with pytest.raises(ValidationError, match="value_error"):
BackendConfig(num_replicas=-1)
# Test max_batch_size validation.
BackendConfig({"max_batch_size": 10}, accepts_batches=True)
BackendConfig(
max_batch_size=10,
internal_metadata=BackendMetadata(
accepts_batches=True))._validate_complete()
with pytest.raises(ValueError):
BackendConfig({"max_batch_size": 10}, accepts_batches=False)
with pytest.raises(TypeError):
BackendConfig({"max_batch_size": 1.0})
with pytest.raises(TypeError):
BackendConfig({"max_batch_size": "hello"})
with pytest.raises(ValueError):
BackendConfig({"max_batch_size": 0})
with pytest.raises(ValueError):
BackendConfig({"max_batch_size": -1})
BackendConfig(
max_batch_size=10,
internal_metadata=BackendMetadata(
accepts_batches=False))._validate_complete()
with pytest.raises(ValidationError, match="type_error"):
BackendConfig(max_batch_size="hello")
with pytest.raises(ValidationError, match="value_error"):
BackendConfig(max_batch_size=0)
with pytest.raises(ValidationError, match="value_error"):
BackendConfig(max_batch_size=-1)
# Test dynamic default for max_concurrent_queries.
assert BackendConfig().max_concurrent_queries == 8
assert BackendConfig(max_batch_size=7).max_concurrent_queries == 14
assert BackendConfig(
max_batch_size=10,
internal_metadata=BackendMetadata(
is_blocking=False)).max_concurrent_queries == ASYNC_CONCURRENCY
assert BackendConfig(
max_batch_size=7, batch_wait_timeout=1.0).max_concurrent_queries == 14
def test_backend_config_update():
b = BackendConfig({"num_replicas": 1, "max_batch_size": 1})
b = BackendConfig(num_replicas=1, max_batch_size=1)
# Test updating a key works.
b.update({"num_replicas": 2})
b.num_replicas = 2
assert b.num_replicas == 2
# Check that not specifying a key doesn't update it.
assert b.max_batch_size == 1
# Check that passing an invalid key fails.
with pytest.raises(ValueError):
b.update({"unknown": 1})
# Check that input is validated.
with pytest.raises(TypeError):
b.update({"num_replicas": "hello"})
with pytest.raises(ValueError):
b.update({"num_replicas": -1})
with pytest.raises(ValidationError):
b.num_replicas = "Hello"
with pytest.raises(ValidationError):
b.num_replicas = -1
# Test batch validation.
b = BackendConfig({}, accepts_batches=False)
b.update({"max_batch_size": 1})
b = BackendConfig(internal_metadata=BackendMetadata(accepts_batches=False))
b.max_batch_size = 1
with pytest.raises(ValueError):
b.update({"max_batch_size": 2})
b.max_batch_size = 2
b._validate_complete()
b = BackendConfig({}, accepts_batches=True)
b.update({"max_batch_size": 2})
b = BackendConfig(internal_metadata=BackendMetadata(accepts_batches=True))
b.max_batch_size = 2
def test_replica_config_validation():
+3 -2
View File
@@ -183,7 +183,8 @@ def test_worker_replica_failure(serve_instance):
temp_path = os.path.join(tempfile.gettempdir(),
serve.utils.get_random_letters())
serve.create_backend("replica_failure", Worker, temp_path)
serve.update_backend_config("replica_failure", {"num_replicas": 2})
serve.update_backend_config(
"replica_failure", BackendConfig(num_replicas=2))
serve.create_endpoint(
"replica_failure", backend="replica_failure", route="/replica_failure")
@@ -219,7 +220,7 @@ def test_create_backend_idempotent(serve_instance):
controller = serve.api._get_controller()
replica_config = ReplicaConfig(f)
backend_config = BackendConfig({"num_replicas": 1})
backend_config = BackendConfig(num_replicas=1)
for i in range(10):
ray.get(
+1 -1
View File
@@ -181,7 +181,7 @@ async def test_router_use_max_concurrency(serve_instance):
q = ray.remote(VisibleRouter).remote()
await q.setup.remote("")
backend_name = "max-concurrent-test"
config = BackendConfig({"max_concurrent_queries": 1})
config = BackendConfig(max_concurrent_queries=1)
await q.set_traffic.remote("svc", TrafficPolicy({backend_name: 1.0}))
await q.add_new_worker.remote(backend_name, "replica-tag", worker)
await q.set_backend_config.remote(backend_name, config)
+1
View File
@@ -37,6 +37,7 @@ scipy==1.4.1
tabulate
tensorboardX
uvicorn
pydantic
dataclasses
# Requirements for running tests
+1 -1
View File
@@ -110,7 +110,7 @@ if os.getenv("RAY_USE_NEW_GCS") == "on":
# in this directory
extras = {
"debug": [],
"serve": ["uvicorn", "flask", "requests", "dataclasses"],
"serve": ["uvicorn", "flask", "requests", "pydantic", "dataclasses"],
"tune": ["tabulate", "tensorboardX", "pandas"]
}