diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index bd2b5d25b..8c658de9d 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -366,7 +366,7 @@ def get_handle(endpoint_name, master_actor.get_all_endpoints.remote()) return RayServeHandle( - ray.get(master_actor.get_router.remote())[0], + ray.get(master_actor.get_http_proxy.remote())[0], endpoint_name, relative_slo_ms, absolute_slo_ms, diff --git a/python/ray/serve/constants.py b/python/ray/serve/constants.py index e72f66657..8261496dc 100644 --- a/python/ray/serve/constants.py +++ b/python/ray/serve/constants.py @@ -1,9 +1,6 @@ #: Actor name used to register master actor SERVE_MASTER_NAME = "SERVE_MASTER_ACTOR" -#: Actor name used to register router actor -SERVE_ROUTER_NAME = "SERVE_ROUTER_ACTOR" - #: Actor name used to register HTTP proxy actor SERVE_PROXY_NAME = "SERVE_PROXY_ACTOR" diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index d236d1d66..0b7d25573 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -1,16 +1,16 @@ import asyncio +from urllib.parse import parse_qs import uvicorn import ray +from ray.exceptions import RayTaskError from ray import serve from ray.serve.context import TaskContext from ray.serve.metric import MetricClient from ray.serve.request_params import RequestMetadata from ray.serve.http_util import Response -from ray.serve.utils import logger - -from urllib.parse import parse_qs +from ray.serve.router import Router # The maximum number of times to retry a request due to actor failure. # TODO(edoakes): this should probably be configurable. @@ -26,12 +26,11 @@ class HTTPProxy: # blocks forever """ - async def fetch_config_from_master(self): + async def fetch_config_from_master(self, instance_name=None): assert ray.is_initialized() master = serve.api._get_master_actor() - self.route_table, [self.router_handle - ] = await master.get_http_proxy_config.remote() + self.route_table = await master.get_http_proxy_config.remote() # The exporter is required to return results for /-/metrics endpoint. [self.metric_exporter] = await master.get_metric_exporter.remote() @@ -42,6 +41,9 @@ class HTTPProxy: description="The number of requests processed", label_names=("route", )) + self.router = Router() + await self.router.setup(instance_name) + def set_route_table(self, route_table): self.route_table = route_table @@ -155,24 +157,14 @@ class HTTPProxy: shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None), ) - retries = 0 - while retries <= MAX_ACTOR_DEAD_RETRIES: - try: - result = await self.router_handle.enqueue_request.remote( - request_metadata, scope, http_body_bytes) - if not isinstance(result, ray.exceptions.RayActorError): - await Response(result).send(scope, receive, send) - break - logger.warning("Got RayActorError: {}".format(str(result))) - await asyncio.sleep(0.1) - except Exception as e: - error_message = "Internal Error. Traceback: {}.".format(e) - await error_sender(error_message, 500) - break + result = await self.router.enqueue_request(request_metadata, scope, + http_body_bytes) + + if isinstance(result, RayTaskError): + error_message = "Task Error. Traceback: {}.".format(result) + await error_sender(error_message, 500) else: - logger.debug("Maximum actor death retries exceeded") - await error_sender( - "Internal Error. Maximum actor death retries exceeded", 500) + await Response(result).send(scope, receive, send) @ray.remote @@ -180,7 +172,7 @@ class HTTPProxyActor: async def __init__(self, host, port, instance_name=None): serve.init(name=instance_name) self.app = HTTPProxy() - await self.app.fetch_config_from_master() + await self.app.fetch_config_from_master(instance_name) self.host = host self.port = port @@ -206,3 +198,28 @@ class HTTPProxyActor: async def set_route_table(self, route_table): self.app.set_route_table(route_table) + + # ------ Proxy router logic ------ # + async def add_new_worker(self, backend_tag, replica_tag, worker_handle): + return await self.app.router.add_new_worker(backend_tag, replica_tag, + worker_handle) + + async def set_traffic(self, endpoint, traffic_policy): + return await self.app.router.set_traffic(endpoint, traffic_policy) + + async def set_backend_config(self, backend, config): + return await self.app.router.set_backend_config(backend, config) + + async def remove_backend(self, backend): + return await self.app.router.remove_backend(backend) + + async def remove_endpoint(self, endpoint): + return await self.app.router.remove_endpoint(endpoint) + + async def remove_worker(self, backend_tag, replica_tag): + return await self.app.router.remove_worker(backend_tag, replica_tag) + + async def enqueue_request(self, request_meta, *request_args, + **request_kwargs): + return await self.app.router.enqueue_request( + request_meta, *request_args, **request_kwargs) diff --git a/python/ray/serve/master.py b/python/ray/serve/master.py index 02c7a4e54..5111f6251 100644 --- a/python/ray/serve/master.py +++ b/python/ray/serve/master.py @@ -7,12 +7,11 @@ import time import ray import ray.cloudpickle as pickle from ray.serve.backend_worker import create_backend_worker -from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_ROUTER_NAME, - SERVE_PROXY_NAME, SERVE_METRIC_SINK_NAME) +from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_PROXY_NAME, + SERVE_METRIC_SINK_NAME) from ray.serve.http_proxy import HTTPProxyActor from ray.serve.kv_store import RayInternalKVStore from ray.serve.metric.exporter import MetricExporterActor -from ray.serve.router import Router from ray.serve.exceptions import RayServeException from ray.serve.utils import (format_actor_name, get_random_letters, logger, try_schedule_resources_on_nodes) @@ -129,7 +128,6 @@ class ServeMaster: # If starting the actor for the first time, starts up the other system # components. If recovering, fetches their actor handles. self._get_or_start_metric_exporter(metric_exporter_class) - self._get_or_start_router() self._get_or_start_http_proxy(http_node_id, http_proxy_host, http_proxy_port) @@ -153,27 +151,6 @@ class ServeMaster: asyncio.get_event_loop().create_task( self._recover_from_checkpoint(checkpoint)) - def _get_or_start_router(self): - """Get the router belonging to this serve instance. - - If the router does not already exist, it will be started. - """ - router_name = format_actor_name(SERVE_ROUTER_NAME, self.instance_name) - try: - self.router = ray.get_actor(router_name) - except ValueError: - logger.info("Starting router with name '{}'".format(router_name)) - self.router = ray.remote(Router).options( - name=router_name, - max_concurrency=ASYNC_CONCURRENCY, - max_restarts=-1, - max_task_retries=-1, - ).remote(instance_name=self.instance_name) - - def get_router(self): - """Returns a handle to the router managed by this actor.""" - return [self.router] - def _get_or_start_http_proxy(self, node_id, host, port): """Get the HTTP proxy belonging to this serve instance. @@ -197,13 +174,19 @@ class ServeMaster: ).remote( host, port, instance_name=self.instance_name) + # Since router is a merged with HTTP proxy actor, the router will be + # proxied via the HTTP actor. Even though the two variable names are + # pointing to the same object, their semantic differences make the code + # more readable. (e.g. http_proxy.set_route_table, router.add_worker) + self.router = self.http_proxy + def get_http_proxy(self): """Returns a handle to the HTTP proxy managed by this actor.""" return [self.http_proxy] def get_http_proxy_config(self): """Called by the HTTP proxy on startup to fetch required state.""" - return self.routes, self.get_router() + return self.routes def _get_or_start_metric_exporter(self, metric_exporter_class): """Get the metric exporter belonging to this serve instance. @@ -766,7 +749,6 @@ class ServeMaster: """Shuts down the serve instance completely.""" async with self.write_lock: ray.kill(self.http_proxy, no_restart=True) - ray.kill(self.router, no_restart=True) ray.kill(self.metric_exporter, no_restart=True) for replica_dict in self.workers.values(): for replica in replica_dict.values(): diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index 1743dd14e..b10229e87 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -87,7 +87,7 @@ def _make_future_unwrapper(client_futures: List[asyncio.Future], class Router: """A router that routes request to available workers.""" - async def __init__(self, instance_name=None): + async def setup(self, instance_name=None): # Note: Several queues are used in the router # - When a request come in, it's placed inside its corresponding # endpoint_queue. @@ -198,8 +198,6 @@ class Router: self.endpoint_queues[endpoint].appendleft(query) self.flush_endpoint_queue(endpoint) - # Note: a future change can be to directly return the ObjectRef from - # replica task submission try: result = await query.async_future except RayTaskError as e: diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 81b8db028..cbda31dbd 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -570,7 +570,7 @@ def test_shutdown(serve_instance): def check_dead(): for actor_name in [ constants.SERVE_MASTER_NAME, constants.SERVE_PROXY_NAME, - constants.SERVE_ROUTER_NAME, constants.SERVE_METRIC_SINK_NAME + constants.SERVE_METRIC_SINK_NAME ]: try: ray.get_actor(format_actor_name(actor_name, instance_name)) diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index 5beff28ec..e3a402f4d 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -50,6 +50,7 @@ async def test_runner_wraps_error(): async def test_runner_actor(serve_instance): q = ray.remote(Router).remote() + await q.setup.remote() def echo(flask_request, i=None): return i @@ -71,6 +72,7 @@ async def test_runner_actor(serve_instance): async def test_ray_serve_mixin(serve_instance): q = ray.remote(Router).remote() + await q.setup.remote() CONSUMER_NAME = "runner-cls" PRODUCER_NAME = "prod-cls" @@ -96,6 +98,7 @@ async def test_ray_serve_mixin(serve_instance): async def test_task_runner_check_context(serve_instance): q = ray.remote(Router).remote() + await q.setup.remote() def echo(flask_request, i=None): # Accessing the flask_request without web context should throw. @@ -117,6 +120,7 @@ async def test_task_runner_check_context(serve_instance): async def test_task_runner_custom_method_single(serve_instance): q = ray.remote(Router).remote() + await q.setup.remote() class NonBatcher: def a(self, _): @@ -151,6 +155,7 @@ async def test_task_runner_custom_method_single(serve_instance): async def test_task_runner_custom_method_batch(serve_instance): q = ray.remote(Router).remote() + await q.setup.remote() @serve.accept_batch class Batcher: @@ -216,6 +221,7 @@ async def test_task_runner_custom_method_batch(serve_instance): async def test_task_runner_perform_batch(serve_instance): q = ray.remote(Router).remote() + await q.setup.remote() def batcher(*args, **kwargs): return [serve.context.batch_size] * serve.context.batch_size @@ -246,6 +252,7 @@ async def test_task_runner_perform_batch(serve_instance): async def test_task_runner_perform_async(serve_instance): q = ray.remote(Router).remote() + await q.setup.remote() @ray.remote class Barrier: diff --git a/python/ray/serve/tests/test_failure.py b/python/ray/serve/tests/test_failure.py index 959384fa2..0e0f2c45d 100644 --- a/python/ray/serve/tests/test_failure.py +++ b/python/ray/serve/tests/test_failure.py @@ -107,44 +107,6 @@ def test_http_proxy_failure(serve_instance): assert response.text == "hello2" -def _kill_router(): - [router] = ray.get(serve.api._get_master_actor().get_router.remote()) - ray.kill(router, no_restart=False) - - -def test_router_failure(serve_instance): - serve.init() - - def function(): - return "hello1" - - serve.create_backend("router_failure:v1", function) - serve.create_endpoint( - "router_failure", backend="router_failure:v1", route="/router_failure") - - assert request_with_retries("/router_failure", timeout=5).text == "hello1" - - for _ in range(10): - response = request_with_retries("/router_failure", timeout=30) - assert response.text == "hello1" - - _kill_router() - - for _ in range(10): - response = request_with_retries("/router_failure", timeout=30) - assert response.text == "hello1" - - def function(): - return "hello2" - - serve.create_backend("router_failure:v2", function) - serve.set_traffic("router_failure", {"router_failure:v2": 1.0}) - - for _ in range(10): - response = request_with_retries("/router_failure", timeout=30) - assert response.text == "hello2" - - def _get_worker_handles(backend): master_actor = serve.api._get_master_actor() backend_dict = ray.get(master_actor.get_all_worker_handles.remote()) diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index 54b29c60a..1954d46de 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -48,6 +48,8 @@ def task_runner_mock_actor(): async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() + await q.setup.remote() + q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0})) q.add_new_worker.remote("backend-single-prod", "replica-1", task_runner_mock_actor) @@ -64,6 +66,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor): async def test_slo(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() + await q.setup.remote() await q.set_traffic.remote("svc", TrafficPolicy({"backend-slo": 1.0})) all_request_sent = [] @@ -88,6 +91,7 @@ async def test_slo(serve_instance, task_runner_mock_actor): async def test_alter_backend(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() + await q.setup.remote() await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1})) await q.add_new_worker.remote("backend-alter", "replica-1", @@ -106,6 +110,7 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor): async def test_split_traffic_random(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() + await q.setup.remote() await q.set_traffic.remote( "svc", TrafficPolicy({ @@ -135,6 +140,7 @@ async def test_queue_remove_replicas(serve_instance): temp_actor = mock_task_runner() q = ray.remote(TestRouter).remote() + await q.setup.remote() await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor) await q.remove_worker.remote("backend-remove", "replica-1") assert ray.get(q.worker_queue_size.remote("backend")) == 0 @@ -142,6 +148,7 @@ async def test_queue_remove_replicas(serve_instance): async def test_shard_key(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() + await q.setup.remote() num_backends = 5 traffic_dict = {} @@ -196,6 +203,7 @@ async def test_router_use_max_concurrency(serve_instance): worker = MockWorker.remote() q = ray.remote(VisibleRouter).remote() + await q.setup.remote() BACKEND_NAME = "max-concurrent-test" config = BackendConfig({"max_concurrent_queries": 1}) await q.set_traffic.remote("svc", TrafficPolicy({BACKEND_NAME: 1.0}))