diff --git a/.travis.yml b/.travis.yml index 5a6d4c0dd..650fdace8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -175,6 +175,7 @@ script: # ray tests # Python3.5+ only. Otherwise we will get `SyntaxError` regardless of how we set the tester. - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then python -c 'import sys;exit(sys.version_info>=(3,5))' || python -m pytest -v --durations=5 --timeout=300 python/ray/experimental/test/async_test.py; fi + - if [ $RAY_CI_SERVE_AFFECTED == "1" ]; then python -c 'import sys;exit(sys.version_info>=(3,5))' || python -m pytest -v --durations=5 --timeout=300 python/ray/experimental/serve/tests; fi - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then python -m pytest -v --durations=10 --timeout=300 python/ray/tests --ignore=python/ray/tests/perf_integration_tests; fi deploy: diff --git a/ci/travis/determine_tests_to_run.py b/ci/travis/determine_tests_to_run.py index 7518a0148..4346302d4 100644 --- a/ci/travis/determine_tests_to_run.py +++ b/ci/travis/determine_tests_to_run.py @@ -5,6 +5,9 @@ from __future__ import print_function import os import subprocess +import sys +from functools import partial +from pprint import pformat def list_changed_files(commit_range): @@ -30,6 +33,7 @@ if __name__ == "__main__": RAY_CI_TUNE_AFFECTED = 0 RAY_CI_RLLIB_AFFECTED = 0 + RAY_CI_SERVE_AFFECTED = 0 RAY_CI_JAVA_AFFECTED = 0 RAY_CI_PYTHON_AFFECTED = 0 RAY_CI_LINUX_WHEELS_AFFECTED = 0 @@ -40,6 +44,8 @@ if __name__ == "__main__": files = list_changed_files(os.environ["TRAVIS_COMMIT_RANGE"].replace( "...", "..")) + print(pformat(files), file=sys.stderr) + skip_prefix_list = [ "doc/", "examples/", "dev/", "docker/", "kubernetes/", "site/" ] @@ -54,9 +60,14 @@ if __name__ == "__main__": RAY_CI_RLLIB_AFFECTED = 1 RAY_CI_LINUX_WHEELS_AFFECTED = 1 RAY_CI_MACOS_WHEELS_AFFECTED = 1 + elif changed_file.startswith("python/ray/experimental/serve"): + RAY_CI_SERVE_AFFECTED = 1 + RAY_CI_LINUX_WHEELS_AFFECTED = 1 + RAY_CI_MACOS_WHEELS_AFFECTED = 1 elif changed_file.startswith("python/"): RAY_CI_TUNE_AFFECTED = 1 RAY_CI_RLLIB_AFFECTED = 1 + RAY_CI_SERVE_AFFECTED = 1 RAY_CI_PYTHON_AFFECTED = 1 RAY_CI_LINUX_WHEELS_AFFECTED = 1 RAY_CI_MACOS_WHEELS_AFFECTED = 1 @@ -70,6 +81,7 @@ if __name__ == "__main__": elif changed_file.startswith("src/"): RAY_CI_TUNE_AFFECTED = 1 RAY_CI_RLLIB_AFFECTED = 1 + RAY_CI_SERVE_AFFECTED = 1 RAY_CI_JAVA_AFFECTED = 1 RAY_CI_PYTHON_AFFECTED = 1 RAY_CI_LINUX_WHEELS_AFFECTED = 1 @@ -77,6 +89,7 @@ if __name__ == "__main__": else: RAY_CI_TUNE_AFFECTED = 1 RAY_CI_RLLIB_AFFECTED = 1 + RAY_CI_SERVE_AFFECTED = 1 RAY_CI_JAVA_AFFECTED = 1 RAY_CI_PYTHON_AFFECTED = 1 RAY_CI_LINUX_WHEELS_AFFECTED = 1 @@ -84,16 +97,22 @@ if __name__ == "__main__": else: RAY_CI_TUNE_AFFECTED = 1 RAY_CI_RLLIB_AFFECTED = 1 + RAY_CI_SERVE_AFFECTED = 1 RAY_CI_JAVA_AFFECTED = 1 RAY_CI_PYTHON_AFFECTED = 1 RAY_CI_LINUX_WHEELS_AFFECTED = 1 RAY_CI_MACOS_WHEELS_AFFECTED = 1 - print("export RAY_CI_TUNE_AFFECTED={}".format(RAY_CI_TUNE_AFFECTED)) - print("export RAY_CI_RLLIB_AFFECTED={}".format(RAY_CI_RLLIB_AFFECTED)) - print("export RAY_CI_JAVA_AFFECTED={}".format(RAY_CI_JAVA_AFFECTED)) - print("export RAY_CI_PYTHON_AFFECTED={}".format(RAY_CI_PYTHON_AFFECTED)) - print("export RAY_CI_LINUX_WHEELS_AFFECTED={}" - .format(RAY_CI_LINUX_WHEELS_AFFECTED)) - print("export RAY_CI_MACOS_WHEELS_AFFECTED={}" - .format(RAY_CI_MACOS_WHEELS_AFFECTED)) + # Log the modified environment variables visible in console. + for output_stream in [sys.stdout, sys.stderr]: + _print = partial(print, file=output_stream) + _print("export RAY_CI_TUNE_AFFECTED={}".format(RAY_CI_TUNE_AFFECTED)) + _print("export RAY_CI_RLLIB_AFFECTED={}".format(RAY_CI_RLLIB_AFFECTED)) + _print("export RAY_CI_SERVE_AFFECTED={}".format(RAY_CI_SERVE_AFFECTED)) + _print("export RAY_CI_JAVA_AFFECTED={}".format(RAY_CI_JAVA_AFFECTED)) + _print( + "export RAY_CI_PYTHON_AFFECTED={}".format(RAY_CI_PYTHON_AFFECTED)) + _print("export RAY_CI_LINUX_WHEELS_AFFECTED={}" + .format(RAY_CI_LINUX_WHEELS_AFFECTED)) + _print("export RAY_CI_MACOS_WHEELS_AFFECTED={}" + .format(RAY_CI_MACOS_WHEELS_AFFECTED)) diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index 8864c2fae..b1fa9c2ce 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -34,7 +34,8 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" pip install -q scipy tensorflow cython==0.29.0 gym opencv-python-headless pyyaml pandas==0.24.2 requests \ - feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate psutil aiohttp + feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate psutil aiohttp \ + uvicorn dataclasses pygments werkzeug elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then # Install miniconda. wget https://repo.continuum.io/miniconda/Miniconda2-4.5.4-MacOSX-x86_64.sh -O miniconda.sh -nv @@ -48,7 +49,8 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" pip install -q cython==0.29.0 tensorflow gym opencv-python-headless pyyaml pandas==0.24.2 requests \ - feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate psutil aiohttp + feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate psutil aiohttp \ + uvicorn dataclasses pygments werkzeug elif [[ "$LINT" == "1" ]]; then sudo apt-get update sudo apt-get install -y build-essential curl unzip diff --git a/python/ray/experimental/serve/__init__.py b/python/ray/experimental/serve/__init__.py new file mode 100644 index 000000000..cd3a84c42 --- /dev/null +++ b/python/ray/experimental/serve/__init__.py @@ -0,0 +1,12 @@ +import sys +if sys.version_info < (3, 0): + raise ImportError("serve is Python 3 only.") + +from ray.experimental.serve.api import (init, create_backend, create_endpoint, + link, split, rollback, get_handle, + global_state) # noqa: E402 + +__all__ = [ + "init", "create_backend", "create_endpoint", "link", "split", "rollback", + "get_handle", "global_state" +] diff --git a/python/ray/experimental/serve/api.py b/python/ray/experimental/serve/api.py new file mode 100644 index 000000000..dd89961b6 --- /dev/null +++ b/python/ray/experimental/serve/api.py @@ -0,0 +1,187 @@ +import inspect + +import numpy as np + +import ray +from ray.experimental.serve.task_runner import RayServeMixin, TaskRunnerActor +from ray.experimental.serve.utils import pformat_color_json, logger +from ray.experimental.serve.global_state import GlobalState + +global_state = GlobalState() + + +def init(blocking=False, object_store_memory=int(1e8)): + """Initialize a serve cluster. + + Calling `ray.init` before `serve.init` is optional. When there is not a ray + cluster initialized, serve will call `ray.init` with `object_store_memory` + requirement. + + Args: + blocking (bool): If true, the function will wait for the HTTP server to + be healthy before returns. + object_store_memory (int): Allocated shared memory size in bytes. The + default is 100MiB. The default is kept low for latency stability + reason. + """ + if not ray.is_initialized(): + ray.init(object_store_memory=object_store_memory) + + # NOTE(simon): Currently the initialization order is fixed. + # HTTP server depends on the API server. + global_state.init_api_server() + global_state.init_router() + global_state.init_http_server() + + if blocking: + global_state.wait_until_http_ready() + + +def create_endpoint(endpoint_name, route_expression, blocking=True): + """Create a service endpoint given route_expression. + + Args: + endpoint_name (str): A name to associate to the endpoint. It will be + used as key to set traffic policy. + route_expression (str): A string begin with "/". HTTP server will use + the string to match the path. + blocking (bool): If true, the function will wait for service to be + registered before returning + """ + future = global_state.kv_store_actor_handle.register_service.remote( + route_expression, endpoint_name) + if blocking: + ray.get(future) + global_state.registered_endpoints.add(endpoint_name) + + +def create_backend(func_or_class, backend_tag, *actor_init_args): + """Create a backend using func_or_class and assign backend_tag. + + Args: + func_or_class (callable, class): a function or a class implements + __call__ protocol. + backend_tag (str): a unique tag assign to this backend. It will be used + to associate services in traffic policy. + *actor_init_args (optional): the argument to pass to the class + initialization method. + """ + if inspect.isfunction(func_or_class): + runner = TaskRunnerActor.remote(func_or_class) + elif inspect.isclass(func_or_class): + # Python inheritance order is right-to-left. We put RayServeMixin + # on the left to make sure its methods are not overriden. + @ray.remote + class CustomActor(RayServeMixin, func_or_class): + pass + + runner = CustomActor.remote(*actor_init_args) + else: + raise TypeError( + "Backend must be a function or class, it is {}.".format( + type(func_or_class))) + + global_state.backend_actor_handles.append(runner) + + runner._ray_serve_setup.remote(backend_tag, + global_state.router_actor_handle) + runner._ray_serve_main_loop.remote(runner) + + global_state.registered_backends.add(backend_tag) + + +def link(endpoint_name, backend_tag): + """Associate a service endpoint with backend tag. + + Example: + + >>> serve.link("service-name", "backend:v1") + + Note: + This is equivalent to + + >>> serve.split("service-name", {"backend:v1": 1.0}) + """ + assert endpoint_name in global_state.registered_endpoints + + global_state.router_actor_handle.link.remote(endpoint_name, backend_tag) + global_state.policy_action_history[endpoint_name].append({backend_tag: 1}) + + +def split(endpoint_name, traffic_policy_dictionary): + """Associate a service endpoint with traffic policy. + + Example: + + >>> serve.split("service-name", { + "backend:v1": 0.5, + "backend:v2": 0.5 + }) + + Args: + endpoint_name (str): A registered service endpoint. + traffic_policy_dictionary (dict): a dictionary maps backend names + to their traffic weights. The weights must sum to 1. + """ + + # Perform dictionary checks + assert endpoint_name in global_state.registered_endpoints + + assert isinstance(traffic_policy_dictionary, + dict), "Traffic policy must be dictionary" + prob = 0 + for backend, weight in traffic_policy_dictionary.items(): + prob += weight + assert (backend in global_state.registered_backends + ), "backend {} is not registered".format(backend) + assert np.isclose( + prob, 1, + atol=0.02), "weights must sum to 1, currently it sums to {}".format( + prob) + + global_state.router_actor_handle.set_traffic.remote( + endpoint_name, traffic_policy_dictionary) + global_state.policy_action_history[endpoint_name].append( + traffic_policy_dictionary) + + +def rollback(endpoint_name): + """Rollback a traffic policy decision. + + Args: + endpoint_name (str): A registered service endpoint. + """ + assert endpoint_name in global_state.registered_endpoints + action_queues = global_state.policy_action_history[endpoint_name] + cur_policy, prev_policy = action_queues[-1], action_queues[-2] + + logger.warning(""" +Current traffic policy is: +{cur_policy} + +Will rollback to: +{prev_policy} +""".format( + cur_policy=pformat_color_json(cur_policy), + prev_policy=pformat_color_json(prev_policy))) + + action_queues.pop() + global_state.router_actor_handle.set_traffic.remote( + endpoint_name, prev_policy) + + +def get_handle(endpoint_name): + """Retrieve RayServeHandle for service endpoint to invoke it from Python. + + Args: + endpoint_name (str): A registered service endpoint. + + Returns: + RayServeHandle + """ + assert endpoint_name in global_state.registered_endpoints + + # Delay import due to it's dependency on global_state + from ray.experimental.serve.handle import RayServeHandle + + return RayServeHandle(global_state.router_actor_handle, endpoint_name) diff --git a/python/ray/experimental/serve/constants.py b/python/ray/experimental/serve/constants.py new file mode 100644 index 000000000..e3aa5fc77 --- /dev/null +++ b/python/ray/experimental/serve/constants.py @@ -0,0 +1,2 @@ +#: The interval which http server refreshes its routing table +HTTP_ROUTER_CHECKER_INTERVAL_S = 2 diff --git a/python/ray/experimental/serve/examples/echo.py b/python/ray/experimental/serve/examples/echo.py new file mode 100644 index 000000000..f6cf5afbe --- /dev/null +++ b/python/ray/experimental/serve/examples/echo.py @@ -0,0 +1,28 @@ +""" +Example service that prints out http context. +""" + +import time + +import requests + +from ray.experimental import serve +from ray.experimental.serve.utils import pformat_color_json + + +def echo(context): + return context + + +serve.init(blocking=True) + +serve.create_endpoint("my_endpoint", "/echo", blocking=True) +serve.create_backend(echo, "echo:v1") +serve.link("my_endpoint", "echo:v1") + +while True: + resp = requests.get("http://127.0.0.1:8000/echo").json() + print(pformat_color_json(resp)) + + print("...Sleeping for 2 seconds...") + time.sleep(2) diff --git a/python/ray/experimental/serve/examples/echo_actor.py b/python/ray/experimental/serve/examples/echo_actor.py new file mode 100644 index 000000000..98422679d --- /dev/null +++ b/python/ray/experimental/serve/examples/echo_actor.py @@ -0,0 +1,41 @@ +""" +Example actor that adds message to the end of query_string. +""" + +import time + +import requests +from werkzeug import urls + +from ray.experimental import serve +from ray.experimental.serve.utils import pformat_color_json + + +class EchoActor: + def __init__(self, message): + self.message = message + + def __call__(self, context): + query_string_dict = urls.url_decode(context["query_string"]) + message = "" + message += query_string_dict.get("message", "") + message += " " + message += self.message + return message + + +serve.init(blocking=True) + +serve.create_endpoint("my_endpoint", "/echo", blocking=True) +serve.create_backend(EchoActor, "echo:v1", "world") +serve.link("my_endpoint", "echo:v1") + +while True: + resp = requests.get("http://127.0.0.1:8000/echo?message=hello").json() + print(pformat_color_json(resp)) + + resp = requests.get("http://127.0.0.1:8000/echo").json() + print(pformat_color_json(resp)) + + print("...Sleeping for 2 seconds...") + time.sleep(2) diff --git a/python/ray/experimental/serve/examples/echo_error.py b/python/ray/experimental/serve/examples/echo_error.py new file mode 100644 index 000000000..25d900068 --- /dev/null +++ b/python/ray/experimental/serve/examples/echo_error.py @@ -0,0 +1,44 @@ +""" +Example of error handling mechanism in ray serve. + +We are going to define a buggy function that raise some exception: +>>> def echo(_): + raise Exception("oh no") + +The expected behavior is: +- HTTP server should respond with "internal error" in the response JSON +- ray.get(handle.remote(33)) should raise RayTaskError with traceback. + +This shows that error is hidden from HTTP side but always visible when calling +from Python. +""" + +import time + +import requests + +import ray +from ray.experimental import serve +from ray.experimental.serve.utils import pformat_color_json + + +def echo(_): + raise Exception("Something went wrong...") + + +serve.init(blocking=True) + +serve.create_endpoint("my_endpoint", "/echo", blocking=True) +serve.create_backend(echo, "echo:v1") +serve.link("my_endpoint", "echo:v1") + +for _ in range(2): + resp = requests.get("http://127.0.0.1:8000/echo").json() + print(pformat_color_json(resp)) + + print("...Sleeping for 2 seconds...") + time.sleep(2) + +handle = serve.get_handle("my_endpoint") + +ray.get(handle.remote(33)) diff --git a/python/ray/experimental/serve/examples/echo_rollback.py b/python/ray/experimental/serve/examples/echo_rollback.py new file mode 100644 index 000000000..bcdf7e14e --- /dev/null +++ b/python/ray/experimental/serve/examples/echo_rollback.py @@ -0,0 +1,50 @@ +""" +Example rollback action in ray serve. We first deploy only v1, then set a + 50/50 deployment between v1 and v2, and finally roll back to only v1. +""" +import time + +import requests + +from ray.experimental import serve +from ray.experimental.serve.utils import pformat_color_json + + +def echo_v1(_): + return "v1" + + +def echo_v2(_): + return "v2" + + +serve.init(blocking=True) + +serve.create_endpoint("my_endpoint", "/echo", blocking=True) +serve.create_backend(echo_v1, "echo:v1") +serve.link("my_endpoint", "echo:v1") + +for _ in range(3): + resp = requests.get("http://127.0.0.1:8000/echo").json() + print(pformat_color_json(resp)) + + print("...Sleeping for 2 seconds...") + time.sleep(2) + +serve.create_backend(echo_v2, "echo:v2") +serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5}) + +for _ in range(6): + resp = requests.get("http://127.0.0.1:8000/echo").json() + print(pformat_color_json(resp)) + + print("...Sleeping for 2 seconds...") + time.sleep(2) + +serve.rollback("my_endpoint") +for _ in range(6): + resp = requests.get("http://127.0.0.1:8000/echo").json() + print(pformat_color_json(resp)) + + print("...Sleeping for 2 seconds...") + time.sleep(2) diff --git a/python/ray/experimental/serve/examples/echo_split.py b/python/ray/experimental/serve/examples/echo_split.py new file mode 100644 index 000000000..6942db5a5 --- /dev/null +++ b/python/ray/experimental/serve/examples/echo_split.py @@ -0,0 +1,41 @@ +""" +Example of traffic splitting. We will first use echo:v1. Then v1 and v2 +will split the incoming traffic evenly. +""" +import time + +import requests + +from ray.experimental import serve +from ray.experimental.serve.utils import pformat_color_json + + +def echo_v1(_): + return "v1" + + +def echo_v2(_): + return "v2" + + +serve.init(blocking=True) + +serve.create_endpoint("my_endpoint", "/echo", blocking=True) +serve.create_backend(echo_v1, "echo:v1") +serve.link("my_endpoint", "echo:v1") + +for _ in range(3): + resp = requests.get("http://127.0.0.1:8000/echo").json() + print(pformat_color_json(resp)) + + print("...Sleeping for 2 seconds...") + time.sleep(2) + +serve.create_backend(echo_v2, "echo:v2") +serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5}) +while True: + resp = requests.get("http://127.0.0.1:8000/echo").json() + print(pformat_color_json(resp)) + + print("...Sleeping for 2 seconds...") + time.sleep(2) diff --git a/python/ray/experimental/serve/global_state.py b/python/ray/experimental/serve/global_state.py new file mode 100644 index 000000000..aea51d526 --- /dev/null +++ b/python/ray/experimental/serve/global_state.py @@ -0,0 +1,84 @@ +import time +from collections import defaultdict, deque + +import ray +from ray.experimental.serve.kv_store_service import KVStoreProxyActor +from ray.experimental.serve.queues import CentralizedQueuesActor +from ray.experimental.serve.utils import logger +from ray.experimental.serve.server import HTTPActor + +# TODO(simon): Global state currently is designed to resides in the driver +# process. In the next iteration, we will move all mutable states into +# two actors: (1) namespaced key-value store backed by persistent store +# (2) actor supervisors holding all actor handles and is responsible +# for new actor instantiation and dead actor termination. + +LOG_PREFIX = "[Global State] " + + +class GlobalState: + """Encapsulate all global state in the serving system. + + Warning: + Currently the state resides inside driver process. The state will be + moved into a key value stored service AND a supervisor service. + """ + + def __init__(self): + #: holds all actor handles. + self.backend_actor_handles = [] + + #: actor handle to KV store actor + self.kv_store_actor_handle = None + #: actor handle to HTTP server + self.http_actor_handle = None + #: actor handle the router actor + self.router_actor_handle = None + + #: Set[str] list of backend names, used for deduplication + self.registered_backends = set() + #: Set[str] list of service endpoint names, used for deduplication + self.registered_endpoints = set() + + #: Mapping of endpoints -> a stack of traffic policy + self.policy_action_history = defaultdict(deque) + + #: HTTP address. Currently it's hard coded to localhost with port 8000 + # In future iteration, HTTP server will be started on every node and + # use random/available port in a pre-defined port range. TODO(simon) + self.http_address = "" + + def init_api_server(self): + logger.info(LOG_PREFIX + "Initalizing routing table") + self.kv_store_actor_handle = KVStoreProxyActor.remote() + logger.info((LOG_PREFIX + "Health checking routing table {}").format( + ray.get(self.kv_store_actor_handle.get_request_count.remote())), ) + + def init_http_server(self): + logger.info(LOG_PREFIX + "Initializing HTTP server") + self.http_actor_handle = HTTPActor.remote(self.kv_store_actor_handle, + self.router_actor_handle) + self.http_actor_handle.run.remote(host="0.0.0.0", port=8000) + self.http_address = "http://localhost:8000" + + def init_router(self): + logger.info(LOG_PREFIX + "Initializing queuing system") + self.router_actor_handle = CentralizedQueuesActor.remote() + self.router_actor_handle.register_self_handle.remote( + self.router_actor_handle) + + def wait_until_http_ready(self, num_retries=5, backoff_time_s=1): + routing_table_request_count = 0 + retries = num_retries + + while not routing_table_request_count: + routing_table_request_count = (ray.get( + self.kv_store_actor_handle.get_request_count.remote())) + logger.debug((LOG_PREFIX + "Checking if HTTP server is ready." + "{} retries left.").format(retries)) + time.sleep(backoff_time_s) + retries -= 1 + if retries == 0: + raise Exception( + "HTTP server not ready after {} retries.".format( + num_retries)) diff --git a/python/ray/experimental/serve/handle.py b/python/ray/experimental/serve/handle.py new file mode 100644 index 000000000..56886f25c --- /dev/null +++ b/python/ray/experimental/serve/handle.py @@ -0,0 +1,64 @@ +import ray +from ray.experimental import serve + + +class RayServeHandle: + """A handle to a service endpoint. + + Invoking this endpoint with .remote is equivalent to pinging + an HTTP endpoint. + + Example: + >>> handle = serve.get_handle("my_endpoint") + >>> handle + RayServeHandle( + Endpoint="my_endpoint", + URL="...", + Traffic=... + ) + >>> handle.remote(my_request_content) + ObjectID(...) + >>> ray.get(handle.remote(...)) + # result + >>> ray.get(handle.remote(let_it_crash_request)) + # raises RayTaskError Exception + """ + + def __init__(self, router_handle, endpoint_name): + self.router_handle = router_handle + self.endpoint_name = endpoint_name + + def remote(self, *args): + # TODO(simon): Support kwargs once #5606 is merged. + result_object_id_bytes = ray.get( + self.router_handle.enqueue_request.remote(self.endpoint_name, + *args)) + return ray.ObjectID(result_object_id_bytes) + + def get_traffic_policy(self): + # TODO(simon): This method is implemented via checking global state + # because we are sure handle and global_state are in the same process. + # However, once global_state is deprecated, this method need to be + # updated accordingly. + history = serve.global_state.policy_action_history[self.endpoint_name] + if len(history): + return history[-1] + else: + return None + + def get_http_endpoint(self): + return serve.global_state.http_address + + def __repr__(self): + return """ +RayServeHandle( + Endpoint="{endpoint_name}", + URL="{http_endpoint}/{endpoint_name}, + Traffic={traffic_policy} +) +""".format(endpoint_name=self.endpoint_name, + http_endpoint=self.get_http_endpoint(), + traffic_policy=self.get_traffic_policy()) + + # TODO(simon): a convenience function that dumps equivalent requests + # code for a given call. diff --git a/python/ray/experimental/serve/kv_store_service.py b/python/ray/experimental/serve/kv_store_service.py new file mode 100644 index 000000000..e472fee69 --- /dev/null +++ b/python/ray/experimental/serve/kv_store_service.py @@ -0,0 +1,173 @@ +import json +from abc import ABC + +import ray +import ray.experimental.internal_kv as ray_kv +from ray.experimental.serve.utils import logger + + +class NamespacedKVStore(ABC): + """Abstract base class for a namespaced key-value store. + + The idea is that multiple key-value stores can be created while sharing + the same storage system. The keys of each instance are namespaced to avoid + object_id key collision. + + Example: + + >>> store_ns1 = NamespacedKVStore(namespace="ns1") + >>> store_ns2 = NamespacedKVStore(namespace="ns2") + # Two stores can share the same connection like Redis or SQL Table + >>> store_ns1.put("same-key", 1) + >>> store_ns1.get("same-key") + 1 + >>> store_ns2.put("same-key", 2) + >>> store_ns2.get("same-key", 2) + 2 + """ + + def __init__(self, namespace): + raise NotImplementedError() + + def get(self, key): + """Retrieve the value for the given key. + + Args: + key (str) + """ + raise NotImplementedError() + + def put(self, key, value): + """Serialize the value and store it under the given key. + + Args: + key (str) + value (object): any serializable object. The serialization method + is determined by the subclass implementation. + """ + raise NotImplementedError() + + def as_dict(self): + """Return the entire namespace as a dictionary. + + Returns: + data (dict): key value pairs in current namespace + """ + raise NotImplementedError() + + +class InMemoryKVStore(NamespacedKVStore): + """A reference implementation used for testing.""" + + def __init__(self, namespace): + self.data = dict() + + # Namespace is ignored, because each namespace is backed by + # an in-memory Python dictionary. + self.namespace = namespace + + def get(self, key): + return self.data[key] + + def put(self, key, value): + self.data[key] = value + + def as_dict(self): + return self.data.copy() + + +class RayInternalKVStore(NamespacedKVStore): + """A NamespacedKVStore implementation using ray's `internal_kv`.""" + + def __init__(self, namespace): + assert ray_kv._internal_kv_initialized() + self.index_key = "RAY_SERVE_INDEX" + self.namespace = namespace + self._put(self.index_key, []) + + def _format_key(self, key): + return "{ns}-{key}".format(ns=self.namespace, key=key) + + def _remove_format_key(self, formatted_key): + return formatted_key.replace(self.namespace + "-", "", 1) + + def _serialize(self, obj): + return json.dumps(obj) + + def _deserialize(self, buffer): + return json.loads(buffer) + + def _put(self, key, value): + ray_kv._internal_kv_put( + self._format_key(self._serialize(key)), + self._serialize(value), + overwrite=True, + ) + + def _get(self, key): + return self._deserialize( + ray_kv._internal_kv_get(self._format_key(self._serialize(key)))) + + def get(self, key): + return self._get(key) + + def put(self, key, value): + assert isinstance(key, str), "Key must be a string." + + self._put(key, value) + + all_keys = set(self._get(self.index_key)) + all_keys.add(key) + self._put(self.index_key, list(all_keys)) + + def as_dict(self): + data = {} + all_keys = self._get(self.index_key) + for key in all_keys: + data[self._remove_format_key(key)] = self._get(key) + return data + + +class KVStoreProxy: + def __init__(self, kv_class=InMemoryKVStore): + self.routing_table = kv_class(namespace="routes") + self.request_count = 0 + + def register_service(self, route: str, service: str): + """Create an entry in the routing table + + Args: + route: http path name. Must begin with '/'. + service: service name. This is the name http actor will push + the request to. + """ + logger.debug("[KV] Registering route {} to service {}.".format( + route, service)) + self.routing_table.put(route, service) + + def list_service(self): + """Returns the routing table.""" + self.request_count += 1 + table = self.routing_table.as_dict() + return table + + def get_request_count(self): + """Return the number of requests that fetched the routing table. + + This method is used for two purpose: + + 1. Make sure HTTP server has started and healthy. Incremented request + count means HTTP server is actively fetching routing table. + + 2. Make sure HTTP server does not have stale routing table. This number + should be incremented every HTTP_ROUTER_CHECKER_INTERVAL_S seconds. + Supervisor should check this number as indirect indicator of http + server's health. + """ + return self.request_count + + +@ray.remote +class KVStoreProxyActor(KVStoreProxy): + def __init__(self, kv_class=RayInternalKVStore): + super().__init__(kv_class=kv_class) diff --git a/python/ray/experimental/serve/queues.py b/python/ray/experimental/serve/queues.py new file mode 100644 index 000000000..582113ab4 --- /dev/null +++ b/python/ray/experimental/serve/queues.py @@ -0,0 +1,155 @@ +from collections import defaultdict, deque + +import numpy as np + +import ray +from ray.experimental.serve.utils import get_custom_object_id, logger + + +class Query: + def __init__(self, request_body, result_object_id=None): + self.request_body = request_body + if result_object_id is None: + self.result_object_id = get_custom_object_id() + else: + self.result_object_id = result_object_id + + +class WorkIntent: + def __init__(self, work_object_id=None): + if work_object_id is None: + self.work_object_id = get_custom_object_id() + else: + self.work_object_id = work_object_id + + +class CentralizedQueues: + """A router that routes request to available workers. + + Router aceepts each request from the `enqueue_request` method and enqueues + it. It also accepts worker request to work (called work_intention in code) + from workers via the `dequeue_request` method. The traffic policy is used + to match requests with their corresponding workers. + + Behavior: + >>> # psuedo-code + >>> queue = CentralizedQueues() + >>> queue.enqueue_request('service-name', data) + # nothing happens, request is queued. + # returns result ObjectID, which will contains the final result + >>> queue.dequeue_request('backend-1') + # nothing happens, work intention is queued. + # return work ObjectID, which will contains the future request payload + >>> queue.link('service-name', 'backend-1') + # here the enqueue_requester is matched with worker, request + # data is put into work ObjectID, and the worker processes the request + # and store the result into result ObjectID + + Traffic policy splits the traffic among different workers + probabilistically: + + 1. When all backends are ready to receive traffic, we will randomly + choose a backend based on the weights assigned by the traffic policy + dictionary. + + 2. When more than 1 but not all backends are ready, we will normalize the + weights of the ready backends to 1 and choose a backend via sampling. + + 3. When there is only 1 backend ready, we will only use that backend. + """ + + def __init__(self): + # service_name -> request queue + self.queues = defaultdict(deque) + + # service_name -> traffic_policy + self.traffic = defaultdict(dict) + + # backend_name -> worker queue + self.workers = defaultdict(deque) + + def enqueue_request(self, service, request_data): + query = Query(request_data) + self.queues[service].append(query) + self.flush() + return query.result_object_id.binary() + + def dequeue_request(self, backend): + intention = WorkIntent() + self.workers[backend].append(intention) + self.flush() + return intention.work_object_id.binary() + + def link(self, service, backend): + logger.debug("Link %s with %s", service, backend) + self.traffic[service][backend] = 1.0 + self.flush() + + def set_traffic(self, service, traffic_dict): + logger.debug("Setting traffic for service %s to %s", service, + traffic_dict) + self.traffic[service] = traffic_dict + self.flush() + + def flush(self): + """In the default case, flush calls ._flush. + + When this class is a Ray actor, .flush can be scheduled as a remote + method invocation. + """ + self._flush() + + def _get_available_backends(self, service): + backends_in_policy = set(self.traffic[service].keys()) + available_workers = set((backend + for backend, queues in self.workers.items() + if len(queues) > 0)) + return list(backends_in_policy.intersection(available_workers)) + + def _flush(self): + for service, queue in self.queues.items(): + ready_backends = self._get_available_backends(service) + + while len(queue) and len(ready_backends): + # Fast path, only one backend available. + if len(ready_backends) == 1: + backend = ready_backends[0] + request, work = (queue.popleft(), + self.workers[backend].popleft()) + ray.worker.global_worker.put_object( + work.work_object_id, request) + + # We have more than one backend available. + # We will roll a dice among the multiple backends. + else: + backend_weights = np.array([ + self.traffic[service][backend_name] + for backend_name in ready_backends + ]) + # Normalize the weights to 1. + backend_weights /= backend_weights.sum() + chosen_backend = np.random.choice( + ready_backends, p=backend_weights).squeeze() + + request, work = ( + queue.popleft(), + self.workers[chosen_backend].popleft(), + ) + ray.worker.global_worker.put_object( + work.work_object_id, request) + + ready_backends = self._get_available_backends(service) + + +@ray.remote +class CentralizedQueuesActor(CentralizedQueues): + self_handle = None + + def register_self_handle(self, handle_to_this_actor): + self.self_handle = handle_to_this_actor + + def flush(self): + if self.self_handle: + self.self_handle._flush.remote() + else: + self._flush() diff --git a/python/ray/experimental/serve/server.py b/python/ray/experimental/serve/server.py new file mode 100644 index 000000000..af70a782b --- /dev/null +++ b/python/ray/experimental/serve/server.py @@ -0,0 +1,125 @@ +import asyncio +import json + +import uvicorn + +import ray +from ray.experimental.async_api import _async_init, as_future +from ray.experimental.serve.utils import BytesEncoder +from ray.experimental.serve.constants import HTTP_ROUTER_CHECKER_INTERVAL_S + + +class JSONResponse: + """ASGI compliant response class. + + It is expected to be called in async context and pass along + `scope, receive, send` as in ASGI spec. + + >>> await JSONResponse({"k": "v"})(scope, receive, send) + """ + + def __init__(self, content=None, status_code=200): + """Construct a JSON HTTP Response. + + Args: + content (optional): Any JSON serializable object. + status_code (int, optional): Default status code is 200. + """ + self.body = self.render(content) + self.status_code = status_code + self.raw_headers = [[b"content-type", b"application/json"]] + + def render(self, content): + if content is None: + return b"" + if isinstance(content, bytes): + return content + return json.dumps(content, cls=BytesEncoder, indent=2).encode() + + async def __call__(self, scope, receive, send): + await send({ + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + }) + await send({"type": "http.response.body", "body": self.body}) + + +class HTTPProxy: + """ + This class should be instantiated and ran by ASGI server. + + >>> import uvicorn + >>> uvicorn.run(HTTPProxy(kv_store_actor_handle, router_handle)) + # blocks forever + """ + + def __init__(self, kv_store_actor_handle, router_handle): + """ + Args: + kv_store_actor_handle (ray.actor.ActorHandle): handle to routing + table actor. It will be used to populate routing table. It + should implement `handle.list_service()` + router_handle (ray.actor.ActorHandle): actor handle to push request + to. It should implement + `handle.enqueue_request.remote(endpoint, body)` + """ + assert ray.is_initialized() + + self.admin_actor = kv_store_actor_handle + self.router = router_handle + self.route_table = dict() + + async def route_checker(self, interval): + while True: + try: + self.route_table = await as_future( + self.admin_actor.list_service.remote()) + except ray.exceptions.RayletError: # Gracefully handle termination + return + + await asyncio.sleep(interval) + + async def __call__(self, scope, receive, send): + # NOTE: This implements ASGI protocol specified in + # https://asgi.readthedocs.io/en/latest/specs/index.html + + if scope["type"] == "lifespan": + await _async_init() + asyncio.ensure_future( + self.route_checker(interval=HTTP_ROUTER_CHECKER_INTERVAL_S)) + return + + current_path = scope["path"] + if current_path == "/": + await JSONResponse(self.route_table)(scope, receive, send) + elif current_path in self.route_table: + endpoint_name = self.route_table[current_path] + result_object_id_bytes = await as_future( + self.router.enqueue_request.remote(endpoint_name, scope)) + result = await as_future(ray.ObjectID(result_object_id_bytes)) + + if isinstance(result, ray.exceptions.RayTaskError): + await JSONResponse({ + "error": "internal error, please use python API to debug" + })(scope, receive, send) + else: + await JSONResponse({"result": result})(scope, receive, send) + else: + error_message = ("Path {} not found. " + "Please ping http://.../ for routing table" + ).format(current_path) + + await JSONResponse( + { + "error": error_message + }, status_code=404)(scope, receive, send) + + +@ray.remote +class HTTPActor: + def __init__(self, kv_store_actor_handle, router_handle): + self.app = HTTPProxy(kv_store_actor_handle, router_handle) + + def run(self, host="0.0.0.0", port=8000): + uvicorn.run(self.app, host=host, port=port, lifespan="on") diff --git a/python/ray/experimental/serve/task_runner.py b/python/ray/experimental/serve/task_runner.py new file mode 100644 index 000000000..0533a6c02 --- /dev/null +++ b/python/ray/experimental/serve/task_runner.py @@ -0,0 +1,96 @@ +import traceback + +import ray + + +class TaskRunner: + """A simple class that runs a function. + + The purpose of this class is to model what the most basic actor could be. + That is, a ray serve actor should implement the TaskRunner interface. + """ + + def __init__(self, func_to_run): + self.func = func_to_run + + def __call__(self, *args): + return self.func(*args) + + +def wrap_to_ray_error(callable_obj, *args): + """Utility method that catch and seal exceptions in execution""" + try: + return callable_obj(*args) + except Exception: + traceback_str = ray.utils.format_error_message(traceback.format_exc()) + return ray.exceptions.RayTaskError(str(callable_obj), traceback_str) + + +class RayServeMixin: + """This mixin class adds the functionality to fetch from router queues. + + Warning: + It assumes the main execution method is `__call__` of the user defined + class. This means that serve will call `your_instance.__call__` when + each request comes in. This behavior will be fixed in the future to + allow assigning artibrary methods. + + Example: + >>> # Use ray.remote decorator and RayServeMixin + >>> # to make MyClass servable. + >>> @ray.remote + class RayServeActor(RayServeMixin, MyClass): + pass + """ + _ray_serve_self_handle = None + _ray_serve_router_handle = None + _ray_serve_setup_completed = False + _ray_serve_dequeue_requestr_name = None + + def _ray_serve_setup(self, my_name, _ray_serve_router_handle): + self._ray_serve_dequeue_requestr_name = my_name + self._ray_serve_router_handle = _ray_serve_router_handle + self._ray_serve_setup_completed = True + + def _ray_serve_main_loop(self, my_handle): + assert self._ray_serve_setup_completed + self._ray_serve_self_handle = my_handle + + work_token = ray.get( + self._ray_serve_router_handle.dequeue_request.remote( + self._ray_serve_dequeue_requestr_name)) + work_item = ray.get(ray.ObjectID(work_token)) + + # TODO(simon): + # __call__ should be able to take multiple *args and **kwargs. + result = wrap_to_ray_error(self.__call__, work_item.request_body) + result_object_id = work_item.result_object_id + ray.worker.global_worker.put_object(result_object_id, result) + + # The worker finished one unit of work. + # It will now tail recursively schedule the main_loop again. + + # TODO(simon): remove tail recursion, ask router to callback instead + self._ray_serve_self_handle._ray_serve_main_loop.remote(my_handle) + + +class TaskRunnerBackend(TaskRunner, RayServeMixin): + """A simple function serving backend + + Note that this is not yet an actor. To make it an actor: + + >>> @ray.remote + class TaskRunnerActor(TaskRunnerBackend): + pass + + Note: + This class is not used in the actual ray serve system. It exists + for documentation purpose. + """ + + pass + + +@ray.remote +class TaskRunnerActor(TaskRunnerBackend): + pass diff --git a/python/ray/experimental/serve/tests/conftest.py b/python/ray/experimental/serve/tests/conftest.py new file mode 100644 index 000000000..9f784a18e --- /dev/null +++ b/python/ray/experimental/serve/tests/conftest.py @@ -0,0 +1,21 @@ +import pytest + +import ray +from ray.experimental import serve + + +@pytest.fixture(scope="session") +def serve_instance(): + serve.init() + serve.global_state.wait_until_http_ready() + yield + + +@pytest.fixture(scope="session") +def ray_instance(): + ray_already_initialized = ray.is_initialized() + if not ray_already_initialized: + ray.init(object_store_memory=int(1e8)) + yield + if not ray_already_initialized: + ray.shutdown() diff --git a/python/ray/experimental/serve/tests/test_api.py b/python/ray/experimental/serve/tests/test_api.py new file mode 100644 index 000000000..aa2002dea --- /dev/null +++ b/python/ray/experimental/serve/tests/test_api.py @@ -0,0 +1,33 @@ +import time + +import requests +from flaky import flaky + +import ray +from ray.experimental import serve + + +def delay_rerun(*_): + time.sleep(1) + return True + + +# flaky test because the routing table might not be populated +@flaky(rerun_filter=delay_rerun) +def test_e2e(serve_instance): + serve.create_endpoint("endpoint", "/api") + result = ray.get( + serve.global_state.kv_store_actor_handle.list_service.remote()) + assert result == {"/api": "endpoint"} + + assert requests.get("http://127.0.0.1:8000/").json() == result + + def echo(i): + return i + + serve.create_backend(echo, "echo:v1") + serve.link("endpoint", "echo:v1") + + resp = requests.get("http://127.0.0.1:8000/api").json()["result"] + assert resp["path"] == "/api" + assert resp["method"] == "GET" diff --git a/python/ray/experimental/serve/tests/test_queue.py b/python/ray/experimental/serve/tests/test_queue.py new file mode 100644 index 000000000..6bb231169 --- /dev/null +++ b/python/ray/experimental/serve/tests/test_queue.py @@ -0,0 +1,72 @@ +import ray +from ray.experimental.serve.queues import CentralizedQueues + + +def test_single_prod_cons_queue(serve_instance): + q = CentralizedQueues() + q.link("svc", "backend") + + result_object_id = q.enqueue_request("svc", 1) + work_object_id = q.dequeue_request("backend") + got_work = ray.get(ray.ObjectID(work_object_id)) + assert got_work.request_body == 1 + + ray.worker.global_worker.put_object(got_work.result_object_id, 2) + assert ray.get(ray.ObjectID(result_object_id)) == 2 + + +def test_alter_backend(serve_instance): + q = CentralizedQueues() + + result_object_id = q.enqueue_request("svc", 1) + work_object_id = q.dequeue_request("backend-1") + q.set_traffic("svc", {"backend-1": 1}) + got_work = ray.get(ray.ObjectID(work_object_id)) + assert got_work.request_body == 1 + ray.worker.global_worker.put_object(got_work.result_object_id, 2) + assert ray.get(ray.ObjectID(result_object_id)) == 2 + + result_object_id = q.enqueue_request("svc", 1) + work_object_id = q.dequeue_request("backend-2") + q.set_traffic("svc", {"backend-2": 1}) + got_work = ray.get(ray.ObjectID(work_object_id)) + assert got_work.request_body == 1 + ray.worker.global_worker.put_object(got_work.result_object_id, 2) + assert ray.get(ray.ObjectID(result_object_id)) == 2 + + +def test_split_traffic(serve_instance): + q = CentralizedQueues() + + q.enqueue_request("svc", 1) + q.enqueue_request("svc", 1) + q.set_traffic("svc", {}) + work_object_id_1 = q.dequeue_request("backend-1") + work_object_id_2 = q.dequeue_request("backend-2") + q.set_traffic("svc", {"backend-1": 0.5, "backend-2": 0.5}) + + got_work = ray.get( + [ray.ObjectID(work_object_id_1), + ray.ObjectID(work_object_id_2)]) + assert [g.request_body for g in got_work] == [1, 1] + + +def test_probabilities(serve_instance): + q = CentralizedQueues() + + [q.enqueue_request("svc", 1) for i in range(100)] + + work_object_id_1_s = [ + ray.ObjectID(q.dequeue_request("backend-1")) for i in range(100) + ] + work_object_id_2_s = [ + ray.ObjectID(q.dequeue_request("backend-2")) for i in range(100) + ] + + q.set_traffic("svc", {"backend-1": 0.1, "backend-2": 0.9}) + + backend_1_ready_object_ids, _ = ray.wait( + work_object_id_1_s, num_returns=100, timeout=0.0) + backend_2_ready_object_ids, _ = ray.wait( + work_object_id_2_s, num_returns=100, timeout=0.0) + assert len(backend_1_ready_object_ids) < len(backend_2_ready_object_ids) diff --git a/python/ray/experimental/serve/tests/test_routing.py b/python/ray/experimental/serve/tests/test_routing.py new file mode 100644 index 000000000..bc66cd136 --- /dev/null +++ b/python/ray/experimental/serve/tests/test_routing.py @@ -0,0 +1,27 @@ +from ray.experimental.serve.kv_store_service import (InMemoryKVStore, + RayInternalKVStore) + + +def test_default_in_memory_kv(): + kv = InMemoryKVStore("") + kv.put("1", 2) + assert kv.get("1") == 2 + kv.put("1", 3) + assert kv.get("1") == 3 + assert kv.as_dict() == {"1": 3} + + +def test_ray_interal_kv(ray_instance): + kv = RayInternalKVStore("") + kv.put("1", 2) + assert kv.get("1") == 2 + kv.put("1", 3) + assert kv.get("1") == 3 + assert kv.as_dict() == {"1": 3} + + kv = RayInternalKVStore("othernamespace") + kv.put("1", 2) + assert kv.get("1") == 2 + kv.put("1", 3) + assert kv.get("1") == 3 + assert kv.as_dict() == {"1": 3} diff --git a/python/ray/experimental/serve/tests/test_task_runner.py b/python/ray/experimental/serve/tests/test_task_runner.py new file mode 100644 index 000000000..e8d9fafdd --- /dev/null +++ b/python/ray/experimental/serve/tests/test_task_runner.py @@ -0,0 +1,80 @@ +import ray +from ray.experimental.serve.queues import CentralizedQueuesActor +from ray.experimental.serve.task_runner import ( + RayServeMixin, + TaskRunner, + TaskRunnerActor, + wrap_to_ray_error, +) + + +def test_runner_basic(): + def echo(i): + return i + + r = TaskRunner(echo) + assert r(1) == 1 + + +def test_runner_wraps_error(): + def echo(i): + return i + + assert wrap_to_ray_error(echo, 2) == 2 + + def error(_): + return 1 / 0 + + assert isinstance(wrap_to_ray_error(error, 1), ray.exceptions.RayTaskError) + + +def test_runner_actor(serve_instance): + q = CentralizedQueuesActor.remote() + + def echo(i): + return i + + CONSUMER_NAME = "runner" + PRODUCER_NAME = "prod" + + runner = TaskRunnerActor.remote(echo) + + runner._ray_serve_setup.remote(CONSUMER_NAME, q) + runner._ray_serve_main_loop.remote(runner) + + q.link.remote(PRODUCER_NAME, CONSUMER_NAME) + + for query in [333, 444, 555]: + result_token = ray.ObjectID( + ray.get(q.enqueue_request.remote(PRODUCER_NAME, query))) + assert ray.get(result_token) == query + + +def test_ray_serve_mixin(serve_instance): + q = CentralizedQueuesActor.remote() + + CONSUMER_NAME = "runner-cls" + PRODUCER_NAME = "prod-cls" + + class MyAdder: + def __init__(self, inc): + self.increment = inc + + def __call__(self, context): + return context + self.increment + + @ray.remote + class CustomActor(MyAdder, RayServeMixin): + pass + + runner = CustomActor.remote(3) + + runner._ray_serve_setup.remote(CONSUMER_NAME, q) + runner._ray_serve_main_loop.remote(runner) + + q.link.remote(PRODUCER_NAME, CONSUMER_NAME) + + for query in [333, 444, 555]: + result_token = ray.ObjectID( + ray.get(q.enqueue_request.remote(PRODUCER_NAME, query))) + assert ray.get(result_token) == query + 3 diff --git a/python/ray/experimental/serve/tests/test_util.py b/python/ray/experimental/serve/tests/test_util.py new file mode 100644 index 000000000..be17e4ba6 --- /dev/null +++ b/python/ray/experimental/serve/tests/test_util.py @@ -0,0 +1,9 @@ +import json + +from ray.experimental.serve.utils import BytesEncoder + + +def test_bytes_encoder(): + data_before = {"inp": {"nest": b"bytes"}} + data_after = {"inp": {"nest": "bytes"}} + assert json.loads(json.dumps(data_before, cls=BytesEncoder)) == data_after diff --git a/python/ray/experimental/serve/utils.py b/python/ray/experimental/serve/utils.py new file mode 100644 index 000000000..62fd663e6 --- /dev/null +++ b/python/ray/experimental/serve/utils.py @@ -0,0 +1,51 @@ +import json +import logging + +from pygments import formatters, highlight, lexers + +import ray + + +def _get_logger(): + logger = logging.getLogger("ray.serve") + # TODO(simon): Make logging level configurable. + logger.setLevel(logging.INFO) + return logger + + +logger = _get_logger() + + +class BytesEncoder(json.JSONEncoder): + """Allow bytes to be part of the JSON document. + + BytesEncoder will walk the JSON tree and decode bytes with utf-8 codec. + + Example: + >>> json.dumps({b'a': b'c'}, cls=BytesEncoder) + '{"a":"c"}' + """ + + def default(self, o): # pylint: disable=E0202 + if isinstance(o, bytes): + return o.decode("utf-8") + return super().default(o) + + +def get_custom_object_id(): + """Use ray worker API to get computed ObjectID""" + worker = ray.worker.global_worker + object_id = ray._raylet.compute_put_id(worker.current_task_id, + worker.task_context.put_index) + worker.task_context.put_index += 1 + return object_id + + +def pformat_color_json(d): + """Use pygments to pretty format and colroize dictionary""" + formatted_json = json.dumps(d, sort_keys=True, indent=4) + + colorful_json = highlight(formatted_json, lexers.JsonLexer(), + formatters.TerminalFormatter()) + + return colorful_json diff --git a/python/setup.py b/python/setup.py index f995a6c36..316a3d6b8 100644 --- a/python/setup.py +++ b/python/setup.py @@ -68,6 +68,7 @@ extras = { ], "debug": ["psutil", "setproctitle", "py-spy"], "dashboard": ["psutil", "aiohttp"], + "serve": ["uvicorn", "pygments", "werkzeug"], }