mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 05:17:38 +08:00
[Serve] Implement flask_request and named python request (#5849)
* Implement flask_request and named python request * Forgot to include missing files * Address comment * Add flask to requirements for doc (lint failed) * Update doc requirement so lint will build * Install flask in CI * Fix typo in .travis.yml
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
from enum import IntEnum
|
||||
from ray.experimental.serve.exceptions import RayServeException
|
||||
|
||||
|
||||
class TaskContext(IntEnum):
|
||||
"""TaskContext constants for queue.enqueue method"""
|
||||
Web = 1
|
||||
Python = 2
|
||||
|
||||
|
||||
# Global variable will be modified in worker
|
||||
# web == True: currrently processing a request from web server
|
||||
# web == False: currently processing a request from python
|
||||
web = False
|
||||
|
||||
_not_in_web_context_error = """
|
||||
Accessing the request object outside of the web context. Please use
|
||||
"serve.context.web" to determine when the function is called within
|
||||
a web context.
|
||||
"""
|
||||
|
||||
|
||||
class FakeFlaskQuest:
|
||||
def __getattribute__(self, name):
|
||||
raise RayServeException(_not_in_web_context_error)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
raise RayServeException(_not_in_web_context_error)
|
||||
@@ -10,8 +10,8 @@ from ray.experimental import serve
|
||||
from ray.experimental.serve.utils import pformat_color_json
|
||||
|
||||
|
||||
def echo(context):
|
||||
return context
|
||||
def echo(flask_request):
|
||||
return "hello " + flask_request.args.get("name", "serve!")
|
||||
|
||||
|
||||
serve.init(blocking=True)
|
||||
|
||||
@@ -1,41 +1,47 @@
|
||||
"""
|
||||
Example actor that adds message to the end of query_string.
|
||||
Example actor that adds an increment to a number. This number can
|
||||
come from either web (parsing Flask request) or python call.
|
||||
|
||||
This actor can be called from HTTP as well as from Python.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
from werkzeug import urls
|
||||
|
||||
import ray
|
||||
from ray.experimental import serve
|
||||
from ray.experimental.serve.utils import pformat_color_json
|
||||
|
||||
|
||||
class EchoActor:
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
class MagicCounter:
|
||||
def __init__(self, increment):
|
||||
self.increment = increment
|
||||
|
||||
def __call__(self, context):
|
||||
query_string_dict = urls.url_decode(context["query_string"])
|
||||
message = ""
|
||||
message += query_string_dict.get("message", "")
|
||||
message += " "
|
||||
message += self.message
|
||||
return message
|
||||
def __call__(self, flask_request, base_number=None):
|
||||
if serve.context.web:
|
||||
base_number = int(flask_request.args.get("base_number", "0"))
|
||||
|
||||
return base_number + self.increment
|
||||
|
||||
|
||||
serve.init(blocking=True)
|
||||
serve.create_endpoint("magic_counter", "/counter", blocking=True)
|
||||
serve.create_backend(MagicCounter, "counter:v1", 42) # increment=42
|
||||
serve.link("magic_counter", "counter:v1")
|
||||
|
||||
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
|
||||
serve.create_backend(EchoActor, "echo:v1", "world")
|
||||
serve.link("my_endpoint", "echo:v1")
|
||||
|
||||
while True:
|
||||
resp = requests.get("http://127.0.0.1:8000/echo?message=hello").json()
|
||||
print("Sending ten queries via HTTP")
|
||||
for i in range(10):
|
||||
url = "http://127.0.0.1:8000/counter?base_number={}".format(i)
|
||||
print("> Pinging {}".format(url))
|
||||
resp = requests.get(url).json()
|
||||
print(pformat_color_json(resp))
|
||||
|
||||
resp = requests.get("http://127.0.0.1:8000/echo").json()
|
||||
print(pformat_color_json(resp))
|
||||
time.sleep(0.2)
|
||||
|
||||
print("...Sleeping for 2 seconds...")
|
||||
time.sleep(2)
|
||||
print("Sending ten queries via Python")
|
||||
handle = serve.get_handle("magic_counter")
|
||||
for i in range(10):
|
||||
print("> Pinging handle.remote(base_number={})".format(i))
|
||||
result = ray.get(handle.remote(base_number=i))
|
||||
print("< Result {}".format(result))
|
||||
|
||||
@@ -7,7 +7,7 @@ We are going to define a buggy function that raise some exception:
|
||||
|
||||
The expected behavior is:
|
||||
- HTTP server should respond with "internal error" in the response JSON
|
||||
- ray.get(handle.remote(33)) should raise RayTaskError with traceback.
|
||||
- ray.get(handle.remote()) should raise RayTaskError with traceback.
|
||||
|
||||
This shows that error is hidden from HTTP side but always visible when calling
|
||||
from Python.
|
||||
@@ -40,5 +40,5 @@ for _ in range(2):
|
||||
time.sleep(2)
|
||||
|
||||
handle = serve.get_handle("my_endpoint")
|
||||
|
||||
ray.get(handle.remote(33))
|
||||
print("Invoke from python will raise exception with traceback:")
|
||||
ray.get(handle.remote())
|
||||
|
||||
@@ -16,8 +16,11 @@ serve.create_endpoint("my_endpoint", "/echo")
|
||||
|
||||
|
||||
# a backend can be a function or class.
|
||||
def echo_v1(request):
|
||||
return request
|
||||
# it can be made to be invoked from web as well as python.
|
||||
def echo_v1(flask_request, response="hello from python!"):
|
||||
if serve.context.web:
|
||||
response = flask_request.url
|
||||
return response
|
||||
|
||||
|
||||
serve.create_backend(echo_v1, "echo:v1")
|
||||
@@ -29,14 +32,14 @@ serve.link("my_endpoint", "echo:v1")
|
||||
print(requests.get("http://127.0.0.1:8000/echo").json())
|
||||
# The service will be reachable from http
|
||||
|
||||
print(ray.get(serve.get_handle("my_endpoint").remote("hello")))
|
||||
print(ray.get(serve.get_handle("my_endpoint").remote(response="hello")))
|
||||
|
||||
# as well as within the ray system.
|
||||
|
||||
|
||||
# We can also add a new backend and split the traffic.
|
||||
def echo_v2(request):
|
||||
# magic
|
||||
def echo_v2(flask_request):
|
||||
# magic, only from web.
|
||||
return "something new"
|
||||
|
||||
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
"""
|
||||
Example rollback action in ray serve. We first deploy only v1, then set a
|
||||
50/50 deployment between v1 and v2, and finally roll back to only v1.
|
||||
"""
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from ray.experimental import serve
|
||||
from ray.experimental.serve.utils import pformat_color_json
|
||||
|
||||
|
||||
def echo_v1(_):
|
||||
return "v1"
|
||||
|
||||
|
||||
def echo_v2(_):
|
||||
return "v2"
|
||||
|
||||
|
||||
serve.init(blocking=True)
|
||||
|
||||
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
|
||||
serve.create_backend(echo_v1, "echo:v1")
|
||||
serve.link("my_endpoint", "echo:v1")
|
||||
|
||||
for _ in range(3):
|
||||
resp = requests.get("http://127.0.0.1:8000/echo").json()
|
||||
print(pformat_color_json(resp))
|
||||
|
||||
print("...Sleeping for 2 seconds...")
|
||||
time.sleep(2)
|
||||
|
||||
serve.create_backend(echo_v2, "echo:v2")
|
||||
serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
|
||||
|
||||
for _ in range(6):
|
||||
resp = requests.get("http://127.0.0.1:8000/echo").json()
|
||||
print(pformat_color_json(resp))
|
||||
|
||||
print("...Sleeping for 2 seconds...")
|
||||
time.sleep(2)
|
||||
|
||||
serve.rollback("my_endpoint")
|
||||
for _ in range(6):
|
||||
resp = requests.get("http://127.0.0.1:8000/echo").json()
|
||||
print(pformat_color_json(resp))
|
||||
|
||||
print("...Sleeping for 2 seconds...")
|
||||
time.sleep(2)
|
||||
@@ -0,0 +1,2 @@
|
||||
class RayServeException(Exception):
|
||||
pass
|
||||
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import requests
|
||||
|
||||
import ray
|
||||
from ray.experimental.serve.kv_store_service import KVStoreProxyActor
|
||||
from ray.experimental.serve.queues import CentralizedQueuesActor
|
||||
@@ -68,12 +70,17 @@ class GlobalState:
|
||||
self.router_actor_handle)
|
||||
|
||||
def wait_until_http_ready(self, num_retries=5, backoff_time_s=1):
|
||||
routing_table_request_count = 0
|
||||
http_is_ready = False
|
||||
retries = num_retries
|
||||
|
||||
while not routing_table_request_count:
|
||||
routing_table_request_count = (ray.get(
|
||||
self.kv_store_actor_handle.get_request_count.remote()))
|
||||
while not http_is_ready:
|
||||
try:
|
||||
resp = requests.get(self.http_address)
|
||||
assert resp.status_code == 200
|
||||
http_is_ready = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.debug((LOG_PREFIX + "Checking if HTTP server is ready."
|
||||
"{} retries left.").format(retries))
|
||||
time.sleep(backoff_time_s)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import ray
|
||||
from ray.experimental import serve
|
||||
from ray.experimental.serve.context import TaskContext
|
||||
from ray.experimental.serve.exceptions import RayServeException
|
||||
|
||||
|
||||
class RayServeHandle:
|
||||
@@ -28,11 +30,17 @@ class RayServeHandle:
|
||||
self.router_handle = router_handle
|
||||
self.endpoint_name = endpoint_name
|
||||
|
||||
def remote(self, *args):
|
||||
# TODO(simon): Support kwargs once #5606 is merged.
|
||||
def remote(self, *args, **kwargs):
|
||||
if len(args) != 0:
|
||||
raise RayServeException(
|
||||
"handle.remote must be invoked with keyword arguments.")
|
||||
|
||||
result_object_id_bytes = ray.get(
|
||||
self.router_handle.enqueue_request.remote(self.endpoint_name,
|
||||
*args))
|
||||
self.router_handle.enqueue_request.remote(
|
||||
service=self.endpoint_name,
|
||||
request_args=(),
|
||||
request_kwargs=kwargs,
|
||||
request_context=TaskContext.Python))
|
||||
return ray.ObjectID(result_object_id_bytes)
|
||||
|
||||
def get_traffic_policy(self):
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
import flask
|
||||
import io
|
||||
|
||||
|
||||
def build_flask_request(asgi_scope_dict, request_body):
|
||||
"""Build and return a flask request from ASGI payload
|
||||
|
||||
This function is indented to be used immediately before task invocation
|
||||
happen.
|
||||
"""
|
||||
wsgi_environ = build_wsgi_environ(asgi_scope_dict, request_body)
|
||||
return flask.Request(wsgi_environ)
|
||||
|
||||
|
||||
def build_wsgi_environ(scope, body):
|
||||
"""
|
||||
Builds a scope and request body into a WSGI environ object.
|
||||
|
||||
This code snippet is taken from https://github.com/django/asgiref/blob
|
||||
/36c3e8dc70bf38fe2db87ac20b514f21aaf5ea9d/asgiref/wsgi.py#L52
|
||||
|
||||
WSGI specification can be found at
|
||||
https://www.python.org/dev/peps/pep-0333/
|
||||
|
||||
This function helps translate ASGI scope and body into a flask request.
|
||||
"""
|
||||
environ = {
|
||||
"REQUEST_METHOD": scope["method"],
|
||||
"SCRIPT_NAME": scope.get("root_path", ""),
|
||||
"PATH_INFO": scope["path"],
|
||||
"QUERY_STRING": scope["query_string"].decode("ascii"),
|
||||
"SERVER_PROTOCOL": "HTTP/{}".format(scope["http_version"]),
|
||||
"wsgi.version": (1, 0),
|
||||
"wsgi.url_scheme": scope.get("scheme", "http"),
|
||||
"wsgi.input": body,
|
||||
"wsgi.errors": io.BytesIO(),
|
||||
"wsgi.multithread": True,
|
||||
"wsgi.multiprocess": True,
|
||||
"wsgi.run_once": False,
|
||||
}
|
||||
|
||||
# Get server name and port - required in WSGI, not in ASGI
|
||||
environ["SERVER_NAME"] = scope["server"][0]
|
||||
environ["SERVER_PORT"] = str(scope["server"][1])
|
||||
environ["REMOTE_ADDR"] = scope["client"][0]
|
||||
|
||||
# Transforms headers into environ entries.
|
||||
for name, value in scope.get("headers", []):
|
||||
# name, values are both bytes, we need to decode them to string
|
||||
name = name.decode("latin1")
|
||||
value = value.decode("latin1")
|
||||
|
||||
# Handle name correction to conform to WSGI spec
|
||||
# https://www.python.org/dev/peps/pep-0333/#environ-variables
|
||||
if name == "content-length":
|
||||
corrected_name = "CONTENT_LENGTH"
|
||||
elif name == "content-type":
|
||||
corrected_name = "CONTENT_TYPE"
|
||||
else:
|
||||
corrected_name = "HTTP_%s" % name.upper().replace("-", "_")
|
||||
|
||||
# If the header value repeated,
|
||||
# we will just concatenate it to the field.
|
||||
if corrected_name in environ:
|
||||
value = environ[corrected_name] + "," + value
|
||||
|
||||
environ[corrected_name] = value
|
||||
return environ
|
||||
@@ -7,8 +7,15 @@ from ray.experimental.serve.utils import get_custom_object_id, logger
|
||||
|
||||
|
||||
class Query:
|
||||
def __init__(self, request_body, result_object_id=None):
|
||||
self.request_body = request_body
|
||||
def __init__(self,
|
||||
request_args,
|
||||
request_kwargs,
|
||||
request_context,
|
||||
result_object_id=None):
|
||||
self.request_args = request_args
|
||||
self.request_kwargs = request_kwargs
|
||||
self.request_context = request_context
|
||||
|
||||
if result_object_id is None:
|
||||
self.result_object_id = get_custom_object_id()
|
||||
else:
|
||||
@@ -34,7 +41,8 @@ class CentralizedQueues:
|
||||
Behavior:
|
||||
>>> # psuedo-code
|
||||
>>> queue = CentralizedQueues()
|
||||
>>> queue.enqueue_request('service-name', data)
|
||||
>>> queue.enqueue_request(
|
||||
"service-name", request_args, request_kwargs, request_context)
|
||||
# nothing happens, request is queued.
|
||||
# returns result ObjectID, which will contains the final result
|
||||
>>> queue.dequeue_request('backend-1')
|
||||
@@ -68,8 +76,9 @@ class CentralizedQueues:
|
||||
# backend_name -> worker queue
|
||||
self.workers = defaultdict(deque)
|
||||
|
||||
def enqueue_request(self, service, request_data):
|
||||
query = Query(request_data)
|
||||
def enqueue_request(self, service, request_args, request_kwargs,
|
||||
request_context):
|
||||
query = Query(request_args, request_kwargs, request_context)
|
||||
self.queues[service].append(query)
|
||||
self.flush()
|
||||
return query.result_object_id.binary()
|
||||
|
||||
@@ -7,6 +7,7 @@ import ray
|
||||
from ray.experimental.async_api import _async_init, as_future
|
||||
from ray.experimental.serve.utils import BytesEncoder
|
||||
from ray.experimental.serve.constants import HTTP_ROUTER_CHECKER_INTERVAL_S
|
||||
from ray.experimental.serve.context import TaskContext
|
||||
|
||||
|
||||
class JSONResponse:
|
||||
@@ -70,8 +71,13 @@ class HTTPProxy:
|
||||
self.router = router_handle
|
||||
self.route_table = dict()
|
||||
|
||||
self.route_checker_should_shutdown = False
|
||||
|
||||
async def route_checker(self, interval):
|
||||
while True:
|
||||
if self.route_checker_should_shutdown:
|
||||
return
|
||||
|
||||
try:
|
||||
self.route_table = await as_future(
|
||||
self.admin_actor.list_service.remote())
|
||||
@@ -80,40 +86,73 @@ class HTTPProxy:
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
async def handle_lifespan_message(self, scope, receive, send):
|
||||
assert scope["type"] == "lifespan"
|
||||
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
await _async_init()
|
||||
asyncio.ensure_future(
|
||||
self.route_checker(interval=HTTP_ROUTER_CHECKER_INTERVAL_S))
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown":
|
||||
self.route_checker_should_shutdown = True
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
||||
async def receive_http_body(self, scope, receive, send):
|
||||
body_buffer = []
|
||||
more_body = True
|
||||
while more_body:
|
||||
message = await receive()
|
||||
assert message["type"] == "http.request"
|
||||
|
||||
more_body = message["more_body"]
|
||||
body_buffer.append(message["body"])
|
||||
|
||||
return b"".join(body_buffer)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
# NOTE: This implements ASGI protocol specified in
|
||||
# https://asgi.readthedocs.io/en/latest/specs/index.html
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
await _async_init()
|
||||
asyncio.ensure_future(
|
||||
self.route_checker(interval=HTTP_ROUTER_CHECKER_INTERVAL_S))
|
||||
await self.handle_lifespan_message(scope, receive, send)
|
||||
return
|
||||
|
||||
assert scope["type"] == "http"
|
||||
current_path = scope["path"]
|
||||
if current_path == "/":
|
||||
await JSONResponse(self.route_table)(scope, receive, send)
|
||||
elif current_path in self.route_table:
|
||||
endpoint_name = self.route_table[current_path]
|
||||
result_object_id_bytes = await as_future(
|
||||
self.router.enqueue_request.remote(endpoint_name, scope))
|
||||
result = await as_future(ray.ObjectID(result_object_id_bytes))
|
||||
return
|
||||
|
||||
if isinstance(result, ray.exceptions.RayTaskError):
|
||||
await JSONResponse({
|
||||
"error": "internal error, please use python API to debug"
|
||||
})(scope, receive, send)
|
||||
else:
|
||||
await JSONResponse({"result": result})(scope, receive, send)
|
||||
else:
|
||||
if current_path not in self.route_table:
|
||||
error_message = ("Path {} not found. "
|
||||
"Please ping http://.../ for routing table"
|
||||
).format(current_path)
|
||||
|
||||
await JSONResponse(
|
||||
{
|
||||
"error": error_message
|
||||
}, status_code=404)(scope, receive, send)
|
||||
return
|
||||
|
||||
endpoint_name = self.route_table[current_path]
|
||||
http_body_bytes = await self.receive_http_body(scope, receive, send)
|
||||
|
||||
result_object_id_bytes = await as_future(
|
||||
self.router.enqueue_request.remote(
|
||||
service=endpoint_name,
|
||||
request_args=(scope, http_body_bytes),
|
||||
request_kwargs=dict(),
|
||||
request_context=TaskContext.Web))
|
||||
|
||||
result = await as_future(ray.ObjectID(result_object_id_bytes))
|
||||
|
||||
if isinstance(result, ray.exceptions.RayTaskError):
|
||||
await JSONResponse({
|
||||
"error": "internal error, please use python API to debug"
|
||||
})(scope, receive, send)
|
||||
else:
|
||||
await JSONResponse({"result": result})(scope, receive, send)
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -122,4 +161,5 @@ class HTTPActor:
|
||||
self.app = HTTPProxy(kv_store_actor_handle, router_handle)
|
||||
|
||||
def run(self, host="0.0.0.0", port=8000):
|
||||
uvicorn.run(self.app, host=host, port=port, lifespan="on")
|
||||
uvicorn.run(
|
||||
self.app, host=host, port=port, lifespan="on", access_log=False)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import traceback
|
||||
|
||||
import ray
|
||||
from ray.experimental.serve import context as serve_context
|
||||
from ray.experimental.serve.context import TaskContext, FakeFlaskQuest
|
||||
from ray.experimental.serve.http_util import build_flask_request
|
||||
|
||||
|
||||
class TaskRunner:
|
||||
@@ -13,18 +15,18 @@ class TaskRunner:
|
||||
def __init__(self, func_to_run):
|
||||
self.func = func_to_run
|
||||
|
||||
def __call__(self, *args):
|
||||
return self.func(*args)
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def wrap_to_ray_error(callable_obj, *args):
|
||||
def wrap_to_ray_error(exception):
|
||||
"""Utility method that catch and seal exceptions in execution"""
|
||||
try:
|
||||
return callable_obj(*args)
|
||||
# Raise and catch so we can access traceback.format_exc()
|
||||
raise exception
|
||||
except Exception as e:
|
||||
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
||||
return ray.exceptions.RayTaskError(
|
||||
str(callable_obj), traceback_str, e.__class__)
|
||||
return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__)
|
||||
|
||||
|
||||
class RayServeMixin:
|
||||
@@ -62,12 +64,28 @@ class RayServeMixin:
|
||||
self._ray_serve_dequeue_requestr_name))
|
||||
work_item = ray.get(ray.ObjectID(work_token))
|
||||
|
||||
# TODO(simon):
|
||||
# __call__ should be able to take multiple *args and **kwargs.
|
||||
result = wrap_to_ray_error(self.__call__, work_item.request_body)
|
||||
result_object_id = work_item.result_object_id
|
||||
ray.worker.global_worker.put_object(result_object_id, result)
|
||||
if work_item.request_context == TaskContext.Web:
|
||||
serve_context.web = True
|
||||
asgi_scope, body_bytes = work_item.request_args
|
||||
flask_request = build_flask_request(asgi_scope, body_bytes)
|
||||
args = (flask_request, )
|
||||
kwargs = {}
|
||||
else:
|
||||
serve_context.web = False
|
||||
args = (FakeFlaskQuest(), )
|
||||
kwargs = work_item.request_kwargs
|
||||
|
||||
result_object_id = work_item.result_object_id
|
||||
|
||||
try:
|
||||
result = self.__call__(*args, **kwargs)
|
||||
ray.worker.global_worker.put_object(result_object_id, result)
|
||||
except Exception as e:
|
||||
wrapped_exception = wrap_to_ray_error(e)
|
||||
ray.worker.global_worker.put_object(result_object_id,
|
||||
wrapped_exception)
|
||||
|
||||
serve_context.web = False
|
||||
# The worker finished one unit of work.
|
||||
# It will now tail recursively schedule the main_loop again.
|
||||
|
||||
|
||||
@@ -6,8 +6,7 @@ from ray.experimental import serve
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def serve_instance():
|
||||
serve.init()
|
||||
serve.global_state.wait_until_http_ready()
|
||||
serve.init(blocking=True)
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@@ -1,33 +1,34 @@
|
||||
import time
|
||||
|
||||
import requests
|
||||
from flaky import flaky
|
||||
|
||||
import ray
|
||||
from ray.experimental import serve
|
||||
|
||||
|
||||
def delay_rerun(*_):
|
||||
time.sleep(1)
|
||||
return True
|
||||
|
||||
|
||||
# flaky test because the routing table might not be populated
|
||||
@flaky(rerun_filter=delay_rerun)
|
||||
def test_e2e(serve_instance):
|
||||
serve.create_endpoint("endpoint", "/api")
|
||||
result = ray.get(
|
||||
serve.global_state.kv_store_actor_handle.list_service.remote())
|
||||
assert result == {"/api": "endpoint"}
|
||||
|
||||
assert requests.get("http://127.0.0.1:8000/").json() == result
|
||||
retry_count = 3
|
||||
while True:
|
||||
try:
|
||||
resp = requests.get("http://127.0.0.1:8000/").json()
|
||||
assert resp == result
|
||||
break
|
||||
except Exception:
|
||||
time.sleep(0.5)
|
||||
retry_count -= 1
|
||||
if retry_count == 0:
|
||||
assert False, "Route table hasn't been updated after 3 tries."
|
||||
|
||||
def echo(i):
|
||||
return i
|
||||
def function(flask_request):
|
||||
return "OK"
|
||||
|
||||
serve.create_backend(echo, "echo:v1")
|
||||
serve.create_backend(function, "echo:v1")
|
||||
serve.link("endpoint", "echo:v1")
|
||||
|
||||
resp = requests.get("http://127.0.0.1:8000/api").json()["result"]
|
||||
assert resp["path"] == "/api"
|
||||
assert resp["method"] == "GET"
|
||||
assert resp == "OK"
|
||||
|
||||
@@ -6,10 +6,11 @@ def test_single_prod_cons_queue(serve_instance):
|
||||
q = CentralizedQueues()
|
||||
q.link("svc", "backend")
|
||||
|
||||
result_object_id = q.enqueue_request("svc", 1)
|
||||
result_object_id = q.enqueue_request("svc", 1, "kwargs", None)
|
||||
work_object_id = q.dequeue_request("backend")
|
||||
got_work = ray.get(ray.ObjectID(work_object_id))
|
||||
assert got_work.request_body == 1
|
||||
assert got_work.request_args == 1
|
||||
assert got_work.request_kwargs == "kwargs"
|
||||
|
||||
ray.worker.global_worker.put_object(got_work.result_object_id, 2)
|
||||
assert ray.get(ray.ObjectID(result_object_id)) == 2
|
||||
@@ -18,19 +19,19 @@ def test_single_prod_cons_queue(serve_instance):
|
||||
def test_alter_backend(serve_instance):
|
||||
q = CentralizedQueues()
|
||||
|
||||
result_object_id = q.enqueue_request("svc", 1)
|
||||
result_object_id = q.enqueue_request("svc", 1, "kwargs", None)
|
||||
work_object_id = q.dequeue_request("backend-1")
|
||||
q.set_traffic("svc", {"backend-1": 1})
|
||||
got_work = ray.get(ray.ObjectID(work_object_id))
|
||||
assert got_work.request_body == 1
|
||||
assert got_work.request_args == 1
|
||||
ray.worker.global_worker.put_object(got_work.result_object_id, 2)
|
||||
assert ray.get(ray.ObjectID(result_object_id)) == 2
|
||||
|
||||
result_object_id = q.enqueue_request("svc", 1)
|
||||
result_object_id = q.enqueue_request("svc", 1, "kwargs", None)
|
||||
work_object_id = q.dequeue_request("backend-2")
|
||||
q.set_traffic("svc", {"backend-2": 1})
|
||||
got_work = ray.get(ray.ObjectID(work_object_id))
|
||||
assert got_work.request_body == 1
|
||||
assert got_work.request_args == 1
|
||||
ray.worker.global_worker.put_object(got_work.result_object_id, 2)
|
||||
assert ray.get(ray.ObjectID(result_object_id)) == 2
|
||||
|
||||
@@ -38,8 +39,8 @@ def test_alter_backend(serve_instance):
|
||||
def test_split_traffic(serve_instance):
|
||||
q = CentralizedQueues()
|
||||
|
||||
q.enqueue_request("svc", 1)
|
||||
q.enqueue_request("svc", 1)
|
||||
q.enqueue_request("svc", 1, "kwargs", None)
|
||||
q.enqueue_request("svc", 1, "kwargs", None)
|
||||
q.set_traffic("svc", {})
|
||||
work_object_id_1 = q.dequeue_request("backend-1")
|
||||
work_object_id_2 = q.dequeue_request("backend-2")
|
||||
@@ -48,13 +49,13 @@ def test_split_traffic(serve_instance):
|
||||
got_work = ray.get(
|
||||
[ray.ObjectID(work_object_id_1),
|
||||
ray.ObjectID(work_object_id_2)])
|
||||
assert [g.request_body for g in got_work] == [1, 1]
|
||||
assert [g.request_args for g in got_work] == [1, 1]
|
||||
|
||||
|
||||
def test_probabilities(serve_instance):
|
||||
q = CentralizedQueues()
|
||||
|
||||
[q.enqueue_request("svc", 1) for i in range(100)]
|
||||
[q.enqueue_request("svc", 1, "kwargs", None) for i in range(100)]
|
||||
|
||||
work_object_id_1_s = [
|
||||
ray.ObjectID(q.dequeue_request("backend-1")) for i in range(100)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import ray
|
||||
from ray.experimental.serve.queues import CentralizedQueuesActor
|
||||
from ray.experimental.serve.task_runner import (
|
||||
@@ -6,6 +7,7 @@ from ray.experimental.serve.task_runner import (
|
||||
TaskRunnerActor,
|
||||
wrap_to_ray_error,
|
||||
)
|
||||
import ray.experimental.serve.context as context
|
||||
|
||||
|
||||
def test_runner_basic():
|
||||
@@ -17,21 +19,14 @@ def test_runner_basic():
|
||||
|
||||
|
||||
def test_runner_wraps_error():
|
||||
def echo(i):
|
||||
return i
|
||||
|
||||
assert wrap_to_ray_error(echo, 2) == 2
|
||||
|
||||
def error(_):
|
||||
return 1 / 0
|
||||
|
||||
assert isinstance(wrap_to_ray_error(error, 1), ray.exceptions.RayTaskError)
|
||||
wrapped = wrap_to_ray_error(Exception())
|
||||
assert isinstance(wrapped, ray.exceptions.RayTaskError)
|
||||
|
||||
|
||||
def test_runner_actor(serve_instance):
|
||||
q = CentralizedQueuesActor.remote()
|
||||
|
||||
def echo(i):
|
||||
def echo(flask_request, i=None):
|
||||
return i
|
||||
|
||||
CONSUMER_NAME = "runner"
|
||||
@@ -46,7 +41,12 @@ def test_runner_actor(serve_instance):
|
||||
|
||||
for query in [333, 444, 555]:
|
||||
result_token = ray.ObjectID(
|
||||
ray.get(q.enqueue_request.remote(PRODUCER_NAME, query)))
|
||||
ray.get(
|
||||
q.enqueue_request.remote(
|
||||
PRODUCER_NAME,
|
||||
request_args=None,
|
||||
request_kwargs={"i": query},
|
||||
request_context=context.TaskContext.Python)))
|
||||
assert ray.get(result_token) == query
|
||||
|
||||
|
||||
@@ -60,8 +60,8 @@ def test_ray_serve_mixin(serve_instance):
|
||||
def __init__(self, inc):
|
||||
self.increment = inc
|
||||
|
||||
def __call__(self, context):
|
||||
return context + self.increment
|
||||
def __call__(self, flask_request, i=None):
|
||||
return i + self.increment
|
||||
|
||||
@ray.remote
|
||||
class CustomActor(MyAdder, RayServeMixin):
|
||||
@@ -76,5 +76,38 @@ def test_ray_serve_mixin(serve_instance):
|
||||
|
||||
for query in [333, 444, 555]:
|
||||
result_token = ray.ObjectID(
|
||||
ray.get(q.enqueue_request.remote(PRODUCER_NAME, query)))
|
||||
ray.get(
|
||||
q.enqueue_request.remote(
|
||||
PRODUCER_NAME,
|
||||
request_args=None,
|
||||
request_kwargs={"i": query},
|
||||
request_context=context.TaskContext.Python)))
|
||||
assert ray.get(result_token) == query + 3
|
||||
|
||||
|
||||
def test_task_runner_check_context(serve_instance):
|
||||
q = CentralizedQueuesActor.remote()
|
||||
|
||||
def echo(flask_request, i=None):
|
||||
# Accessing the flask_request without web context should throw.
|
||||
return flask_request.args["i"]
|
||||
|
||||
CONSUMER_NAME = "runner"
|
||||
PRODUCER_NAME = "producer"
|
||||
|
||||
runner = TaskRunnerActor.remote(echo)
|
||||
|
||||
runner._ray_serve_setup.remote(CONSUMER_NAME, q)
|
||||
runner._ray_serve_main_loop.remote(runner)
|
||||
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
result_token = ray.ObjectID(
|
||||
ray.get(
|
||||
q.enqueue_request.remote(
|
||||
PRODUCER_NAME,
|
||||
request_args=None,
|
||||
request_kwargs={"i": 42},
|
||||
request_context=context.TaskContext.Python)))
|
||||
|
||||
with pytest.raises(ray.exceptions.RayTaskError):
|
||||
ray.get(result_token)
|
||||
|
||||
+1
-1
@@ -77,7 +77,7 @@ extras = {
|
||||
],
|
||||
"debug": ["psutil", "setproctitle", "py-spy >= 0.2.0"],
|
||||
"dashboard": ["aiohttp", "psutil", "setproctitle"],
|
||||
"serve": ["uvicorn", "pygments", "werkzeug"],
|
||||
"serve": ["uvicorn", "pygments", "werkzeug", "flask"],
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user