diff --git a/.travis.yml b/.travis.yml index d4f6d339f..d29a78b3d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -174,7 +174,7 @@ script: # ray serve tests - if [ $RAY_CI_SERVE_AFFECTED == "1" ]; then python -c 'import sys;exit(sys.version_info>=(3,5))' || python -m pytest -v --durations=5 --timeout=300 python/ray/experimental/serve/tests; fi - - if [ $RAY_CI_SERVE_AFFECTED == "1" ]; then python -c 'import sys;exit(sys.version_info>=(3,5))' || ./ci/suppress_output python python/ray/experimental/serve/example/echo_full.py; fi + - if [ $RAY_CI_SERVE_AFFECTED == "1" ]; then python -c 'import sys;exit(sys.version_info>=(3,5))' || ./ci/suppress_output python python/ray/experimental/serve/examples/echo_full.py; fi # ray tests # Python3.5+ only. Otherwise we will get `SyntaxError` regardless of how we set the tester. diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index da866ef6d..ed29836d1 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -35,7 +35,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then export PATH="$HOME/miniconda/bin:$PATH" pip install -q scipy tensorflow cython==0.29.0 gym opencv-python-headless pyyaml pandas==0.24.2 requests \ feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate psutil aiohttp \ - uvicorn dataclasses pygments werkzeug kubernetes + uvicorn dataclasses pygments werkzeug kubernetes flask elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then # Install miniconda. wget https://repo.continuum.io/miniconda/Miniconda2-4.5.4-MacOSX-x86_64.sh -O miniconda.sh -nv @@ -50,7 +50,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then export PATH="$HOME/miniconda/bin:$PATH" pip install -q cython==0.29.0 tensorflow gym opencv-python-headless pyyaml pandas==0.24.2 requests \ feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate psutil aiohttp \ - uvicorn dataclasses pygments werkzeug kubernetes + uvicorn dataclasses pygments werkzeug kubernetes flask elif [[ "$LINT" == "1" ]]; then sudo apt-get update sudo apt-get install -y build-essential curl unzip diff --git a/doc/requirements-doc.txt b/doc/requirements-doc.txt index 9f44fcfcd..8a7c91eef 100644 --- a/doc/requirements-doc.txt +++ b/doc/requirements-doc.txt @@ -18,3 +18,7 @@ sphinx-gallery sphinx-jsonschema sphinx_rtd_theme pandas +flask +uvicorn +pygments +werkzeug diff --git a/python/ray/experimental/serve/context.py b/python/ray/experimental/serve/context.py new file mode 100644 index 000000000..43e1fd710 --- /dev/null +++ b/python/ray/experimental/serve/context.py @@ -0,0 +1,28 @@ +from enum import IntEnum +from ray.experimental.serve.exceptions import RayServeException + + +class TaskContext(IntEnum): + """TaskContext constants for queue.enqueue method""" + Web = 1 + Python = 2 + + +# Global variable will be modified in worker +# web == True: currrently processing a request from web server +# web == False: currently processing a request from python +web = False + +_not_in_web_context_error = """ +Accessing the request object outside of the web context. Please use +"serve.context.web" to determine when the function is called within +a web context. +""" + + +class FakeFlaskQuest: + def __getattribute__(self, name): + raise RayServeException(_not_in_web_context_error) + + def __setattr__(self, name, value): + raise RayServeException(_not_in_web_context_error) diff --git a/python/ray/experimental/serve/examples/echo.py b/python/ray/experimental/serve/examples/echo.py index f6cf5afbe..e124367ea 100644 --- a/python/ray/experimental/serve/examples/echo.py +++ b/python/ray/experimental/serve/examples/echo.py @@ -10,8 +10,8 @@ from ray.experimental import serve from ray.experimental.serve.utils import pformat_color_json -def echo(context): - return context +def echo(flask_request): + return "hello " + flask_request.args.get("name", "serve!") serve.init(blocking=True) diff --git a/python/ray/experimental/serve/examples/echo_actor.py b/python/ray/experimental/serve/examples/echo_actor.py index 98422679d..1e279ceb4 100644 --- a/python/ray/experimental/serve/examples/echo_actor.py +++ b/python/ray/experimental/serve/examples/echo_actor.py @@ -1,41 +1,47 @@ """ -Example actor that adds message to the end of query_string. +Example actor that adds an increment to a number. This number can +come from either web (parsing Flask request) or python call. + +This actor can be called from HTTP as well as from Python. """ import time import requests -from werkzeug import urls +import ray from ray.experimental import serve from ray.experimental.serve.utils import pformat_color_json -class EchoActor: - def __init__(self, message): - self.message = message +class MagicCounter: + def __init__(self, increment): + self.increment = increment - def __call__(self, context): - query_string_dict = urls.url_decode(context["query_string"]) - message = "" - message += query_string_dict.get("message", "") - message += " " - message += self.message - return message + def __call__(self, flask_request, base_number=None): + if serve.context.web: + base_number = int(flask_request.args.get("base_number", "0")) + + return base_number + self.increment serve.init(blocking=True) +serve.create_endpoint("magic_counter", "/counter", blocking=True) +serve.create_backend(MagicCounter, "counter:v1", 42) # increment=42 +serve.link("magic_counter", "counter:v1") -serve.create_endpoint("my_endpoint", "/echo", blocking=True) -serve.create_backend(EchoActor, "echo:v1", "world") -serve.link("my_endpoint", "echo:v1") - -while True: - resp = requests.get("http://127.0.0.1:8000/echo?message=hello").json() +print("Sending ten queries via HTTP") +for i in range(10): + url = "http://127.0.0.1:8000/counter?base_number={}".format(i) + print("> Pinging {}".format(url)) + resp = requests.get(url).json() print(pformat_color_json(resp)) - resp = requests.get("http://127.0.0.1:8000/echo").json() - print(pformat_color_json(resp)) + time.sleep(0.2) - print("...Sleeping for 2 seconds...") - time.sleep(2) +print("Sending ten queries via Python") +handle = serve.get_handle("magic_counter") +for i in range(10): + print("> Pinging handle.remote(base_number={})".format(i)) + result = ray.get(handle.remote(base_number=i)) + print("< Result {}".format(result)) diff --git a/python/ray/experimental/serve/examples/echo_error.py b/python/ray/experimental/serve/examples/echo_error.py index 25d900068..a6ed56b99 100644 --- a/python/ray/experimental/serve/examples/echo_error.py +++ b/python/ray/experimental/serve/examples/echo_error.py @@ -7,7 +7,7 @@ We are going to define a buggy function that raise some exception: The expected behavior is: - HTTP server should respond with "internal error" in the response JSON -- ray.get(handle.remote(33)) should raise RayTaskError with traceback. +- ray.get(handle.remote()) should raise RayTaskError with traceback. This shows that error is hidden from HTTP side but always visible when calling from Python. @@ -40,5 +40,5 @@ for _ in range(2): time.sleep(2) handle = serve.get_handle("my_endpoint") - -ray.get(handle.remote(33)) +print("Invoke from python will raise exception with traceback:") +ray.get(handle.remote()) diff --git a/python/ray/experimental/serve/examples/echo_full.py b/python/ray/experimental/serve/examples/echo_full.py index 86801d4f1..2edfff4c2 100644 --- a/python/ray/experimental/serve/examples/echo_full.py +++ b/python/ray/experimental/serve/examples/echo_full.py @@ -16,8 +16,11 @@ serve.create_endpoint("my_endpoint", "/echo") # a backend can be a function or class. -def echo_v1(request): - return request +# it can be made to be invoked from web as well as python. +def echo_v1(flask_request, response="hello from python!"): + if serve.context.web: + response = flask_request.url + return response serve.create_backend(echo_v1, "echo:v1") @@ -29,14 +32,14 @@ serve.link("my_endpoint", "echo:v1") print(requests.get("http://127.0.0.1:8000/echo").json()) # The service will be reachable from http -print(ray.get(serve.get_handle("my_endpoint").remote("hello"))) +print(ray.get(serve.get_handle("my_endpoint").remote(response="hello"))) # as well as within the ray system. # We can also add a new backend and split the traffic. -def echo_v2(request): - # magic +def echo_v2(flask_request): + # magic, only from web. return "something new" diff --git a/python/ray/experimental/serve/examples/echo_rollback.py b/python/ray/experimental/serve/examples/echo_rollback.py deleted file mode 100644 index bcdf7e14e..000000000 --- a/python/ray/experimental/serve/examples/echo_rollback.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -Example rollback action in ray serve. We first deploy only v1, then set a - 50/50 deployment between v1 and v2, and finally roll back to only v1. -""" -import time - -import requests - -from ray.experimental import serve -from ray.experimental.serve.utils import pformat_color_json - - -def echo_v1(_): - return "v1" - - -def echo_v2(_): - return "v2" - - -serve.init(blocking=True) - -serve.create_endpoint("my_endpoint", "/echo", blocking=True) -serve.create_backend(echo_v1, "echo:v1") -serve.link("my_endpoint", "echo:v1") - -for _ in range(3): - resp = requests.get("http://127.0.0.1:8000/echo").json() - print(pformat_color_json(resp)) - - print("...Sleeping for 2 seconds...") - time.sleep(2) - -serve.create_backend(echo_v2, "echo:v2") -serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5}) - -for _ in range(6): - resp = requests.get("http://127.0.0.1:8000/echo").json() - print(pformat_color_json(resp)) - - print("...Sleeping for 2 seconds...") - time.sleep(2) - -serve.rollback("my_endpoint") -for _ in range(6): - resp = requests.get("http://127.0.0.1:8000/echo").json() - print(pformat_color_json(resp)) - - print("...Sleeping for 2 seconds...") - time.sleep(2) diff --git a/python/ray/experimental/serve/exceptions.py b/python/ray/experimental/serve/exceptions.py new file mode 100644 index 000000000..7e1f957cf --- /dev/null +++ b/python/ray/experimental/serve/exceptions.py @@ -0,0 +1,2 @@ +class RayServeException(Exception): + pass diff --git a/python/ray/experimental/serve/global_state.py b/python/ray/experimental/serve/global_state.py index aea51d526..e817ca185 100644 --- a/python/ray/experimental/serve/global_state.py +++ b/python/ray/experimental/serve/global_state.py @@ -1,6 +1,8 @@ import time from collections import defaultdict, deque +import requests + import ray from ray.experimental.serve.kv_store_service import KVStoreProxyActor from ray.experimental.serve.queues import CentralizedQueuesActor @@ -68,12 +70,17 @@ class GlobalState: self.router_actor_handle) def wait_until_http_ready(self, num_retries=5, backoff_time_s=1): - routing_table_request_count = 0 + http_is_ready = False retries = num_retries - while not routing_table_request_count: - routing_table_request_count = (ray.get( - self.kv_store_actor_handle.get_request_count.remote())) + while not http_is_ready: + try: + resp = requests.get(self.http_address) + assert resp.status_code == 200 + http_is_ready = True + except Exception: + pass + logger.debug((LOG_PREFIX + "Checking if HTTP server is ready." "{} retries left.").format(retries)) time.sleep(backoff_time_s) diff --git a/python/ray/experimental/serve/handle.py b/python/ray/experimental/serve/handle.py index 56886f25c..4ce6a30e8 100644 --- a/python/ray/experimental/serve/handle.py +++ b/python/ray/experimental/serve/handle.py @@ -1,5 +1,7 @@ import ray from ray.experimental import serve +from ray.experimental.serve.context import TaskContext +from ray.experimental.serve.exceptions import RayServeException class RayServeHandle: @@ -28,11 +30,17 @@ class RayServeHandle: self.router_handle = router_handle self.endpoint_name = endpoint_name - def remote(self, *args): - # TODO(simon): Support kwargs once #5606 is merged. + def remote(self, *args, **kwargs): + if len(args) != 0: + raise RayServeException( + "handle.remote must be invoked with keyword arguments.") + result_object_id_bytes = ray.get( - self.router_handle.enqueue_request.remote(self.endpoint_name, - *args)) + self.router_handle.enqueue_request.remote( + service=self.endpoint_name, + request_args=(), + request_kwargs=kwargs, + request_context=TaskContext.Python)) return ray.ObjectID(result_object_id_bytes) def get_traffic_policy(self): diff --git a/python/ray/experimental/serve/http_util.py b/python/ray/experimental/serve/http_util.py new file mode 100644 index 000000000..667aabd3c --- /dev/null +++ b/python/ray/experimental/serve/http_util.py @@ -0,0 +1,68 @@ +import flask +import io + + +def build_flask_request(asgi_scope_dict, request_body): + """Build and return a flask request from ASGI payload + + This function is indented to be used immediately before task invocation + happen. + """ + wsgi_environ = build_wsgi_environ(asgi_scope_dict, request_body) + return flask.Request(wsgi_environ) + + +def build_wsgi_environ(scope, body): + """ + Builds a scope and request body into a WSGI environ object. + + This code snippet is taken from https://github.com/django/asgiref/blob + /36c3e8dc70bf38fe2db87ac20b514f21aaf5ea9d/asgiref/wsgi.py#L52 + + WSGI specification can be found at + https://www.python.org/dev/peps/pep-0333/ + + This function helps translate ASGI scope and body into a flask request. + """ + environ = { + "REQUEST_METHOD": scope["method"], + "SCRIPT_NAME": scope.get("root_path", ""), + "PATH_INFO": scope["path"], + "QUERY_STRING": scope["query_string"].decode("ascii"), + "SERVER_PROTOCOL": "HTTP/{}".format(scope["http_version"]), + "wsgi.version": (1, 0), + "wsgi.url_scheme": scope.get("scheme", "http"), + "wsgi.input": body, + "wsgi.errors": io.BytesIO(), + "wsgi.multithread": True, + "wsgi.multiprocess": True, + "wsgi.run_once": False, + } + + # Get server name and port - required in WSGI, not in ASGI + environ["SERVER_NAME"] = scope["server"][0] + environ["SERVER_PORT"] = str(scope["server"][1]) + environ["REMOTE_ADDR"] = scope["client"][0] + + # Transforms headers into environ entries. + for name, value in scope.get("headers", []): + # name, values are both bytes, we need to decode them to string + name = name.decode("latin1") + value = value.decode("latin1") + + # Handle name correction to conform to WSGI spec + # https://www.python.org/dev/peps/pep-0333/#environ-variables + if name == "content-length": + corrected_name = "CONTENT_LENGTH" + elif name == "content-type": + corrected_name = "CONTENT_TYPE" + else: + corrected_name = "HTTP_%s" % name.upper().replace("-", "_") + + # If the header value repeated, + # we will just concatenate it to the field. + if corrected_name in environ: + value = environ[corrected_name] + "," + value + + environ[corrected_name] = value + return environ diff --git a/python/ray/experimental/serve/queues.py b/python/ray/experimental/serve/queues.py index f5718a39f..9b48ae2b2 100644 --- a/python/ray/experimental/serve/queues.py +++ b/python/ray/experimental/serve/queues.py @@ -7,8 +7,15 @@ from ray.experimental.serve.utils import get_custom_object_id, logger class Query: - def __init__(self, request_body, result_object_id=None): - self.request_body = request_body + def __init__(self, + request_args, + request_kwargs, + request_context, + result_object_id=None): + self.request_args = request_args + self.request_kwargs = request_kwargs + self.request_context = request_context + if result_object_id is None: self.result_object_id = get_custom_object_id() else: @@ -34,7 +41,8 @@ class CentralizedQueues: Behavior: >>> # psuedo-code >>> queue = CentralizedQueues() - >>> queue.enqueue_request('service-name', data) + >>> queue.enqueue_request( + "service-name", request_args, request_kwargs, request_context) # nothing happens, request is queued. # returns result ObjectID, which will contains the final result >>> queue.dequeue_request('backend-1') @@ -68,8 +76,9 @@ class CentralizedQueues: # backend_name -> worker queue self.workers = defaultdict(deque) - def enqueue_request(self, service, request_data): - query = Query(request_data) + def enqueue_request(self, service, request_args, request_kwargs, + request_context): + query = Query(request_args, request_kwargs, request_context) self.queues[service].append(query) self.flush() return query.result_object_id.binary() diff --git a/python/ray/experimental/serve/server.py b/python/ray/experimental/serve/server.py index af70a782b..f1ec15587 100644 --- a/python/ray/experimental/serve/server.py +++ b/python/ray/experimental/serve/server.py @@ -7,6 +7,7 @@ import ray from ray.experimental.async_api import _async_init, as_future from ray.experimental.serve.utils import BytesEncoder from ray.experimental.serve.constants import HTTP_ROUTER_CHECKER_INTERVAL_S +from ray.experimental.serve.context import TaskContext class JSONResponse: @@ -70,8 +71,13 @@ class HTTPProxy: self.router = router_handle self.route_table = dict() + self.route_checker_should_shutdown = False + async def route_checker(self, interval): while True: + if self.route_checker_should_shutdown: + return + try: self.route_table = await as_future( self.admin_actor.list_service.remote()) @@ -80,40 +86,73 @@ class HTTPProxy: await asyncio.sleep(interval) + async def handle_lifespan_message(self, scope, receive, send): + assert scope["type"] == "lifespan" + + message = await receive() + if message["type"] == "lifespan.startup": + await _async_init() + asyncio.ensure_future( + self.route_checker(interval=HTTP_ROUTER_CHECKER_INTERVAL_S)) + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + self.route_checker_should_shutdown = True + await send({"type": "lifespan.shutdown.complete"}) + + async def receive_http_body(self, scope, receive, send): + body_buffer = [] + more_body = True + while more_body: + message = await receive() + assert message["type"] == "http.request" + + more_body = message["more_body"] + body_buffer.append(message["body"]) + + return b"".join(body_buffer) + async def __call__(self, scope, receive, send): # NOTE: This implements ASGI protocol specified in # https://asgi.readthedocs.io/en/latest/specs/index.html if scope["type"] == "lifespan": - await _async_init() - asyncio.ensure_future( - self.route_checker(interval=HTTP_ROUTER_CHECKER_INTERVAL_S)) + await self.handle_lifespan_message(scope, receive, send) return + assert scope["type"] == "http" current_path = scope["path"] if current_path == "/": await JSONResponse(self.route_table)(scope, receive, send) - elif current_path in self.route_table: - endpoint_name = self.route_table[current_path] - result_object_id_bytes = await as_future( - self.router.enqueue_request.remote(endpoint_name, scope)) - result = await as_future(ray.ObjectID(result_object_id_bytes)) + return - if isinstance(result, ray.exceptions.RayTaskError): - await JSONResponse({ - "error": "internal error, please use python API to debug" - })(scope, receive, send) - else: - await JSONResponse({"result": result})(scope, receive, send) - else: + if current_path not in self.route_table: error_message = ("Path {} not found. " "Please ping http://.../ for routing table" ).format(current_path) - await JSONResponse( { "error": error_message }, status_code=404)(scope, receive, send) + return + + endpoint_name = self.route_table[current_path] + http_body_bytes = await self.receive_http_body(scope, receive, send) + + result_object_id_bytes = await as_future( + self.router.enqueue_request.remote( + service=endpoint_name, + request_args=(scope, http_body_bytes), + request_kwargs=dict(), + request_context=TaskContext.Web)) + + result = await as_future(ray.ObjectID(result_object_id_bytes)) + + if isinstance(result, ray.exceptions.RayTaskError): + await JSONResponse({ + "error": "internal error, please use python API to debug" + })(scope, receive, send) + else: + await JSONResponse({"result": result})(scope, receive, send) @ray.remote @@ -122,4 +161,5 @@ class HTTPActor: self.app = HTTPProxy(kv_store_actor_handle, router_handle) def run(self, host="0.0.0.0", port=8000): - uvicorn.run(self.app, host=host, port=port, lifespan="on") + uvicorn.run( + self.app, host=host, port=port, lifespan="on", access_log=False) diff --git a/python/ray/experimental/serve/task_runner.py b/python/ray/experimental/serve/task_runner.py index ff42cf67c..4aecaa1b4 100644 --- a/python/ray/experimental/serve/task_runner.py +++ b/python/ray/experimental/serve/task_runner.py @@ -1,6 +1,8 @@ import traceback - import ray +from ray.experimental.serve import context as serve_context +from ray.experimental.serve.context import TaskContext, FakeFlaskQuest +from ray.experimental.serve.http_util import build_flask_request class TaskRunner: @@ -13,18 +15,18 @@ class TaskRunner: def __init__(self, func_to_run): self.func = func_to_run - def __call__(self, *args): - return self.func(*args) + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) -def wrap_to_ray_error(callable_obj, *args): +def wrap_to_ray_error(exception): """Utility method that catch and seal exceptions in execution""" try: - return callable_obj(*args) + # Raise and catch so we can access traceback.format_exc() + raise exception except Exception as e: traceback_str = ray.utils.format_error_message(traceback.format_exc()) - return ray.exceptions.RayTaskError( - str(callable_obj), traceback_str, e.__class__) + return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__) class RayServeMixin: @@ -62,12 +64,28 @@ class RayServeMixin: self._ray_serve_dequeue_requestr_name)) work_item = ray.get(ray.ObjectID(work_token)) - # TODO(simon): - # __call__ should be able to take multiple *args and **kwargs. - result = wrap_to_ray_error(self.__call__, work_item.request_body) - result_object_id = work_item.result_object_id - ray.worker.global_worker.put_object(result_object_id, result) + if work_item.request_context == TaskContext.Web: + serve_context.web = True + asgi_scope, body_bytes = work_item.request_args + flask_request = build_flask_request(asgi_scope, body_bytes) + args = (flask_request, ) + kwargs = {} + else: + serve_context.web = False + args = (FakeFlaskQuest(), ) + kwargs = work_item.request_kwargs + result_object_id = work_item.result_object_id + + try: + result = self.__call__(*args, **kwargs) + ray.worker.global_worker.put_object(result_object_id, result) + except Exception as e: + wrapped_exception = wrap_to_ray_error(e) + ray.worker.global_worker.put_object(result_object_id, + wrapped_exception) + + serve_context.web = False # The worker finished one unit of work. # It will now tail recursively schedule the main_loop again. diff --git a/python/ray/experimental/serve/tests/conftest.py b/python/ray/experimental/serve/tests/conftest.py index 9f784a18e..688d7e976 100644 --- a/python/ray/experimental/serve/tests/conftest.py +++ b/python/ray/experimental/serve/tests/conftest.py @@ -6,8 +6,7 @@ from ray.experimental import serve @pytest.fixture(scope="session") def serve_instance(): - serve.init() - serve.global_state.wait_until_http_ready() + serve.init(blocking=True) yield diff --git a/python/ray/experimental/serve/tests/test_api.py b/python/ray/experimental/serve/tests/test_api.py index aa2002dea..a45c37ba2 100644 --- a/python/ray/experimental/serve/tests/test_api.py +++ b/python/ray/experimental/serve/tests/test_api.py @@ -1,33 +1,34 @@ import time import requests -from flaky import flaky import ray from ray.experimental import serve -def delay_rerun(*_): - time.sleep(1) - return True - - -# flaky test because the routing table might not be populated -@flaky(rerun_filter=delay_rerun) def test_e2e(serve_instance): serve.create_endpoint("endpoint", "/api") result = ray.get( serve.global_state.kv_store_actor_handle.list_service.remote()) assert result == {"/api": "endpoint"} - assert requests.get("http://127.0.0.1:8000/").json() == result + retry_count = 3 + while True: + try: + resp = requests.get("http://127.0.0.1:8000/").json() + assert resp == result + break + except Exception: + time.sleep(0.5) + retry_count -= 1 + if retry_count == 0: + assert False, "Route table hasn't been updated after 3 tries." - def echo(i): - return i + def function(flask_request): + return "OK" - serve.create_backend(echo, "echo:v1") + serve.create_backend(function, "echo:v1") serve.link("endpoint", "echo:v1") resp = requests.get("http://127.0.0.1:8000/api").json()["result"] - assert resp["path"] == "/api" - assert resp["method"] == "GET" + assert resp == "OK" diff --git a/python/ray/experimental/serve/tests/test_queue.py b/python/ray/experimental/serve/tests/test_queue.py index 6bb231169..7f89cfbce 100644 --- a/python/ray/experimental/serve/tests/test_queue.py +++ b/python/ray/experimental/serve/tests/test_queue.py @@ -6,10 +6,11 @@ def test_single_prod_cons_queue(serve_instance): q = CentralizedQueues() q.link("svc", "backend") - result_object_id = q.enqueue_request("svc", 1) + result_object_id = q.enqueue_request("svc", 1, "kwargs", None) work_object_id = q.dequeue_request("backend") got_work = ray.get(ray.ObjectID(work_object_id)) - assert got_work.request_body == 1 + assert got_work.request_args == 1 + assert got_work.request_kwargs == "kwargs" ray.worker.global_worker.put_object(got_work.result_object_id, 2) assert ray.get(ray.ObjectID(result_object_id)) == 2 @@ -18,19 +19,19 @@ def test_single_prod_cons_queue(serve_instance): def test_alter_backend(serve_instance): q = CentralizedQueues() - result_object_id = q.enqueue_request("svc", 1) + result_object_id = q.enqueue_request("svc", 1, "kwargs", None) work_object_id = q.dequeue_request("backend-1") q.set_traffic("svc", {"backend-1": 1}) got_work = ray.get(ray.ObjectID(work_object_id)) - assert got_work.request_body == 1 + assert got_work.request_args == 1 ray.worker.global_worker.put_object(got_work.result_object_id, 2) assert ray.get(ray.ObjectID(result_object_id)) == 2 - result_object_id = q.enqueue_request("svc", 1) + result_object_id = q.enqueue_request("svc", 1, "kwargs", None) work_object_id = q.dequeue_request("backend-2") q.set_traffic("svc", {"backend-2": 1}) got_work = ray.get(ray.ObjectID(work_object_id)) - assert got_work.request_body == 1 + assert got_work.request_args == 1 ray.worker.global_worker.put_object(got_work.result_object_id, 2) assert ray.get(ray.ObjectID(result_object_id)) == 2 @@ -38,8 +39,8 @@ def test_alter_backend(serve_instance): def test_split_traffic(serve_instance): q = CentralizedQueues() - q.enqueue_request("svc", 1) - q.enqueue_request("svc", 1) + q.enqueue_request("svc", 1, "kwargs", None) + q.enqueue_request("svc", 1, "kwargs", None) q.set_traffic("svc", {}) work_object_id_1 = q.dequeue_request("backend-1") work_object_id_2 = q.dequeue_request("backend-2") @@ -48,13 +49,13 @@ def test_split_traffic(serve_instance): got_work = ray.get( [ray.ObjectID(work_object_id_1), ray.ObjectID(work_object_id_2)]) - assert [g.request_body for g in got_work] == [1, 1] + assert [g.request_args for g in got_work] == [1, 1] def test_probabilities(serve_instance): q = CentralizedQueues() - [q.enqueue_request("svc", 1) for i in range(100)] + [q.enqueue_request("svc", 1, "kwargs", None) for i in range(100)] work_object_id_1_s = [ ray.ObjectID(q.dequeue_request("backend-1")) for i in range(100) diff --git a/python/ray/experimental/serve/tests/test_task_runner.py b/python/ray/experimental/serve/tests/test_task_runner.py index e8d9fafdd..e26fa4434 100644 --- a/python/ray/experimental/serve/tests/test_task_runner.py +++ b/python/ray/experimental/serve/tests/test_task_runner.py @@ -1,3 +1,4 @@ +import pytest import ray from ray.experimental.serve.queues import CentralizedQueuesActor from ray.experimental.serve.task_runner import ( @@ -6,6 +7,7 @@ from ray.experimental.serve.task_runner import ( TaskRunnerActor, wrap_to_ray_error, ) +import ray.experimental.serve.context as context def test_runner_basic(): @@ -17,21 +19,14 @@ def test_runner_basic(): def test_runner_wraps_error(): - def echo(i): - return i - - assert wrap_to_ray_error(echo, 2) == 2 - - def error(_): - return 1 / 0 - - assert isinstance(wrap_to_ray_error(error, 1), ray.exceptions.RayTaskError) + wrapped = wrap_to_ray_error(Exception()) + assert isinstance(wrapped, ray.exceptions.RayTaskError) def test_runner_actor(serve_instance): q = CentralizedQueuesActor.remote() - def echo(i): + def echo(flask_request, i=None): return i CONSUMER_NAME = "runner" @@ -46,7 +41,12 @@ def test_runner_actor(serve_instance): for query in [333, 444, 555]: result_token = ray.ObjectID( - ray.get(q.enqueue_request.remote(PRODUCER_NAME, query))) + ray.get( + q.enqueue_request.remote( + PRODUCER_NAME, + request_args=None, + request_kwargs={"i": query}, + request_context=context.TaskContext.Python))) assert ray.get(result_token) == query @@ -60,8 +60,8 @@ def test_ray_serve_mixin(serve_instance): def __init__(self, inc): self.increment = inc - def __call__(self, context): - return context + self.increment + def __call__(self, flask_request, i=None): + return i + self.increment @ray.remote class CustomActor(MyAdder, RayServeMixin): @@ -76,5 +76,38 @@ def test_ray_serve_mixin(serve_instance): for query in [333, 444, 555]: result_token = ray.ObjectID( - ray.get(q.enqueue_request.remote(PRODUCER_NAME, query))) + ray.get( + q.enqueue_request.remote( + PRODUCER_NAME, + request_args=None, + request_kwargs={"i": query}, + request_context=context.TaskContext.Python))) assert ray.get(result_token) == query + 3 + + +def test_task_runner_check_context(serve_instance): + q = CentralizedQueuesActor.remote() + + def echo(flask_request, i=None): + # Accessing the flask_request without web context should throw. + return flask_request.args["i"] + + CONSUMER_NAME = "runner" + PRODUCER_NAME = "producer" + + runner = TaskRunnerActor.remote(echo) + + runner._ray_serve_setup.remote(CONSUMER_NAME, q) + runner._ray_serve_main_loop.remote(runner) + + q.link.remote(PRODUCER_NAME, CONSUMER_NAME) + result_token = ray.ObjectID( + ray.get( + q.enqueue_request.remote( + PRODUCER_NAME, + request_args=None, + request_kwargs={"i": 42}, + request_context=context.TaskContext.Python))) + + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(result_token) diff --git a/python/setup.py b/python/setup.py index a951aa666..3e92b4999 100644 --- a/python/setup.py +++ b/python/setup.py @@ -77,7 +77,7 @@ extras = { ], "debug": ["psutil", "setproctitle", "py-spy >= 0.2.0"], "dashboard": ["aiohttp", "psutil", "setproctitle"], - "serve": ["uvicorn", "pygments", "werkzeug"], + "serve": ["uvicorn", "pygments", "werkzeug", "flask"], }