diff --git a/doc/requirements-doc.txt b/doc/requirements-doc.txt index 18d73244f..83e0979e5 100644 --- a/doc/requirements-doc.txt +++ b/doc/requirements-doc.txt @@ -11,6 +11,7 @@ opencv-python-headless==4.3.0.36 pandas pickle5 pillow +pydantic pygments pyyaml recommonmark diff --git a/doc/source/serve/package-ref.rst b/doc/source/serve/package-ref.rst index 719e9ef20..a397d6d4b 100644 --- a/doc/source/serve/package-ref.rst +++ b/doc/source/serve/package-ref.rst @@ -5,6 +5,7 @@ Basic APIs ---------- .. autofunction:: ray.serve.init .. autofunction:: ray.serve.shutdown +.. autoclass:: ray.serve.BackendConfig .. autofunction:: ray.serve.create_backend .. autofunction:: ray.serve.create_endpoint diff --git a/python/ray/serve/__init__.py b/python/ray/serve/__init__.py index f8b86c20a..41fa0e786 100644 --- a/python/ray/serve/__init__.py +++ b/python/ray/serve/__init__.py @@ -3,20 +3,10 @@ from ray.serve.api import (init, create_backend, delete_backend, shadow_traffic, get_handle, update_backend_config, get_backend_config, accept_batch, list_backends, list_endpoints, shutdown) # noqa: E402 - +from ray.serve.config import BackendConfig __all__ = [ - "init", - "create_backend", - "delete_backend", - "create_endpoint", - "delete_endpoint", - "set_traffic", - "shadow_traffic", - "get_handle", - "update_backend_config", - "get_backend_config", - "accept_batch", - "list_backends", - "list_endpoints", - "shutdown", + "init", "create_backend", "delete_backend", "create_endpoint", + "delete_endpoint", "set_traffic", "shadow_traffic", "get_handle", + "update_backend_config", "get_backend_config", "accept_batch", + "list_backends", "list_endpoints", "shutdown", "BackendConfig" ] diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index f34623c6d..0dc4d0f06 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -7,7 +7,7 @@ from ray.serve.controller import ServeController from ray.serve.handle import RayServeHandle from ray.serve.utils import (block_until_http_ready, format_actor_name) from ray.serve.exceptions import RayServeException -from ray.serve.config import BackendConfig, ReplicaConfig +from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata from ray.actor import ActorHandle from typing import Any, Callable, Dict, List, Optional, Type, Union @@ -201,16 +201,18 @@ def list_endpoints() -> Dict[str, Dict[str, Any]]: @_ensure_connected -def update_backend_config(backend_tag: str, - config_options: Dict[str, Any]) -> None: +def update_backend_config( + backend_tag: str, + config_options: Union[BackendConfig, Dict[str, Any]]) -> None: """Update a backend configuration for a backend tag. Keys not specified in the passed will be left unchanged. Args: backend_tag(str): A registered backend. - config_options(dict): Backend config options to update. - Supported options: + config_options(dict, serve.BackendConfig): Backend config options to + update. Either a BackendConfig object or a dict mapping strings to + values for the following supported options: - "num_replicas": number of worker processes to start up that will handle requests to this backend. - "max_batch_size": the maximum number of requests that will @@ -222,14 +224,16 @@ def update_backend_config(backend_tag: str, that will be sent to a replica of this backend without receiving a response. """ - if not isinstance(config_options, dict): - raise ValueError("config_options must be a dictionary.") + + if not isinstance(config_options, (BackendConfig, dict)): + raise TypeError( + "config_options must be a BackendConfig or dictionary.") ray.get( controller.update_backend_config.remote(backend_tag, config_options)) @_ensure_connected -def get_backend_config(backend_tag: str): +def get_backend_config(backend_tag: str) -> BackendConfig: """Get the backend configuration for a backend tag. Args: @@ -239,11 +243,12 @@ def get_backend_config(backend_tag: str): @_ensure_connected -def create_backend(backend_tag: str, - func_or_class: Union[Callable, Type[Callable]], - *actor_init_args: Any, - ray_actor_options: Optional[Dict] = None, - config: Optional[Dict[str, Any]] = None) -> None: +def create_backend( + backend_tag: str, + func_or_class: Union[Callable, Type[Callable]], + *actor_init_args: Any, + ray_actor_options: Optional[Dict] = None, + config: Optional[Union[BackendConfig, Dict[str, Any]]] = None) -> None: """Create a backend with the provided tag. The backend will serve requests with func_or_class. @@ -256,8 +261,9 @@ def create_backend(backend_tag: str, initialization method. ray_actor_options (optional): options to be passed into the @ray.remote decorator for the backend actor. - config (optional): configuration options for this backend. - Supported options: + config (dict, serve.BackendConfig, optional): configuration options + for this backend. Either a BackendConfig, or a dictionary mapping + strings to values for the following supported options: - "num_replicas": number of worker processes to start up that will handle requests to this backend. - "max_batch_size": the maximum number of requests that will @@ -276,14 +282,20 @@ def create_backend(backend_tag: str, if config is None: config = {} - if not isinstance(config, dict): - raise TypeError("config must be a dictionary.") - replica_config = ReplicaConfig( func_or_class, *actor_init_args, ray_actor_options=ray_actor_options) - backend_config = BackendConfig(config, replica_config.accepts_batches, - replica_config.is_blocking) - + metadata = BackendMetadata( + accepts_batches=replica_config.accepts_batches, + is_blocking=replica_config.is_blocking) + if isinstance(config, dict): + backend_config = BackendConfig.parse_obj({ + **config, "internal_metadata": metadata + }) + elif isinstance(config, BackendConfig): + backend_config = config.copy(update={"internal_metadata": metadata}) + else: + raise TypeError("config must be a BackendConfig or a dictionary.") + backend_config._validate_complete() ray.get( controller.create_backend.remote(backend_tag, backend_config, replica_config)) diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index e5e9bba2e..bcf2bfe36 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -318,7 +318,7 @@ class RayServeWorker: all_evaluated_futures = [] - if not self.config.accepts_batches: + if not self.config.internal_metadata.accepts_batches: query = batch[0] evaluated = asyncio.ensure_future(self.invoke_single(query)) all_evaluated_futures = [evaluated] @@ -336,7 +336,7 @@ class RayServeWorker: chain_future( unpack_future(evaluated, len(group)), result_futures) - if self.config.is_blocking: + if self.config.internal_metadata.is_blocking: # We use asyncio.wait here so if the result is exception, # it will not be raised. await asyncio.wait(all_evaluated_futures) diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 4fbe0dccd..216e26848 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -1,6 +1,9 @@ import inspect +from pydantic import BaseModel, PositiveInt, validator from ray.serve.constants import ASYNC_CONCURRENCY +from typing import Optional, Dict, Any +from dataclasses import dataclass def _callable_accepts_batch(func_or_class): @@ -17,86 +20,89 @@ def _callable_is_blocking(func_or_class): return not inspect.iscoroutinefunction(func_or_class.__call__) -class BackendConfig: - def __init__(self, config_dict, accepts_batches=False, is_blocking=True): - assert isinstance(config_dict, dict) - # Make a copy so that we don't modify the input dict. - config_dict = config_dict.copy() +@dataclass +class BackendMetadata: + accepts_batches: bool = False + is_blocking: bool = True + autoscaling_config: Optional[Dict[str, Any]] = None - self.accepts_batches = accepts_batches - self.is_blocking = is_blocking - self.num_replicas = config_dict.pop("num_replicas", 1) - self.max_batch_size = config_dict.pop("max_batch_size", None) - self.batch_wait_timeout = config_dict.pop("batch_wait_timeout", 0) - self.max_concurrent_queries = config_dict.pop("max_concurrent_queries", - None) - self.autoscaling_config = config_dict.pop("autoscaling", None) - if self.max_concurrent_queries is None: +class BackendConfig(BaseModel): + """Configuration options for a backend, to be set by the user. + + :param num_replicas: The number of worker processes to start up that will + handle requests to this backend. Defaults to 0. + :type num_replicas: int, optional + :param max_batch_size: The maximum number of requests that will be + processed in one batch by this backend. Defaults to None (no + maximium). + :type max_batch_size: int, optional + :param batch_wait_timeout: The time in seconds that backend replicas will + wait for a full batch of requests before processing a partial batch. + Defaults to 0. + :type batch_wait_timeout: float, optional + :param max_concurrent_queries: The maximum number of queries that will be + sent to a replica of this backend without receiving a response. + Defaults to None (no maximum). + :type max_concurrent_queries: int, optional + """ + + internal_metadata: BackendMetadata = BackendMetadata() + num_replicas: PositiveInt = 1 + max_batch_size: Optional[PositiveInt] = None + batch_wait_timeout: float = 0 + max_concurrent_queries: Optional[int] = None + + class Config: + validate_assignment = True + extra = "forbid" + arbitrary_types_allowed = True + + def _validate_batch_size(self): + if (self.max_batch_size is not None + and not self.internal_metadata.accepts_batches + and self.max_batch_size > 1): + raise ValueError( + "max_batch_size is set in config but the function or " + "method does not accept batching. Please use " + "@serve.accept_batch to explicitly mark that the function or " + "method accepts a list of requests as an argument.") + + # This is not a pydantic validator, so that we may skip this method when + # creating partially filled BackendConfig objects to pass as updates--for + # example, BackendConfig(max_batch_size=5). + def _validate_complete(self): + self._validate_batch_size() + + # Dynamic default for max_concurrent_queries + @validator("max_concurrent_queries", always=True) + def set_max_queries_by_mode(cls, v, values): + if v is None: # Model serving mode: if the servable is blocking and the wait # timeout is default zero seconds, then we keep the existing # behavior to allow at most max batch size queries. - if self.is_blocking and self.batch_wait_timeout == 0: - if self.max_batch_size: - self.max_concurrent_queries = 2 * self.max_batch_size + if (values["internal_metadata"].is_blocking + and values["batch_wait_timeout"] == 0): + if ("max_batch_size" in values + and values["max_batch_size"] is not None): + v = 2 * values["max_batch_size"] else: - self.max_concurrent_queries = 8 + v = 8 # Pipeline/async mode: if the servable is not blocking, # router should just keep pushing queries to the worker # replicas until a high limit. - if not self.is_blocking: - self.max_concurrent_queries = ASYNC_CONCURRENCY + if not values["internal_metadata"].is_blocking: + v = ASYNC_CONCURRENCY # Batch inference mode: user specifies non zero timeout to wait for # full batch. We will use 2*max_batch_size to perform double # buffering to keep the replica busy. - if self.max_batch_size is not None and self.batch_wait_timeout > 0: - self.max_concurrent_queries = 2 * self.max_batch_size - - if len(config_dict) != 0: - raise ValueError("Unknown options in backend config: {}".format( - list(config_dict.keys()))) - - self._validate() - - def update(self, config_dict): - """Updates this BackendConfig with options set in the passed config. - - Unspecified keys will remain the same. - """ - if "num_replicas" in config_dict: - self.num_replicas = config_dict.pop("num_replicas") - if "max_batch_size" in config_dict: - self.max_batch_size = config_dict.pop("max_batch_size") - if "max_concurrent_queries" in config_dict: - self.max_concurrent_queries = config_dict.pop( - "max_concurrent_queries") - - if len(config_dict) != 0: - raise ValueError("Unknown options in backend config: {}".format( - list(config_dict.keys()))) - - self._validate() - - def _validate(self): - if not isinstance(self.num_replicas, int): - raise TypeError("num_replicas must be an int.") - elif self.num_replicas < 1: - raise ValueError("num_replicas must be >= 1.") - - if self.max_batch_size is not None: - if not isinstance(self.max_batch_size, int): - raise TypeError("max_batch_size must be an integer.") - elif self.max_batch_size < 1: - raise ValueError("max_batch_size must be >= 1.") - - if not self.accepts_batches and self.max_batch_size > 1: - raise ValueError( - "max_batch_size is set in config but the function or " - "method does not accept batching. Please use " - "@serve.accept_batch to explicitly mark the function or " - "method as batchable and takes in list as arguments.") + if ("max_batch_size" in values + and values["max_batch_size"] is not None + and values["batch_wait_timeout"] > 0): + v = 2 * values["max_batch_size"] + return v class ReplicaConfig: diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index bf2faa2a8..d42541eed 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -1,8 +1,10 @@ import asyncio -from collections import defaultdict, namedtuple +from collections import defaultdict import os import random import time +from typing import Union, Dict, Any, List, Tuple +from pydantic import BaseModel import ray import ray.cloudpickle as pickle @@ -16,7 +18,6 @@ from ray.serve.utils import (format_actor_name, get_random_letters, logger, try_schedule_resources_on_nodes, get_all_node_ids) from ray.serve.config import BackendConfig, ReplicaConfig from ray.actor import ActorHandle -from typing import Dict, List, Any, Tuple import numpy as np @@ -62,8 +63,17 @@ class TrafficPolicy: self.shadow_dict[backend] = proportion -BackendInfo = namedtuple("BackendInfo", - ["worker_class", "backend_config", "replica_config"]) +class BackendInfo(BaseModel): + # TODO(architkulkarni): Add type hint for worker_class after upgrading + # cloudpickle and adding types to RayServeWrappedWorker + worker_class: Any + backend_config: BackendConfig + replica_config: ReplicaConfig + + class Config: + # TODO(architkulkarni): Remove once ReplicaConfig is a pydantic + # model + arbitrary_types_allowed = True @ray.remote @@ -313,9 +323,10 @@ class ServeController: for router in self.routers.values() ]) await self.broadcast_backend_config(backend) - if info.backend_config.autoscaling_config is not None: + metadata = info.backend_config.internal_metadata + if metadata.autoscaling_config is not None: self.autoscaling_policies[backend] = BasicAutoscalingPolicy( - backend, info.backend_config.autoscaling_config) + backend, metadata.autoscaling_config) # Push configuration state to the routers. await asyncio.gather(*[ @@ -753,11 +764,14 @@ class ServeController: # Save creator that starts replicas, the arguments to be passed in, # and the configuration for the backends. self.backends[backend_tag] = BackendInfo( - backend_worker, backend_config, replica_config) - if backend_config.autoscaling_config is not None: + worker_class=backend_worker, + backend_config=backend_config, + replica_config=replica_config) + metadata = backend_config.internal_metadata + if metadata.autoscaling_config is not None: self.autoscaling_policies[ backend_tag] = BasicAutoscalingPolicy( - backend_tag, backend_config.autoscaling_config) + backend_tag, metadata.autoscaling_config) try: self._scale_replicas(backend_tag, backend_config.num_replicas) @@ -814,16 +828,25 @@ class ServeController: await self._stop_pending_replicas() await self._remove_pending_backends() - async def update_backend_config(self, backend_tag: str, - config_options: Dict[str, Any]) -> None: + async def update_backend_config( + self, backend_tag: str, + config_options: "Union[BackendConfig, Dict[str, Any]]") -> None: """Set the config for the specified backend.""" async with self.write_lock: assert (backend_tag in self.backends ), "Backend {} is not registered.".format(backend_tag) - assert isinstance(config_options, dict) + assert isinstance(config_options, BackendConfig) or isinstance( + config_options, dict) - self.backends[backend_tag].backend_config.update(config_options) - backend_config = self.backends[backend_tag].backend_config + if isinstance(config_options, BackendConfig): + update_data = config_options.dict(exclude_unset=True) + elif isinstance(config_options, dict): + update_data = config_options + + stored_backend_config = self.backends[backend_tag].backend_config + backend_config = stored_backend_config.copy(update=update_data) + backend_config._validate_complete() + self.backends[backend_tag].backend_config = backend_config # Scale the replicas with the new configuration. self._scale_replicas(backend_tag, backend_config.num_replicas) diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index c1b141724..803171c9e 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -11,6 +11,7 @@ from ray.test_utils import wait_for_condition from ray.serve import constants from ray.serve.exceptions import RayServeException from ray.serve.utils import format_actor_name, get_random_letters +from ray.serve.config import BackendConfig def test_e2e(serve_instance): @@ -150,6 +151,43 @@ def test_scaling_replicas(serve_instance): self.count += 1 return self.count + serve.create_backend( + "counter:v1", Counter, config=BackendConfig(num_replicas=2)) + serve.create_endpoint("counter", backend="counter:v1", route="/increment") + + # Keep checking the routing table until /increment is populated + while "/increment" not in requests.get( + "http://127.0.0.1:8000/-/routes").json(): + time.sleep(0.2) + + counter_result = [] + for _ in range(10): + resp = requests.get("http://127.0.0.1:8000/increment").json() + counter_result.append(resp) + + # If the load is shared among two replicas. The max result cannot be 10. + assert max(counter_result) < 10 + + serve.update_backend_config("counter:v1", {"num_replicas": 1}) + + counter_result = [] + for _ in range(10): + resp = requests.get("http://127.0.0.1:8000/increment").json() + counter_result.append(resp) + # Give some time for a replica to spin down. But majority of the request + # should be served by the only remaining replica. + assert max(counter_result) - min(counter_result) > 6 + + +def test_scaling_replicas_legacy(serve_instance): + class Counter: + def __init__(self): + self.count = 0 + + def __call__(self, _): + self.count += 1 + return self.count + serve.create_backend("counter:v1", Counter, config={"num_replicas": 2}) serve.create_endpoint("counter", backend="counter:v1", route="/increment") @@ -188,6 +226,43 @@ def test_batching(serve_instance): batch_size = serve.context.batch_size return [self.count] * batch_size + # set the max batch size + serve.create_backend( + "counter:v11", + BatchingExample, + config=BackendConfig(max_batch_size=5, batch_wait_timeout=1)) + serve.create_endpoint( + "counter1", backend="counter:v11", route="/increment2") + + # Keep checking the routing table until /increment is populated + while "/increment2" not in requests.get( + "http://127.0.0.1:8000/-/routes").json(): + time.sleep(0.2) + + future_list = [] + handle = serve.get_handle("counter1") + for _ in range(20): + f = handle.remote(temp=1) + future_list.append(f) + + counter_result = ray.get(future_list) + # since count is only updated per batch of queries + # If there atleast one __call__ fn call with batch size greater than 1 + # counter result will always be less than 20 + assert max(counter_result) < 20 + + +def test_batching_legacy(serve_instance): + class BatchingExample: + def __init__(self): + self.count = 0 + + @serve.accept_batch + def __call__(self, flask_request, temp=None): + self.count += 1 + batch_size = serve.context.batch_size + return [self.count] * batch_size + # set the max batch size serve.create_backend( "counter:v11", @@ -227,6 +302,27 @@ def test_batching_exception(serve_instance): batch_size = serve.context.batch_size return batch_size + # set the max batch size + serve.create_backend( + "exception:v1", NoListReturned, config=BackendConfig(max_batch_size=5)) + serve.create_endpoint( + "exception-test", backend="exception:v1", route="/noListReturned") + + handle = serve.get_handle("exception-test") + with pytest.raises(ray.exceptions.RayTaskError): + assert ray.get(handle.remote(temp=1)) + + +def test_batching_exception_legacy(serve_instance): + class NoListReturned: + def __init__(self): + self.count = 0 + + @serve.accept_batch + def __call__(self, flask_request, temp=None): + batch_size = serve.context.batch_size + return batch_size + # set the max batch size serve.create_backend( "exception:v1", NoListReturned, config={"max_batch_size": 5}) @@ -248,6 +344,40 @@ def test_updating_config(serve_instance): batch_size = serve.context.batch_size return [1] * batch_size + serve.create_backend( + "bsimple:v1", + BatchSimple, + config=BackendConfig(max_batch_size=2, num_replicas=3)) + serve.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple") + + controller = serve.api._get_controller() + old_replica_tag_list = ray.get( + controller._list_replicas.remote("bsimple:v1")) + + serve.update_backend_config("bsimple:v1", BackendConfig(max_batch_size=5)) + new_replica_tag_list = ray.get( + controller._list_replicas.remote("bsimple:v1")) + new_all_tag_list = [] + for worker_dict in ray.get( + controller.get_all_worker_handles.remote()).values(): + new_all_tag_list.extend(list(worker_dict.keys())) + + # the old and new replica tag list should be identical + # and should be subset of all_tag_list + assert set(old_replica_tag_list) <= set(new_all_tag_list) + assert set(old_replica_tag_list) == set(new_replica_tag_list) + + +def test_updating_config_legacy(serve_instance): + class BatchSimple: + def __init__(self): + self.count = 0 + + @serve.accept_batch + def __call__(self, flask_request, temp=None): + batch_size = serve.context.batch_size + return [1] * batch_size + serve.create_backend( "bsimple:v1", BatchSimple, @@ -450,6 +580,43 @@ def test_parallel_start(serve_instance): barrier = Barrier.remote(release_on=2) + class LongStartingServable: + def __init__(self): + ray.get(barrier.wait.remote(), timeout=10) + + def __call__(self, _): + return "Ready" + + serve.create_backend( + "p:v0", LongStartingServable, config=BackendConfig(num_replicas=2)) + serve.create_endpoint("test-parallel", backend="p:v0") + handle = serve.get_handle("test-parallel") + + ray.get(handle.remote(), timeout=10) + + +def test_parallel_start_legacy(serve_instance): + # Test the ability to start multiple replicas in parallel. + # In the past, when Serve scale up a backend, it does so one by one and + # wait for each replica to initialize. This test avoid this by preventing + # the first replica to finish initialization unless the second replica is + # also started. + @ray.remote + class Barrier: + def __init__(self, release_on): + self.release_on = release_on + self.current_waiters = 0 + self.event = asyncio.Event() + + async def wait(self): + self.current_waiters += 1 + if self.current_waiters == self.release_on: + self.event.set() + else: + await self.event.wait() + + barrier = Barrier.remote(release_on=2) + class LongStartingServable: def __init__(self): ray.get(barrier.wait.remote(), timeout=10) @@ -512,6 +679,33 @@ def test_list_endpoints(serve_instance): def test_list_backends(serve_instance): serve.init() + @serve.accept_batch + def f(): + pass + + serve.create_backend("backend", f, config=BackendConfig(max_batch_size=10)) + backends = serve.list_backends() + assert len(backends) == 1 + assert "backend" in backends + assert backends["backend"]["max_batch_size"] == 10 + + serve.create_backend("backend2", f, config=BackendConfig(num_replicas=10)) + backends = serve.list_backends() + assert len(backends) == 2 + assert backends["backend2"]["num_replicas"] == 10 + + serve.delete_backend("backend") + backends = serve.list_backends() + assert len(backends) == 1 + assert "backend2" in backends + + serve.delete_backend("backend2") + assert len(serve.list_backends()) == 0 + + +def test_list_backends_legacy(serve_instance): + serve.init() + @serve.accept_batch def f(): pass @@ -555,6 +749,37 @@ def test_endpoint_input_validation(serve_instance): def test_create_infeasible_error(serve_instance): serve.init() + def f(): + pass + + # Non existent resource should be infeasible. + with pytest.raises(RayServeException, match="Cannot scale backend"): + serve.create_backend( + "f:1", + f, + ray_actor_options={"resources": { + "MagicMLResource": 100 + }}) + + # Even each replica might be feasible, the total might not be. + current_cpus = int(ray.nodes()[0]["Resources"]["CPU"]) + with pytest.raises(RayServeException, match="Cannot scale backend"): + serve.create_backend( + "f:1", + f, + ray_actor_options={"resources": { + "CPU": 1, + }}, + config=BackendConfig(num_replicas=(current_cpus + 20))) + + # No replica should be created! + replicas = ray.get(serve.api.controller._list_replicas.remote("f1")) + assert len(replicas) == 0 + + +def test_create_infeasible_error_legacy(serve_instance): + serve.init() + def f(): pass diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index 0b13c6cfe..9c77590e5 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -10,7 +10,7 @@ from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error from ray.serve.controller import TrafficPolicy from ray.serve.request_params import RequestMetadata from ray.serve.router import Router -from ray.serve.config import BackendConfig +from ray.serve.config import BackendConfig, BackendMetadata from ray.serve.exceptions import RayServeException pytestmark = pytest.mark.asyncio @@ -19,7 +19,7 @@ pytestmark = pytest.mark.asyncio def setup_worker(name, func_or_class, init_args=None, - backend_config=BackendConfig({})): + backend_config=BackendConfig()): if init_args is None: init_args = () @@ -178,10 +178,9 @@ async def test_task_runner_custom_method_batch(serve_instance): PRODUCER_NAME = "producer" backend_config = BackendConfig( - { - "max_batch_size": 4, - "batch_wait_timeout": 2 - }, accepts_batches=True) + max_batch_size=4, + batch_wait_timeout=2, + internal_metadata=BackendMetadata(accepts_batches=True)) worker = setup_worker( CONSUMER_NAME, Batcher, backend_config=backend_config) @@ -230,10 +229,9 @@ async def test_task_runner_perform_batch(serve_instance): PRODUCER_NAME = "producer" config = BackendConfig( - { - "max_batch_size": 2, - "batch_wait_timeout": 10 - }, accepts_batches=True) + max_batch_size=2, + batch_wait_timeout=10, + internal_metadata=BackendMetadata(accepts_batches=True)) worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) @@ -277,7 +275,9 @@ async def test_task_runner_perform_async(serve_instance): CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" - config = BackendConfig({"max_concurrent_queries": 10}, is_blocking=False) + config = BackendConfig( + max_concurrent_queries=10, + internal_metadata=BackendMetadata(is_blocking=False)) worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) diff --git a/python/ray/serve/tests/test_config.py b/python/ray/serve/tests/test_config.py index 7c9b3ad26..00dc6fa2f 100644 --- a/python/ray/serve/tests/test_config.py +++ b/python/ray/serve/tests/test_config.py @@ -1,67 +1,75 @@ import pytest from ray import serve -from ray.serve.config import BackendConfig, ReplicaConfig +from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata +from ray.serve.constants import ASYNC_CONCURRENCY +from pydantic import ValidationError def test_backend_config_validation(): # Test unknown key. - with pytest.raises(ValueError, match="unknown_key"): - BackendConfig({"unknown_key": -1}) - - # Test that the input dict isn't modified. - config = {"num_replicas": 2} - BackendConfig(config) - assert len(config) == 1 and config["num_replicas"] == 2 + with pytest.raises(ValidationError): + BackendConfig(unknown_key=-1) # Test num_replicas validation. - BackendConfig({"num_replicas": 1}) - with pytest.raises(TypeError): - BackendConfig({"num_replicas": "hello"}) - with pytest.raises(ValueError): - BackendConfig({"num_replicas": -1}) + BackendConfig(num_replicas=1) + with pytest.raises(ValidationError, match="type_error"): + BackendConfig(num_replicas="hello") + with pytest.raises(ValidationError, match="value_error"): + BackendConfig(num_replicas=-1) # Test max_batch_size validation. - BackendConfig({"max_batch_size": 10}, accepts_batches=True) + BackendConfig( + max_batch_size=10, + internal_metadata=BackendMetadata( + accepts_batches=True))._validate_complete() with pytest.raises(ValueError): - BackendConfig({"max_batch_size": 10}, accepts_batches=False) - with pytest.raises(TypeError): - BackendConfig({"max_batch_size": 1.0}) - with pytest.raises(TypeError): - BackendConfig({"max_batch_size": "hello"}) - with pytest.raises(ValueError): - BackendConfig({"max_batch_size": 0}) - with pytest.raises(ValueError): - BackendConfig({"max_batch_size": -1}) + BackendConfig( + max_batch_size=10, + internal_metadata=BackendMetadata( + accepts_batches=False))._validate_complete() + with pytest.raises(ValidationError, match="type_error"): + BackendConfig(max_batch_size="hello") + with pytest.raises(ValidationError, match="value_error"): + BackendConfig(max_batch_size=0) + with pytest.raises(ValidationError, match="value_error"): + BackendConfig(max_batch_size=-1) + + # Test dynamic default for max_concurrent_queries. + assert BackendConfig().max_concurrent_queries == 8 + assert BackendConfig(max_batch_size=7).max_concurrent_queries == 14 + assert BackendConfig( + max_batch_size=10, + internal_metadata=BackendMetadata( + is_blocking=False)).max_concurrent_queries == ASYNC_CONCURRENCY + assert BackendConfig( + max_batch_size=7, batch_wait_timeout=1.0).max_concurrent_queries == 14 def test_backend_config_update(): - b = BackendConfig({"num_replicas": 1, "max_batch_size": 1}) + b = BackendConfig(num_replicas=1, max_batch_size=1) # Test updating a key works. - b.update({"num_replicas": 2}) + b.num_replicas = 2 assert b.num_replicas == 2 # Check that not specifying a key doesn't update it. assert b.max_batch_size == 1 - # Check that passing an invalid key fails. - with pytest.raises(ValueError): - b.update({"unknown": 1}) - # Check that input is validated. - with pytest.raises(TypeError): - b.update({"num_replicas": "hello"}) - with pytest.raises(ValueError): - b.update({"num_replicas": -1}) + with pytest.raises(ValidationError): + b.num_replicas = "Hello" + with pytest.raises(ValidationError): + b.num_replicas = -1 # Test batch validation. - b = BackendConfig({}, accepts_batches=False) - b.update({"max_batch_size": 1}) + b = BackendConfig(internal_metadata=BackendMetadata(accepts_batches=False)) + b.max_batch_size = 1 with pytest.raises(ValueError): - b.update({"max_batch_size": 2}) + b.max_batch_size = 2 + b._validate_complete() - b = BackendConfig({}, accepts_batches=True) - b.update({"max_batch_size": 2}) + b = BackendConfig(internal_metadata=BackendMetadata(accepts_batches=True)) + b.max_batch_size = 2 def test_replica_config_validation(): diff --git a/python/ray/serve/tests/test_failure.py b/python/ray/serve/tests/test_failure.py index 3917a4050..6a309e69d 100644 --- a/python/ray/serve/tests/test_failure.py +++ b/python/ray/serve/tests/test_failure.py @@ -183,7 +183,8 @@ def test_worker_replica_failure(serve_instance): temp_path = os.path.join(tempfile.gettempdir(), serve.utils.get_random_letters()) serve.create_backend("replica_failure", Worker, temp_path) - serve.update_backend_config("replica_failure", {"num_replicas": 2}) + serve.update_backend_config( + "replica_failure", BackendConfig(num_replicas=2)) serve.create_endpoint( "replica_failure", backend="replica_failure", route="/replica_failure") @@ -219,7 +220,7 @@ def test_create_backend_idempotent(serve_instance): controller = serve.api._get_controller() replica_config = ReplicaConfig(f) - backend_config = BackendConfig({"num_replicas": 1}) + backend_config = BackendConfig(num_replicas=1) for i in range(10): ray.get( diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index 67cdc6749..7a080fb6b 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -181,7 +181,7 @@ async def test_router_use_max_concurrency(serve_instance): q = ray.remote(VisibleRouter).remote() await q.setup.remote("") backend_name = "max-concurrent-test" - config = BackendConfig({"max_concurrent_queries": 1}) + config = BackendConfig(max_concurrent_queries=1) await q.set_traffic.remote("svc", TrafficPolicy({backend_name: 1.0})) await q.add_new_worker.remote(backend_name, "replica-tag", worker) await q.set_backend_config.remote(backend_name, config) diff --git a/python/requirements.txt b/python/requirements.txt index 799fd78ed..b4411342f 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -37,6 +37,7 @@ scipy==1.4.1 tabulate tensorboardX uvicorn +pydantic dataclasses # Requirements for running tests diff --git a/python/setup.py b/python/setup.py index ced73786a..02d29ce0b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -110,7 +110,7 @@ if os.getenv("RAY_USE_NEW_GCS") == "on": # in this directory extras = { "debug": [], - "serve": ["uvicorn", "flask", "requests", "dataclasses"], + "serve": ["uvicorn", "flask", "requests", "pydantic", "dataclasses"], "tune": ["tabulate", "tensorboardX", "pandas"] }