mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:34:48 +08:00
[Serve] Adding BackendConfig (#6541)
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
import sys
|
||||
from ray.experimental.serve.backend_config import BackendConfig
|
||||
from ray.experimental.serve.policy import RoutePolicy
|
||||
if sys.version_info < (3, 0):
|
||||
raise ImportError("serve is Python 3 only.")
|
||||
|
||||
from ray.experimental.serve.api import (init, create_backend, create_endpoint,
|
||||
link, split, get_handle, stat,
|
||||
scale) # noqa: E402
|
||||
|
||||
from ray.experimental.serve.api import (
|
||||
init, create_backend, create_endpoint, link, split, get_handle, stat,
|
||||
set_backend_config, get_backend_config, accept_batch) # noqa: E402
|
||||
__all__ = [
|
||||
"init", "create_backend", "create_endpoint", "link", "split", "get_handle",
|
||||
"stat", "scale", "RoutePolicy"
|
||||
"stat", "set_backend_config", "get_backend_config", "BackendConfig",
|
||||
"RoutePolicy", "accept_batch"
|
||||
]
|
||||
|
||||
@@ -14,6 +14,7 @@ from ray.experimental.serve.task_runner import RayServeMixin, TaskRunnerActor
|
||||
from ray.experimental.serve.utils import (block_until_http_ready,
|
||||
get_random_letters)
|
||||
from ray.experimental.serve.exceptions import RayServeException
|
||||
from ray.experimental.serve.backend_config import BackendConfig
|
||||
from ray.experimental.serve.policy import RoutePolicy
|
||||
global_state = None
|
||||
|
||||
@@ -36,6 +37,28 @@ def _ensure_connected(f):
|
||||
return check
|
||||
|
||||
|
||||
def accept_batch(f):
|
||||
"""Annotation to mark a serving function that batch is accepted.
|
||||
|
||||
This annotation need to be used to mark a function expect all arguments
|
||||
to be passed into a list.
|
||||
|
||||
Example:
|
||||
|
||||
>>> @serve.accept_batch
|
||||
def serving_func(flask_request):
|
||||
assert isinstance(flask_request, list)
|
||||
...
|
||||
|
||||
>>> class ServingActor:
|
||||
@serve.accept_batch
|
||||
def __call__(self, *, python_arg=None):
|
||||
assert isinstance(python_arg, list)
|
||||
"""
|
||||
f.serve_accept_batch = True
|
||||
return f
|
||||
|
||||
|
||||
def init(kv_store_connector=None,
|
||||
kv_store_path=None,
|
||||
blocking=False,
|
||||
@@ -126,7 +149,62 @@ def create_endpoint(endpoint_name, route, blocking=True):
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def create_backend(func_or_class, backend_tag, *actor_init_args):
|
||||
def set_backend_config(backend_tag, backend_config):
|
||||
"""Set a backend configuration for a backend tag
|
||||
|
||||
Args:
|
||||
backend_tag(str): A registered backend.
|
||||
backend_config(BackendConfig) : Desired backend configuration.
|
||||
"""
|
||||
assert backend_tag in global_state.backend_table.list_backends(), (
|
||||
"Backend {} is not registered.".format(backend_tag))
|
||||
assert isinstance(backend_config,
|
||||
BackendConfig), ("backend_config must be"
|
||||
" of instance BackendConfig")
|
||||
backend_config_dict = dict(backend_config)
|
||||
|
||||
old_backend_config_dict = global_state.backend_table.get_info(backend_tag)
|
||||
global_state.backend_table.register_info(backend_tag, backend_config_dict)
|
||||
|
||||
# inform the router about change in configuration
|
||||
# particularly for setting max_batch_size
|
||||
ray.get(global_state.init_or_get_router().set_backend_config.remote(
|
||||
backend_tag, backend_config_dict))
|
||||
|
||||
# checking if replicas need to be restarted
|
||||
# Replicas are restarted if there is any change in the backend config
|
||||
# related to restart_configs
|
||||
# TODO(alind) : have replica restarting policies selected by the user
|
||||
|
||||
need_to_restart_replicas = any(
|
||||
old_backend_config_dict[k] != backend_config_dict[k]
|
||||
for k in BackendConfig.restart_on_change_fields)
|
||||
if need_to_restart_replicas:
|
||||
# kill all the replicas for restarting with new configurations
|
||||
scale(backend_tag, 0)
|
||||
|
||||
# scale the replicas with new configuration
|
||||
scale(backend_tag, backend_config_dict["num_replicas"])
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def get_backend_config(backend_tag):
|
||||
"""get the backend configuration for a backend tag
|
||||
|
||||
Args:
|
||||
backend_tag(str): A registered backend.
|
||||
"""
|
||||
assert backend_tag in global_state.backend_table.list_backends(), (
|
||||
"Backend {} is not registered.".format(backend_tag))
|
||||
backend_config_dict = global_state.backend_table.get_info(backend_tag)
|
||||
return BackendConfig(**backend_config_dict)
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def create_backend(func_or_class,
|
||||
backend_tag,
|
||||
*actor_init_args,
|
||||
backend_config=BackendConfig()):
|
||||
"""Create a backend using func_or_class and assign backend_tag.
|
||||
|
||||
Args:
|
||||
@@ -134,28 +212,66 @@ def create_backend(func_or_class, backend_tag, *actor_init_args):
|
||||
__call__ protocol.
|
||||
backend_tag (str): a unique tag assign to this backend. It will be used
|
||||
to associate services in traffic policy.
|
||||
backend_config (BackendConfig): An object defining backend properties
|
||||
for starting a backend.
|
||||
*actor_init_args (optional): the argument to pass to the class
|
||||
initialization method.
|
||||
"""
|
||||
assert isinstance(backend_config,
|
||||
BackendConfig), ("backend_config must be"
|
||||
" of instance BackendConfig")
|
||||
backend_config_dict = dict(backend_config)
|
||||
|
||||
should_accept_batch = (True if backend_config.max_batch_size is not None
|
||||
else False)
|
||||
batch_annotation_not_found = RayServeException(
|
||||
"max_batch_size is set in config but the function or method does not "
|
||||
"accept batching. Please use @serve.accept_batch to explicitly mark "
|
||||
"the function or method as batchable and takes in list as arguments.")
|
||||
|
||||
arg_list = []
|
||||
if inspect.isfunction(func_or_class):
|
||||
if should_accept_batch and not hasattr(func_or_class,
|
||||
"serve_accept_batch"):
|
||||
raise batch_annotation_not_found
|
||||
|
||||
# arg list for a fn is function itself
|
||||
arg_list = [func_or_class]
|
||||
# ignore lint on lambda expression
|
||||
creator = lambda: TaskRunnerActor.remote(func_or_class) # noqa: E731
|
||||
creator = lambda kwrgs: TaskRunnerActor._remote(**kwrgs) # noqa: E731
|
||||
elif inspect.isclass(func_or_class):
|
||||
if should_accept_batch and not hasattr(func_or_class.__call__,
|
||||
"serve_accept_batch"):
|
||||
raise batch_annotation_not_found
|
||||
|
||||
# Python inheritance order is right-to-left. We put RayServeMixin
|
||||
# on the left to make sure its methods are not overriden.
|
||||
@ray.remote
|
||||
class CustomActor(RayServeMixin, func_or_class):
|
||||
pass
|
||||
|
||||
arg_list = actor_init_args
|
||||
# ignore lint on lambda expression
|
||||
creator = lambda: CustomActor.remote(*actor_init_args) # noqa: E731
|
||||
creator = lambda kwargs: CustomActor._remote(**kwargs) # noqa: E731
|
||||
else:
|
||||
raise TypeError(
|
||||
"Backend must be a function or class, it is {}.".format(
|
||||
type(func_or_class)))
|
||||
|
||||
# save creator which starts replicas
|
||||
global_state.backend_table.register_backend(backend_tag, creator)
|
||||
scale(backend_tag, 1)
|
||||
|
||||
# save information about configurations needed to start the replicas
|
||||
global_state.backend_table.register_info(backend_tag, backend_config_dict)
|
||||
|
||||
# save the initial arguments needed by replicas
|
||||
global_state.backend_table.save_init_args(backend_tag, arg_list)
|
||||
|
||||
# set the backend config inside the router
|
||||
# particularly for max-batch-size
|
||||
ray.get(global_state.init_or_get_router().set_backend_config.remote(
|
||||
backend_tag, backend_config_dict))
|
||||
scale(backend_tag, backend_config_dict["num_replicas"])
|
||||
|
||||
|
||||
def _start_replica(backend_tag):
|
||||
@@ -163,12 +279,20 @@ def _start_replica(backend_tag):
|
||||
"Backend {} is not registered.".format(backend_tag))
|
||||
|
||||
replica_tag = "{}#{}".format(backend_tag, get_random_letters(length=6))
|
||||
|
||||
# get the info which starts the replicas
|
||||
creator = global_state.backend_table.get_backend_creator(backend_tag)
|
||||
backend_config_dict = global_state.backend_table.get_info(backend_tag)
|
||||
backend_config = BackendConfig(**backend_config_dict)
|
||||
init_args = global_state.backend_table.get_init_args(backend_tag)
|
||||
|
||||
# get actor creation kwargs
|
||||
actor_kwargs = backend_config.get_actor_creation_args(init_args)
|
||||
|
||||
# Create the runner in the nursery
|
||||
[runner_handle] = ray.get(
|
||||
global_state.actor_nursery_handle.start_actor_with_creator.remote(
|
||||
creator, replica_tag))
|
||||
creator, actor_kwargs, replica_tag))
|
||||
|
||||
# Setup the worker
|
||||
ray.get(
|
||||
@@ -216,7 +340,8 @@ def scale(backend_tag, num_replicas):
|
||||
"""
|
||||
assert backend_tag in global_state.backend_table.list_backends(), (
|
||||
"Backend {} is not registered.".format(backend_tag))
|
||||
assert num_replicas > 0, "Number of replicas must be greater than 1."
|
||||
assert num_replicas >= 0, ("Number of replicas must be"
|
||||
" greater than or equal to 0.")
|
||||
|
||||
replicas = global_state.backend_table.list_replicas(backend_tag)
|
||||
current_num_replicas = len(replicas)
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class BackendConfig:
|
||||
# configs not needed for actor creation when
|
||||
# instantiating a replica
|
||||
_serve_configs = ["_num_replicas", "max_batch_size"]
|
||||
|
||||
# configs which when changed leads to restarting
|
||||
# the existing replicas.
|
||||
restart_on_change_fields = ["resources", "num_cpus", "num_gpus"]
|
||||
|
||||
def __init__(self,
|
||||
num_replicas=1,
|
||||
resources=None,
|
||||
max_batch_size=None,
|
||||
num_cpus=None,
|
||||
num_gpus=None,
|
||||
memory=None,
|
||||
object_store_memory=None):
|
||||
"""
|
||||
Class for defining backend configuration.
|
||||
"""
|
||||
|
||||
# serve configs
|
||||
self.num_replicas = num_replicas
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
# ray actor configs
|
||||
self.resources = resources
|
||||
self.num_cpus = num_cpus
|
||||
self.num_gpus = num_gpus
|
||||
self.memory = memory
|
||||
self.object_store_memory = object_store_memory
|
||||
|
||||
@property
|
||||
def num_replicas(self):
|
||||
return self._num_replicas
|
||||
|
||||
@num_replicas.setter
|
||||
def num_replicas(self, val):
|
||||
if not (val > 0):
|
||||
raise Exception("num_replicas must be greater than zero")
|
||||
self._num_replicas = val
|
||||
|
||||
def __iter__(self):
|
||||
for k in self.__dict__.keys():
|
||||
key, val = k, self.__dict__[k]
|
||||
if key == "_num_replicas":
|
||||
key = "num_replicas"
|
||||
yield key, val
|
||||
|
||||
def get_actor_creation_args(self, init_args):
|
||||
ret_d = deepcopy(self.__dict__)
|
||||
for k in self._serve_configs:
|
||||
ret_d.pop(k)
|
||||
ret_d["args"] = init_args
|
||||
return ret_d
|
||||
@@ -14,6 +14,11 @@ class TaskContext(IntEnum):
|
||||
# web == False: currently processing a request from python
|
||||
web = False
|
||||
|
||||
# batching information in serve context
|
||||
# batch_size == None : the backend doesn't support batching
|
||||
# batch_size(int) : the number of elements of input list
|
||||
batch_size = None
|
||||
|
||||
_not_in_web_context_error = """
|
||||
Accessing the request object outside of the web context. Please use
|
||||
"serve.context.web" to determine when the function is called within
|
||||
@@ -21,7 +26,7 @@ a web context.
|
||||
"""
|
||||
|
||||
|
||||
class FakeFlaskQuest:
|
||||
class FakeFlaskRequest:
|
||||
def __getattribute__(self, name):
|
||||
raise RayServeException(_not_in_web_context_error)
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ class MagicCounter:
|
||||
def __call__(self, flask_request, base_number=None):
|
||||
if serve.context.web:
|
||||
base_number = int(flask_request.args.get("base_number", "0"))
|
||||
|
||||
return base_number + self.increment
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Example actor that adds an increment to a number. This number can
|
||||
come from either web (parsing Flask request) or python call.
|
||||
The queries incoming to this actor are batched.
|
||||
This actor can be called from HTTP as well as from Python.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
import ray
|
||||
from ray.experimental import serve
|
||||
from ray.experimental.serve.utils import pformat_color_json
|
||||
from ray.experimental.serve import BackendConfig
|
||||
|
||||
|
||||
class MagicCounter:
|
||||
def __init__(self, increment):
|
||||
self.increment = increment
|
||||
|
||||
@serve.accept_batch
|
||||
def __call__(self, flask_request_list, base_number=None):
|
||||
# batch_size = serve.context.batch_size
|
||||
if serve.context.web:
|
||||
result = []
|
||||
for flask_request in flask_request_list:
|
||||
base_number = int(flask_request.args.get("base_number", "0"))
|
||||
result.append(base_number)
|
||||
return list(map(lambda x: x + self.increment, result))
|
||||
else:
|
||||
result = []
|
||||
for b in base_number:
|
||||
ans = b + self.increment
|
||||
result.append(ans)
|
||||
return result
|
||||
|
||||
|
||||
serve.init(blocking=True)
|
||||
serve.create_endpoint("magic_counter", "/counter", blocking=True)
|
||||
b_config = BackendConfig(max_batch_size=5)
|
||||
serve.create_backend(
|
||||
MagicCounter, "counter:v1", 42, backend_config=b_config) # increment=42
|
||||
serve.link("magic_counter", "counter:v1")
|
||||
|
||||
print("Sending ten queries via HTTP")
|
||||
for i in range(10):
|
||||
url = "http://127.0.0.1:8000/counter?base_number={}".format(i)
|
||||
print("> Pinging {}".format(url))
|
||||
resp = requests.get(url).json()
|
||||
print(pformat_color_json(resp))
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
print("Sending ten queries via Python")
|
||||
handle = serve.get_handle("magic_counter")
|
||||
for i in range(10):
|
||||
print("> Pinging handle.remote(base_number={})".format(i))
|
||||
result = ray.get(handle.remote(base_number=i))
|
||||
print("< Result {}".format(result))
|
||||
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
This example has backend which has batching functionality enabled.
|
||||
"""
|
||||
|
||||
import ray
|
||||
from ray.experimental import serve
|
||||
from ray.experimental.serve import BackendConfig
|
||||
|
||||
|
||||
class MagicCounter:
|
||||
def __init__(self, increment):
|
||||
self.increment = increment
|
||||
|
||||
@serve.accept_batch
|
||||
def __call__(self, flask_request, base_number=None):
|
||||
# __call__ fn should preserve the batch size
|
||||
# base_number is a python list
|
||||
|
||||
if serve.context.batch_size is not None:
|
||||
batch_size = serve.context.batch_size
|
||||
result = []
|
||||
for base_num in base_number:
|
||||
ret_str = "Number: {} Batch size: {}".format(
|
||||
base_num, batch_size)
|
||||
result.append(ret_str)
|
||||
return result
|
||||
return ""
|
||||
|
||||
|
||||
serve.init(blocking=True)
|
||||
serve.create_endpoint("magic_counter", "/counter", blocking=True)
|
||||
# specify max_batch_size in BackendConfig
|
||||
b_config = BackendConfig(max_batch_size=5)
|
||||
serve.create_backend(
|
||||
MagicCounter, "counter:v1", 42, backend_config=b_config) # increment=42
|
||||
print("Backend Config for backend: 'counter:v1'")
|
||||
print(b_config)
|
||||
serve.link("magic_counter", "counter:v1")
|
||||
|
||||
handle = serve.get_handle("magic_counter")
|
||||
future_list = []
|
||||
|
||||
# fire 30 requests
|
||||
for r in range(30):
|
||||
print("> [REMOTE] Pinging handle.remote(base_number={})".format(r))
|
||||
f = handle.remote(base_number=r)
|
||||
future_list.append(f)
|
||||
|
||||
# get results of queries as they complete
|
||||
left_futures = future_list
|
||||
while left_futures:
|
||||
completed_futures, remaining_futures = ray.wait(left_futures, timeout=0.05)
|
||||
if len(completed_futures) > 0:
|
||||
result = ray.get(completed_futures[0])
|
||||
print("< " + result)
|
||||
left_futures = remaining_futures
|
||||
@@ -27,6 +27,7 @@ def echo_v1(flask_request, response="hello from python!"):
|
||||
|
||||
|
||||
serve.create_backend(echo_v1, "echo:v1")
|
||||
backend_config_v1 = serve.get_backend_config("echo:v1")
|
||||
|
||||
# We can link an endpoint to a backend, the means all the traffic
|
||||
# goes to my_endpoint will now goes to echo:v1 backend.
|
||||
@@ -47,6 +48,7 @@ def echo_v2(flask_request):
|
||||
|
||||
|
||||
serve.create_backend(echo_v2, "echo:v2")
|
||||
backend_config_v2 = serve.get_backend_config("echo:v2")
|
||||
|
||||
# The two backend will now split the traffic 50%-50%.
|
||||
serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
|
||||
@@ -56,9 +58,12 @@ for _ in range(10):
|
||||
print(requests.get("http://127.0.0.1:8000/echo").json())
|
||||
time.sleep(0.5)
|
||||
|
||||
# You can also scale each backend independently.
|
||||
serve.scale("echo:v1", 2)
|
||||
serve.scale("echo:v2", 2)
|
||||
# You can also change number of replicas
|
||||
# for each backend independently.
|
||||
backend_config_v1.num_replicas = 2
|
||||
serve.set_backend_config("echo:v1", backend_config_v1)
|
||||
backend_config_v2.num_replicas = 2
|
||||
serve.set_backend_config("echo:v2", backend_config_v2)
|
||||
|
||||
# As well as retrieving relevant system metrics
|
||||
print(pformat_color_json(serve.stat()))
|
||||
|
||||
@@ -43,8 +43,14 @@ class ActorNursery:
|
||||
self.actor_handles[handle] = tag
|
||||
return [handle]
|
||||
|
||||
def start_actor_with_creator(self, creator, tag):
|
||||
handle = creator()
|
||||
def start_actor_with_creator(self, creator, kwargs, tag):
|
||||
"""
|
||||
Args:
|
||||
creator (Callable[Dict]): a closure that should return
|
||||
a newly created actor handle when called with kwargs.
|
||||
The kwargs input is passed to `ActorCls_remote` method.
|
||||
"""
|
||||
handle = creator(kwargs)
|
||||
self.actor_handles[handle] = tag
|
||||
return [handle]
|
||||
|
||||
|
||||
@@ -212,11 +212,26 @@ class BackendTable:
|
||||
def __init__(self, kv_connector):
|
||||
self.backend_table = kv_connector("backend_creator")
|
||||
self.replica_table = kv_connector("replica_table")
|
||||
self.backend_info = kv_connector("backend_info")
|
||||
self.backend_init_args = kv_connector("backend_init_args")
|
||||
|
||||
def register_backend(self, backend_tag: str, backend_creator):
|
||||
backend_creator_serialized = pickle.dumps(backend_creator)
|
||||
self.backend_table.put(backend_tag, backend_creator_serialized)
|
||||
|
||||
def save_init_args(self, backend_tag: str, arg_list):
|
||||
serialized_arg_list = pickle.dumps(arg_list)
|
||||
self.backend_init_args.put(backend_tag, serialized_arg_list)
|
||||
|
||||
def get_init_args(self, backend_tag):
|
||||
return pickle.loads(self.backend_init_args.get(backend_tag))
|
||||
|
||||
def register_info(self, backend_tag: str, backend_info_d):
|
||||
self.backend_info.put(backend_tag, json.dumps(backend_info_d))
|
||||
|
||||
def get_info(self, backend_tag):
|
||||
return json.loads(self.backend_info.get(backend_tag, "{}"))
|
||||
|
||||
def get_backend_creator(self, backend_tag):
|
||||
return pickle.loads(self.backend_table.get(backend_tag))
|
||||
|
||||
|
||||
@@ -83,6 +83,9 @@ class CentralizedQueues:
|
||||
# service_name -> traffic_policy
|
||||
self.traffic = defaultdict(dict)
|
||||
|
||||
# backend_name -> backend_config
|
||||
self.backend_info = dict()
|
||||
|
||||
# backend_name -> worker request queue
|
||||
self.workers = defaultdict(deque)
|
||||
|
||||
@@ -157,6 +160,11 @@ class CentralizedQueues:
|
||||
self.traffic[service] = traffic_dict
|
||||
self.flush()
|
||||
|
||||
def set_backend_config(self, backend, config_dict):
|
||||
logger.debug("Setting backend config for "
|
||||
"backend {} to {}".format(backend, config_dict))
|
||||
self.backend_info[backend] = config_dict
|
||||
|
||||
def flush(self):
|
||||
"""In the default case, flush calls ._flush.
|
||||
|
||||
@@ -184,11 +192,23 @@ class CentralizedQueues:
|
||||
|
||||
buffer_queue = self.buffer_queues[backend]
|
||||
work_queue = self.workers[backend]
|
||||
max_batch_size = None
|
||||
if backend in self.backend_info:
|
||||
max_batch_size = self.backend_info[backend][
|
||||
"max_batch_size"]
|
||||
|
||||
while len(buffer_queue) and len(work_queue):
|
||||
request, work = (
|
||||
buffer_queue.pop(0),
|
||||
work_queue.popleft(),
|
||||
)
|
||||
# get the work from work intent queue
|
||||
work = work_queue.popleft()
|
||||
# see if backend accepts batched queries
|
||||
if max_batch_size is not None:
|
||||
pop_size = min(len(buffer_queue), max_batch_size)
|
||||
request = [
|
||||
buffer_queue.pop(0) for _ in range(pop_size)
|
||||
]
|
||||
else:
|
||||
request = buffer_queue.pop(0)
|
||||
|
||||
work.replica_handle._ray_serve_call.remote(request)
|
||||
|
||||
# selects the backend and puts the service queue query to the buffer
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import io
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import ray
|
||||
from ray.experimental.serve import context as serve_context
|
||||
from ray.experimental.serve.context import FakeFlaskQuest, TaskContext
|
||||
from ray.experimental.serve.http_util import build_flask_request
|
||||
from ray.experimental.serve.context import FakeFlaskRequest
|
||||
from collections import defaultdict
|
||||
from ray.experimental.serve.utils import parse_request_item
|
||||
from ray.experimental.serve.exceptions import RayServeException
|
||||
|
||||
|
||||
class TaskRunner:
|
||||
@@ -93,23 +94,10 @@ class RayServeMixin:
|
||||
self._ray_serve_dequeue_requester_name,
|
||||
self._ray_serve_self_handle)
|
||||
|
||||
def _ray_serve_call(self, request):
|
||||
work_item = request
|
||||
|
||||
if work_item.request_context == TaskContext.Web:
|
||||
serve_context.web = True
|
||||
asgi_scope, body_bytes = work_item.request_args
|
||||
flask_request = build_flask_request(asgi_scope,
|
||||
io.BytesIO(body_bytes))
|
||||
args = (flask_request, )
|
||||
kwargs = {}
|
||||
else:
|
||||
serve_context.web = False
|
||||
args = (FakeFlaskQuest(), )
|
||||
kwargs = work_item.request_kwargs
|
||||
|
||||
result_object_id = work_item.result_object_id
|
||||
|
||||
def invoke_single(self, request_item):
|
||||
args, kwargs, is_web_context, result_object_id = parse_request_item(
|
||||
request_item)
|
||||
serve_context.web = is_web_context
|
||||
start_timestamp = time.time()
|
||||
try:
|
||||
result = self.__call__(*args, **kwargs)
|
||||
@@ -121,8 +109,91 @@ class RayServeMixin:
|
||||
result_object_id)
|
||||
self._serve_metric_latency_list.append(time.time() - start_timestamp)
|
||||
|
||||
serve_context.web = False
|
||||
def invoke_batch(self, request_item_list):
|
||||
# TODO(alind) : create no-http services. The enqueues
|
||||
# from such services will always be TaskContext.Python.
|
||||
|
||||
# Assumption : all the requests in a bacth
|
||||
# have same serve context.
|
||||
|
||||
# For batching kwargs are modified as follows -
|
||||
# kwargs [Python Context] : key,val
|
||||
# kwargs_list : key, [val1,val2, ... , valn]
|
||||
# or
|
||||
# args[Web Context] : val
|
||||
# args_list : [val1,val2, ...... , valn]
|
||||
# where n (current batch size) <= max_batch_size of a backend
|
||||
|
||||
kwargs_list = defaultdict(list)
|
||||
result_object_ids, context_flag_list, arg_list = [], [], []
|
||||
curr_batch_size = len(request_item_list)
|
||||
|
||||
for item in request_item_list:
|
||||
args, kwargs, is_web_context, result_object_id = (
|
||||
parse_request_item(item))
|
||||
context_flag_list.append(is_web_context)
|
||||
|
||||
# Python context only have kwargs
|
||||
# Web context only have one positional argument
|
||||
if is_web_context:
|
||||
arg_list.append(args[0])
|
||||
else:
|
||||
for k, v in kwargs.items():
|
||||
kwargs_list[k].append(v)
|
||||
result_object_ids.append(result_object_id)
|
||||
|
||||
try:
|
||||
# check mixing of query context
|
||||
# unified context needed
|
||||
if len(set(context_flag_list)) != 1:
|
||||
raise RayServeException(
|
||||
"Batched queries contain mixed context.")
|
||||
serve_context.web = all(context_flag_list)
|
||||
if serve_context.web:
|
||||
args = (arg_list, )
|
||||
else:
|
||||
# Set the flask request as a list to conform
|
||||
# with batching semantics: when in batching
|
||||
# mode, each argument it turned into list.
|
||||
fake_flask_request_lst = [
|
||||
FakeFlaskRequest() for _ in range(curr_batch_size)
|
||||
]
|
||||
args = (fake_flask_request_lst, )
|
||||
# set the current batch size (n) for serve_context
|
||||
serve_context.batch_size = len(result_object_ids)
|
||||
start_timestamp = time.time()
|
||||
result_list = self.__call__(*args, **kwargs_list)
|
||||
if (not isinstance(result_list, list)) or (len(result_list) !=
|
||||
len(result_object_ids)):
|
||||
raise RayServeException("__call__ function "
|
||||
"doesn't preserve batch-size. "
|
||||
"Please return a list of result "
|
||||
"with length equals to the batch "
|
||||
"size.")
|
||||
for result, result_object_id in zip(result_list,
|
||||
result_object_ids):
|
||||
ray.worker.global_worker.put_object(result, result_object_id)
|
||||
self._serve_metric_latency_list.append(time.time() -
|
||||
start_timestamp)
|
||||
except Exception as e:
|
||||
wrapped_exception = wrap_to_ray_error(e)
|
||||
self._serve_metric_error_counter += len(result_object_ids)
|
||||
for result_object_id in result_object_ids:
|
||||
ray.worker.global_worker.put_object(wrapped_exception,
|
||||
result_object_id)
|
||||
|
||||
def _ray_serve_call(self, request):
|
||||
work_item = request
|
||||
# check if work_item is a list or not
|
||||
# if it is list: then batching supported
|
||||
if not isinstance(work_item, list):
|
||||
self.invoke_single(work_item)
|
||||
else:
|
||||
self.invoke_batch(work_item)
|
||||
|
||||
# re-assign to default values
|
||||
serve_context.web = False
|
||||
serve_context.batch_size = None
|
||||
self._ray_serve_fetch()
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,10 @@ from ray.experimental import serve
|
||||
@pytest.fixture(scope="session")
|
||||
def serve_instance():
|
||||
_, new_db_path = tempfile.mkstemp(suffix=".test.db")
|
||||
serve.init(kv_store_path=new_db_path, blocking=True)
|
||||
serve.init(
|
||||
kv_store_path=new_db_path,
|
||||
blocking=True,
|
||||
ray_init_kwargs={"num_cpus": 36})
|
||||
yield
|
||||
os.remove(new_db_path)
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from ray.experimental import serve
|
||||
from ray.experimental.serve import BackendConfig
|
||||
import ray
|
||||
|
||||
|
||||
def test_e2e(serve_instance):
|
||||
@@ -50,11 +52,10 @@ def test_scaling_replicas(serve_instance):
|
||||
while "/increment" not in requests.get("http://127.0.0.1:8000/").json():
|
||||
time.sleep(0.2)
|
||||
|
||||
serve.create_backend(Counter, "counter:v1")
|
||||
b_config = BackendConfig(num_replicas=2)
|
||||
serve.create_backend(Counter, "counter:v1", backend_config=b_config)
|
||||
serve.link("counter", "counter:v1")
|
||||
|
||||
serve.scale("counter:v1", 2)
|
||||
|
||||
counter_result = []
|
||||
for _ in range(10):
|
||||
resp = requests.get("http://127.0.0.1:8000/increment").json()["result"]
|
||||
@@ -63,7 +64,9 @@ def test_scaling_replicas(serve_instance):
|
||||
# If the load is shared among two replicas. The max result cannot be 10.
|
||||
assert max(counter_result) < 10
|
||||
|
||||
serve.scale("counter:v1", 1)
|
||||
b_config = serve.get_backend_config("counter:v1")
|
||||
b_config.num_replicas = 1
|
||||
serve.set_backend_config("counter:v1", b_config)
|
||||
|
||||
counter_result = []
|
||||
for _ in range(10):
|
||||
@@ -72,3 +75,126 @@ def test_scaling_replicas(serve_instance):
|
||||
# Give some time for a replica to spin down. But majority of the request
|
||||
# should be served by the only remaining replica.
|
||||
assert max(counter_result) - min(counter_result) > 6
|
||||
|
||||
|
||||
def test_batching(serve_instance):
|
||||
class BatchingExample:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
@serve.accept_batch
|
||||
def __call__(self, flask_request, temp=None):
|
||||
self.count += 1
|
||||
batch_size = serve.context.batch_size
|
||||
return [self.count] * batch_size
|
||||
|
||||
serve.create_endpoint("counter1", "/increment")
|
||||
|
||||
# Keep checking the routing table until /increment is populated
|
||||
while "/increment" not in requests.get("http://127.0.0.1:8000/").json():
|
||||
time.sleep(0.2)
|
||||
|
||||
# set the max batch size
|
||||
b_config = BackendConfig(max_batch_size=5)
|
||||
serve.create_backend(
|
||||
BatchingExample, "counter:v11", backend_config=b_config)
|
||||
serve.link("counter1", "counter:v11")
|
||||
|
||||
future_list = []
|
||||
handle = serve.get_handle("counter1")
|
||||
for _ in range(20):
|
||||
f = handle.remote(temp=1)
|
||||
future_list.append(f)
|
||||
|
||||
counter_result = ray.get(future_list)
|
||||
# since count is only updated per batch of queries
|
||||
# If there atleast one __call__ fn call with batch size greater than 1
|
||||
# counter result will always be less than 20
|
||||
assert max(counter_result) < 20
|
||||
|
||||
|
||||
def test_batching_exception(serve_instance):
|
||||
class NoListReturned:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
@serve.accept_batch
|
||||
def __call__(self, flask_request, temp=None):
|
||||
batch_size = serve.context.batch_size
|
||||
return batch_size
|
||||
|
||||
serve.create_endpoint("exception-test", "/noListReturned")
|
||||
# set the max batch size
|
||||
b_config = BackendConfig(max_batch_size=5)
|
||||
serve.create_backend(
|
||||
NoListReturned, "exception:v1", backend_config=b_config)
|
||||
serve.link("exception-test", "exception:v1")
|
||||
|
||||
handle = serve.get_handle("exception-test")
|
||||
with pytest.raises(ray.exceptions.RayTaskError):
|
||||
assert ray.get(handle.remote(temp=1))
|
||||
|
||||
|
||||
def test_killing_replicas(serve_instance):
|
||||
class Simple:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
def __call__(self, flask_request, temp=None):
|
||||
return temp
|
||||
|
||||
serve.create_endpoint("simple", "/simple")
|
||||
b_config = BackendConfig(num_replicas=3, num_cpus=2)
|
||||
serve.create_backend(Simple, "simple:v1", backend_config=b_config)
|
||||
global_state = serve.api._get_global_state()
|
||||
old_replica_tag_list = global_state.backend_table.list_replicas(
|
||||
"simple:v1")
|
||||
|
||||
bnew_config = serve.get_backend_config("simple:v1")
|
||||
# change the config
|
||||
bnew_config.num_cpus = 1
|
||||
# set the config
|
||||
serve.set_backend_config("simple:v1", bnew_config)
|
||||
new_replica_tag_list = global_state.backend_table.list_replicas(
|
||||
"simple:v1")
|
||||
global_state.refresh_actor_handle_cache()
|
||||
new_all_tag_list = list(global_state.actor_handle_cache.keys())
|
||||
|
||||
# the new_replica_tag_list must be subset of all_tag_list
|
||||
assert set(new_replica_tag_list) <= set(new_all_tag_list)
|
||||
|
||||
# the old_replica_tag_list must not be subset of all_tag_list
|
||||
assert not set(old_replica_tag_list) <= set(new_all_tag_list)
|
||||
|
||||
|
||||
def test_not_killing_replicas(serve_instance):
|
||||
class BatchSimple:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
@serve.accept_batch
|
||||
def __call__(self, flask_request, temp=None):
|
||||
batch_size = serve.context.batch_size
|
||||
return [1] * batch_size
|
||||
|
||||
serve.create_endpoint("bsimple", "/bsimple")
|
||||
b_config = BackendConfig(num_replicas=3, max_batch_size=2)
|
||||
serve.create_backend(BatchSimple, "bsimple:v1", backend_config=b_config)
|
||||
global_state = serve.api._get_global_state()
|
||||
old_replica_tag_list = global_state.backend_table.list_replicas(
|
||||
"bsimple:v1")
|
||||
|
||||
bnew_config = serve.get_backend_config("bsimple:v1")
|
||||
# change the config
|
||||
bnew_config.max_batch_size = 5
|
||||
# set the config
|
||||
serve.set_backend_config("bsimple:v1", bnew_config)
|
||||
new_replica_tag_list = global_state.backend_table.list_replicas(
|
||||
"bsimple:v1")
|
||||
global_state.refresh_actor_handle_cache()
|
||||
new_all_tag_list = list(global_state.actor_handle_cache.keys())
|
||||
|
||||
# the old and new replica tag list should be identical
|
||||
# and should be subset of all_tag_list
|
||||
assert set(old_replica_tag_list) <= set(new_all_tag_list)
|
||||
assert set(old_replica_tag_list) == set(new_replica_tag_list)
|
||||
|
||||
@@ -3,9 +3,28 @@ import logging
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
import io
|
||||
|
||||
import requests
|
||||
from pygments import formatters, highlight, lexers
|
||||
from ray.experimental.serve.context import FakeFlaskRequest, TaskContext
|
||||
from ray.experimental.serve.http_util import build_flask_request
|
||||
|
||||
|
||||
def parse_request_item(request_item):
|
||||
if request_item.request_context == TaskContext.Web:
|
||||
is_web_context = True
|
||||
asgi_scope, body_bytes = request_item.request_args
|
||||
flask_request = build_flask_request(asgi_scope, io.BytesIO(body_bytes))
|
||||
args = (flask_request, )
|
||||
kwargs = {}
|
||||
else:
|
||||
is_web_context = False
|
||||
args = (FakeFlaskRequest(), )
|
||||
kwargs = request_item.request_kwargs
|
||||
|
||||
result_object_id = request_item.result_object_id
|
||||
return args, kwargs, is_web_context, result_object_id
|
||||
|
||||
|
||||
def _get_logger():
|
||||
|
||||
Reference in New Issue
Block a user