[Serve] Implement ServeHandle refactoring (#10527)

This commit is contained in:
Simon Mo
2020-09-04 15:50:56 -07:00
committed by GitHub
parent 6b6780a108
commit 55b6c19d98
21 changed files with 396 additions and 384 deletions
+15 -73
View File
@@ -2,7 +2,6 @@ import asyncio
import traceback
import inspect
from collections.abc import Iterable
from collections import defaultdict
from itertools import groupby
from typing import Union, List, Any, Callable, Type
import time
@@ -10,8 +9,6 @@ import time
import ray
from ray.async_compat import sync_to_async
from ray.serve import context as serve_context
from ray.serve.context import FakeFlaskRequest
from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
unpack_future)
from ray.serve.exceptions import RayServeException
@@ -212,43 +209,18 @@ class RayServeWorker:
return self.callable
return getattr(self.callable, method_name)
def has_positional_args(self, f: Callable) -> bool:
# NOTE:
# In the case of simple functions, not actors, the f will be
# function.__call__, but we need to inspect the function itself.
if self.is_function:
f = self.callable
signature = inspect.signature(f)
for param in signature.parameters.values():
if (param.kind == param.POSITIONAL_OR_KEYWORD
and param.default is param.empty):
return True
return False
def _reset_context(self) -> None:
# NOTE(simon): context management won't work in async mode because
# many concurrent queries might be running at the same time.
serve_context.web = None
serve_context.batch_size = None
async def invoke_single(self, request_item: Query) -> Any:
args, kwargs, is_web_context = parse_request_item(request_item)
serve_context.web = is_web_context
method_to_call = self.get_runner_method(request_item)
args = args if self.has_positional_args(method_to_call) else []
method_to_call = ensure_async(method_to_call)
method_to_call = ensure_async(self.get_runner_method(request_item))
arg = parse_request_item(request_item)
start = time.time()
try:
result = await method_to_call(*args, **kwargs)
result = await method_to_call(arg)
self.request_counter.record(1, {"backend": self.backend_tag})
except Exception as e:
result = wrap_to_ray_error(e)
self.error_counter.record(1, {"backend": self.backend_tag})
finally:
self._reset_context()
self.processing_latency_tracker.record(
(time.time() - start) * 1000, {
"backend": self.backend_tag,
@@ -259,56 +231,28 @@ class RayServeWorker:
return result
async def invoke_batch(self, request_item_list: List[Query]) -> List[Any]:
arg_list = []
kwargs_list = defaultdict(list)
context_flags = set()
batch_size = len(request_item_list)
args = []
call_methods = set()
batch_size = len(request_item_list)
# Construct the batch of requests
for item in request_item_list:
args, kwargs, is_web_context = parse_request_item(item)
context_flags.add(is_web_context)
call_method = self.get_runner_method(item)
call_methods.add(call_method)
if is_web_context:
# Python context only have kwargs
flask_request = args[0]
arg_list.append(flask_request)
else:
# Web context only have one positional argument
for k, v in kwargs.items():
kwargs_list[k].append(v)
# Set the flask request as a list to conform
# with batching semantics: when in batching
# mode, each argument is turned into list.
if self.has_positional_args(call_method):
arg_list.append(FakeFlaskRequest())
args.append(parse_request_item(item))
call_methods.add(self.get_runner_method(item))
timing_start = time.time()
try:
# Check mixing of query context (unified context needed).
if len(context_flags) != 1:
raise RayServeException(
"Batched queries contain mixed context. Please only send "
"the same type of requests in batching mode.")
serve_context.web = context_flags.pop()
if len(call_methods) != 1:
raise RayServeException(
"Queries contain mixed calling methods. Please only send "
"the same type of requests in batching mode.")
call_method = ensure_async(call_methods.pop())
serve_context.batch_size = batch_size
# Flask requests are passed to __call__ as a list
arg_list = [arg_list]
f"Queries contain mixed calling methods: {call_methods}. "
"Please only send the same type of requests in batching "
"mode.")
self.request_counter.record(batch_size,
{"backend": self.backend_tag})
result_list = await call_method(*arg_list, **kwargs_list)
call_method = ensure_async(call_methods.pop())
result_list = await call_method(args)
if not isinstance(result_list, Iterable) or isinstance(
result_list, (dict, set)):
@@ -328,11 +272,9 @@ class RayServeWorker:
"results with length equal to the batch size"
".".format(batch_size, len(result_list)))
raise RayServeException(error_message)
self._reset_context()
except Exception as e:
wrapped_exception = wrap_to_ray_error(e)
self.error_counter.record(1, {"backend": self.backend_tag})
self._reset_context()
result_list = [wrapped_exception for _ in range(batch_size)]
self.processing_latency_tracker.record(
-27
View File
@@ -1,34 +1,7 @@
from enum import IntEnum
from ray.serve.exceptions import RayServeException
class TaskContext(IntEnum):
"""TaskContext constants for queue.enqueue method"""
Web = 1
Python = 2
# Global variable will be modified in worker
# web == True: currrently processing a request from web server
# web == False: currently processing a request from python
web = False
# batching information in serve context
# batch_size == None : the backend doesn't support batching
# batch_size(int) : the number of elements of input list
batch_size = None
_not_in_web_context_error = """
Accessing the request object outside of the web context. Please use
"serve.context.web" to determine when the function is called within
a web context.
"""
class FakeFlaskRequest:
def __getattribute__(self, name):
raise RayServeException(_not_in_web_context_error)
def __setattr__(self, name, value):
raise RayServeException(_not_in_web_context_error)
@@ -15,14 +15,14 @@ client = serve.start()
# Let's define two models that just print out the data they received.
def model_one(_unused_flask_request, data=None):
print("Model 1 called with data ", data)
def model_one(request):
print("Model 1 called with data ", request.args.get("data"))
return random()
def model_two(_unused_flask_request, data=None):
print("Model 2 called with data ", data)
return data
def model_two(request):
print("Model 2 called with data ", request.args.get("data"))
return request.args.get("data")
class ComposedModel:
@@ -57,17 +57,8 @@ print("Result returned:", results)
# __doc_define_servable_v1_begin__
@serve.accept_batch
def batch_adder_v1(flask_requests: List, *, numbers: List = []):
# Depending on request context, we process the input data differently.
print("Current context is", "web" if serve.context.web else "python")
if serve.context.web:
# If the requests come from web request, we parse the flask request
# to numbers
numbers = [int(request.args["number"]) for request in flask_requests]
else:
# Otherwise, we are processing requests invoked directly from Python.
numbers = numbers
def batch_adder_v1(requests: List):
numbers = [int(request.args["number"]) for request in requests]
input_array = np.array(numbers)
print("Our input array has shape:", input_array.shape)
# Sleep for 200ms, this could be performing CPU intensive computation
@@ -97,7 +88,7 @@ input_batch = list(range(9))
print("Input batch is", input_batch)
# Input batch is [0, 1, 2, 3, 4, 5, 6, 7, 8]
result_batch = ray.get([handle.remote(numbers=i) for i in input_batch])
result_batch = ray.get([handle.remote(number=i) for i in input_batch])
# Output
# (pid=...) Current context is python
# (pid=...) Our input array has shape: (1,)
+3 -4
View File
@@ -12,9 +12,8 @@ client = serve.start()
# a backend can be a function or class.
# it can be made to be invoked from web as well as python.
def echo_v1(flask_request, response="hello from python!"):
if serve.context.web:
response = flask_request.url
def echo_v1(flask_request):
response = flask_request.args.get("response", "web")
return response
@@ -46,7 +45,7 @@ client.set_traffic("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
# Observe requests are now split between two backends.
for _ in range(10):
print(requests.get("http://127.0.0.1:8000/echo").text)
time.sleep(0.5)
time.sleep(0.2)
# You can also change number of replicas for each backend independently.
client.update_backend_config("echo:v1", {"num_replicas": 2})
+27 -15
View File
@@ -1,7 +1,7 @@
from typing import Optional
from typing import Optional, Dict, Any, Union
from ray.serve.context import TaskContext
from ray.serve.request_params import RequestMetadata
from ray.serve.router import RequestMetadata
class RayServeHandle:
@@ -29,42 +29,53 @@ class RayServeHandle:
self,
router_handle,
endpoint_name,
http_method=None,
*,
method_name=None,
shard_key=None,
http_method=None,
http_headers=None,
):
self.router_handle = router_handle
self.endpoint_name = endpoint_name
self.http_method = http_method
self.method_name = method_name
self.shard_key = shard_key
self.http_method = http_method
self.http_headers = http_headers
def remote(self, *args, **kwargs):
"""Invoke a request on the endpoint.
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
**kwargs):
"""Issue an asynchrounous request to the endpoint.
Returns a Ray ObjectRef whose result can be waited for or retrieved
using `ray.wait` or `ray.get`, respectively.
Returns a Ray ObjectRef whose results can be waited for or retrieved
using ray.wait or ray.get, respectively.
Returns:
ray.ObjectRef
Input:
request_data(dict, Any): If it's a dictionary, the data will be
available in ``request.json()`` or ``request.form()``. Otherwise,
it will be available in ``request.data``.
``**kwargs``: All keyword arguments will be available in
``request.args``.
"""
if len(args) > 0:
raise ValueError(
"handle.remote must be invoked with keyword arguments.")
request_metadata = RequestMetadata(
self.endpoint_name,
TaskContext.Python,
http_method=self.http_method or "GET",
call_method=self.method_name or "__call__",
shard_key=self.shard_key,
http_method=self.http_method or "GET",
http_headers=self.http_headers or dict(),
)
return self.router_handle.enqueue_request.remote(
request_metadata, **kwargs)
request_metadata, request_data, **kwargs)
def options(self,
method_name: Optional[str] = None,
*,
shard_key: Optional[str] = None,
http_method: Optional[str] = None,
shard_key: Optional[str] = None):
http_headers: Optional[Dict[str, str]] = None):
"""Set options for this handle.
Args:
@@ -77,9 +88,10 @@ class RayServeHandle:
self.router_handle,
self.endpoint_name,
# Don't override existing method
http_method=self.http_method or http_method,
method_name=self.method_name or method_name,
shard_key=self.shard_key or shard_key,
http_method=self.http_method or http_method,
http_headers=self.http_headers or http_headers,
)
def __repr__(self):
+1 -2
View File
@@ -8,9 +8,8 @@ import ray
from ray.exceptions import RayTaskError
from ray.serve.context import TaskContext
from ray.experimental import metrics
from ray.serve.request_params import RequestMetadata
from ray.serve.http_util import Response
from ray.serve.router import Router
from ray.serve.router import Router, RequestMetadata
# The maximum number of times to retry a request due to actor failure.
# TODO(edoakes): this should probably be configurable.
-15
View File
@@ -1,15 +0,0 @@
from dataclasses import dataclass
from typing import Optional
from ray.serve.context import TaskContext
@dataclass
class RequestMetadata:
endpoint: str
request_context: TaskContext
call_method: str = "__call__"
shard_key: Optional[str] = None
http_method: str = "GET"
is_shadow_query: bool = False
+19 -2
View File
@@ -4,7 +4,7 @@ from collections import defaultdict, deque
import time
from typing import DefaultDict, List, Dict, Any, Optional
import pickle
from dataclasses import dataclass
from dataclasses import dataclass, field
from ray.exceptions import RayTaskError
@@ -12,12 +12,29 @@ import ray
from ray.experimental import metrics
from ray.serve.context import TaskContext
from ray.serve.endpoint_policy import RandomEndpointPolicy
from ray.serve.request_params import RequestMetadata
from ray.serve.utils import logger, chain_future
REPORT_QUEUE_LENGTH_PERIOD_S = 1.0
@dataclass
class RequestMetadata:
endpoint: str
request_context: TaskContext
call_method: str = "__call__"
shard_key: Optional[str] = None
http_method: str = "GET"
http_headers: Dict[str, str] = field(default_factory=dict)
is_shadow_query: bool = False
def __post_init__(self):
self.http_headers.setdefault("X-Serve-Call-Method", self.call_method)
self.http_headers.setdefault("X-Serve-Shard-Key", self.shard_key)
@dataclass
class Query:
args: List[Any]
+2
View File
@@ -11,6 +11,8 @@ if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False):
@pytest.fixture(scope="session")
def _shared_serve_instance():
# Uncomment the line below to turn on debug log for tests.
# os.environ["SERVE_LOG_DEBUG"] = "1"
ray.init(
num_cpus=36,
_metrics_export_port=9999,
+22 -26
View File
@@ -242,9 +242,9 @@ def test_batching(serve_instance):
self.count = 0
@serve.accept_batch
def __call__(self, flask_request, temp=None):
def __call__(self, requests):
self.count += 1
batch_size = serve.context.batch_size
batch_size = len(requests)
return [self.count] * batch_size
# set the max batch size
@@ -281,10 +281,9 @@ def test_batching_legacy(serve_instance):
self.count = 0
@serve.accept_batch
def __call__(self, flask_request, temp=None):
def __call__(self, request):
self.count += 1
batch_size = serve.context.batch_size
return [self.count] * batch_size
return [self.count] * len(request)
# set the max batch size
client.create_backend(
@@ -305,7 +304,7 @@ def test_batching_legacy(serve_instance):
future_list = []
handle = client.get_handle("counter1")
for _ in range(20):
f = handle.remote(temp=1)
f = handle.remote()
future_list.append(f)
counter_result = ray.get(future_list)
@@ -323,9 +322,8 @@ def test_batching_exception(serve_instance):
self.count = 0
@serve.accept_batch
def __call__(self, flask_request, temp=None):
batch_size = serve.context.batch_size
return batch_size
def __call__(self, requests):
return len(requests)
# set the max batch size
client.create_backend(
@@ -369,9 +367,8 @@ def test_updating_config(serve_instance):
self.count = 0
@serve.accept_batch
def __call__(self, flask_request, temp=None):
batch_size = serve.context.batch_size
return [1] * batch_size
def __call__(self, request):
return [1] * len(request)
client.create_backend(
"bsimple:v1",
@@ -405,9 +402,8 @@ def test_updating_config_legacy(serve_instance):
self.count = 0
@serve.accept_batch
def __call__(self, flask_request, temp=None):
batch_size = serve.context.batch_size
return [1] * batch_size
def __call__(self, request):
return [1] * len(request)
client.create_backend(
"bsimple:v1",
@@ -439,7 +435,7 @@ def test_updating_config_legacy(serve_instance):
def test_delete_backend(serve_instance):
client = serve_instance
def function():
def function(_):
return "hello"
client.create_backend("delete:v1", function)
@@ -466,7 +462,7 @@ def test_delete_backend(serve_instance):
with pytest.raises(ValueError):
client.set_traffic("delete_backend", {"delete:v1": 1.0})
def function2():
def function2(_):
return "olleh"
# Check that we can now reuse the previously delete backend's tag.
@@ -480,7 +476,7 @@ def test_delete_backend(serve_instance):
def test_delete_endpoint(serve_instance, route):
client = serve_instance
def function():
def function(_):
return "hello"
backend_name = "delete-endpoint:v1"
@@ -521,7 +517,7 @@ def test_shard_key(serve_instance, route):
traffic_dict = {}
for i in range(num_backends):
def function():
def function(_):
return i
backend_name = "backend-split-" + str(i)
@@ -560,7 +556,7 @@ def test_multiple_instances():
client1 = serve.start(http_port=8001)
def function():
def function(_):
return "hello1"
client1.create_backend(backend, function)
@@ -572,7 +568,7 @@ def test_multiple_instances():
# the same names and check that they don't collide.
client2 = serve.start(http_port=8002)
def function():
def function(_):
return "hello2"
client2.create_backend(backend, function)
@@ -889,19 +885,19 @@ def test_shadow_traffic(serve_instance):
counter = RequestCounter.remote()
def f():
def f(_):
ray.get(counter.record.remote("backend1"))
return "hello"
def f_shadow_1():
def f_shadow_1(_):
ray.get(counter.record.remote("backend2"))
return "oops"
def f_shadow_2():
def f_shadow_2(_):
ray.get(counter.record.remote("backend3"))
return "oops"
def f_shadow_3():
def f_shadow_3(_):
ray.get(counter.record.remote("backend4"))
return "oops"
@@ -950,7 +946,7 @@ def test_connect(serve_instance):
# Check that you can call serve.connect() from within a backend for both
# detached and non-detached instances.
def connect_in_backend():
def connect_in_backend(_):
client = serve.connect()
client.create_backend("backend-ception", connect_in_backend)
return client._controller_name
+92 -148
View File
@@ -8,8 +8,7 @@ from ray import serve
import ray.serve.context as context
from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error
from ray.serve.controller import TrafficPolicy
from ray.serve.request_params import RequestMetadata
from ray.serve.router import Router
from ray.serve.router import Router, RequestMetadata
from ray.serve.config import BackendConfig, BackendMetadata
from ray.serve.exceptions import RayServeException
@@ -45,85 +44,66 @@ def setup_worker(name,
return worker
async def add_servable_to_router(servable, router, **kwargs):
worker = setup_worker("backend", servable, **kwargs)
await router.add_new_worker.remote("backend", "replica", worker)
await router.set_traffic.remote("endpoint", TrafficPolicy({
"backend": 1.0
}))
if "backend_config" in kwargs:
await router.set_backend_config.remote("backend",
kwargs["backend_config"])
return worker
def make_request_param(call_method="__call__"):
return RequestMetadata(
"endpoint", context.TaskContext.Python, call_method=call_method)
@pytest.fixture
def router(serve_instance):
q = ray.remote(Router).remote()
ray.get(q.setup.remote("", serve_instance._controller_name))
yield q
ray.kill(q)
async def test_runner_wraps_error():
wrapped = wrap_to_ray_error(Exception())
assert isinstance(wrapped, ray.exceptions.RayTaskError)
async def test_runner_actor(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("", serve_instance._controller_name)
async def test_servable_function(serve_instance, router):
def echo(request):
return request.args["i"]
def echo(flask_request, i=None):
return i
CONSUMER_NAME = "runner"
PRODUCER_NAME = "prod"
worker = setup_worker(CONSUMER_NAME, echo)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
_ = await add_servable_to_router(echo, router)
for query in [333, 444, 555]:
query_param = RequestMetadata(PRODUCER_NAME,
context.TaskContext.Python)
result = await q.enqueue_request.remote(query_param, i=query)
query_param = make_request_param()
result = await router.enqueue_request.remote(query_param, i=query)
assert result == query
async def test_ray_serve_mixin(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("", serve_instance._controller_name)
CONSUMER_NAME = "runner-cls"
PRODUCER_NAME = "prod-cls"
async def test_servable_class(serve_instance, router):
class MyAdder:
def __init__(self, inc):
self.increment = inc
def __call__(self, flask_request, i=None):
return i + self.increment
def __call__(self, request):
return request.args["i"] + self.increment
worker = setup_worker(CONSUMER_NAME, MyAdder, init_args=(3, ))
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
_ = await add_servable_to_router(MyAdder, router, init_args=(3, ))
for query in [333, 444, 555]:
query_param = RequestMetadata(PRODUCER_NAME,
context.TaskContext.Python)
result = await q.enqueue_request.remote(query_param, i=query)
query_param = make_request_param()
result = await router.enqueue_request.remote(query_param, i=query)
assert result == query + 3
async def test_task_runner_check_context(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("", serve_instance._controller_name)
def echo(flask_request, i=None):
# Accessing the flask_request without web context should throw.
return flask_request.args["i"]
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
worker = setup_worker(CONSUMER_NAME, echo)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
result_oid = q.enqueue_request.remote(query_param, i=42)
with pytest.raises(ray.exceptions.RayTaskError):
await result_oid
async def test_task_runner_custom_method_single(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("", serve_instance._controller_name)
async def test_task_runner_custom_method_single(serve_instance, router):
class NonBatcher:
def a(self, _):
return "a"
@@ -131,129 +111,97 @@ async def test_task_runner_custom_method_single(serve_instance):
def b(self, _):
return "b"
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
_ = await add_servable_to_router(NonBatcher, router)
worker = setup_worker(CONSUMER_NAME, NonBatcher)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
query_param = RequestMetadata(
PRODUCER_NAME, context.TaskContext.Python, call_method="a")
a_result = await q.enqueue_request.remote(query_param)
query_param = make_request_param("a")
a_result = await router.enqueue_request.remote(query_param)
assert a_result == "a"
query_param = RequestMetadata(
PRODUCER_NAME, context.TaskContext.Python, call_method="b")
b_result = await q.enqueue_request.remote(query_param)
query_param = make_request_param("b")
b_result = await router.enqueue_request.remote(query_param)
assert b_result == "b"
query_param = RequestMetadata(
PRODUCER_NAME, context.TaskContext.Python, call_method="non_exist")
query_param = make_request_param("non_exist")
with pytest.raises(ray.exceptions.RayTaskError):
await q.enqueue_request.remote(query_param)
await router.enqueue_request.remote(query_param)
async def test_task_runner_custom_method_batch(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("", serve_instance._controller_name)
async def test_task_runner_custom_method_batch(serve_instance, router):
@serve.accept_batch
class Batcher:
def a(self, _):
return ["a-{}".format(i) for i in range(serve.context.batch_size)]
def a(self, requests):
return ["a-{}".format(i) for i in range(len(requests))]
def b(self, _):
return ["b-{}".format(i) for i in range(serve.context.batch_size)]
def error_different_size(self, _):
return [""] * (serve.context.batch_size * 2)
def error_non_iterable(self, _):
return 42
def return_np_array(self, _):
return np.array([1] * serve.context.batch_size).astype(np.int32)
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
def b(self, requests):
return ["b-{}".format(i) for i in range(len(requests))]
backend_config = BackendConfig(
max_batch_size=4,
batch_wait_timeout=2,
batch_wait_timeout=10,
internal_metadata=BackendMetadata(accepts_batches=True))
worker = setup_worker(
CONSUMER_NAME, Batcher, backend_config=backend_config)
await q.set_traffic.remote(PRODUCER_NAME,
TrafficPolicy({
CONSUMER_NAME: 1.0
}))
await q.set_backend_config.remote(CONSUMER_NAME, backend_config)
def make_request_param(call_method):
return RequestMetadata(
PRODUCER_NAME, context.TaskContext.Python, call_method=call_method)
_ = await add_servable_to_router(
Batcher, router, backend_config=backend_config)
a_query_param = make_request_param("a")
b_query_param = make_request_param("b")
futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
futures = [router.enqueue_request.remote(a_query_param) for _ in range(2)]
futures += [router.enqueue_request.remote(b_query_param) for _ in range(2)]
gathered = await asyncio.gather(*futures)
assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
async def test_servable_batch_error(serve_instance, router):
@serve.accept_batch
class ErrorBatcher:
def error_different_size(self, requests):
return [""] * (len(requests) + 10)
def error_non_iterable(self, _):
return 42
def return_np_array(self, requests):
return np.array([1] * len(requests)).astype(np.int32)
backend_config = BackendConfig(
max_batch_size=4,
internal_metadata=BackendMetadata(accepts_batches=True))
_ = await add_servable_to_router(
ErrorBatcher, router, backend_config=backend_config)
with pytest.raises(RayServeException, match="doesn't preserve batch size"):
different_size = make_request_param("error_different_size")
await q.enqueue_request.remote(different_size)
await router.enqueue_request.remote(different_size)
with pytest.raises(RayServeException, match="iterable"):
non_iterable = make_request_param("error_non_iterable")
await q.enqueue_request.remote(non_iterable)
await router.enqueue_request.remote(non_iterable)
np_array = make_request_param("return_np_array")
result_np_value = await q.enqueue_request.remote(np_array)
result_np_value = await router.enqueue_request.remote(np_array)
assert isinstance(result_np_value, np.int32)
async def test_task_runner_perform_batch(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("", serve_instance._controller_name)
def batcher(*args, **kwargs):
return [serve.context.batch_size] * serve.context.batch_size
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
async def test_task_runner_perform_batch(serve_instance, router):
def batcher(requests):
batch_size = len(requests)
return [batch_size] * batch_size
config = BackendConfig(
max_batch_size=2,
batch_wait_timeout=10,
internal_metadata=BackendMetadata(accepts_batches=True))
worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
await q.set_backend_config.remote(CONSUMER_NAME, config)
await q.set_traffic.remote(PRODUCER_NAME,
TrafficPolicy({
CONSUMER_NAME: 1.0
}))
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
_ = await add_servable_to_router(batcher, router, backend_config=config)
query_param = make_request_param()
my_batch_sizes = await asyncio.gather(
*[q.enqueue_request.remote(query_param) for _ in range(3)])
*[router.enqueue_request.remote(query_param) for _ in range(3)])
assert my_batch_sizes == [2, 2, 1]
async def test_task_runner_perform_async(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote("", serve_instance._controller_name)
async def test_task_runner_perform_async(serve_instance, router):
@ray.remote
class Barrier:
def __init__(self, release_on):
@@ -274,22 +222,18 @@ async def test_task_runner_perform_async(serve_instance):
await barrier.wait.remote()
return "done!"
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
config = BackendConfig(
max_concurrent_queries=10,
internal_metadata=BackendMetadata(is_blocking=False))
worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
await q.set_backend_config.remote(CONSUMER_NAME, config)
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
_ = await add_servable_to_router(
wait_and_go, router, backend_config=config)
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
query_param = make_request_param()
done, not_done = await asyncio.wait(
[q.enqueue_request.remote(query_param) for _ in range(10)], timeout=10)
[router.enqueue_request.remote(query_param) for _ in range(10)],
timeout=10)
assert len(done) == 10
for item in done:
await item == "done!"
+2 -2
View File
@@ -81,11 +81,11 @@ def test_replica_config_validation():
def __call__(self):
pass
def function():
def function(_):
pass
@serve.accept_batch
def batch_function():
def batch_function(_):
pass
ReplicaConfig(Class)
+7 -7
View File
@@ -23,7 +23,7 @@ def request_with_retries(endpoint, timeout=30):
def test_controller_failure(serve_instance):
client = serve_instance
def function():
def function(_):
return "hello1"
client.create_backend("controller_failure:v1", function)
@@ -45,7 +45,7 @@ def test_controller_failure(serve_instance):
response = request_with_retries("/controller_failure", timeout=30)
assert response.text == "hello1"
def function():
def function(_):
return "hello2"
ray.kill(client._controller, no_restart=False)
@@ -57,7 +57,7 @@ def test_controller_failure(serve_instance):
response = request_with_retries("/controller_failure", timeout=30)
assert response.text == "hello2"
def function():
def function(_):
return "hello3"
ray.kill(client._controller, no_restart=False)
@@ -85,7 +85,7 @@ def _kill_routers(client):
def test_http_proxy_failure(serve_instance):
client = serve_instance
def function():
def function(_):
return "hello1"
client.create_backend("proxy_failure:v1", function)
@@ -100,7 +100,7 @@ def test_http_proxy_failure(serve_instance):
_kill_routers(client)
def function():
def function(_):
return "hello2"
client.create_backend("proxy_failure:v2", function)
@@ -213,7 +213,7 @@ def test_worker_replica_failure(serve_instance):
def test_create_backend_idempotent(serve_instance):
client = serve_instance
def f():
def f(_):
return "hello"
controller = client._controller
@@ -236,7 +236,7 @@ def test_create_backend_idempotent(serve_instance):
def test_create_endpoint_idempotent(serve_instance):
client = serve_instance
def f():
def f(_):
return "hello"
client.create_backend("my_backend", f)
+51 -3
View File
@@ -1,8 +1,8 @@
import requests
import ray
from ray import serve
import requests
def test_handle_in_endpoint(serve_instance):
client = serve_instance
@@ -16,7 +16,7 @@ def test_handle_in_endpoint(serve_instance):
client = serve.connect()
self.handle = client.get_handle("endpoint1")
def __call__(self):
def __call__(self, _):
return ray.get(self.handle.remote())
client.create_backend("endpoint1:v0", Endpoint1)
@@ -36,6 +36,54 @@ def test_handle_in_endpoint(serve_instance):
assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello"
def test_handle_http_args(serve_instance):
client = serve_instance
class Endpoint:
def __call__(self, request):
return {
"args": dict(request.args),
"headers": dict(request.headers),
"method": request.method,
"json": request.json
}
client.create_backend("backend", Endpoint)
client.create_endpoint(
"endpoint", backend="backend", route="/endpoint", methods=["POST"])
ground_truth = {
"args": {
"arg1": "1",
"arg2": "2"
},
"headers": {
"X-Custom-Header": "value"
},
"method": "POST",
"json": {
"json_key": "json_val"
}
}
resp_web = requests.post(
"http://127.0.0.1:8000/endpoint?arg1=1&arg2=2",
headers=ground_truth["headers"],
json=ground_truth["json"]).json()
handle = client.get_handle("endpoint")
resp_handle = ray.get(
handle.options(
http_method=ground_truth["method"],
http_headers=ground_truth["headers"]).remote(
ground_truth["json"], **ground_truth["args"]))
for resp in [resp_web, resp_handle]:
for field in ["args", "method", "json"]:
assert resp[field] == ground_truth[field]
resp["headers"]["X-Custom-Header"] == "value"
if __name__ == "__main__":
import sys
import pytest
+2 -2
View File
@@ -11,8 +11,8 @@ def test_np_in_composed_model(serve_instance):
# AttributeError: 'bytes' object has no attribute 'readonly'
# in cloudpickle _from_numpy_buffer
def sum_model(_request, data=None):
return np.sum(data)
def sum_model(request):
return np.sum(request.args["data"])
class ComposedModel:
def __init__(self):
+1 -2
View File
@@ -4,8 +4,7 @@ import pytest
import ray
from ray.serve.controller import TrafficPolicy
from ray.serve.router import Router, Query
from ray.serve.request_params import RequestMetadata
from ray.serve.router import Router, Query, RequestMetadata
from ray.serve.utils import get_random_letters
from ray.test_utils import SignalActor
from ray.serve.config import BackendConfig
+64 -18
View File
@@ -9,37 +9,83 @@ import time
from typing import List
import io
import os
from ray.serve.exceptions import RayServeException
import requests
import numpy as np
import pydantic
from werkzeug.datastructures import ImmutableMultiDict
import ray
from ray.serve.constants import HTTP_PROXY_TIMEOUT
from ray.serve.context import FakeFlaskRequest, TaskContext
from ray.serve.context import TaskContext
from ray.serve.http_util import build_flask_request
import numpy as np
try:
import pydantic
except ImportError:
pydantic = None
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
class ServeRequest:
"""The request object used in Python context.
ServeRequest is built to have similar API as Flask.Request. You only need
to write your model serving code once; it can be queried by both HTTP and
Python.
"""
def __init__(self, data, kwargs, headers, method):
self._data = data
self._kwargs = kwargs
self._headers = headers
self._method = method
@property
def headers(self):
"""The HTTP headers from ``handle.option(http_headers=...)``."""
return self._headers
@property
def method(self):
"""The HTTP method data from ``handle.option(http_method=...)``."""
return self._method
@property
def args(self):
"""The keyword arguments from ``handle.remote(**kwargs)``."""
return ImmutableMultiDict(self._kwargs)
@property
def json(self):
"""The request dictionary, from ``handle.remote(dict)``."""
if not isinstance(self._data, dict):
raise RayServeException("Request data is not a dictionary. "
f"It is {type(self._data)}.")
return self._data
@property
def form(self):
"""The request dictionary, from ``handle.remote(dict)``."""
if not isinstance(self._data, dict):
raise RayServeException("Request data is not a dictionary. "
f"It is {type(self._data)}.")
return self._data
@property
def data(self):
"""The request data from ``handle.remote(obj)``."""
return self._data
def parse_request_item(request_item):
if request_item.metadata.request_context == TaskContext.Web:
is_web_context = True
asgi_scope, body_bytes = request_item.args
flask_request = build_flask_request(asgi_scope, io.BytesIO(body_bytes))
args = (flask_request, )
kwargs = {}
return build_flask_request(asgi_scope, io.BytesIO(body_bytes))
else:
is_web_context = False
args = (FakeFlaskRequest(), )
kwargs = request_item.kwargs
return args, kwargs, is_web_context
return ServeRequest(
request_item.args[0] if len(request_item.args) == 1 else None,
request_item.kwargs,
headers=request_item.metadata.http_headers,
method=request_item.metadata.http_method,
)
def _get_logger():
@@ -66,7 +112,7 @@ class ServeEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=E0202
if isinstance(o, bytes):
return o.decode("utf-8")
if pydantic is not None and isinstance(o, pydantic.BaseModel):
if isinstance(o, pydantic.BaseModel):
return o.dict()
if isinstance(o, Exception):
return str(o)