From 7f9ddfcfd869b16b49e47f2d48c027d13dd45922 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 2 Apr 2020 16:44:53 -0500 Subject: [PATCH] Only access route_table and policy_table in master actor (#7835) --- python/ray/serve/api.py | 41 ++++++----------------- python/ray/serve/global_state.py | 54 ++++++++++++++++++++++++++---- python/ray/serve/handle.py | 8 ++--- python/ray/serve/tests/test_api.py | 10 ------ 4 files changed, 60 insertions(+), 53 deletions(-) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 2a2b9f579..1c76c8779 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -4,15 +4,13 @@ from tempfile import mkstemp from multiprocessing import cpu_count -import numpy as np - import ray from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, SERVE_MASTER_NAME) from ray.serve.global_state import GlobalState, ServeMaster from ray.serve.kv_store_service import SQLiteKVStore from ray.serve.task_runner import RayServeMixin, TaskRunnerActor -from ray.serve.utils import block_until_http_ready, expand +from ray.serve.utils import block_until_http_ready from ray.serve.exceptions import RayServeException, batch_annotation_not_found from ray.serve.backend_config import BackendConfig from ray.serve.policy import RoutePolicy @@ -114,6 +112,10 @@ def init( if not ray.is_initialized(): ray.init(**ray_init_kwargs) + # Register serialization context once + ray.register_custom_serializer(Query, Query.ray_serialize, + Query.ray_deserialize) + # Try to get serve master actor if it exists try: ray.util.get_actor(SERVE_MASTER_NAME) @@ -166,13 +168,9 @@ def create_endpoint(endpoint_name, route=None, methods=["GET"]): registered before returning """ methods = [m.upper() for m in methods] - global_state.route_table.register_service( - route, endpoint_name, methods=methods) - http_proxy = global_state.get_http_proxy() ray.get( - http_proxy.set_route_table.remote( - global_state.route_table.list_service( - include_methods=True, include_headless=False))) + global_state.master_actor.create_endpoint.remote( + route, endpoint_name, methods)) @_ensure_connected @@ -371,26 +369,9 @@ def split(endpoint_name, traffic_policy_dictionary): traffic_policy_dictionary (dict): a dictionary maps backend names to their traffic weights. The weights must sum to 1. """ - assert endpoint_name in expand( - global_state.route_table.list_service(include_headless=True).values()) - - assert isinstance(traffic_policy_dictionary, - dict), "Traffic policy must be dictionary" - prob = 0 - for backend, weight in traffic_policy_dictionary.items(): - prob += weight - assert (backend in global_state.backend_table.list_backends() - ), "backend {} is not registered".format(backend) - assert np.isclose( - prob, 1, - atol=0.02), "weights must sum to 1, currently it sums to {}".format( - prob) - - global_state.policy_table.register_traffic_policy( - endpoint_name, traffic_policy_dictionary) - router = global_state.get_router() ray.get( - router.set_traffic.remote(endpoint_name, traffic_policy_dictionary)) + global_state.master_actor.split_traffic.remote( + endpoint_name, traffic_policy_dictionary)) @_ensure_connected @@ -413,9 +394,7 @@ def get_handle(endpoint_name, RayServeHandle """ if not missing_ok: - assert endpoint_name in expand( - global_state.route_table.list_service( - include_headless=True).values()) + assert endpoint_name in global_state.get_all_endpoints() # Delay import due to it's dependency on global_state from ray.serve.handle import RayServeHandle diff --git a/python/ray/serve/global_state.py b/python/ray/serve/global_state.py index d5085ca83..af39d2d8e 100644 --- a/python/ray/serve/global_state.py +++ b/python/ray/serve/global_state.py @@ -5,7 +5,9 @@ from ray.serve.http_proxy import HTTPProxyActor from ray.serve.kv_store_service import (BackendTable, RoutingTable, TrafficPolicyTable) from ray.serve.metric import (MetricMonitor, start_metric_monitor_loop) -from ray.serve.utils import get_random_letters +from ray.serve.utils import expand, get_random_letters + +import numpy as np @ray.remote @@ -32,6 +34,9 @@ class ServeMaster: def get_kv_store_connector(self): return self.kv_store_connector + def get_traffic_policy(self, endpoint_name): + return self.policy_table.list_traffic_policy()[endpoint_name] + def start_router(self, router_class, init_kwargs): assert self.router is None, "Router already started." self.router = router_class.options( @@ -119,6 +124,41 @@ class ServeMaster: def get_all_handles(self): return self.tag_to_actor_handles + def get_all_endpoints(self): + return expand( + self.route_table.list_service(include_headless=True).values()) + + def split_traffic(self, endpoint_name, traffic_policy_dictionary): + assert endpoint_name in expand( + self.route_table.list_service(include_headless=True).values()) + + assert isinstance(traffic_policy_dictionary, + dict), "Traffic policy must be dictionary" + prob = 0 + for backend, weight in traffic_policy_dictionary.items(): + prob += weight + assert (backend in self.backend_table.list_backends() + ), "backend {} is not registered".format(backend) + assert np.isclose( + prob, 1, atol=0.02 + ), "weights must sum to 1, currently it sums to {}".format(prob) + + self.policy_table.register_traffic_policy(endpoint_name, + traffic_policy_dictionary) + [router] = self.get_router() + ray.get( + router.set_traffic.remote(endpoint_name, + traffic_policy_dictionary)) + + def create_endpoint(self, route, endpoint_name, methods): + self.route_table.register_service( + route, endpoint_name, methods=methods) + [http_proxy] = self.get_http_proxy() + ray.get( + http_proxy.set_route_table.remote( + self.route_table.list_service( + include_methods=True, include_headless=False))) + class GlobalState: """Encapsulate all global state in the serving system. @@ -137,15 +177,17 @@ class GlobalState: # Connect to all the tables. kv_store_connector = ray.get( self.master_actor.get_kv_store_connector.remote()) - self.route_table = RoutingTable(kv_store_connector) self.backend_table = BackendTable(kv_store_connector) - self.policy_table = TrafficPolicyTable(kv_store_connector) - - def get_http_proxy(self): - return ray.get(self.master_actor.get_http_proxy.remote())[0] def get_router(self): return ray.get(self.master_actor.get_router.remote())[0] def get_metric_monitor(self): return ray.get(self.master_actor.get_metric_monitor.remote())[0] + + def get_traffic_policy(self, endpoint_name): + return ray.get( + self.master_actor.get_traffic_policy.remote(endpoint_name)) + + def get_all_endpoints(self): + return ray.get(self.master_actor.get_all_endpoints.remote()) diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index e209c65a8..ccb499197 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -102,16 +102,12 @@ class RayServeHandle: method_name=method_name, ) - def get_traffic_policy(self): - policy_table = serve.api._get_global_state().policy_table - all_services = policy_table.list_traffic_policy() - return all_services[self.endpoint_name] - def get_http_endpoint(self): return DEFAULT_HTTP_ADDRESS def _ensure_backend_unique(self, backend_tag=None): - traffic_policy = self.get_traffic_policy() + global_state = serve.api._get_global_state() + traffic_policy = global_state.get_traffic_policy(self.endpoint_name) if backend_tag is None: assert len(traffic_policy) == 1, ( "Multiple backends detected. " diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 2cb475bb3..f0a4bd772 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -5,7 +5,6 @@ import requests from ray import serve from ray.serve import BackendConfig import ray -from ray.serve.constants import NO_ROUTE_KEY from ray.serve.exceptions import RayServeException from ray.serve.handle import RayServeHandle @@ -13,8 +12,6 @@ from ray.serve.handle import RayServeHandle def test_e2e(serve_instance): serve.init() # so we have access to global state serve.create_endpoint("endpoint", "/api", methods=["GET", "POST"]) - result = serve.api._get_global_state().route_table.list_service() - assert result["/api"] == "endpoint" retry_count = 5 timeout_sleep = 0.5 @@ -62,13 +59,6 @@ def test_route_decorator(serve_instance): def test_no_route(serve_instance): serve.create_endpoint("noroute-endpoint") - global_state = serve.api._get_global_state() - - result = global_state.route_table.list_service(include_headless=True) - assert result[NO_ROUTE_KEY] == ["noroute-endpoint"] - - without_headless_result = global_state.route_table.list_service() - assert NO_ROUTE_KEY not in without_headless_result def func(_, i=1): return 1