[Serve] Adding BackendConfig (#6541)

This commit is contained in:
alindkhare
2019-12-28 00:34:50 -05:00
committed by Simon Mo
parent 96f2f8ff10
commit a76fadb899
15 changed files with 618 additions and 49 deletions
+6 -5
View File
@@ -1,13 +1,14 @@
import sys
from ray.experimental.serve.backend_config import BackendConfig
from ray.experimental.serve.policy import RoutePolicy
if sys.version_info < (3, 0):
raise ImportError("serve is Python 3 only.")
from ray.experimental.serve.api import (init, create_backend, create_endpoint,
link, split, get_handle, stat,
scale) # noqa: E402
from ray.experimental.serve.api import (
init, create_backend, create_endpoint, link, split, get_handle, stat,
set_backend_config, get_backend_config, accept_batch) # noqa: E402
__all__ = [
"init", "create_backend", "create_endpoint", "link", "split", "get_handle",
"stat", "scale", "RoutePolicy"
"stat", "set_backend_config", "get_backend_config", "BackendConfig",
"RoutePolicy", "accept_batch"
]
+131 -6
View File
@@ -14,6 +14,7 @@ from ray.experimental.serve.task_runner import RayServeMixin, TaskRunnerActor
from ray.experimental.serve.utils import (block_until_http_ready,
get_random_letters)
from ray.experimental.serve.exceptions import RayServeException
from ray.experimental.serve.backend_config import BackendConfig
from ray.experimental.serve.policy import RoutePolicy
global_state = None
@@ -36,6 +37,28 @@ def _ensure_connected(f):
return check
def accept_batch(f):
"""Annotation to mark a serving function that batch is accepted.
This annotation need to be used to mark a function expect all arguments
to be passed into a list.
Example:
>>> @serve.accept_batch
def serving_func(flask_request):
assert isinstance(flask_request, list)
...
>>> class ServingActor:
@serve.accept_batch
def __call__(self, *, python_arg=None):
assert isinstance(python_arg, list)
"""
f.serve_accept_batch = True
return f
def init(kv_store_connector=None,
kv_store_path=None,
blocking=False,
@@ -126,7 +149,62 @@ def create_endpoint(endpoint_name, route, blocking=True):
@_ensure_connected
def create_backend(func_or_class, backend_tag, *actor_init_args):
def set_backend_config(backend_tag, backend_config):
"""Set a backend configuration for a backend tag
Args:
backend_tag(str): A registered backend.
backend_config(BackendConfig) : Desired backend configuration.
"""
assert backend_tag in global_state.backend_table.list_backends(), (
"Backend {} is not registered.".format(backend_tag))
assert isinstance(backend_config,
BackendConfig), ("backend_config must be"
" of instance BackendConfig")
backend_config_dict = dict(backend_config)
old_backend_config_dict = global_state.backend_table.get_info(backend_tag)
global_state.backend_table.register_info(backend_tag, backend_config_dict)
# inform the router about change in configuration
# particularly for setting max_batch_size
ray.get(global_state.init_or_get_router().set_backend_config.remote(
backend_tag, backend_config_dict))
# checking if replicas need to be restarted
# Replicas are restarted if there is any change in the backend config
# related to restart_configs
# TODO(alind) : have replica restarting policies selected by the user
need_to_restart_replicas = any(
old_backend_config_dict[k] != backend_config_dict[k]
for k in BackendConfig.restart_on_change_fields)
if need_to_restart_replicas:
# kill all the replicas for restarting with new configurations
scale(backend_tag, 0)
# scale the replicas with new configuration
scale(backend_tag, backend_config_dict["num_replicas"])
@_ensure_connected
def get_backend_config(backend_tag):
"""get the backend configuration for a backend tag
Args:
backend_tag(str): A registered backend.
"""
assert backend_tag in global_state.backend_table.list_backends(), (
"Backend {} is not registered.".format(backend_tag))
backend_config_dict = global_state.backend_table.get_info(backend_tag)
return BackendConfig(**backend_config_dict)
@_ensure_connected
def create_backend(func_or_class,
backend_tag,
*actor_init_args,
backend_config=BackendConfig()):
"""Create a backend using func_or_class and assign backend_tag.
Args:
@@ -134,28 +212,66 @@ def create_backend(func_or_class, backend_tag, *actor_init_args):
__call__ protocol.
backend_tag (str): a unique tag assign to this backend. It will be used
to associate services in traffic policy.
backend_config (BackendConfig): An object defining backend properties
for starting a backend.
*actor_init_args (optional): the argument to pass to the class
initialization method.
"""
assert isinstance(backend_config,
BackendConfig), ("backend_config must be"
" of instance BackendConfig")
backend_config_dict = dict(backend_config)
should_accept_batch = (True if backend_config.max_batch_size is not None
else False)
batch_annotation_not_found = RayServeException(
"max_batch_size is set in config but the function or method does not "
"accept batching. Please use @serve.accept_batch to explicitly mark "
"the function or method as batchable and takes in list as arguments.")
arg_list = []
if inspect.isfunction(func_or_class):
if should_accept_batch and not hasattr(func_or_class,
"serve_accept_batch"):
raise batch_annotation_not_found
# arg list for a fn is function itself
arg_list = [func_or_class]
# ignore lint on lambda expression
creator = lambda: TaskRunnerActor.remote(func_or_class) # noqa: E731
creator = lambda kwrgs: TaskRunnerActor._remote(**kwrgs) # noqa: E731
elif inspect.isclass(func_or_class):
if should_accept_batch and not hasattr(func_or_class.__call__,
"serve_accept_batch"):
raise batch_annotation_not_found
# Python inheritance order is right-to-left. We put RayServeMixin
# on the left to make sure its methods are not overriden.
@ray.remote
class CustomActor(RayServeMixin, func_or_class):
pass
arg_list = actor_init_args
# ignore lint on lambda expression
creator = lambda: CustomActor.remote(*actor_init_args) # noqa: E731
creator = lambda kwargs: CustomActor._remote(**kwargs) # noqa: E731
else:
raise TypeError(
"Backend must be a function or class, it is {}.".format(
type(func_or_class)))
# save creator which starts replicas
global_state.backend_table.register_backend(backend_tag, creator)
scale(backend_tag, 1)
# save information about configurations needed to start the replicas
global_state.backend_table.register_info(backend_tag, backend_config_dict)
# save the initial arguments needed by replicas
global_state.backend_table.save_init_args(backend_tag, arg_list)
# set the backend config inside the router
# particularly for max-batch-size
ray.get(global_state.init_or_get_router().set_backend_config.remote(
backend_tag, backend_config_dict))
scale(backend_tag, backend_config_dict["num_replicas"])
def _start_replica(backend_tag):
@@ -163,12 +279,20 @@ def _start_replica(backend_tag):
"Backend {} is not registered.".format(backend_tag))
replica_tag = "{}#{}".format(backend_tag, get_random_letters(length=6))
# get the info which starts the replicas
creator = global_state.backend_table.get_backend_creator(backend_tag)
backend_config_dict = global_state.backend_table.get_info(backend_tag)
backend_config = BackendConfig(**backend_config_dict)
init_args = global_state.backend_table.get_init_args(backend_tag)
# get actor creation kwargs
actor_kwargs = backend_config.get_actor_creation_args(init_args)
# Create the runner in the nursery
[runner_handle] = ray.get(
global_state.actor_nursery_handle.start_actor_with_creator.remote(
creator, replica_tag))
creator, actor_kwargs, replica_tag))
# Setup the worker
ray.get(
@@ -216,7 +340,8 @@ def scale(backend_tag, num_replicas):
"""
assert backend_tag in global_state.backend_table.list_backends(), (
"Backend {} is not registered.".format(backend_tag))
assert num_replicas > 0, "Number of replicas must be greater than 1."
assert num_replicas >= 0, ("Number of replicas must be"
" greater than or equal to 0.")
replicas = global_state.backend_table.list_replicas(backend_tag)
current_num_replicas = len(replicas)
@@ -0,0 +1,58 @@
from copy import deepcopy
class BackendConfig:
# configs not needed for actor creation when
# instantiating a replica
_serve_configs = ["_num_replicas", "max_batch_size"]
# configs which when changed leads to restarting
# the existing replicas.
restart_on_change_fields = ["resources", "num_cpus", "num_gpus"]
def __init__(self,
num_replicas=1,
resources=None,
max_batch_size=None,
num_cpus=None,
num_gpus=None,
memory=None,
object_store_memory=None):
"""
Class for defining backend configuration.
"""
# serve configs
self.num_replicas = num_replicas
self.max_batch_size = max_batch_size
# ray actor configs
self.resources = resources
self.num_cpus = num_cpus
self.num_gpus = num_gpus
self.memory = memory
self.object_store_memory = object_store_memory
@property
def num_replicas(self):
return self._num_replicas
@num_replicas.setter
def num_replicas(self, val):
if not (val > 0):
raise Exception("num_replicas must be greater than zero")
self._num_replicas = val
def __iter__(self):
for k in self.__dict__.keys():
key, val = k, self.__dict__[k]
if key == "_num_replicas":
key = "num_replicas"
yield key, val
def get_actor_creation_args(self, init_args):
ret_d = deepcopy(self.__dict__)
for k in self._serve_configs:
ret_d.pop(k)
ret_d["args"] = init_args
return ret_d
+6 -1
View File
@@ -14,6 +14,11 @@ class TaskContext(IntEnum):
# web == False: currently processing a request from python
web = False
# batching information in serve context
# batch_size == None : the backend doesn't support batching
# batch_size(int) : the number of elements of input list
batch_size = None
_not_in_web_context_error = """
Accessing the request object outside of the web context. Please use
"serve.context.web" to determine when the function is called within
@@ -21,7 +26,7 @@ a web context.
"""
class FakeFlaskQuest:
class FakeFlaskRequest:
def __getattribute__(self, name):
raise RayServeException(_not_in_web_context_error)
@@ -21,7 +21,6 @@ class MagicCounter:
def __call__(self, flask_request, base_number=None):
if serve.context.web:
base_number = int(flask_request.args.get("base_number", "0"))
return base_number + self.increment
@@ -0,0 +1,60 @@
"""
Example actor that adds an increment to a number. This number can
come from either web (parsing Flask request) or python call.
The queries incoming to this actor are batched.
This actor can be called from HTTP as well as from Python.
"""
import time
import requests
import ray
from ray.experimental import serve
from ray.experimental.serve.utils import pformat_color_json
from ray.experimental.serve import BackendConfig
class MagicCounter:
def __init__(self, increment):
self.increment = increment
@serve.accept_batch
def __call__(self, flask_request_list, base_number=None):
# batch_size = serve.context.batch_size
if serve.context.web:
result = []
for flask_request in flask_request_list:
base_number = int(flask_request.args.get("base_number", "0"))
result.append(base_number)
return list(map(lambda x: x + self.increment, result))
else:
result = []
for b in base_number:
ans = b + self.increment
result.append(ans)
return result
serve.init(blocking=True)
serve.create_endpoint("magic_counter", "/counter", blocking=True)
b_config = BackendConfig(max_batch_size=5)
serve.create_backend(
MagicCounter, "counter:v1", 42, backend_config=b_config) # increment=42
serve.link("magic_counter", "counter:v1")
print("Sending ten queries via HTTP")
for i in range(10):
url = "http://127.0.0.1:8000/counter?base_number={}".format(i)
print("> Pinging {}".format(url))
resp = requests.get(url).json()
print(pformat_color_json(resp))
time.sleep(0.2)
print("Sending ten queries via Python")
handle = serve.get_handle("magic_counter")
for i in range(10):
print("> Pinging handle.remote(base_number={})".format(i))
result = ray.get(handle.remote(base_number=i))
print("< Result {}".format(result))
@@ -0,0 +1,56 @@
"""
This example has backend which has batching functionality enabled.
"""
import ray
from ray.experimental import serve
from ray.experimental.serve import BackendConfig
class MagicCounter:
def __init__(self, increment):
self.increment = increment
@serve.accept_batch
def __call__(self, flask_request, base_number=None):
# __call__ fn should preserve the batch size
# base_number is a python list
if serve.context.batch_size is not None:
batch_size = serve.context.batch_size
result = []
for base_num in base_number:
ret_str = "Number: {} Batch size: {}".format(
base_num, batch_size)
result.append(ret_str)
return result
return ""
serve.init(blocking=True)
serve.create_endpoint("magic_counter", "/counter", blocking=True)
# specify max_batch_size in BackendConfig
b_config = BackendConfig(max_batch_size=5)
serve.create_backend(
MagicCounter, "counter:v1", 42, backend_config=b_config) # increment=42
print("Backend Config for backend: 'counter:v1'")
print(b_config)
serve.link("magic_counter", "counter:v1")
handle = serve.get_handle("magic_counter")
future_list = []
# fire 30 requests
for r in range(30):
print("> [REMOTE] Pinging handle.remote(base_number={})".format(r))
f = handle.remote(base_number=r)
future_list.append(f)
# get results of queries as they complete
left_futures = future_list
while left_futures:
completed_futures, remaining_futures = ray.wait(left_futures, timeout=0.05)
if len(completed_futures) > 0:
result = ray.get(completed_futures[0])
print("< " + result)
left_futures = remaining_futures
@@ -27,6 +27,7 @@ def echo_v1(flask_request, response="hello from python!"):
serve.create_backend(echo_v1, "echo:v1")
backend_config_v1 = serve.get_backend_config("echo:v1")
# We can link an endpoint to a backend, the means all the traffic
# goes to my_endpoint will now goes to echo:v1 backend.
@@ -47,6 +48,7 @@ def echo_v2(flask_request):
serve.create_backend(echo_v2, "echo:v2")
backend_config_v2 = serve.get_backend_config("echo:v2")
# The two backend will now split the traffic 50%-50%.
serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
@@ -56,9 +58,12 @@ for _ in range(10):
print(requests.get("http://127.0.0.1:8000/echo").json())
time.sleep(0.5)
# You can also scale each backend independently.
serve.scale("echo:v1", 2)
serve.scale("echo:v2", 2)
# You can also change number of replicas
# for each backend independently.
backend_config_v1.num_replicas = 2
serve.set_backend_config("echo:v1", backend_config_v1)
backend_config_v2.num_replicas = 2
serve.set_backend_config("echo:v2", backend_config_v2)
# As well as retrieving relevant system metrics
print(pformat_color_json(serve.stat()))
@@ -43,8 +43,14 @@ class ActorNursery:
self.actor_handles[handle] = tag
return [handle]
def start_actor_with_creator(self, creator, tag):
handle = creator()
def start_actor_with_creator(self, creator, kwargs, tag):
"""
Args:
creator (Callable[Dict]): a closure that should return
a newly created actor handle when called with kwargs.
The kwargs input is passed to `ActorCls_remote` method.
"""
handle = creator(kwargs)
self.actor_handles[handle] = tag
return [handle]
@@ -212,11 +212,26 @@ class BackendTable:
def __init__(self, kv_connector):
self.backend_table = kv_connector("backend_creator")
self.replica_table = kv_connector("replica_table")
self.backend_info = kv_connector("backend_info")
self.backend_init_args = kv_connector("backend_init_args")
def register_backend(self, backend_tag: str, backend_creator):
backend_creator_serialized = pickle.dumps(backend_creator)
self.backend_table.put(backend_tag, backend_creator_serialized)
def save_init_args(self, backend_tag: str, arg_list):
serialized_arg_list = pickle.dumps(arg_list)
self.backend_init_args.put(backend_tag, serialized_arg_list)
def get_init_args(self, backend_tag):
return pickle.loads(self.backend_init_args.get(backend_tag))
def register_info(self, backend_tag: str, backend_info_d):
self.backend_info.put(backend_tag, json.dumps(backend_info_d))
def get_info(self, backend_tag):
return json.loads(self.backend_info.get(backend_tag, "{}"))
def get_backend_creator(self, backend_tag):
return pickle.loads(self.backend_table.get(backend_tag))
+24 -4
View File
@@ -83,6 +83,9 @@ class CentralizedQueues:
# service_name -> traffic_policy
self.traffic = defaultdict(dict)
# backend_name -> backend_config
self.backend_info = dict()
# backend_name -> worker request queue
self.workers = defaultdict(deque)
@@ -157,6 +160,11 @@ class CentralizedQueues:
self.traffic[service] = traffic_dict
self.flush()
def set_backend_config(self, backend, config_dict):
logger.debug("Setting backend config for "
"backend {} to {}".format(backend, config_dict))
self.backend_info[backend] = config_dict
def flush(self):
"""In the default case, flush calls ._flush.
@@ -184,11 +192,23 @@ class CentralizedQueues:
buffer_queue = self.buffer_queues[backend]
work_queue = self.workers[backend]
max_batch_size = None
if backend in self.backend_info:
max_batch_size = self.backend_info[backend][
"max_batch_size"]
while len(buffer_queue) and len(work_queue):
request, work = (
buffer_queue.pop(0),
work_queue.popleft(),
)
# get the work from work intent queue
work = work_queue.popleft()
# see if backend accepts batched queries
if max_batch_size is not None:
pop_size = min(len(buffer_queue), max_batch_size)
request = [
buffer_queue.pop(0) for _ in range(pop_size)
]
else:
request = buffer_queue.pop(0)
work.replica_handle._ray_serve_call.remote(request)
# selects the backend and puts the service queue query to the buffer
+92 -21
View File
@@ -1,11 +1,12 @@
import io
import time
import traceback
import ray
from ray.experimental.serve import context as serve_context
from ray.experimental.serve.context import FakeFlaskQuest, TaskContext
from ray.experimental.serve.http_util import build_flask_request
from ray.experimental.serve.context import FakeFlaskRequest
from collections import defaultdict
from ray.experimental.serve.utils import parse_request_item
from ray.experimental.serve.exceptions import RayServeException
class TaskRunner:
@@ -93,23 +94,10 @@ class RayServeMixin:
self._ray_serve_dequeue_requester_name,
self._ray_serve_self_handle)
def _ray_serve_call(self, request):
work_item = request
if work_item.request_context == TaskContext.Web:
serve_context.web = True
asgi_scope, body_bytes = work_item.request_args
flask_request = build_flask_request(asgi_scope,
io.BytesIO(body_bytes))
args = (flask_request, )
kwargs = {}
else:
serve_context.web = False
args = (FakeFlaskQuest(), )
kwargs = work_item.request_kwargs
result_object_id = work_item.result_object_id
def invoke_single(self, request_item):
args, kwargs, is_web_context, result_object_id = parse_request_item(
request_item)
serve_context.web = is_web_context
start_timestamp = time.time()
try:
result = self.__call__(*args, **kwargs)
@@ -121,8 +109,91 @@ class RayServeMixin:
result_object_id)
self._serve_metric_latency_list.append(time.time() - start_timestamp)
serve_context.web = False
def invoke_batch(self, request_item_list):
# TODO(alind) : create no-http services. The enqueues
# from such services will always be TaskContext.Python.
# Assumption : all the requests in a bacth
# have same serve context.
# For batching kwargs are modified as follows -
# kwargs [Python Context] : key,val
# kwargs_list : key, [val1,val2, ... , valn]
# or
# args[Web Context] : val
# args_list : [val1,val2, ...... , valn]
# where n (current batch size) <= max_batch_size of a backend
kwargs_list = defaultdict(list)
result_object_ids, context_flag_list, arg_list = [], [], []
curr_batch_size = len(request_item_list)
for item in request_item_list:
args, kwargs, is_web_context, result_object_id = (
parse_request_item(item))
context_flag_list.append(is_web_context)
# Python context only have kwargs
# Web context only have one positional argument
if is_web_context:
arg_list.append(args[0])
else:
for k, v in kwargs.items():
kwargs_list[k].append(v)
result_object_ids.append(result_object_id)
try:
# check mixing of query context
# unified context needed
if len(set(context_flag_list)) != 1:
raise RayServeException(
"Batched queries contain mixed context.")
serve_context.web = all(context_flag_list)
if serve_context.web:
args = (arg_list, )
else:
# Set the flask request as a list to conform
# with batching semantics: when in batching
# mode, each argument it turned into list.
fake_flask_request_lst = [
FakeFlaskRequest() for _ in range(curr_batch_size)
]
args = (fake_flask_request_lst, )
# set the current batch size (n) for serve_context
serve_context.batch_size = len(result_object_ids)
start_timestamp = time.time()
result_list = self.__call__(*args, **kwargs_list)
if (not isinstance(result_list, list)) or (len(result_list) !=
len(result_object_ids)):
raise RayServeException("__call__ function "
"doesn't preserve batch-size. "
"Please return a list of result "
"with length equals to the batch "
"size.")
for result, result_object_id in zip(result_list,
result_object_ids):
ray.worker.global_worker.put_object(result, result_object_id)
self._serve_metric_latency_list.append(time.time() -
start_timestamp)
except Exception as e:
wrapped_exception = wrap_to_ray_error(e)
self._serve_metric_error_counter += len(result_object_ids)
for result_object_id in result_object_ids:
ray.worker.global_worker.put_object(wrapped_exception,
result_object_id)
def _ray_serve_call(self, request):
work_item = request
# check if work_item is a list or not
# if it is list: then batching supported
if not isinstance(work_item, list):
self.invoke_single(work_item)
else:
self.invoke_batch(work_item)
# re-assign to default values
serve_context.web = False
serve_context.batch_size = None
self._ray_serve_fetch()
@@ -10,7 +10,10 @@ from ray.experimental import serve
@pytest.fixture(scope="session")
def serve_instance():
_, new_db_path = tempfile.mkstemp(suffix=".test.db")
serve.init(kv_store_path=new_db_path, blocking=True)
serve.init(
kv_store_path=new_db_path,
blocking=True,
ray_init_kwargs={"num_cpus": 36})
yield
os.remove(new_db_path)
+131 -5
View File
@@ -1,8 +1,10 @@
import time
import pytest
import requests
from ray.experimental import serve
from ray.experimental.serve import BackendConfig
import ray
def test_e2e(serve_instance):
@@ -50,11 +52,10 @@ def test_scaling_replicas(serve_instance):
while "/increment" not in requests.get("http://127.0.0.1:8000/").json():
time.sleep(0.2)
serve.create_backend(Counter, "counter:v1")
b_config = BackendConfig(num_replicas=2)
serve.create_backend(Counter, "counter:v1", backend_config=b_config)
serve.link("counter", "counter:v1")
serve.scale("counter:v1", 2)
counter_result = []
for _ in range(10):
resp = requests.get("http://127.0.0.1:8000/increment").json()["result"]
@@ -63,7 +64,9 @@ def test_scaling_replicas(serve_instance):
# If the load is shared among two replicas. The max result cannot be 10.
assert max(counter_result) < 10
serve.scale("counter:v1", 1)
b_config = serve.get_backend_config("counter:v1")
b_config.num_replicas = 1
serve.set_backend_config("counter:v1", b_config)
counter_result = []
for _ in range(10):
@@ -72,3 +75,126 @@ def test_scaling_replicas(serve_instance):
# Give some time for a replica to spin down. But majority of the request
# should be served by the only remaining replica.
assert max(counter_result) - min(counter_result) > 6
def test_batching(serve_instance):
class BatchingExample:
def __init__(self):
self.count = 0
@serve.accept_batch
def __call__(self, flask_request, temp=None):
self.count += 1
batch_size = serve.context.batch_size
return [self.count] * batch_size
serve.create_endpoint("counter1", "/increment")
# Keep checking the routing table until /increment is populated
while "/increment" not in requests.get("http://127.0.0.1:8000/").json():
time.sleep(0.2)
# set the max batch size
b_config = BackendConfig(max_batch_size=5)
serve.create_backend(
BatchingExample, "counter:v11", backend_config=b_config)
serve.link("counter1", "counter:v11")
future_list = []
handle = serve.get_handle("counter1")
for _ in range(20):
f = handle.remote(temp=1)
future_list.append(f)
counter_result = ray.get(future_list)
# since count is only updated per batch of queries
# If there atleast one __call__ fn call with batch size greater than 1
# counter result will always be less than 20
assert max(counter_result) < 20
def test_batching_exception(serve_instance):
class NoListReturned:
def __init__(self):
self.count = 0
@serve.accept_batch
def __call__(self, flask_request, temp=None):
batch_size = serve.context.batch_size
return batch_size
serve.create_endpoint("exception-test", "/noListReturned")
# set the max batch size
b_config = BackendConfig(max_batch_size=5)
serve.create_backend(
NoListReturned, "exception:v1", backend_config=b_config)
serve.link("exception-test", "exception:v1")
handle = serve.get_handle("exception-test")
with pytest.raises(ray.exceptions.RayTaskError):
assert ray.get(handle.remote(temp=1))
def test_killing_replicas(serve_instance):
class Simple:
def __init__(self):
self.count = 0
def __call__(self, flask_request, temp=None):
return temp
serve.create_endpoint("simple", "/simple")
b_config = BackendConfig(num_replicas=3, num_cpus=2)
serve.create_backend(Simple, "simple:v1", backend_config=b_config)
global_state = serve.api._get_global_state()
old_replica_tag_list = global_state.backend_table.list_replicas(
"simple:v1")
bnew_config = serve.get_backend_config("simple:v1")
# change the config
bnew_config.num_cpus = 1
# set the config
serve.set_backend_config("simple:v1", bnew_config)
new_replica_tag_list = global_state.backend_table.list_replicas(
"simple:v1")
global_state.refresh_actor_handle_cache()
new_all_tag_list = list(global_state.actor_handle_cache.keys())
# the new_replica_tag_list must be subset of all_tag_list
assert set(new_replica_tag_list) <= set(new_all_tag_list)
# the old_replica_tag_list must not be subset of all_tag_list
assert not set(old_replica_tag_list) <= set(new_all_tag_list)
def test_not_killing_replicas(serve_instance):
class BatchSimple:
def __init__(self):
self.count = 0
@serve.accept_batch
def __call__(self, flask_request, temp=None):
batch_size = serve.context.batch_size
return [1] * batch_size
serve.create_endpoint("bsimple", "/bsimple")
b_config = BackendConfig(num_replicas=3, max_batch_size=2)
serve.create_backend(BatchSimple, "bsimple:v1", backend_config=b_config)
global_state = serve.api._get_global_state()
old_replica_tag_list = global_state.backend_table.list_replicas(
"bsimple:v1")
bnew_config = serve.get_backend_config("bsimple:v1")
# change the config
bnew_config.max_batch_size = 5
# set the config
serve.set_backend_config("bsimple:v1", bnew_config)
new_replica_tag_list = global_state.backend_table.list_replicas(
"bsimple:v1")
global_state.refresh_actor_handle_cache()
new_all_tag_list = list(global_state.actor_handle_cache.keys())
# the old and new replica tag list should be identical
# and should be subset of all_tag_list
assert set(old_replica_tag_list) <= set(new_all_tag_list)
assert set(old_replica_tag_list) == set(new_replica_tag_list)
+19
View File
@@ -3,9 +3,28 @@ import logging
import random
import string
import time
import io
import requests
from pygments import formatters, highlight, lexers
from ray.experimental.serve.context import FakeFlaskRequest, TaskContext
from ray.experimental.serve.http_util import build_flask_request
def parse_request_item(request_item):
if request_item.request_context == TaskContext.Web:
is_web_context = True
asgi_scope, body_bytes = request_item.request_args
flask_request = build_flask_request(asgi_scope, io.BytesIO(body_bytes))
args = (flask_request, )
kwargs = {}
else:
is_web_context = False
args = (FakeFlaskRequest(), )
kwargs = request_item.request_kwargs
result_object_id = request_item.result_object_id
return args, kwargs, is_web_context, result_object_id
def _get_logger():