[serve] Refactor BackendConfig (#8202)

This commit is contained in:
Edward Oakes
2020-04-30 22:31:07 -05:00
committed by GitHub
parent 95d187e556
commit 6373c70661
33 changed files with 406 additions and 310 deletions
+2 -3
View File
@@ -1,13 +1,12 @@
from ray.serve.backend_config import BackendConfig
from ray.serve.policy import RoutePolicy
from ray.serve.api import (init, create_backend, delete_backend,
create_endpoint, delete_endpoint, set_traffic,
get_handle, stat, set_backend_config,
get_handle, stat, update_backend_config,
get_backend_config, accept_batch) # noqa: E402
__all__ = [
"init", "create_backend", "delete_backend", "create_endpoint",
"delete_endpoint", "set_traffic", "get_handle", "stat",
"set_backend_config", "get_backend_config", "BackendConfig", "RoutePolicy",
"update_backend_config", "get_backend_config", "RoutePolicy",
"accept_batch"
]
+32 -48
View File
@@ -1,4 +1,3 @@
import inspect
from functools import wraps
from tempfile import mkstemp
@@ -11,8 +10,8 @@ from ray.serve.master import ServeMaster
from ray.serve.handle import RayServeHandle
from ray.serve.kv_store_service import SQLiteKVStore
from ray.serve.utils import block_until_http_ready, retry_actor_failures
from ray.serve.exceptions import RayServeException, batch_annotation_not_found
from ray.serve.backend_config import BackendConfig
from ray.serve.exceptions import RayServeException
from ray.serve.config import BackendConfig, ReplicaConfig
from ray.serve.policy import RoutePolicy
from ray.serve.router import Query
from ray.serve.request_params import RequestMetadata
@@ -56,7 +55,7 @@ def accept_batch(f):
def __call__(self, *, python_arg=None):
assert isinstance(python_arg, list)
"""
f.serve_accept_batch = True
f._serve_accept_batch = True
return f
@@ -176,15 +175,19 @@ def delete_endpoint(endpoint):
@_ensure_connected
def set_backend_config(backend_tag, backend_config):
"""Set a backend configuration for a backend tag
def update_backend_config(backend_tag, config_options):
"""Update a backend configuration for a backend tag.
Keys not specified in the passed will be left unchanged.
Args:
backend_tag(str): A registered backend.
backend_config(BackendConfig) : Desired backend configuration.
config_options(dict): Backend config options to update.
"""
retry_actor_failures(master_actor.set_backend_config, backend_tag,
backend_config)
if not isinstance(config_options, dict):
raise ValueError("config_options must be a dictionary.")
retry_actor_failures(master_actor.update_backend_config, backend_tag,
config_options)
@_ensure_connected
@@ -197,56 +200,37 @@ def get_backend_config(backend_tag):
return retry_actor_failures(master_actor.get_backend_config, backend_tag)
def _backend_accept_batch(func_or_class):
if inspect.isfunction(func_or_class):
return hasattr(func_or_class, "serve_accept_batch")
elif inspect.isclass(func_or_class):
return hasattr(func_or_class.__call__, "serve_accept_batch")
@_ensure_connected
def create_backend(func_or_class,
backend_tag,
def create_backend(backend_tag,
func_or_class,
*actor_init_args,
backend_config=None):
"""Create a backend using func_or_class and assign backend_tag.
ray_actor_options=None,
config=None):
"""Create a backend with the provided tag.
The backend will serve requests with func_or_class.
Args:
backend_tag (str): a unique tag assign to identify this backend.
func_or_class (callable, class): a function or a class implementing
__call__.
backend_tag (str): a unique tag assign to this backend. It will be used
to associate services in traffic policy.
backend_config (BackendConfig): An object defining backend properties
for starting a backend.
*actor_init_args (optional): the argument to pass to the class
actor_init_args (optional): the arguments to pass to the class.
initialization method.
ray_actor_options (optional): options to be passed into the
@ray.remote decorator for the backend actor.
config: (optional) configuration options for this backend.
"""
# Configure backend_config
if backend_config is None:
backend_config = BackendConfig()
assert isinstance(backend_config,
BackendConfig), ("backend_config must be"
" of instance BackendConfig")
if config is None:
config = {}
if not isinstance(config, dict):
raise TypeError("config must be a dictionary.")
# Validate that func_or_class is a function or class.
if inspect.isfunction(func_or_class):
if len(actor_init_args) != 0:
raise ValueError(
"actor_init_args not supported for function backend.")
elif not inspect.isclass(func_or_class):
raise ValueError(
"Backend must be a function or class, it is {}.".format(
type(func_or_class)))
# Make sure the batch size is correct.
should_accept_batch = backend_config.max_batch_size is not None
if should_accept_batch and not _backend_accept_batch(func_or_class):
raise batch_annotation_not_found
if _backend_accept_batch(func_or_class):
backend_config.has_accept_batch_annotation = True
replica_config = ReplicaConfig(
func_or_class, *actor_init_args, ray_actor_options=ray_actor_options)
backend_config = BackendConfig(config, replica_config.accepts_batches)
retry_actor_failures(master_actor.create_backend, backend_tag,
backend_config, func_or_class, actor_init_args)
backend_config, replica_config)
@_ensure_connected
-62
View File
@@ -1,62 +0,0 @@
from copy import deepcopy
class BackendConfig:
# configs not needed for actor creation when
# instantiating a replica
_serve_configs = [
"_num_replicas", "max_batch_size", "has_accept_batch_annotation"
]
# configs which when changed leads to restarting
# the existing replicas.
restart_on_change_fields = ["resources", "num_cpus", "num_gpus"]
def __init__(self,
num_replicas=1,
resources=None,
max_batch_size=None,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None,
has_accept_batch_annotation=False):
"""
Class for defining backend configuration.
"""
# backend metadata
self.has_accept_batch_annotation = has_accept_batch_annotation
# serve configs
self.num_replicas = num_replicas
self.max_batch_size = max_batch_size
# ray actor configs
self.resources = resources
self.num_cpus = num_cpus
self.num_gpus = num_gpus
self.memory = memory
self.object_store_memory = object_store_memory
@property
def num_replicas(self):
return self._num_replicas
@num_replicas.setter
def num_replicas(self, val):
if not (val > 0):
raise Exception("num_replicas must be greater than zero")
self._num_replicas = val
def __iter__(self):
for k in self.__dict__.keys():
key, val = k, self.__dict__[k]
if key == "_num_replicas":
key = "num_replicas"
yield key, val
def get_actor_creation_args(self):
ret_d = deepcopy(self.__dict__)
for k in self._serve_configs:
ret_d.pop(k)
return ret_d
+130
View File
@@ -0,0 +1,130 @@
import inspect
def _callable_accepts_batch(func_or_class):
if inspect.isfunction(func_or_class):
return hasattr(func_or_class, "_serve_accept_batch")
elif inspect.isclass(func_or_class):
return hasattr(func_or_class.__call__, "_serve_accept_batch")
class BackendConfig:
def __init__(self, config_dict, accepts_batches=False):
assert isinstance(config_dict, dict)
self.accepts_batches = accepts_batches
self.num_replicas = config_dict.pop("num_replicas", 1)
self.max_batch_size = config_dict.pop("max_batch_size", None)
if len(config_dict) != 0:
raise ValueError("Unknown options in backend config: {}".format(
list(config_dict.keys())))
self._validate()
def update(self, config_dict):
"""Updates this BackendConfig with options set in the passed config.
Unspecified keys will remain the same.
"""
if "num_replicas" in config_dict:
self.num_replicas = config_dict.pop("num_replicas")
if "max_batch_size" in config_dict:
self.max_batch_size = config_dict.pop("max_batch_size")
if len(config_dict) != 0:
raise ValueError("Unknown options in backend config: {}".format(
list(config_dict.keys())))
self._validate()
def _validate(self):
if not isinstance(self.num_replicas, int):
raise TypeError("num_replicas must be an int.")
elif self.num_replicas < 1:
raise ValueError("num_replicas must be >= 1.")
if self.max_batch_size is not None:
if not isinstance(self.max_batch_size, int):
raise TypeError("max_batch_size must be an integer.")
elif self.max_batch_size < 1:
raise ValueError("max_batch_size must be >= 1.")
if not self.accepts_batches and self.max_batch_size > 1:
raise ValueError(
"max_batch_size is set in config but the function or "
"method does not accept batching. Please use "
"@serve.accept_batch to explicitly mark the function or "
"method as batchable and takes in list as arguments.")
class ReplicaConfig:
def __init__(self, func_or_class, *actor_init_args,
ray_actor_options=None):
self.func_or_class = func_or_class
self.accepts_batches = _callable_accepts_batch(func_or_class)
self.actor_init_args = list(actor_init_args)
if ray_actor_options is None:
self.ray_actor_options = {}
else:
self.ray_actor_options = ray_actor_options
self._validate()
def _validate(self):
# Validate that func_or_class is a function or class.
if inspect.isfunction(self.func_or_class):
if len(self.actor_init_args) != 0:
raise ValueError(
"actor_init_args not supported for function backend.")
elif not inspect.isclass(self.func_or_class):
raise TypeError(
"Backend must be a function or class, it is {}.".format(
type(self.func_or_class)))
if not isinstance(self.ray_actor_options, dict):
raise TypeError("ray_actor_options must be a dictionary.")
elif "detached" in self.ray_actor_options:
raise ValueError(
"Specifying detached in actor_init_args is not allowed.")
elif "name" in self.ray_actor_options:
raise ValueError(
"Specifying name in actor_init_args is not allowed.")
elif "max_reconstructions" in self.ray_actor_options:
raise ValueError("Specifying max_reconstructions in "
"actor_init_args is not allowed.")
else:
num_cpus = self.ray_actor_options.get("num_cpus", 0)
if not (isinstance(num_cpus, int) or isinstance(num_cpus, float)):
raise TypeError(
"num_cpus in ray_actor_options must be an int or a float.")
elif num_cpus < 0:
raise ValueError("num_cpus in ray_actor_options must be >= 0.")
num_gpus = self.ray_actor_options.get("num_gpus", 0)
if not (isinstance(num_gpus, int) or isinstance(num_gpus, float)):
raise TypeError(
"num_gpus in ray_actor_options must be an int or a float.")
elif num_gpus < 0:
raise ValueError("num_gpus in ray_actor_options must be >= 0.")
memory = self.ray_actor_options.get("memory", 0)
if not (isinstance(memory, int) or isinstance(memory, float)):
raise TypeError(
"memory in ray_actor_options must be an int or a float.")
elif memory < 0:
raise ValueError("num_gpus in ray_actor_options must be >= 0.")
object_store_memory = self.ray_actor_options.get(
"object_store_memory", 0)
if not (isinstance(object_store_memory, int)
or isinstance(object_store_memory, float)):
raise TypeError(
"object_store_memory in ray_actor_options must be "
"an int or a float.")
elif object_store_memory < 0:
raise ValueError(
"object_store_memory in ray_actor_options must be >= 0.")
if not isinstance(
self.ray_actor_options.get("resources", {}), dict):
raise TypeError(
"resources in ray_actor_options must be a dictionary.")
+1 -1
View File
@@ -13,7 +13,7 @@ def noop(_):
serve.create_endpoint("noop", "/noop")
serve.create_backend(noop, "noop")
serve.create_backend("noop", noop)
serve.set_traffic("noop", {"noop": 1.0})
url = "{}/noop".format(DEFAULT_HTTP_ADDRESS)
@@ -13,7 +13,7 @@ class Counter:
serve.create_endpoint("counter", "/counter")
serve.create_backend(Counter, "counter")
serve.create_backend("counter", Counter)
serve.set_traffic("counter", {"counter": 1.0})
requests.get("http://127.0.0.1:8000/counter").json()
@@ -9,7 +9,7 @@ def echo(flask_request):
serve.create_endpoint("hello", "/hello")
serve.create_backend(echo, "hello")
serve.create_backend("hello", echo)
serve.set_traffic("hello", {"hello": 1.0})
requests.get("http://127.0.0.1:8000/hello").text
@@ -46,7 +46,7 @@ class ImageModel:
# __doc_deploy_begin__
serve.init()
serve.create_endpoint("predictor", "/image_predict", methods=["POST"])
serve.create_backend(ImageModel, "resnet18:v0")
serve.create_backend("resnet18:v0", ImageModel)
serve.set_traffic("predictor", {"resnet18:v0": 1})
# __doc_deploy_end__
@@ -66,7 +66,7 @@ class BoostingModel:
# __doc_deploy_begin__
serve.init()
serve.create_endpoint("iris_classifier", "/regressor")
serve.create_backend(BoostingModel, "lr:v1")
serve.create_backend("lr:v1", BoostingModel)
serve.set_traffic("iris_classifier", {"lr:v1": 1})
# __doc_deploy_end__
@@ -70,7 +70,7 @@ class TFMnistModel:
# __doc_deploy_begin__
serve.init()
serve.create_endpoint(endpoint_name="tf_classifier", route="/mnist")
serve.create_backend(TFMnistModel, "tf:v1", "/tmp/mnist_model.h5")
serve.create_backend("tf:v1", TFMnistModel, "/tmp/mnist_model.h5")
serve.set_traffic("tf_classifier", {"tf:v1": 1})
# __doc_deploy_end__
+1 -1
View File
@@ -17,7 +17,7 @@ def echo(flask_request):
serve.init(blocking=True)
serve.create_endpoint("my_endpoint", "/echo")
serve.create_backend(echo, "echo:v1")
serve.create_backend("echo:v1", echo)
serve.set_traffic("my_endpoint", {"echo:v1": 1.0})
while True:
+1 -1
View File
@@ -26,7 +26,7 @@ class MagicCounter:
serve.init(blocking=True)
serve.create_endpoint("magic_counter", "/counter")
serve.create_backend(MagicCounter, "counter:v1", 42) # increment=42
serve.create_backend("counter:v1", MagicCounter, 42) # increment=42
serve.set_traffic("magic_counter", {"counter:v1": 1.0})
print("Sending ten queries via HTTP")
@@ -12,7 +12,6 @@ import requests
import ray
from ray import serve
from ray.serve.utils import pformat_color_json
from ray.serve import BackendConfig
class MagicCounter:
@@ -38,9 +37,9 @@ class MagicCounter:
serve.init(blocking=True)
serve.create_endpoint("magic_counter", "/counter")
b_config = BackendConfig(max_batch_size=5)
serve.create_backend(
MagicCounter, "counter:v1", 42, backend_config=b_config) # increment=42
"counter:v1", MagicCounter, 42,
config={"max_batch_size": 5}) # increment=42
serve.set_traffic("magic_counter", {"counter:v1": 1.0})
print("Sending ten queries via HTTP")
+3 -4
View File
@@ -4,7 +4,6 @@ This example has backend which has batching functionality enabled.
import ray
from ray import serve
from ray.serve import BackendConfig
class MagicCounter:
@@ -30,11 +29,11 @@ class MagicCounter:
serve.init(blocking=True)
serve.create_endpoint("magic_counter", "/counter")
# specify max_batch_size in BackendConfig
b_config = BackendConfig(max_batch_size=5)
backend_config = {"max_batch_size": 5}
serve.create_backend(
MagicCounter, "counter:v1", 42, backend_config=b_config) # increment=42
"counter:v1", MagicCounter, 42, config=backend_config) # increment=42
print("Backend Config for backend: 'counter:v1'")
print(b_config)
print(backend_config)
serve.set_traffic("magic_counter", {"counter:v1": 1.0})
handle = serve.get_handle("magic_counter")
+1 -1
View File
@@ -29,7 +29,7 @@ def echo(_):
serve.init(blocking=True)
serve.create_endpoint("my_endpoint", "/echo")
serve.create_backend(echo, "echo:v1")
serve.create_backend("echo:v1", echo)
serve.set_traffic("my_endpoint", {"echo:v1": 1.0})
for _ in range(2):
@@ -32,10 +32,10 @@ serve.init(
serve.create_endpoint("my_endpoint", "/echo")
# create first backend
serve.create_backend(echo_v1, "echo:v1")
serve.create_backend("echo:v1", echo_v1)
# create second backend
serve.create_backend(echo_v2, "echo:v2")
serve.create_backend("echo:v2", echo_v2)
# link and split the service to two backends
serve.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
+5 -10
View File
@@ -26,8 +26,7 @@ def echo_v1(flask_request, response="hello from python!"):
return response
serve.create_backend(echo_v1, "echo:v1")
backend_config_v1 = serve.get_backend_config("echo:v1")
serve.create_backend("echo:v1", echo_v1)
# We can link an endpoint to a backend, the means all the traffic
# goes to my_endpoint will now goes to echo:v1 backend.
@@ -47,8 +46,7 @@ def echo_v2(flask_request):
return "something new"
serve.create_backend(echo_v2, "echo:v2")
backend_config_v2 = serve.get_backend_config("echo:v2")
serve.create_backend("echo:v2", echo_v2)
# The two backend will now split the traffic 50%-50%.
serve.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
@@ -58,12 +56,9 @@ for _ in range(10):
print(requests.get("http://127.0.0.1:8000/echo").text)
time.sleep(0.5)
# You can also change number of replicas
# for each backend independently.
backend_config_v1.num_replicas = 2
serve.set_backend_config("echo:v1", backend_config_v1)
backend_config_v2.num_replicas = 2
serve.set_backend_config("echo:v2", backend_config_v2)
# You can also change number of replicas for each backend independently.
serve.update_backend_config("echo:v1", {"num_replicas": 2})
serve.update_backend_config("echo:v2", {"num_replicas": 2})
# As well as retrieving relevant system metrics
print(pformat_color_json(serve.stat()))
+4 -4
View File
@@ -17,7 +17,7 @@ def echo_v1(_, response="hello from python!"):
serve.create_endpoint("echo_v1", "/echo_v1")
serve.create_backend(echo_v1, "echo_v1")
serve.create_backend("echo_v1", echo_v1)
serve.set_traffic("echo_v1", {"echo_v1": 1.0})
@@ -26,7 +26,7 @@ def echo_v2(_, relay=""):
serve.create_endpoint("echo_v2", "/echo_v2")
serve.create_backend(echo_v2, "echo_v2")
serve.create_backend("echo_v2", echo_v2)
serve.set_traffic("echo_v2", {"echo_v2": 1.0})
@@ -35,7 +35,7 @@ def echo_v3(_, relay=""):
serve.create_endpoint("echo_v3", "/echo_v3")
serve.create_backend(echo_v3, "echo_v3")
serve.create_backend("echo_v3", echo_v3)
serve.set_traffic("echo_v3", {"echo_v3": 1.0})
@@ -44,7 +44,7 @@ def echo_v4(_, relay1="", relay2=""):
serve.create_endpoint("echo_v4", "/echo_v4")
serve.create_backend(echo_v4, "echo_v4")
serve.create_backend("echo_v4", echo_v4)
serve.set_traffic("echo_v4", {"echo_v4": 1.0})
"""
The pipeline created is as follows -
@@ -25,10 +25,10 @@ serve.init(blocking=True, queueing_policy=serve.RoutePolicy.RoundRobin)
serve.create_endpoint("my_endpoint", "/echo")
# create first backend
serve.create_backend(echo_v1, "echo:v1")
serve.create_backend("echo:v1", echo_v1)
# create second backend
serve.create_backend(echo_v2, "echo:v2")
serve.create_backend("echo:v2", echo_v2)
# link and split the service to two backends
serve.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
@@ -25,7 +25,7 @@ def echo_v1(flask_request, response="hello from python!"):
return response
serve.create_backend(echo_v1, "echo:v1")
serve.create_backend("echo:v1", echo_v1)
serve.set_traffic("my_endpoint", {"echo:v1": 1.0})
# wait for routing table to get populated
+2 -2
View File
@@ -21,7 +21,7 @@ def echo_v2(_):
serve.init(blocking=True)
serve.create_endpoint("my_endpoint", "/echo")
serve.create_backend(echo_v1, "echo:v1")
serve.create_backend("echo:v1", echo_v1)
serve.set_traffic("my_endpoint", {"echo:v1": 1.0})
for _ in range(3):
@@ -31,7 +31,7 @@ for _ in range(3):
print("...Sleeping for 2 seconds...")
time.sleep(2)
serve.create_backend(echo_v2, "echo:v2")
serve.create_backend("echo:v2", echo_v2)
serve.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
while True:
resp = requests.get("http://127.0.0.1:8000/echo").json()
-6
View File
@@ -1,8 +1,2 @@
class RayServeException(Exception):
pass
batch_annotation_not_found = RayServeException(
"max_batch_size is set in config but the function or method does not "
"accept batching. Please use @serve.accept_batch to explicitly mark "
"the function or method as batchable and takes in list as arguments.")
-12
View File
@@ -125,18 +125,6 @@ class RayServeHandle:
backend_tag, list(traffic_policy.keys()))
return backend_tag
def scale(self, new_num_replicas, backend_tag=None):
backend_tag = self._ensure_backend_unique(backend_tag)
config = serve.get_backend_config(backend_tag)
config.num_replicas = new_num_replicas
serve.set_backend_config(backend_tag, config)
def set_max_batch_size(self, new_max_batch_size, backend_tag=None):
backend_tag = self._ensure_backend_unique(backend_tag)
config = serve.get_backend_config(backend_tag)
config.max_batch_size = new_max_batch_size
serve.set_backend_config(backend_tag, config)
def __repr__(self):
return """
RayServeHandle(
+29 -49
View File
@@ -6,10 +6,8 @@ import time
import ray
import ray.cloudpickle as pickle
from ray.serve.backend_config import BackendConfig
from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_ROUTER_NAME,
SERVE_PROXY_NAME, SERVE_METRIC_MONITOR_NAME)
from ray.serve.exceptions import batch_annotation_not_found
from ray.serve.http_proxy import HTTPProxyActor
from ray.serve.metric import (MetricMonitor, start_metric_monitor_loop)
from ray.serve.backend_worker import create_backend_worker
@@ -56,7 +54,7 @@ class ServeMaster:
self.kv_store_client = kv_store_connector("serve_checkpoints")
# path -> (endpoint, methods).
self.routes = {}
# backend -> (worker_creator, init_args, backend_config).
# backend -> (backend_worker, backend_config, replica_config).
self.backends = {}
# backend -> replica_tags.
self.replicas = defaultdict(list)
@@ -238,9 +236,9 @@ class ServeMaster:
await self.router.add_new_worker.remote(
backend_tag, replica_tag, worker)
for backend, (_, _, backend_config_dict) in self.backends.items():
for backend, (_, backend_config, _) in self.backends.items():
await self.router.set_backend_config.remote(
backend, backend_config_dict)
backend, backend_config)
# Push configuration state to the HTTP proxy.
await self.http_proxy.set_route_table.remote(self.routes)
@@ -261,8 +259,8 @@ class ServeMaster:
def get_backend_configs(self):
"""Fetched by the router on startup."""
backend_configs = {}
for backend, (_, _, backend_config_dict) in self.backends.items():
backend_configs[backend] = backend_config_dict
for backend, (_, backend_config, _) in self.backends.items():
backend_configs[backend] = backend_config
return backend_configs
def get_traffic_policies(self):
@@ -285,16 +283,15 @@ class ServeMaster:
"""
logger.debug("Starting worker '{}' for backend '{}'.".format(
replica_tag, backend_tag))
worker_creator, init_args, config_dict = self.backends[backend_tag]
# TODO(edoakes): just store the BackendConfig in self.backends.
backend_config = BackendConfig(**config_dict)
kwargs = backend_config.get_actor_creation_args()
(backend_worker, backend_config,
replica_config) = self.backends[backend_tag]
worker_handle = async_retryable(ray.remote(worker_creator)).options(
worker_handle = async_retryable(ray.remote(backend_worker)).options(
detached=True,
name=replica_tag,
max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION,
**kwargs).remote(backend_tag, replica_tag, init_args)
**replica_config.ray_actor_options).remote(
backend_tag, replica_tag, replica_config.actor_init_args)
# TODO(edoakes): we should probably have a timeout here.
await worker_handle.ready.remote()
return worker_handle
@@ -534,20 +531,19 @@ class ServeMaster:
await self.http_proxy.set_route_table.remote(self.routes)
await self._remove_pending_endpoints()
async def create_backend(self, backend_tag, backend_config, func_or_class,
actor_init_args):
async def create_backend(self, backend_tag, backend_config,
replica_config):
"""Register a new backend under the specified tag."""
async with self.write_lock:
backend_config_dict = dict(backend_config)
backend_worker = create_backend_worker(func_or_class)
backend_worker = create_backend_worker(
replica_config.func_or_class)
# Save creator that starts replicas, the arguments to be passed in,
# and the configuration for the backends.
self.backends[backend_tag] = (backend_worker, actor_init_args,
backend_config_dict)
self.backends[backend_tag] = (backend_worker, backend_config,
replica_config)
self._scale_replicas(backend_tag,
backend_config_dict["num_replicas"])
self._scale_replicas(backend_tag, backend_config.num_replicas)
# NOTE(edoakes): we must write a checkpoint before starting new
# or pushing the updated config to avoid inconsistent state if we
@@ -558,7 +554,7 @@ class ServeMaster:
# Set the backend config inside the router
# (particularly for max-batch-size).
await self.router.set_backend_config.remote(
backend_tag, backend_config_dict)
backend_tag, backend_config)
async def delete_backend(self, backend_tag):
async with self.write_lock:
@@ -592,37 +588,21 @@ class ServeMaster:
await self._stop_pending_replicas()
await self._remove_pending_backends()
async def set_backend_config(self, backend_tag, backend_config):
async def update_backend_config(self, backend_tag, config_options):
"""Set the config for the specified backend."""
async with self.write_lock:
if backend_tag not in self.backends:
raise ValueError(
"Backend '{}' is not registered.".format(backend_tag))
if not isinstance(backend_config, BackendConfig):
raise ValueError("backend_config must be a BackendConfig.")
backend_config_dict = dict(backend_config)
backend_worker, init_args, old_backend_config_dict = self.backends[
assert (backend_tag in self.backends
), "Backend {} is not registered.".format(backend_tag)
assert isinstance(config_options, dict)
backend_worker, backend_config, replica_config = self.backends[
backend_tag]
if (not old_backend_config_dict["has_accept_batch_annotation"]
and backend_config.max_batch_size is not None):
raise batch_annotation_not_found
self.backends[backend_tag] = (backend_worker, init_args,
backend_config_dict)
# Restart replicas if there is a change in the backend config
# related to restart_configs.
need_to_restart_replicas = any(
old_backend_config_dict[k] != backend_config_dict[k]
for k in BackendConfig.restart_on_change_fields)
if need_to_restart_replicas:
# Kill all the replicas for restarting with new configurations.
self._scale_replicas(backend_tag, 0)
backend_config.update(config_options)
self.backends[backend_tag] = (backend_worker, backend_config,
replica_config)
# Scale the replicas with the new configuration.
self._scale_replicas(backend_tag,
backend_config_dict["num_replicas"])
self._scale_replicas(backend_tag, backend_config.num_replicas)
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
@@ -632,7 +612,7 @@ class ServeMaster:
# Inform the router about change in configuration
# (particularly for setting max_batch_size).
await self.router.set_backend_config.remote(
backend_tag, backend_config_dict)
backend_tag, backend_config)
await self._start_pending_replicas()
await self._stop_pending_replicas()
@@ -641,4 +621,4 @@ class ServeMaster:
"""Get the current config for the specified backend."""
assert (backend_tag in self.backends
), "Backend {} is not registered.".format(backend_tag)
return BackendConfig(**self.backends[backend_tag][2])
return self.backends[backend_tag][2]
+6 -7
View File
@@ -163,8 +163,8 @@ class Router:
backend_configs = retry_actor_failures(
master_actor.get_backend_configs)
for backend, backend_config_dict in backend_configs.items():
await self.set_backend_config(backend, backend_config_dict)
for backend, backend_config in backend_configs.items():
await self.set_backend_config(backend, backend_config)
def is_ready(self):
return True
@@ -260,10 +260,10 @@ class Router:
if service in self.traffic:
del self.traffic[service]
async def set_backend_config(self, backend, config_dict):
async def set_backend_config(self, backend, config):
logger.debug("Setting backend config for "
"backend {} to {}".format(backend, config_dict))
self.backend_info[backend] = config_dict
"backend {} to {}.".format(backend, config))
self.backend_info[backend] = config
async def remove_backend(self, backend):
logger.debug("Removing backend {}".format(backend))
@@ -330,8 +330,7 @@ class Router:
max_batch_size = None
if backend in self.backend_info:
max_batch_size = self.backend_info[backend][
"max_batch_size"]
max_batch_size = self.backend_info[backend].max_batch_size
await self._assign_query_to_worker(
backend, buffer_queue, worker_queue, max_batch_size)
+21 -60
View File
@@ -3,7 +3,6 @@ import pytest
import requests
from ray import serve
from ray.serve import BackendConfig
import ray
@@ -30,7 +29,7 @@ def test_e2e(serve_instance):
def function(flask_request):
return {"method": flask_request.method}
serve.create_backend(function, "echo:v1")
serve.create_backend("echo:v1", function)
serve.set_traffic("endpoint", {"echo:v1": 1.0})
resp = requests.get("http://127.0.0.1:8000/api").json()["method"]
@@ -47,7 +46,7 @@ def test_call_method(serve_instance):
def method(self, request):
return "hello"
serve.create_backend(CallMethod, "call-method")
serve.create_backend("call-method", CallMethod)
serve.set_traffic("call-method", {"call-method": 1.0})
# Test HTTP path.
@@ -68,7 +67,7 @@ def test_no_route(serve_instance):
def func(_, i=1):
return 1
serve.create_backend(func, "backend:1")
serve.create_backend("backend:1", func)
serve.set_traffic("noroute-endpoint", {"backend:1": 1.0})
service_handle = serve.get_handle("noroute-endpoint")
result = ray.get(service_handle.remote(i=1))
@@ -93,7 +92,7 @@ def test_set_traffic_missing_data(serve_instance):
endpoint_name = "foobar"
backend_name = "foo_backend"
serve.create_endpoint(endpoint_name)
serve.create_backend(lambda: 5, backend_name)
serve.create_backend(backend_name, lambda: 5)
with pytest.raises(ValueError):
serve.set_traffic(endpoint_name, {"nonexistent_backend": 1.0})
with pytest.raises(ValueError):
@@ -116,8 +115,7 @@ def test_scaling_replicas(serve_instance):
"http://127.0.0.1:8000/-/routes").json():
time.sleep(0.2)
b_config = BackendConfig(num_replicas=2)
serve.create_backend(Counter, "counter:v1", backend_config=b_config)
serve.create_backend("counter:v1", Counter, config={"num_replicas": 2})
serve.set_traffic("counter", {"counter:v1": 1.0})
counter_result = []
@@ -128,9 +126,7 @@ def test_scaling_replicas(serve_instance):
# If the load is shared among two replicas. The max result cannot be 10.
assert max(counter_result) < 10
b_config = serve.get_backend_config("counter:v1")
b_config.num_replicas = 1
serve.set_backend_config("counter:v1", b_config)
serve.update_backend_config("counter:v1", {"num_replicas": 1})
counter_result = []
for _ in range(10):
@@ -160,9 +156,8 @@ def test_batching(serve_instance):
time.sleep(0.2)
# set the max batch size
b_config = BackendConfig(max_batch_size=5)
serve.create_backend(
BatchingExample, "counter:v11", backend_config=b_config)
"counter:v11", BatchingExample, config={"max_batch_size": 5})
serve.set_traffic("counter1", {"counter:v11": 1.0})
future_list = []
@@ -190,9 +185,8 @@ def test_batching_exception(serve_instance):
serve.create_endpoint("exception-test", "/noListReturned")
# set the max batch size
b_config = BackendConfig(max_batch_size=5)
serve.create_backend(
NoListReturned, "exception:v1", backend_config=b_config)
"exception:v1", NoListReturned, config={"max_batch_size": 5})
serve.set_traffic("exception-test", {"exception:v1": 1.0})
handle = serve.get_handle("exception-test")
@@ -200,41 +194,7 @@ def test_batching_exception(serve_instance):
assert ray.get(handle.remote(temp=1))
def test_killing_replicas(serve_instance):
class Simple:
def __init__(self):
self.count = 0
def __call__(self, flask_request, temp=None):
return temp
serve.create_endpoint("simple", "/simple")
b_config = BackendConfig(num_replicas=3, num_cpus=2)
serve.create_backend(Simple, "simple:v1", backend_config=b_config)
master_actor = serve.api._get_master_actor()
old_replica_tag_list = ray.get(
master_actor._list_replicas.remote("simple:v1"))
bnew_config = serve.get_backend_config("simple:v1")
# change the config
bnew_config.num_cpus = 1
# set the config
serve.set_backend_config("simple:v1", bnew_config)
new_replica_tag_list = ray.get(
master_actor._list_replicas.remote("simple:v1"))
new_all_tag_list = []
for worker_dict in ray.get(
master_actor.get_all_worker_handles.remote()).values():
new_all_tag_list.extend(list(worker_dict.keys()))
# the new_replica_tag_list must be subset of all_tag_list
assert set(new_replica_tag_list) <= set(new_all_tag_list)
# the old_replica_tag_list must not be subset of all_tag_list
assert not set(old_replica_tag_list) <= set(new_all_tag_list)
def test_not_killing_replicas(serve_instance):
def test_updating_config(serve_instance):
class BatchSimple:
def __init__(self):
self.count = 0
@@ -245,17 +205,18 @@ def test_not_killing_replicas(serve_instance):
return [1] * batch_size
serve.create_endpoint("bsimple", "/bsimple")
b_config = BackendConfig(num_replicas=3, max_batch_size=2)
serve.create_backend(BatchSimple, "bsimple:v1", backend_config=b_config)
serve.create_backend(
"bsimple:v1",
BatchSimple,
config={
"max_batch_size": 2,
"num_replicas": 3
})
master_actor = serve.api._get_master_actor()
old_replica_tag_list = ray.get(
master_actor._list_replicas.remote("bsimple:v1"))
bnew_config = serve.get_backend_config("bsimple:v1")
# change the config
bnew_config.max_batch_size = 5
# set the config
serve.set_backend_config("bsimple:v1", bnew_config)
serve.update_backend_config("bsimple:v1", {"max_batch_size": 5})
new_replica_tag_list = ray.get(
master_actor._list_replicas.remote("bsimple:v1"))
new_all_tag_list = []
@@ -275,7 +236,7 @@ def test_delete_backend(serve_instance):
def function():
return "hello"
serve.create_backend(function, "delete:v1")
serve.create_backend("delete:v1", function)
serve.set_traffic("delete_backend", {"delete:v1": 1.0})
assert requests.get("http://127.0.0.1:8000/delete-backend").text == "hello"
@@ -284,7 +245,7 @@ def test_delete_backend(serve_instance):
with pytest.raises(ValueError):
serve.delete_backend("delete:v1")
serve.create_backend(function, "delete:v2")
serve.create_backend("delete:v2", function)
serve.set_traffic("delete_backend", {"delete:v1": 0.5, "delete:v2": 0.5})
with pytest.raises(ValueError):
@@ -302,7 +263,7 @@ def test_delete_backend(serve_instance):
return "olleh"
# Check that we can now reuse the previously delete backend's tag.
serve.create_backend(function2, "delete:v1")
serve.create_backend("delete:v1", function2)
serve.set_traffic("delete_backend", {"delete:v1": 1.0})
assert requests.get("http://127.0.0.1:8000/delete-backend").text == "olleh"
@@ -320,7 +281,7 @@ def test_delete_endpoint(serve_instance, route):
def function():
return "hello"
serve.create_backend(function, "delete-endpoint:v1")
serve.create_backend("delete-endpoint:v1", function)
serve.set_traffic(endpoint_name, {"delete-endpoint:v1": 1.0})
if route is not None:
@@ -8,7 +8,7 @@ import ray.serve.context as context
from ray.serve.policy import RoundRobinPolicyQueueActor
from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error
from ray.serve.request_params import RequestMetadata
from ray.serve.backend_config import BackendConfig
from ray.serve.config import BackendConfig
pytestmark = pytest.mark.asyncio
@@ -161,7 +161,10 @@ async def test_task_runner_custom_method_batch(serve_instance):
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
await q.set_backend_config.remote(
CONSUMER_NAME, BackendConfig(max_batch_size=10).__dict__)
CONSUMER_NAME,
BackendConfig({
"max_batch_size": 10
}, accepts_batches=True))
a_query_param = RequestMetadata(
PRODUCER_NAME, context.TaskContext.Python, call_method="a")
+129
View File
@@ -0,0 +1,129 @@
import pytest
from ray import serve
from ray.serve.config import BackendConfig, ReplicaConfig
def test_backend_config_validation():
# Test unknown key.
with pytest.raises(ValueError, match="unknown_key"):
BackendConfig({"unknown_key": -1})
# Test num_replicas validation.
BackendConfig({"num_replicas": 1})
with pytest.raises(TypeError):
BackendConfig({"num_replicas": "hello"})
with pytest.raises(ValueError):
BackendConfig({"num_replicas": -1})
# Test max_batch_size validation.
BackendConfig({"max_batch_size": 10}, accepts_batches=True)
with pytest.raises(ValueError):
BackendConfig({"max_batch_size": 10}, accepts_batches=False)
with pytest.raises(TypeError):
BackendConfig({"max_batch_size": 1.0})
with pytest.raises(TypeError):
BackendConfig({"max_batch_size": "hello"})
with pytest.raises(ValueError):
BackendConfig({"max_batch_size": 0})
with pytest.raises(ValueError):
BackendConfig({"max_batch_size": -1})
def test_backend_config_update():
b = BackendConfig({"num_replicas": 1, "max_batch_size": 1})
# Test updating a key works.
b.update({"num_replicas": 2})
assert b.num_replicas == 2
# Check that not specifying a key doesn't update it.
assert b.max_batch_size == 1
# Check that passing an invalid key fails.
with pytest.raises(ValueError):
b.update({"unknown": 1})
# Check that input is validated.
with pytest.raises(TypeError):
b.update({"num_replicas": "hello"})
with pytest.raises(ValueError):
b.update({"num_replicas": -1})
# Test batch validation.
b = BackendConfig({}, accepts_batches=False)
b.update({"max_batch_size": 1})
with pytest.raises(ValueError):
b.update({"max_batch_size": 2})
b = BackendConfig({}, accepts_batches=True)
b.update({"max_batch_size": 2})
def test_replica_config_validation():
class Class:
pass
class BatchClass:
@serve.accept_batch
def __call__(self):
pass
def function():
pass
@serve.accept_batch
def batch_function():
pass
ReplicaConfig(Class)
ReplicaConfig(function)
with pytest.raises(TypeError):
ReplicaConfig(Class())
# Check max_batch_size validation.
assert not ReplicaConfig(function).accepts_batches
assert not ReplicaConfig(Class).accepts_batches
assert ReplicaConfig(batch_function).accepts_batches
assert ReplicaConfig(BatchClass).accepts_batches
# Check ray_actor_options validation.
ReplicaConfig(
Class,
ray_actor_options={
"num_cpus": 1.0,
"num_gpus": 10,
"resources": {
"abc": 1.0
},
"memory": 1000000.0,
"object_store_memory": 1000000,
})
with pytest.raises(TypeError):
ReplicaConfig(Class, ray_actor_options=1.0)
with pytest.raises(TypeError):
ReplicaConfig(Class, ray_actor_options=False)
with pytest.raises(TypeError):
ReplicaConfig(Class, ray_actor_options={"num_cpus": "hello"})
with pytest.raises(ValueError):
ReplicaConfig(Class, ray_actor_options={"num_cpus": -1})
with pytest.raises(TypeError):
ReplicaConfig(Class, ray_actor_options={"num_gpus": "hello"})
with pytest.raises(ValueError):
ReplicaConfig(Class, ray_actor_options={"num_gpus": -1})
with pytest.raises(TypeError):
ReplicaConfig(Class, ray_actor_options={"memory": "hello"})
with pytest.raises(ValueError):
ReplicaConfig(Class, ray_actor_options={"memory": -1})
with pytest.raises(TypeError):
ReplicaConfig(
Class, ray_actor_options={"object_store_memory": "hello"})
with pytest.raises(ValueError):
ReplicaConfig(Class, ray_actor_options={"object_store_memory": -1})
with pytest.raises(TypeError):
ReplicaConfig(Class, ray_actor_options={"resources": None})
with pytest.raises(ValueError):
ReplicaConfig(Class, ray_actor_options={"name": None})
with pytest.raises(ValueError):
ReplicaConfig(Class, ray_actor_options={"detached": None})
with pytest.raises(ValueError):
ReplicaConfig(Class, ray_actor_options={"max_reconstructions": None})
+10 -12
View File
@@ -26,7 +26,7 @@ def test_master_failure(serve_instance):
def function():
return "hello1"
serve.create_backend(function, "master_failure:v1")
serve.create_backend("master_failure:v1", function)
serve.set_traffic("master_failure", {"master_failure:v1": 1.0})
assert request_with_retries("/master_failure", timeout=1).text == "hello1"
@@ -46,7 +46,7 @@ def test_master_failure(serve_instance):
ray.kill(serve.api._get_master_actor())
serve.create_backend(function, "master_failure:v2")
serve.create_backend("master_failure:v2", function)
serve.set_traffic("master_failure", {"master_failure:v2": 1.0})
for _ in range(10):
@@ -59,7 +59,7 @@ def test_master_failure(serve_instance):
ray.kill(serve.api._get_master_actor())
serve.create_endpoint("master_failure_2", "/master_failure_2")
ray.kill(serve.api._get_master_actor())
serve.create_backend(function, "master_failure_2")
serve.create_backend("master_failure_2", function)
ray.kill(serve.api._get_master_actor())
serve.set_traffic("master_failure_2", {"master_failure_2": 1.0})
@@ -83,7 +83,7 @@ def test_http_proxy_failure(serve_instance):
def function():
return "hello1"
serve.create_backend(function, "proxy_failure:v1")
serve.create_backend("proxy_failure:v1", function)
serve.set_traffic("proxy_failure", {"proxy_failure:v1": 1.0})
assert request_with_retries("/proxy_failure", timeout=1.0).text == "hello1"
@@ -97,7 +97,7 @@ def test_http_proxy_failure(serve_instance):
def function():
return "hello2"
serve.create_backend(function, "proxy_failure:v2")
serve.create_backend("proxy_failure:v2", function)
serve.set_traffic("proxy_failure", {"proxy_failure:v2": 1.0})
for _ in range(10):
@@ -117,7 +117,7 @@ def test_router_failure(serve_instance):
def function():
return "hello1"
serve.create_backend(function, "router_failure:v1")
serve.create_backend("router_failure:v1", function)
serve.set_traffic("router_failure", {"router_failure:v1": 1.0})
assert request_with_retries("/router_failure", timeout=5).text == "hello1"
@@ -135,7 +135,7 @@ def test_router_failure(serve_instance):
def function():
return "hello2"
serve.create_backend(function, "router_failure:v2")
serve.create_backend("router_failure:v2", function)
serve.set_traffic("router_failure", {"router_failure:v2": 1.0})
for _ in range(10):
@@ -160,7 +160,7 @@ def test_worker_restart(serve_instance):
def __call__(self):
return os.getpid()
serve.create_backend(Worker1, "worker_failure:v1")
serve.create_backend("worker_failure:v1", Worker1)
serve.set_traffic("worker_failure", {"worker_failure:v1": 1.0})
# Get the PID of the worker.
@@ -214,10 +214,8 @@ def test_worker_replica_failure(serve_instance):
pass
temp_path = tempfile.gettempdir() + "/" + serve.utils.get_random_letters()
serve.create_backend(Worker, "replica_failure", temp_path)
backend_config = serve.get_backend_config("replica_failure")
backend_config.num_replicas = 2
serve.set_backend_config("replica_failure", backend_config)
serve.create_backend("replica_failure", Worker, temp_path)
serve.update_backend_config("replica_failure", {"num_replicas": 2})
serve.set_traffic("replica_failure", {"replica_failure": 1.0})
# Wait until both replicas have been started.
+2 -2
View File
@@ -19,11 +19,11 @@ def test_handle_in_endpoint(serve_instance):
return ray.get(self.handle.remote())
serve.create_endpoint("endpoint1", "/endpoint1", methods=["GET", "POST"])
serve.create_backend(Endpoint1, "endpoint1:v0")
serve.create_backend("endpoint1:v0", Endpoint1)
serve.set_traffic("endpoint1", {"endpoint1:v0": 1.0})
serve.create_endpoint("endpoint2", "/endpoint2", methods=["GET", "POST"])
serve.create_backend(Endpoint2, "endpoint2:v0")
serve.create_backend("endpoint2:v0", Endpoint2)
serve.set_traffic("endpoint2", {"endpoint2:v0": 1.0})
assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello"
+1 -1
View File
@@ -11,7 +11,7 @@ def test_nonblocking():
def function(flask_request):
return {"method": flask_request.method}
serve.create_backend(function, "nonblocking:v1")
serve.create_backend("nonblocking:v1", function)
serve.set_traffic("nonblocking", {"nonblocking:v1": 1.0})
resp = requests.get("http://127.0.0.1:8000/nonblocking").json()["method"]
+1 -1
View File
@@ -18,7 +18,7 @@ def driver(flask_request):
return "OK!"
serve.create_endpoint("driver", "/driver")
serve.create_backend(driver, "driver")
serve.create_backend("driver", driver)
serve.set_traffic("driver", {{"driver": 1.0}})
""".format(ray.worker._global_node._redis_address)