From 55146d222fbc04f6300b6e672c930272b0a88cfb Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Tue, 4 Aug 2020 15:57:21 -0500 Subject: [PATCH] [serve] Detect node updates (#9828) --- python/ray/serve/api.py | 4 +- python/ray/serve/controller.py | 169 ++++++++++++++++--------- python/ray/serve/tests/test_failure.py | 4 +- python/ray/serve/tests/test_scaling.py | 56 ++++++-- python/ray/serve/utils.py | 32 ++++- 5 files changed, 188 insertions(+), 77 deletions(-) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 6755134b6..db770c702 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -367,8 +367,10 @@ def get_handle(endpoint_name, if not missing_ok: assert endpoint_name in ray.get(controller.get_all_endpoints.remote()) + # TODO(edoakes): we should choose the router on the same node. + routers = ray.get(controller.get_routers.remote()) return RayServeHandle( - ray.get(controller.get_router.remote())[0], + list(routers.values())[0], endpoint_name, relative_slo_ms, absolute_slo_ms, diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index d16db021d..7592fa10b 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -1,6 +1,5 @@ import asyncio from collections import defaultdict, namedtuple -from itertools import groupby import os import random import time @@ -15,7 +14,7 @@ from ray.serve.kv_store import RayInternalKVStore from ray.serve.metric.exporter import MetricExporterActor from ray.serve.exceptions import RayServeException from ray.serve.utils import (format_actor_name, get_random_letters, logger, - try_schedule_resources_on_nodes) + try_schedule_resources_on_nodes, get_all_node_ids) import numpy as np @@ -28,6 +27,9 @@ CHECKPOINT_KEY = "serve-controller-checkpoint" # error if the desired replicas exceed current resource availability. _RESOURCE_CHECK_ENABLED = True +# How often to call the control loop on the controller. +CONTROL_LOOP_PERIOD_S = 1.0 + class TrafficPolicy: def __init__(self, traffic_dict): @@ -88,7 +90,7 @@ class ServeController: requires all implementations here to be idempotent. """ - async def __init__(self, instance_name, http_proxy_host, http_proxy_port, + async def __init__(self, instance_name, http_host, http_port, metric_exporter_class): # Unique name of the serve instance managed by this actor. Used to # namespace child actors and checkpoints. @@ -122,13 +124,17 @@ class ServeController: self.write_lock = asyncio.Lock() # Cached handles to actors in the system. - self.routers = [] + # node_id -> actor_handle + self.routers = dict() self.metric_exporter = None + self.http_host = http_host + self.http_port = http_port + # If starting the actor for the first time, starts up the other system # components. If recovering, fetches their actor handles. - self._get_or_start_metric_exporter(metric_exporter_class) - self._get_or_start_routers(http_proxy_host, http_proxy_port) + self._start_metric_exporter(metric_exporter_class) + self._start_routers_if_needed() # NOTE(edoakes): unfortunately, we can't completely recover from a # checkpoint in the constructor because we block while waiting for @@ -150,46 +156,69 @@ class ServeController: asyncio.get_event_loop().create_task( self._recover_from_checkpoint(checkpoint)) - def _get_or_start_routers(self, host, port): - """Get the HTTP proxy belonging to this serve instance. + asyncio.get_event_loop().create_task(self.run_control_loop()) - If the HTTP proxy does not already exist, it will be started. + def _start_routers_if_needed(self): + """Start a router on every node if it doesn't already exist.""" + for node_id, node_resource in get_all_node_ids(): + if node_id in self.routers: + continue + + router_name = format_actor_name(SERVE_PROXY_NAME, + self.instance_name, node_id) + try: + router = ray.get_actor(router_name) + except ValueError: + logger.info("Starting router with name '{}' on node '{}' " + "listening on '{}:{}'".format( + router_name, node_id, self.http_host, + self.http_port)) + router = HTTPProxyActor.options( + name=router_name, + max_concurrency=ASYNC_CONCURRENCY, + max_restarts=-1, + max_task_retries=-1, + resources={ + node_resource: 0.01 + }, + ).remote( + self.http_host, + self.http_port, + instance_name=self.instance_name) + + self.routers[node_id] = router + + def _stop_routers_if_needed(self): + """Removes router actors from any nodes that no longer exist. + + Returns whether or not any actors were removed (a checkpoint should + be taken). """ - # TODO(simon): We don't handle nodes being added/removed. To do that, - # we should implement some sort of control loop in master actor. - for _, node_id_group in groupby(sorted(ray.state.node_ids())): - for index, node_id in enumerate(node_id_group): - proxy_name = format_actor_name(SERVE_PROXY_NAME, - self.instance_name) - proxy_name += "-{}-{}".format(node_id, index) - try: - router = ray.get_actor(proxy_name) - except ValueError: - logger.info( - "Starting HTTP proxy with name '{}' on node '{}' " - "listening on port {}".format(proxy_name, node_id, - port)) - router = HTTPProxyActor.options( - name=proxy_name, - max_concurrency=ASYNC_CONCURRENCY, - max_restarts=-1, - max_task_retries=-1, - resources={ - node_id: 0.01 - }, - ).remote( - host, port, instance_name=self.instance_name) - self.routers.append(router) + checkpoint_required = False + all_node_ids = {node_id for node_id, _ in get_all_node_ids()} + to_stop = [] + for node_id in self.routers: + if node_id not in all_node_ids: + logger.info( + "Removing router on removed node '{}'.".format(node_id)) + to_stop.append(node_id) - def get_router(self): - """Returns a handle to the HTTP proxy managed by this actor.""" + for node_id in to_stop: + router_handle = self.routers.pop(node_id) + ray.kill(router_handle, no_restart=True) + checkpoint_required = True + + return checkpoint_required + + def get_routers(self): + """Returns a dictionary of node ID to router actor handles.""" return self.routers def get_router_config(self): - """Called by the HTTP proxy on startup to fetch required state.""" + """Called by the router on startup to fetch required state.""" return self.routes - def _get_or_start_metric_exporter(self, metric_exporter_class): + def _start_metric_exporter(self, metric_exporter_class): """Get the metric exporter belonging to this serve instance. If the metric exporter does not already exist, it will be started. @@ -210,11 +239,13 @@ class ServeController: def _checkpoint(self): """Checkpoint internal state and write it to the KV store.""" + assert self.write_lock.locked() logger.debug("Writing checkpoint") start = time.time() checkpoint = pickle.dumps( - (self.routes, self.backends, self.traffic_policies, self.replicas, - self.replicas_to_start, self.replicas_to_stop, + (self.routes, list( + self.routers.keys()), self.backends, self.traffic_policies, + self.replicas, self.replicas_to_start, self.replicas_to_stop, self.backends_to_remove, self.endpoints_to_remove)) self.kv_store.put(CHECKPOINT_KEY, checkpoint) @@ -229,7 +260,7 @@ class ServeController: Performs the following operations: 1) Deserializes the internal state from the checkpoint. - 2) Pushes the latest configuration to the HTTP proxy and router + 2) Pushes the latest configuration to the routers in case we crashed before updating them. 3) Starts/stops any worker replicas that are pending creation or deletion. @@ -245,6 +276,7 @@ class ServeController: # Load internal state from the checkpoint data. ( self.routes, + router_node_ids, self.backends, self.traffic_policies, self.replicas, @@ -254,6 +286,11 @@ class ServeController: self.endpoints_to_remove, ) = pickle.loads(checkpoint_bytes) + for node_id in router_node_ids: + router_name = format_actor_name(SERVE_PROXY_NAME, + self.instance_name, node_id) + self.routers[node_id] = ray.get_actor(router_name) + # Fetch actor handles for all of the backend replicas in the system. # All of these workers are guaranteed to already exist because they # would not be written to a checkpoint in self.workers until they @@ -270,7 +307,7 @@ class ServeController: for endpoint, traffic_policy in self.traffic_policies.items(): await asyncio.gather(*[ router.set_traffic.remote(endpoint, traffic_policy) - for router in self.routers + for router in self.routers.values() ]) for backend_tag, replica_dict in self.workers.items(): @@ -278,20 +315,20 @@ class ServeController: await asyncio.gather(*[ router.add_new_worker.remote(backend_tag, replica_tag, worker) - for router in self.routers + for router in self.routers.values() ]) for backend, info in self.backends.items(): await asyncio.gather(*[ router.set_backend_config.remote(backend, info.backend_config) - for router in self.routers + for router in self.routers.values() ]) await self.broadcast_backend_config(backend) - # Push configuration state to the HTTP proxy. + # Push configuration state to the routers. await asyncio.gather(*[ router.set_route_table.remote(self.routes) - for router in self.routers + for router in self.routers.values() ]) # Start/stop any pending backend replicas. @@ -307,6 +344,16 @@ class ServeController: self.write_lock.release() + async def run_control_loop(self): + while True: + async with self.write_lock: + self._start_routers_if_needed() + checkpoint_required = self._stop_routers_if_needed() + if checkpoint_required: + self._checkpoint() + + await asyncio.sleep(CONTROL_LOOP_PERIOD_S) + def get_backend_configs(self): """Fetched by the router on startup.""" backend_configs = {} @@ -368,7 +415,7 @@ class ServeController: await asyncio.gather(*[ router.add_new_worker.remote(backend_tag, replica_tag, worker_handle) - for router in self.routers + for router in self.routers.values() ]) async def _start_pending_replicas(self): @@ -409,7 +456,7 @@ class ServeController: # Remove the replica from router. This call is idempotent. await asyncio.gather(*[ router.remove_worker.remote(backend_tag, replica_tag) - for router in self.routers + for router in self.routers.values() ]) # TODO(edoakes): this logic isn't ideal because there may be @@ -429,7 +476,7 @@ class ServeController: for backend_tag in self.backends_to_remove: await asyncio.gather(*[ router.remove_backend.remote(backend_tag) - for router in self.routers + for router in self.routers.values() ]) self.backends_to_remove.clear() @@ -441,7 +488,7 @@ class ServeController: for endpoint_tag in self.endpoints_to_remove: await asyncio.gather(*[ router.remove_endpoint.remote(endpoint_tag) - for router in self.routers + for router in self.routers.values() ]) self.endpoints_to_remove.clear() @@ -558,7 +605,7 @@ class ServeController: self._checkpoint() await asyncio.gather(*[ router.set_traffic.remote(endpoint_name, traffic_policy) - for router in self.routers + for router in self.routers.values() ]) async def set_traffic(self, endpoint_name, traffic_dict): @@ -590,14 +637,14 @@ class ServeController: router.set_traffic.remote( endpoint_name, self.traffic_policies[endpoint_name], - ) for router in self.routers + ) for router in self.routers.values() ]) async def create_endpoint(self, endpoint, traffic_dict, route, methods): """Create a new endpoint with the specified route and methods. If the route is None, this is a "headless" endpoint that will not - be added to the HTTP proxy (can only be accessed via a handle). + be exposed over HTTP and can only be accessed via a handle. """ async with self.write_lock: # If this is a headless endpoint with no route, key the endpoint @@ -632,7 +679,7 @@ class ServeController: await self._set_traffic(endpoint, traffic_dict) await asyncio.gather(*[ router.set_route_table.remote(self.routes) - for router in self.routers + for router in self.routers.values() ]) async def delete_endpoint(self, endpoint): @@ -662,15 +709,13 @@ class ServeController: self.endpoints_to_remove.append(endpoint) # NOTE(edoakes): we must write a checkpoint before pushing the - # updates to the HTTP proxy and router to avoid inconsistent state - # if we crash after pushing the update. + # updates to the routers to avoid inconsistent state if we crash + # after pushing the update. self._checkpoint() - # Update the HTTP proxy first to ensure no new requests for the - # endpoint are sent to the router. await asyncio.gather(*[ router.set_route_table.remote(self.routes) - for router in self.routers + for router in self.routers.values() ]) await self._remove_pending_endpoints() @@ -698,7 +743,7 @@ class ServeController: # (particularly for max-batch-size). await asyncio.gather(*[ router.set_backend_config.remote(backend_tag, backend_config) - for router in self.routers + for router in self.routers.values() ]) await self.broadcast_backend_config(backend_tag) @@ -757,7 +802,7 @@ class ServeController: # (particularly for setting max_batch_size). await asyncio.gather(*[ router.set_backend_config.remote(backend_tag, backend_config) - for router in self.routers + for router in self.routers.values() ]) await self._start_pending_replicas() @@ -788,7 +833,7 @@ class ServeController: async def shutdown(self): """Shuts down the serve instance completely.""" async with self.write_lock: - for router in self.routers: + for router in self.routers.values(): ray.kill(router, no_restart=True) ray.kill(self.metric_exporter, no_restart=True) for replica_dict in self.workers.values(): diff --git a/python/ray/serve/tests/test_failure.py b/python/ray/serve/tests/test_failure.py index 16db58cc5..4919013a6 100644 --- a/python/ray/serve/tests/test_failure.py +++ b/python/ray/serve/tests/test_failure.py @@ -76,8 +76,8 @@ def test_controller_failure(serve_instance): def _kill_routers(): - routers = ray.get(serve.api._get_controller().get_router.remote()) - for router in routers: + routers = ray.get(serve.api._get_controller().get_routers.remote()) + for router in routers.values(): ray.kill(router, no_restart=False) diff --git a/python/ray/serve/tests/test_scaling.py b/python/ray/serve/tests/test_scaling.py index a44575943..05b616fde 100644 --- a/python/ray/serve/tests/test_scaling.py +++ b/python/ray/serve/tests/test_scaling.py @@ -5,9 +5,10 @@ import pytest import ray from ray import serve +from ray.cluster_utils import Cluster from ray.serve.constants import SERVE_PROXY_NAME from ray.serve.utils import block_until_http_ready -from ray.cluster_utils import Cluster +from ray.test_utils import wait_for_condition @pytest.mark.skipif( @@ -24,19 +25,56 @@ def test_multiple_routers(): assert len(node_ids) == 2 serve.init(http_port=8005) - # two actors should be started - head_http = ray.get_actor(SERVE_PROXY_NAME + - "-{}-{}".format(node_ids[0], 0)) - ray.get_actor(SERVE_PROXY_NAME + "-{}-{}".format(node_ids[0], 1)) + def actor_name(index): + return SERVE_PROXY_NAME + "-{}-{}".format(node_ids[0], index) - # wait for the actors to come up + # Two actors should be started. + def get_first_two_actors(): + try: + ray.get_actor(actor_name(0)) + ray.get_actor(actor_name(1)) + return True + except ValueError: + return False + + wait_for_condition(get_first_two_actors) + + # Wait for the actors to come up. ray.get(block_until_http_ready.remote("http://127.0.0.1:8005/-/routes")) - # kill the head_http server, the HTTP server should still functions - ray.kill(head_http, no_restart=True) + # Kill one of the servers, the HTTP server should still function. + ray.kill(ray.get_actor(actor_name(0)), no_restart=True) ray.get(block_until_http_ready.remote("http://127.0.0.1:8005/-/routes")) - # cleanup the nodes (otherwise Ray will segfault) + # Add a new node to the cluster. This should trigger a new router to get + # started. + new_node = cluster.add_node() + + def get_third_actor(): + try: + ray.get_actor(actor_name(2)) + return True + except ValueError: + return False + + wait_for_condition(get_third_actor) + + # Remove the newly-added node from the cluster. The corresponding actor + # should be removed as well. + cluster.remove_node(new_node) + + def third_actor_removed(): + try: + ray.get_actor(actor_name(2)) + return False + except ValueError: + return True + + # Check that the actor is gone and the HTTP server still functions. + wait_for_condition(third_actor_removed) + ray.get(block_until_http_ready.remote("http://127.0.0.1:8005/-/routes")) + + # Clean up the nodes (otherwise Ray will segfault). ray.shutdown() cluster.shutdown() diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index d940ad562..4cf46e2a7 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -1,5 +1,6 @@ import asyncio from functools import singledispatch +from itertools import groupby import json import logging import random @@ -104,11 +105,16 @@ def get_random_letters(length=6): return "".join(random.choices(string.ascii_letters, k=length)) -def format_actor_name(actor_name, instance_name=None): +def format_actor_name(actor_name, instance_name=None, *modifiers): if instance_name is None: - return actor_name + name = actor_name else: - return "{}:{}".format(instance_name, actor_name) + name = "{}:{}".format(instance_name, actor_name) + + for modifier in modifiers: + name += "-{}".format(modifier) + + return name @singledispatch @@ -216,3 +222,23 @@ def try_schedule_resources_on_nodes( successfully_scheduled.append(False) return successfully_scheduled + + +def get_all_node_ids(): + """Get IDs for all nodes in the cluster. + + Handles multiple nodes on the same IP by appending an index to the + node_id, e.g., 'node_id-index'. + + Returns a list of ('node_id-index', 'node_id') tuples (the latter can be + used as a resource requirement for actor placements). + """ + node_ids = [] + # We need to use the node_id and index here because we could + # have multiple virtual nodes on the same host. In that case + # they will have the same IP and therefore node_id. + for _, node_id_group in groupby(sorted(ray.state.node_ids())): + for index, node_id in enumerate(node_id_group): + node_ids.append(("{}-{}".format(node_id, index), node_id)) + + return node_ids