From a76fadb89945becefc9a8973182bfd2b7969f50d Mon Sep 17 00:00:00 2001 From: alindkhare <55404298+alindkhare@users.noreply.github.com> Date: Sat, 28 Dec 2019 00:34:50 -0500 Subject: [PATCH] [Serve] Adding BackendConfig (#6541) --- python/ray/experimental/serve/__init__.py | 11 +- python/ray/experimental/serve/api.py | 137 +++++++++++++++++- .../ray/experimental/serve/backend_config.py | 58 ++++++++ python/ray/experimental/serve/context.py | 7 +- .../experimental/serve/examples/echo_actor.py | 1 - .../serve/examples/echo_actor_batch.py | 60 ++++++++ .../serve/examples/echo_batching.py | 56 +++++++ .../experimental/serve/examples/echo_full.py | 11 +- python/ray/experimental/serve/global_state.py | 10 +- .../experimental/serve/kv_store_service.py | 15 ++ python/ray/experimental/serve/queues.py | 28 +++- python/ray/experimental/serve/task_runner.py | 113 ++++++++++++--- .../ray/experimental/serve/tests/conftest.py | 5 +- .../ray/experimental/serve/tests/test_api.py | 136 ++++++++++++++++- python/ray/experimental/serve/utils.py | 19 +++ 15 files changed, 618 insertions(+), 49 deletions(-) create mode 100644 python/ray/experimental/serve/backend_config.py create mode 100644 python/ray/experimental/serve/examples/echo_actor_batch.py create mode 100644 python/ray/experimental/serve/examples/echo_batching.py diff --git a/python/ray/experimental/serve/__init__.py b/python/ray/experimental/serve/__init__.py index 05d391876..4d86ca222 100644 --- a/python/ray/experimental/serve/__init__.py +++ b/python/ray/experimental/serve/__init__.py @@ -1,13 +1,14 @@ import sys +from ray.experimental.serve.backend_config import BackendConfig from ray.experimental.serve.policy import RoutePolicy if sys.version_info < (3, 0): raise ImportError("serve is Python 3 only.") -from ray.experimental.serve.api import (init, create_backend, create_endpoint, - link, split, get_handle, stat, - scale) # noqa: E402 - +from ray.experimental.serve.api import ( + init, create_backend, create_endpoint, link, split, get_handle, stat, + set_backend_config, get_backend_config, accept_batch) # noqa: E402 __all__ = [ "init", "create_backend", "create_endpoint", "link", "split", "get_handle", - "stat", "scale", "RoutePolicy" + "stat", "set_backend_config", "get_backend_config", "BackendConfig", + "RoutePolicy", "accept_batch" ] diff --git a/python/ray/experimental/serve/api.py b/python/ray/experimental/serve/api.py index f810b3539..e1ed0e2bd 100644 --- a/python/ray/experimental/serve/api.py +++ b/python/ray/experimental/serve/api.py @@ -14,6 +14,7 @@ from ray.experimental.serve.task_runner import RayServeMixin, TaskRunnerActor from ray.experimental.serve.utils import (block_until_http_ready, get_random_letters) from ray.experimental.serve.exceptions import RayServeException +from ray.experimental.serve.backend_config import BackendConfig from ray.experimental.serve.policy import RoutePolicy global_state = None @@ -36,6 +37,28 @@ def _ensure_connected(f): return check +def accept_batch(f): + """Annotation to mark a serving function that batch is accepted. + + This annotation need to be used to mark a function expect all arguments + to be passed into a list. + + Example: + + >>> @serve.accept_batch + def serving_func(flask_request): + assert isinstance(flask_request, list) + ... + + >>> class ServingActor: + @serve.accept_batch + def __call__(self, *, python_arg=None): + assert isinstance(python_arg, list) + """ + f.serve_accept_batch = True + return f + + def init(kv_store_connector=None, kv_store_path=None, blocking=False, @@ -126,7 +149,62 @@ def create_endpoint(endpoint_name, route, blocking=True): @_ensure_connected -def create_backend(func_or_class, backend_tag, *actor_init_args): +def set_backend_config(backend_tag, backend_config): + """Set a backend configuration for a backend tag + + Args: + backend_tag(str): A registered backend. + backend_config(BackendConfig) : Desired backend configuration. + """ + assert backend_tag in global_state.backend_table.list_backends(), ( + "Backend {} is not registered.".format(backend_tag)) + assert isinstance(backend_config, + BackendConfig), ("backend_config must be" + " of instance BackendConfig") + backend_config_dict = dict(backend_config) + + old_backend_config_dict = global_state.backend_table.get_info(backend_tag) + global_state.backend_table.register_info(backend_tag, backend_config_dict) + + # inform the router about change in configuration + # particularly for setting max_batch_size + ray.get(global_state.init_or_get_router().set_backend_config.remote( + backend_tag, backend_config_dict)) + + # checking if replicas need to be restarted + # Replicas are restarted if there is any change in the backend config + # related to restart_configs + # TODO(alind) : have replica restarting policies selected by the user + + need_to_restart_replicas = any( + old_backend_config_dict[k] != backend_config_dict[k] + for k in BackendConfig.restart_on_change_fields) + if need_to_restart_replicas: + # kill all the replicas for restarting with new configurations + scale(backend_tag, 0) + + # scale the replicas with new configuration + scale(backend_tag, backend_config_dict["num_replicas"]) + + +@_ensure_connected +def get_backend_config(backend_tag): + """get the backend configuration for a backend tag + + Args: + backend_tag(str): A registered backend. + """ + assert backend_tag in global_state.backend_table.list_backends(), ( + "Backend {} is not registered.".format(backend_tag)) + backend_config_dict = global_state.backend_table.get_info(backend_tag) + return BackendConfig(**backend_config_dict) + + +@_ensure_connected +def create_backend(func_or_class, + backend_tag, + *actor_init_args, + backend_config=BackendConfig()): """Create a backend using func_or_class and assign backend_tag. Args: @@ -134,28 +212,66 @@ def create_backend(func_or_class, backend_tag, *actor_init_args): __call__ protocol. backend_tag (str): a unique tag assign to this backend. It will be used to associate services in traffic policy. + backend_config (BackendConfig): An object defining backend properties + for starting a backend. *actor_init_args (optional): the argument to pass to the class initialization method. """ + assert isinstance(backend_config, + BackendConfig), ("backend_config must be" + " of instance BackendConfig") + backend_config_dict = dict(backend_config) + + should_accept_batch = (True if backend_config.max_batch_size is not None + else False) + batch_annotation_not_found = RayServeException( + "max_batch_size is set in config but the function or method does not " + "accept batching. Please use @serve.accept_batch to explicitly mark " + "the function or method as batchable and takes in list as arguments.") + + arg_list = [] if inspect.isfunction(func_or_class): + if should_accept_batch and not hasattr(func_or_class, + "serve_accept_batch"): + raise batch_annotation_not_found + + # arg list for a fn is function itself + arg_list = [func_or_class] # ignore lint on lambda expression - creator = lambda: TaskRunnerActor.remote(func_or_class) # noqa: E731 + creator = lambda kwrgs: TaskRunnerActor._remote(**kwrgs) # noqa: E731 elif inspect.isclass(func_or_class): + if should_accept_batch and not hasattr(func_or_class.__call__, + "serve_accept_batch"): + raise batch_annotation_not_found + # Python inheritance order is right-to-left. We put RayServeMixin # on the left to make sure its methods are not overriden. @ray.remote class CustomActor(RayServeMixin, func_or_class): pass + arg_list = actor_init_args # ignore lint on lambda expression - creator = lambda: CustomActor.remote(*actor_init_args) # noqa: E731 + creator = lambda kwargs: CustomActor._remote(**kwargs) # noqa: E731 else: raise TypeError( "Backend must be a function or class, it is {}.".format( type(func_or_class))) + # save creator which starts replicas global_state.backend_table.register_backend(backend_tag, creator) - scale(backend_tag, 1) + + # save information about configurations needed to start the replicas + global_state.backend_table.register_info(backend_tag, backend_config_dict) + + # save the initial arguments needed by replicas + global_state.backend_table.save_init_args(backend_tag, arg_list) + + # set the backend config inside the router + # particularly for max-batch-size + ray.get(global_state.init_or_get_router().set_backend_config.remote( + backend_tag, backend_config_dict)) + scale(backend_tag, backend_config_dict["num_replicas"]) def _start_replica(backend_tag): @@ -163,12 +279,20 @@ def _start_replica(backend_tag): "Backend {} is not registered.".format(backend_tag)) replica_tag = "{}#{}".format(backend_tag, get_random_letters(length=6)) + + # get the info which starts the replicas creator = global_state.backend_table.get_backend_creator(backend_tag) + backend_config_dict = global_state.backend_table.get_info(backend_tag) + backend_config = BackendConfig(**backend_config_dict) + init_args = global_state.backend_table.get_init_args(backend_tag) + + # get actor creation kwargs + actor_kwargs = backend_config.get_actor_creation_args(init_args) # Create the runner in the nursery [runner_handle] = ray.get( global_state.actor_nursery_handle.start_actor_with_creator.remote( - creator, replica_tag)) + creator, actor_kwargs, replica_tag)) # Setup the worker ray.get( @@ -216,7 +340,8 @@ def scale(backend_tag, num_replicas): """ assert backend_tag in global_state.backend_table.list_backends(), ( "Backend {} is not registered.".format(backend_tag)) - assert num_replicas > 0, "Number of replicas must be greater than 1." + assert num_replicas >= 0, ("Number of replicas must be" + " greater than or equal to 0.") replicas = global_state.backend_table.list_replicas(backend_tag) current_num_replicas = len(replicas) diff --git a/python/ray/experimental/serve/backend_config.py b/python/ray/experimental/serve/backend_config.py new file mode 100644 index 000000000..d4cde75f5 --- /dev/null +++ b/python/ray/experimental/serve/backend_config.py @@ -0,0 +1,58 @@ +from copy import deepcopy + + +class BackendConfig: + # configs not needed for actor creation when + # instantiating a replica + _serve_configs = ["_num_replicas", "max_batch_size"] + + # configs which when changed leads to restarting + # the existing replicas. + restart_on_change_fields = ["resources", "num_cpus", "num_gpus"] + + def __init__(self, + num_replicas=1, + resources=None, + max_batch_size=None, + num_cpus=None, + num_gpus=None, + memory=None, + object_store_memory=None): + """ + Class for defining backend configuration. + """ + + # serve configs + self.num_replicas = num_replicas + self.max_batch_size = max_batch_size + + # ray actor configs + self.resources = resources + self.num_cpus = num_cpus + self.num_gpus = num_gpus + self.memory = memory + self.object_store_memory = object_store_memory + + @property + def num_replicas(self): + return self._num_replicas + + @num_replicas.setter + def num_replicas(self, val): + if not (val > 0): + raise Exception("num_replicas must be greater than zero") + self._num_replicas = val + + def __iter__(self): + for k in self.__dict__.keys(): + key, val = k, self.__dict__[k] + if key == "_num_replicas": + key = "num_replicas" + yield key, val + + def get_actor_creation_args(self, init_args): + ret_d = deepcopy(self.__dict__) + for k in self._serve_configs: + ret_d.pop(k) + ret_d["args"] = init_args + return ret_d diff --git a/python/ray/experimental/serve/context.py b/python/ray/experimental/serve/context.py index 3221bfcc6..25ccf0329 100644 --- a/python/ray/experimental/serve/context.py +++ b/python/ray/experimental/serve/context.py @@ -14,6 +14,11 @@ class TaskContext(IntEnum): # web == False: currently processing a request from python web = False +# batching information in serve context +# batch_size == None : the backend doesn't support batching +# batch_size(int) : the number of elements of input list +batch_size = None + _not_in_web_context_error = """ Accessing the request object outside of the web context. Please use "serve.context.web" to determine when the function is called within @@ -21,7 +26,7 @@ a web context. """ -class FakeFlaskQuest: +class FakeFlaskRequest: def __getattribute__(self, name): raise RayServeException(_not_in_web_context_error) diff --git a/python/ray/experimental/serve/examples/echo_actor.py b/python/ray/experimental/serve/examples/echo_actor.py index 1e279ceb4..49ed23fe7 100644 --- a/python/ray/experimental/serve/examples/echo_actor.py +++ b/python/ray/experimental/serve/examples/echo_actor.py @@ -21,7 +21,6 @@ class MagicCounter: def __call__(self, flask_request, base_number=None): if serve.context.web: base_number = int(flask_request.args.get("base_number", "0")) - return base_number + self.increment diff --git a/python/ray/experimental/serve/examples/echo_actor_batch.py b/python/ray/experimental/serve/examples/echo_actor_batch.py new file mode 100644 index 000000000..51844e6b6 --- /dev/null +++ b/python/ray/experimental/serve/examples/echo_actor_batch.py @@ -0,0 +1,60 @@ +""" +Example actor that adds an increment to a number. This number can +come from either web (parsing Flask request) or python call. +The queries incoming to this actor are batched. +This actor can be called from HTTP as well as from Python. +""" + +import time + +import requests + +import ray +from ray.experimental import serve +from ray.experimental.serve.utils import pformat_color_json +from ray.experimental.serve import BackendConfig + + +class MagicCounter: + def __init__(self, increment): + self.increment = increment + + @serve.accept_batch + def __call__(self, flask_request_list, base_number=None): + # batch_size = serve.context.batch_size + if serve.context.web: + result = [] + for flask_request in flask_request_list: + base_number = int(flask_request.args.get("base_number", "0")) + result.append(base_number) + return list(map(lambda x: x + self.increment, result)) + else: + result = [] + for b in base_number: + ans = b + self.increment + result.append(ans) + return result + + +serve.init(blocking=True) +serve.create_endpoint("magic_counter", "/counter", blocking=True) +b_config = BackendConfig(max_batch_size=5) +serve.create_backend( + MagicCounter, "counter:v1", 42, backend_config=b_config) # increment=42 +serve.link("magic_counter", "counter:v1") + +print("Sending ten queries via HTTP") +for i in range(10): + url = "http://127.0.0.1:8000/counter?base_number={}".format(i) + print("> Pinging {}".format(url)) + resp = requests.get(url).json() + print(pformat_color_json(resp)) + + time.sleep(0.2) + +print("Sending ten queries via Python") +handle = serve.get_handle("magic_counter") +for i in range(10): + print("> Pinging handle.remote(base_number={})".format(i)) + result = ray.get(handle.remote(base_number=i)) + print("< Result {}".format(result)) diff --git a/python/ray/experimental/serve/examples/echo_batching.py b/python/ray/experimental/serve/examples/echo_batching.py new file mode 100644 index 000000000..0622c2efe --- /dev/null +++ b/python/ray/experimental/serve/examples/echo_batching.py @@ -0,0 +1,56 @@ +""" +This example has backend which has batching functionality enabled. +""" + +import ray +from ray.experimental import serve +from ray.experimental.serve import BackendConfig + + +class MagicCounter: + def __init__(self, increment): + self.increment = increment + + @serve.accept_batch + def __call__(self, flask_request, base_number=None): + # __call__ fn should preserve the batch size + # base_number is a python list + + if serve.context.batch_size is not None: + batch_size = serve.context.batch_size + result = [] + for base_num in base_number: + ret_str = "Number: {} Batch size: {}".format( + base_num, batch_size) + result.append(ret_str) + return result + return "" + + +serve.init(blocking=True) +serve.create_endpoint("magic_counter", "/counter", blocking=True) +# specify max_batch_size in BackendConfig +b_config = BackendConfig(max_batch_size=5) +serve.create_backend( + MagicCounter, "counter:v1", 42, backend_config=b_config) # increment=42 +print("Backend Config for backend: 'counter:v1'") +print(b_config) +serve.link("magic_counter", "counter:v1") + +handle = serve.get_handle("magic_counter") +future_list = [] + +# fire 30 requests +for r in range(30): + print("> [REMOTE] Pinging handle.remote(base_number={})".format(r)) + f = handle.remote(base_number=r) + future_list.append(f) + +# get results of queries as they complete +left_futures = future_list +while left_futures: + completed_futures, remaining_futures = ray.wait(left_futures, timeout=0.05) + if len(completed_futures) > 0: + result = ray.get(completed_futures[0]) + print("< " + result) + left_futures = remaining_futures diff --git a/python/ray/experimental/serve/examples/echo_full.py b/python/ray/experimental/serve/examples/echo_full.py index 3d6e60276..ba3e36618 100644 --- a/python/ray/experimental/serve/examples/echo_full.py +++ b/python/ray/experimental/serve/examples/echo_full.py @@ -27,6 +27,7 @@ def echo_v1(flask_request, response="hello from python!"): serve.create_backend(echo_v1, "echo:v1") +backend_config_v1 = serve.get_backend_config("echo:v1") # We can link an endpoint to a backend, the means all the traffic # goes to my_endpoint will now goes to echo:v1 backend. @@ -47,6 +48,7 @@ def echo_v2(flask_request): serve.create_backend(echo_v2, "echo:v2") +backend_config_v2 = serve.get_backend_config("echo:v2") # The two backend will now split the traffic 50%-50%. serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5}) @@ -56,9 +58,12 @@ for _ in range(10): print(requests.get("http://127.0.0.1:8000/echo").json()) time.sleep(0.5) -# You can also scale each backend independently. -serve.scale("echo:v1", 2) -serve.scale("echo:v2", 2) +# You can also change number of replicas +# for each backend independently. +backend_config_v1.num_replicas = 2 +serve.set_backend_config("echo:v1", backend_config_v1) +backend_config_v2.num_replicas = 2 +serve.set_backend_config("echo:v2", backend_config_v2) # As well as retrieving relevant system metrics print(pformat_color_json(serve.stat())) diff --git a/python/ray/experimental/serve/global_state.py b/python/ray/experimental/serve/global_state.py index d60a80dec..a4c78e9e2 100644 --- a/python/ray/experimental/serve/global_state.py +++ b/python/ray/experimental/serve/global_state.py @@ -43,8 +43,14 @@ class ActorNursery: self.actor_handles[handle] = tag return [handle] - def start_actor_with_creator(self, creator, tag): - handle = creator() + def start_actor_with_creator(self, creator, kwargs, tag): + """ + Args: + creator (Callable[Dict]): a closure that should return + a newly created actor handle when called with kwargs. + The kwargs input is passed to `ActorCls_remote` method. + """ + handle = creator(kwargs) self.actor_handles[handle] = tag return [handle] diff --git a/python/ray/experimental/serve/kv_store_service.py b/python/ray/experimental/serve/kv_store_service.py index 05f8cb138..83aacc5f9 100644 --- a/python/ray/experimental/serve/kv_store_service.py +++ b/python/ray/experimental/serve/kv_store_service.py @@ -212,11 +212,26 @@ class BackendTable: def __init__(self, kv_connector): self.backend_table = kv_connector("backend_creator") self.replica_table = kv_connector("replica_table") + self.backend_info = kv_connector("backend_info") + self.backend_init_args = kv_connector("backend_init_args") def register_backend(self, backend_tag: str, backend_creator): backend_creator_serialized = pickle.dumps(backend_creator) self.backend_table.put(backend_tag, backend_creator_serialized) + def save_init_args(self, backend_tag: str, arg_list): + serialized_arg_list = pickle.dumps(arg_list) + self.backend_init_args.put(backend_tag, serialized_arg_list) + + def get_init_args(self, backend_tag): + return pickle.loads(self.backend_init_args.get(backend_tag)) + + def register_info(self, backend_tag: str, backend_info_d): + self.backend_info.put(backend_tag, json.dumps(backend_info_d)) + + def get_info(self, backend_tag): + return json.loads(self.backend_info.get(backend_tag, "{}")) + def get_backend_creator(self, backend_tag): return pickle.loads(self.backend_table.get(backend_tag)) diff --git a/python/ray/experimental/serve/queues.py b/python/ray/experimental/serve/queues.py index 62d96862f..b952ecad2 100644 --- a/python/ray/experimental/serve/queues.py +++ b/python/ray/experimental/serve/queues.py @@ -83,6 +83,9 @@ class CentralizedQueues: # service_name -> traffic_policy self.traffic = defaultdict(dict) + # backend_name -> backend_config + self.backend_info = dict() + # backend_name -> worker request queue self.workers = defaultdict(deque) @@ -157,6 +160,11 @@ class CentralizedQueues: self.traffic[service] = traffic_dict self.flush() + def set_backend_config(self, backend, config_dict): + logger.debug("Setting backend config for " + "backend {} to {}".format(backend, config_dict)) + self.backend_info[backend] = config_dict + def flush(self): """In the default case, flush calls ._flush. @@ -184,11 +192,23 @@ class CentralizedQueues: buffer_queue = self.buffer_queues[backend] work_queue = self.workers[backend] + max_batch_size = None + if backend in self.backend_info: + max_batch_size = self.backend_info[backend][ + "max_batch_size"] + while len(buffer_queue) and len(work_queue): - request, work = ( - buffer_queue.pop(0), - work_queue.popleft(), - ) + # get the work from work intent queue + work = work_queue.popleft() + # see if backend accepts batched queries + if max_batch_size is not None: + pop_size = min(len(buffer_queue), max_batch_size) + request = [ + buffer_queue.pop(0) for _ in range(pop_size) + ] + else: + request = buffer_queue.pop(0) + work.replica_handle._ray_serve_call.remote(request) # selects the backend and puts the service queue query to the buffer diff --git a/python/ray/experimental/serve/task_runner.py b/python/ray/experimental/serve/task_runner.py index 16c723b21..40d8ef20c 100644 --- a/python/ray/experimental/serve/task_runner.py +++ b/python/ray/experimental/serve/task_runner.py @@ -1,11 +1,12 @@ -import io import time import traceback import ray from ray.experimental.serve import context as serve_context -from ray.experimental.serve.context import FakeFlaskQuest, TaskContext -from ray.experimental.serve.http_util import build_flask_request +from ray.experimental.serve.context import FakeFlaskRequest +from collections import defaultdict +from ray.experimental.serve.utils import parse_request_item +from ray.experimental.serve.exceptions import RayServeException class TaskRunner: @@ -93,23 +94,10 @@ class RayServeMixin: self._ray_serve_dequeue_requester_name, self._ray_serve_self_handle) - def _ray_serve_call(self, request): - work_item = request - - if work_item.request_context == TaskContext.Web: - serve_context.web = True - asgi_scope, body_bytes = work_item.request_args - flask_request = build_flask_request(asgi_scope, - io.BytesIO(body_bytes)) - args = (flask_request, ) - kwargs = {} - else: - serve_context.web = False - args = (FakeFlaskQuest(), ) - kwargs = work_item.request_kwargs - - result_object_id = work_item.result_object_id - + def invoke_single(self, request_item): + args, kwargs, is_web_context, result_object_id = parse_request_item( + request_item) + serve_context.web = is_web_context start_timestamp = time.time() try: result = self.__call__(*args, **kwargs) @@ -121,8 +109,91 @@ class RayServeMixin: result_object_id) self._serve_metric_latency_list.append(time.time() - start_timestamp) - serve_context.web = False + def invoke_batch(self, request_item_list): + # TODO(alind) : create no-http services. The enqueues + # from such services will always be TaskContext.Python. + # Assumption : all the requests in a bacth + # have same serve context. + + # For batching kwargs are modified as follows - + # kwargs [Python Context] : key,val + # kwargs_list : key, [val1,val2, ... , valn] + # or + # args[Web Context] : val + # args_list : [val1,val2, ...... , valn] + # where n (current batch size) <= max_batch_size of a backend + + kwargs_list = defaultdict(list) + result_object_ids, context_flag_list, arg_list = [], [], [] + curr_batch_size = len(request_item_list) + + for item in request_item_list: + args, kwargs, is_web_context, result_object_id = ( + parse_request_item(item)) + context_flag_list.append(is_web_context) + + # Python context only have kwargs + # Web context only have one positional argument + if is_web_context: + arg_list.append(args[0]) + else: + for k, v in kwargs.items(): + kwargs_list[k].append(v) + result_object_ids.append(result_object_id) + + try: + # check mixing of query context + # unified context needed + if len(set(context_flag_list)) != 1: + raise RayServeException( + "Batched queries contain mixed context.") + serve_context.web = all(context_flag_list) + if serve_context.web: + args = (arg_list, ) + else: + # Set the flask request as a list to conform + # with batching semantics: when in batching + # mode, each argument it turned into list. + fake_flask_request_lst = [ + FakeFlaskRequest() for _ in range(curr_batch_size) + ] + args = (fake_flask_request_lst, ) + # set the current batch size (n) for serve_context + serve_context.batch_size = len(result_object_ids) + start_timestamp = time.time() + result_list = self.__call__(*args, **kwargs_list) + if (not isinstance(result_list, list)) or (len(result_list) != + len(result_object_ids)): + raise RayServeException("__call__ function " + "doesn't preserve batch-size. " + "Please return a list of result " + "with length equals to the batch " + "size.") + for result, result_object_id in zip(result_list, + result_object_ids): + ray.worker.global_worker.put_object(result, result_object_id) + self._serve_metric_latency_list.append(time.time() - + start_timestamp) + except Exception as e: + wrapped_exception = wrap_to_ray_error(e) + self._serve_metric_error_counter += len(result_object_ids) + for result_object_id in result_object_ids: + ray.worker.global_worker.put_object(wrapped_exception, + result_object_id) + + def _ray_serve_call(self, request): + work_item = request + # check if work_item is a list or not + # if it is list: then batching supported + if not isinstance(work_item, list): + self.invoke_single(work_item) + else: + self.invoke_batch(work_item) + + # re-assign to default values + serve_context.web = False + serve_context.batch_size = None self._ray_serve_fetch() diff --git a/python/ray/experimental/serve/tests/conftest.py b/python/ray/experimental/serve/tests/conftest.py index 69a757e46..cf8047e42 100644 --- a/python/ray/experimental/serve/tests/conftest.py +++ b/python/ray/experimental/serve/tests/conftest.py @@ -10,7 +10,10 @@ from ray.experimental import serve @pytest.fixture(scope="session") def serve_instance(): _, new_db_path = tempfile.mkstemp(suffix=".test.db") - serve.init(kv_store_path=new_db_path, blocking=True) + serve.init( + kv_store_path=new_db_path, + blocking=True, + ray_init_kwargs={"num_cpus": 36}) yield os.remove(new_db_path) diff --git a/python/ray/experimental/serve/tests/test_api.py b/python/ray/experimental/serve/tests/test_api.py index d4a974499..7cf360b31 100644 --- a/python/ray/experimental/serve/tests/test_api.py +++ b/python/ray/experimental/serve/tests/test_api.py @@ -1,8 +1,10 @@ import time - +import pytest import requests from ray.experimental import serve +from ray.experimental.serve import BackendConfig +import ray def test_e2e(serve_instance): @@ -50,11 +52,10 @@ def test_scaling_replicas(serve_instance): while "/increment" not in requests.get("http://127.0.0.1:8000/").json(): time.sleep(0.2) - serve.create_backend(Counter, "counter:v1") + b_config = BackendConfig(num_replicas=2) + serve.create_backend(Counter, "counter:v1", backend_config=b_config) serve.link("counter", "counter:v1") - serve.scale("counter:v1", 2) - counter_result = [] for _ in range(10): resp = requests.get("http://127.0.0.1:8000/increment").json()["result"] @@ -63,7 +64,9 @@ def test_scaling_replicas(serve_instance): # If the load is shared among two replicas. The max result cannot be 10. assert max(counter_result) < 10 - serve.scale("counter:v1", 1) + b_config = serve.get_backend_config("counter:v1") + b_config.num_replicas = 1 + serve.set_backend_config("counter:v1", b_config) counter_result = [] for _ in range(10): @@ -72,3 +75,126 @@ def test_scaling_replicas(serve_instance): # Give some time for a replica to spin down. But majority of the request # should be served by the only remaining replica. assert max(counter_result) - min(counter_result) > 6 + + +def test_batching(serve_instance): + class BatchingExample: + def __init__(self): + self.count = 0 + + @serve.accept_batch + def __call__(self, flask_request, temp=None): + self.count += 1 + batch_size = serve.context.batch_size + return [self.count] * batch_size + + serve.create_endpoint("counter1", "/increment") + + # Keep checking the routing table until /increment is populated + while "/increment" not in requests.get("http://127.0.0.1:8000/").json(): + time.sleep(0.2) + + # set the max batch size + b_config = BackendConfig(max_batch_size=5) + serve.create_backend( + BatchingExample, "counter:v11", backend_config=b_config) + serve.link("counter1", "counter:v11") + + future_list = [] + handle = serve.get_handle("counter1") + for _ in range(20): + f = handle.remote(temp=1) + future_list.append(f) + + counter_result = ray.get(future_list) + # since count is only updated per batch of queries + # If there atleast one __call__ fn call with batch size greater than 1 + # counter result will always be less than 20 + assert max(counter_result) < 20 + + +def test_batching_exception(serve_instance): + class NoListReturned: + def __init__(self): + self.count = 0 + + @serve.accept_batch + def __call__(self, flask_request, temp=None): + batch_size = serve.context.batch_size + return batch_size + + serve.create_endpoint("exception-test", "/noListReturned") + # set the max batch size + b_config = BackendConfig(max_batch_size=5) + serve.create_backend( + NoListReturned, "exception:v1", backend_config=b_config) + serve.link("exception-test", "exception:v1") + + handle = serve.get_handle("exception-test") + with pytest.raises(ray.exceptions.RayTaskError): + assert ray.get(handle.remote(temp=1)) + + +def test_killing_replicas(serve_instance): + class Simple: + def __init__(self): + self.count = 0 + + def __call__(self, flask_request, temp=None): + return temp + + serve.create_endpoint("simple", "/simple") + b_config = BackendConfig(num_replicas=3, num_cpus=2) + serve.create_backend(Simple, "simple:v1", backend_config=b_config) + global_state = serve.api._get_global_state() + old_replica_tag_list = global_state.backend_table.list_replicas( + "simple:v1") + + bnew_config = serve.get_backend_config("simple:v1") + # change the config + bnew_config.num_cpus = 1 + # set the config + serve.set_backend_config("simple:v1", bnew_config) + new_replica_tag_list = global_state.backend_table.list_replicas( + "simple:v1") + global_state.refresh_actor_handle_cache() + new_all_tag_list = list(global_state.actor_handle_cache.keys()) + + # the new_replica_tag_list must be subset of all_tag_list + assert set(new_replica_tag_list) <= set(new_all_tag_list) + + # the old_replica_tag_list must not be subset of all_tag_list + assert not set(old_replica_tag_list) <= set(new_all_tag_list) + + +def test_not_killing_replicas(serve_instance): + class BatchSimple: + def __init__(self): + self.count = 0 + + @serve.accept_batch + def __call__(self, flask_request, temp=None): + batch_size = serve.context.batch_size + return [1] * batch_size + + serve.create_endpoint("bsimple", "/bsimple") + b_config = BackendConfig(num_replicas=3, max_batch_size=2) + serve.create_backend(BatchSimple, "bsimple:v1", backend_config=b_config) + global_state = serve.api._get_global_state() + old_replica_tag_list = global_state.backend_table.list_replicas( + "bsimple:v1") + + bnew_config = serve.get_backend_config("bsimple:v1") + # change the config + bnew_config.max_batch_size = 5 + # set the config + serve.set_backend_config("bsimple:v1", bnew_config) + new_replica_tag_list = global_state.backend_table.list_replicas( + "bsimple:v1") + global_state.refresh_actor_handle_cache() + new_all_tag_list = list(global_state.actor_handle_cache.keys()) + + # the old and new replica tag list should be identical + # and should be subset of all_tag_list + assert set(old_replica_tag_list) <= set(new_all_tag_list) + assert set(old_replica_tag_list) == set(new_replica_tag_list) diff --git a/python/ray/experimental/serve/utils.py b/python/ray/experimental/serve/utils.py index 0e180952c..bf41513dc 100644 --- a/python/ray/experimental/serve/utils.py +++ b/python/ray/experimental/serve/utils.py @@ -3,9 +3,28 @@ import logging import random import string import time +import io import requests from pygments import formatters, highlight, lexers +from ray.experimental.serve.context import FakeFlaskRequest, TaskContext +from ray.experimental.serve.http_util import build_flask_request + + +def parse_request_item(request_item): + if request_item.request_context == TaskContext.Web: + is_web_context = True + asgi_scope, body_bytes = request_item.request_args + flask_request = build_flask_request(asgi_scope, io.BytesIO(body_bytes)) + args = (flask_request, ) + kwargs = {} + else: + is_web_context = False + args = (FakeFlaskRequest(), ) + kwargs = request_item.request_kwargs + + result_object_id = request_item.result_object_id + return args, kwargs, is_web_context, result_object_id def _get_logger():