[serve] Detect node updates (#9828)

This commit is contained in:
Edward Oakes
2020-08-04 15:57:21 -05:00
committed by GitHub
parent ef190f358b
commit 55146d222f
5 changed files with 188 additions and 77 deletions
+3 -1
View File
@@ -367,8 +367,10 @@ def get_handle(endpoint_name,
if not missing_ok:
assert endpoint_name in ray.get(controller.get_all_endpoints.remote())
# TODO(edoakes): we should choose the router on the same node.
routers = ray.get(controller.get_routers.remote())
return RayServeHandle(
ray.get(controller.get_router.remote())[0],
list(routers.values())[0],
endpoint_name,
relative_slo_ms,
absolute_slo_ms,
+107 -62
View File
@@ -1,6 +1,5 @@
import asyncio
from collections import defaultdict, namedtuple
from itertools import groupby
import os
import random
import time
@@ -15,7 +14,7 @@ from ray.serve.kv_store import RayInternalKVStore
from ray.serve.metric.exporter import MetricExporterActor
from ray.serve.exceptions import RayServeException
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
try_schedule_resources_on_nodes)
try_schedule_resources_on_nodes, get_all_node_ids)
import numpy as np
@@ -28,6 +27,9 @@ CHECKPOINT_KEY = "serve-controller-checkpoint"
# error if the desired replicas exceed current resource availability.
_RESOURCE_CHECK_ENABLED = True
# How often to call the control loop on the controller.
CONTROL_LOOP_PERIOD_S = 1.0
class TrafficPolicy:
def __init__(self, traffic_dict):
@@ -88,7 +90,7 @@ class ServeController:
requires all implementations here to be idempotent.
"""
async def __init__(self, instance_name, http_proxy_host, http_proxy_port,
async def __init__(self, instance_name, http_host, http_port,
metric_exporter_class):
# Unique name of the serve instance managed by this actor. Used to
# namespace child actors and checkpoints.
@@ -122,13 +124,17 @@ class ServeController:
self.write_lock = asyncio.Lock()
# Cached handles to actors in the system.
self.routers = []
# node_id -> actor_handle
self.routers = dict()
self.metric_exporter = None
self.http_host = http_host
self.http_port = http_port
# If starting the actor for the first time, starts up the other system
# components. If recovering, fetches their actor handles.
self._get_or_start_metric_exporter(metric_exporter_class)
self._get_or_start_routers(http_proxy_host, http_proxy_port)
self._start_metric_exporter(metric_exporter_class)
self._start_routers_if_needed()
# NOTE(edoakes): unfortunately, we can't completely recover from a
# checkpoint in the constructor because we block while waiting for
@@ -150,46 +156,69 @@ class ServeController:
asyncio.get_event_loop().create_task(
self._recover_from_checkpoint(checkpoint))
def _get_or_start_routers(self, host, port):
"""Get the HTTP proxy belonging to this serve instance.
asyncio.get_event_loop().create_task(self.run_control_loop())
If the HTTP proxy does not already exist, it will be started.
def _start_routers_if_needed(self):
"""Start a router on every node if it doesn't already exist."""
for node_id, node_resource in get_all_node_ids():
if node_id in self.routers:
continue
router_name = format_actor_name(SERVE_PROXY_NAME,
self.instance_name, node_id)
try:
router = ray.get_actor(router_name)
except ValueError:
logger.info("Starting router with name '{}' on node '{}' "
"listening on '{}:{}'".format(
router_name, node_id, self.http_host,
self.http_port))
router = HTTPProxyActor.options(
name=router_name,
max_concurrency=ASYNC_CONCURRENCY,
max_restarts=-1,
max_task_retries=-1,
resources={
node_resource: 0.01
},
).remote(
self.http_host,
self.http_port,
instance_name=self.instance_name)
self.routers[node_id] = router
def _stop_routers_if_needed(self):
"""Removes router actors from any nodes that no longer exist.
Returns whether or not any actors were removed (a checkpoint should
be taken).
"""
# TODO(simon): We don't handle nodes being added/removed. To do that,
# we should implement some sort of control loop in master actor.
for _, node_id_group in groupby(sorted(ray.state.node_ids())):
for index, node_id in enumerate(node_id_group):
proxy_name = format_actor_name(SERVE_PROXY_NAME,
self.instance_name)
proxy_name += "-{}-{}".format(node_id, index)
try:
router = ray.get_actor(proxy_name)
except ValueError:
logger.info(
"Starting HTTP proxy with name '{}' on node '{}' "
"listening on port {}".format(proxy_name, node_id,
port))
router = HTTPProxyActor.options(
name=proxy_name,
max_concurrency=ASYNC_CONCURRENCY,
max_restarts=-1,
max_task_retries=-1,
resources={
node_id: 0.01
},
).remote(
host, port, instance_name=self.instance_name)
self.routers.append(router)
checkpoint_required = False
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
to_stop = []
for node_id in self.routers:
if node_id not in all_node_ids:
logger.info(
"Removing router on removed node '{}'.".format(node_id))
to_stop.append(node_id)
def get_router(self):
"""Returns a handle to the HTTP proxy managed by this actor."""
for node_id in to_stop:
router_handle = self.routers.pop(node_id)
ray.kill(router_handle, no_restart=True)
checkpoint_required = True
return checkpoint_required
def get_routers(self):
"""Returns a dictionary of node ID to router actor handles."""
return self.routers
def get_router_config(self):
"""Called by the HTTP proxy on startup to fetch required state."""
"""Called by the router on startup to fetch required state."""
return self.routes
def _get_or_start_metric_exporter(self, metric_exporter_class):
def _start_metric_exporter(self, metric_exporter_class):
"""Get the metric exporter belonging to this serve instance.
If the metric exporter does not already exist, it will be started.
@@ -210,11 +239,13 @@ class ServeController:
def _checkpoint(self):
"""Checkpoint internal state and write it to the KV store."""
assert self.write_lock.locked()
logger.debug("Writing checkpoint")
start = time.time()
checkpoint = pickle.dumps(
(self.routes, self.backends, self.traffic_policies, self.replicas,
self.replicas_to_start, self.replicas_to_stop,
(self.routes, list(
self.routers.keys()), self.backends, self.traffic_policies,
self.replicas, self.replicas_to_start, self.replicas_to_stop,
self.backends_to_remove, self.endpoints_to_remove))
self.kv_store.put(CHECKPOINT_KEY, checkpoint)
@@ -229,7 +260,7 @@ class ServeController:
Performs the following operations:
1) Deserializes the internal state from the checkpoint.
2) Pushes the latest configuration to the HTTP proxy and router
2) Pushes the latest configuration to the routers
in case we crashed before updating them.
3) Starts/stops any worker replicas that are pending creation or
deletion.
@@ -245,6 +276,7 @@ class ServeController:
# Load internal state from the checkpoint data.
(
self.routes,
router_node_ids,
self.backends,
self.traffic_policies,
self.replicas,
@@ -254,6 +286,11 @@ class ServeController:
self.endpoints_to_remove,
) = pickle.loads(checkpoint_bytes)
for node_id in router_node_ids:
router_name = format_actor_name(SERVE_PROXY_NAME,
self.instance_name, node_id)
self.routers[node_id] = ray.get_actor(router_name)
# Fetch actor handles for all of the backend replicas in the system.
# All of these workers are guaranteed to already exist because they
# would not be written to a checkpoint in self.workers until they
@@ -270,7 +307,7 @@ class ServeController:
for endpoint, traffic_policy in self.traffic_policies.items():
await asyncio.gather(*[
router.set_traffic.remote(endpoint, traffic_policy)
for router in self.routers
for router in self.routers.values()
])
for backend_tag, replica_dict in self.workers.items():
@@ -278,20 +315,20 @@ class ServeController:
await asyncio.gather(*[
router.add_new_worker.remote(backend_tag, replica_tag,
worker)
for router in self.routers
for router in self.routers.values()
])
for backend, info in self.backends.items():
await asyncio.gather(*[
router.set_backend_config.remote(backend, info.backend_config)
for router in self.routers
for router in self.routers.values()
])
await self.broadcast_backend_config(backend)
# Push configuration state to the HTTP proxy.
# Push configuration state to the routers.
await asyncio.gather(*[
router.set_route_table.remote(self.routes)
for router in self.routers
for router in self.routers.values()
])
# Start/stop any pending backend replicas.
@@ -307,6 +344,16 @@ class ServeController:
self.write_lock.release()
async def run_control_loop(self):
while True:
async with self.write_lock:
self._start_routers_if_needed()
checkpoint_required = self._stop_routers_if_needed()
if checkpoint_required:
self._checkpoint()
await asyncio.sleep(CONTROL_LOOP_PERIOD_S)
def get_backend_configs(self):
"""Fetched by the router on startup."""
backend_configs = {}
@@ -368,7 +415,7 @@ class ServeController:
await asyncio.gather(*[
router.add_new_worker.remote(backend_tag, replica_tag,
worker_handle)
for router in self.routers
for router in self.routers.values()
])
async def _start_pending_replicas(self):
@@ -409,7 +456,7 @@ class ServeController:
# Remove the replica from router. This call is idempotent.
await asyncio.gather(*[
router.remove_worker.remote(backend_tag, replica_tag)
for router in self.routers
for router in self.routers.values()
])
# TODO(edoakes): this logic isn't ideal because there may be
@@ -429,7 +476,7 @@ class ServeController:
for backend_tag in self.backends_to_remove:
await asyncio.gather(*[
router.remove_backend.remote(backend_tag)
for router in self.routers
for router in self.routers.values()
])
self.backends_to_remove.clear()
@@ -441,7 +488,7 @@ class ServeController:
for endpoint_tag in self.endpoints_to_remove:
await asyncio.gather(*[
router.remove_endpoint.remote(endpoint_tag)
for router in self.routers
for router in self.routers.values()
])
self.endpoints_to_remove.clear()
@@ -558,7 +605,7 @@ class ServeController:
self._checkpoint()
await asyncio.gather(*[
router.set_traffic.remote(endpoint_name, traffic_policy)
for router in self.routers
for router in self.routers.values()
])
async def set_traffic(self, endpoint_name, traffic_dict):
@@ -590,14 +637,14 @@ class ServeController:
router.set_traffic.remote(
endpoint_name,
self.traffic_policies[endpoint_name],
) for router in self.routers
) for router in self.routers.values()
])
async def create_endpoint(self, endpoint, traffic_dict, route, methods):
"""Create a new endpoint with the specified route and methods.
If the route is None, this is a "headless" endpoint that will not
be added to the HTTP proxy (can only be accessed via a handle).
be exposed over HTTP and can only be accessed via a handle.
"""
async with self.write_lock:
# If this is a headless endpoint with no route, key the endpoint
@@ -632,7 +679,7 @@ class ServeController:
await self._set_traffic(endpoint, traffic_dict)
await asyncio.gather(*[
router.set_route_table.remote(self.routes)
for router in self.routers
for router in self.routers.values()
])
async def delete_endpoint(self, endpoint):
@@ -662,15 +709,13 @@ class ServeController:
self.endpoints_to_remove.append(endpoint)
# NOTE(edoakes): we must write a checkpoint before pushing the
# updates to the HTTP proxy and router to avoid inconsistent state
# if we crash after pushing the update.
# updates to the routers to avoid inconsistent state if we crash
# after pushing the update.
self._checkpoint()
# Update the HTTP proxy first to ensure no new requests for the
# endpoint are sent to the router.
await asyncio.gather(*[
router.set_route_table.remote(self.routes)
for router in self.routers
for router in self.routers.values()
])
await self._remove_pending_endpoints()
@@ -698,7 +743,7 @@ class ServeController:
# (particularly for max-batch-size).
await asyncio.gather(*[
router.set_backend_config.remote(backend_tag, backend_config)
for router in self.routers
for router in self.routers.values()
])
await self.broadcast_backend_config(backend_tag)
@@ -757,7 +802,7 @@ class ServeController:
# (particularly for setting max_batch_size).
await asyncio.gather(*[
router.set_backend_config.remote(backend_tag, backend_config)
for router in self.routers
for router in self.routers.values()
])
await self._start_pending_replicas()
@@ -788,7 +833,7 @@ class ServeController:
async def shutdown(self):
"""Shuts down the serve instance completely."""
async with self.write_lock:
for router in self.routers:
for router in self.routers.values():
ray.kill(router, no_restart=True)
ray.kill(self.metric_exporter, no_restart=True)
for replica_dict in self.workers.values():
+2 -2
View File
@@ -76,8 +76,8 @@ def test_controller_failure(serve_instance):
def _kill_routers():
routers = ray.get(serve.api._get_controller().get_router.remote())
for router in routers:
routers = ray.get(serve.api._get_controller().get_routers.remote())
for router in routers.values():
ray.kill(router, no_restart=False)
+47 -9
View File
@@ -5,9 +5,10 @@ import pytest
import ray
from ray import serve
from ray.cluster_utils import Cluster
from ray.serve.constants import SERVE_PROXY_NAME
from ray.serve.utils import block_until_http_ready
from ray.cluster_utils import Cluster
from ray.test_utils import wait_for_condition
@pytest.mark.skipif(
@@ -24,19 +25,56 @@ def test_multiple_routers():
assert len(node_ids) == 2
serve.init(http_port=8005)
# two actors should be started
head_http = ray.get_actor(SERVE_PROXY_NAME +
"-{}-{}".format(node_ids[0], 0))
ray.get_actor(SERVE_PROXY_NAME + "-{}-{}".format(node_ids[0], 1))
def actor_name(index):
return SERVE_PROXY_NAME + "-{}-{}".format(node_ids[0], index)
# wait for the actors to come up
# Two actors should be started.
def get_first_two_actors():
try:
ray.get_actor(actor_name(0))
ray.get_actor(actor_name(1))
return True
except ValueError:
return False
wait_for_condition(get_first_two_actors)
# Wait for the actors to come up.
ray.get(block_until_http_ready.remote("http://127.0.0.1:8005/-/routes"))
# kill the head_http server, the HTTP server should still functions
ray.kill(head_http, no_restart=True)
# Kill one of the servers, the HTTP server should still function.
ray.kill(ray.get_actor(actor_name(0)), no_restart=True)
ray.get(block_until_http_ready.remote("http://127.0.0.1:8005/-/routes"))
# cleanup the nodes (otherwise Ray will segfault)
# Add a new node to the cluster. This should trigger a new router to get
# started.
new_node = cluster.add_node()
def get_third_actor():
try:
ray.get_actor(actor_name(2))
return True
except ValueError:
return False
wait_for_condition(get_third_actor)
# Remove the newly-added node from the cluster. The corresponding actor
# should be removed as well.
cluster.remove_node(new_node)
def third_actor_removed():
try:
ray.get_actor(actor_name(2))
return False
except ValueError:
return True
# Check that the actor is gone and the HTTP server still functions.
wait_for_condition(third_actor_removed)
ray.get(block_until_http_ready.remote("http://127.0.0.1:8005/-/routes"))
# Clean up the nodes (otherwise Ray will segfault).
ray.shutdown()
cluster.shutdown()
+29 -3
View File
@@ -1,5 +1,6 @@
import asyncio
from functools import singledispatch
from itertools import groupby
import json
import logging
import random
@@ -104,11 +105,16 @@ def get_random_letters(length=6):
return "".join(random.choices(string.ascii_letters, k=length))
def format_actor_name(actor_name, instance_name=None):
def format_actor_name(actor_name, instance_name=None, *modifiers):
if instance_name is None:
return actor_name
name = actor_name
else:
return "{}:{}".format(instance_name, actor_name)
name = "{}:{}".format(instance_name, actor_name)
for modifier in modifiers:
name += "-{}".format(modifier)
return name
@singledispatch
@@ -216,3 +222,23 @@ def try_schedule_resources_on_nodes(
successfully_scheduled.append(False)
return successfully_scheduled
def get_all_node_ids():
"""Get IDs for all nodes in the cluster.
Handles multiple nodes on the same IP by appending an index to the
node_id, e.g., 'node_id-index'.
Returns a list of ('node_id-index', 'node_id') tuples (the latter can be
used as a resource requirement for actor placements).
"""
node_ids = []
# We need to use the node_id and index here because we could
# have multiple virtual nodes on the same host. In that case
# they will have the same IP and therefore node_id.
for _, node_id_group in groupby(sorted(ray.state.node_ids())):
for index, node_id in enumerate(node_id_group):
node_ids.append(("{}-{}".format(node_id, index), node_id))
return node_ids