mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 11:51:09 +08:00
[serve] Detect node updates (#9828)
This commit is contained in:
@@ -367,8 +367,10 @@ def get_handle(endpoint_name,
|
||||
if not missing_ok:
|
||||
assert endpoint_name in ray.get(controller.get_all_endpoints.remote())
|
||||
|
||||
# TODO(edoakes): we should choose the router on the same node.
|
||||
routers = ray.get(controller.get_routers.remote())
|
||||
return RayServeHandle(
|
||||
ray.get(controller.get_router.remote())[0],
|
||||
list(routers.values())[0],
|
||||
endpoint_name,
|
||||
relative_slo_ms,
|
||||
absolute_slo_ms,
|
||||
|
||||
+107
-62
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
from collections import defaultdict, namedtuple
|
||||
from itertools import groupby
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
@@ -15,7 +14,7 @@ from ray.serve.kv_store import RayInternalKVStore
|
||||
from ray.serve.metric.exporter import MetricExporterActor
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
|
||||
try_schedule_resources_on_nodes)
|
||||
try_schedule_resources_on_nodes, get_all_node_ids)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -28,6 +27,9 @@ CHECKPOINT_KEY = "serve-controller-checkpoint"
|
||||
# error if the desired replicas exceed current resource availability.
|
||||
_RESOURCE_CHECK_ENABLED = True
|
||||
|
||||
# How often to call the control loop on the controller.
|
||||
CONTROL_LOOP_PERIOD_S = 1.0
|
||||
|
||||
|
||||
class TrafficPolicy:
|
||||
def __init__(self, traffic_dict):
|
||||
@@ -88,7 +90,7 @@ class ServeController:
|
||||
requires all implementations here to be idempotent.
|
||||
"""
|
||||
|
||||
async def __init__(self, instance_name, http_proxy_host, http_proxy_port,
|
||||
async def __init__(self, instance_name, http_host, http_port,
|
||||
metric_exporter_class):
|
||||
# Unique name of the serve instance managed by this actor. Used to
|
||||
# namespace child actors and checkpoints.
|
||||
@@ -122,13 +124,17 @@ class ServeController:
|
||||
self.write_lock = asyncio.Lock()
|
||||
|
||||
# Cached handles to actors in the system.
|
||||
self.routers = []
|
||||
# node_id -> actor_handle
|
||||
self.routers = dict()
|
||||
self.metric_exporter = None
|
||||
|
||||
self.http_host = http_host
|
||||
self.http_port = http_port
|
||||
|
||||
# If starting the actor for the first time, starts up the other system
|
||||
# components. If recovering, fetches their actor handles.
|
||||
self._get_or_start_metric_exporter(metric_exporter_class)
|
||||
self._get_or_start_routers(http_proxy_host, http_proxy_port)
|
||||
self._start_metric_exporter(metric_exporter_class)
|
||||
self._start_routers_if_needed()
|
||||
|
||||
# NOTE(edoakes): unfortunately, we can't completely recover from a
|
||||
# checkpoint in the constructor because we block while waiting for
|
||||
@@ -150,46 +156,69 @@ class ServeController:
|
||||
asyncio.get_event_loop().create_task(
|
||||
self._recover_from_checkpoint(checkpoint))
|
||||
|
||||
def _get_or_start_routers(self, host, port):
|
||||
"""Get the HTTP proxy belonging to this serve instance.
|
||||
asyncio.get_event_loop().create_task(self.run_control_loop())
|
||||
|
||||
If the HTTP proxy does not already exist, it will be started.
|
||||
def _start_routers_if_needed(self):
|
||||
"""Start a router on every node if it doesn't already exist."""
|
||||
for node_id, node_resource in get_all_node_ids():
|
||||
if node_id in self.routers:
|
||||
continue
|
||||
|
||||
router_name = format_actor_name(SERVE_PROXY_NAME,
|
||||
self.instance_name, node_id)
|
||||
try:
|
||||
router = ray.get_actor(router_name)
|
||||
except ValueError:
|
||||
logger.info("Starting router with name '{}' on node '{}' "
|
||||
"listening on '{}:{}'".format(
|
||||
router_name, node_id, self.http_host,
|
||||
self.http_port))
|
||||
router = HTTPProxyActor.options(
|
||||
name=router_name,
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
resources={
|
||||
node_resource: 0.01
|
||||
},
|
||||
).remote(
|
||||
self.http_host,
|
||||
self.http_port,
|
||||
instance_name=self.instance_name)
|
||||
|
||||
self.routers[node_id] = router
|
||||
|
||||
def _stop_routers_if_needed(self):
|
||||
"""Removes router actors from any nodes that no longer exist.
|
||||
|
||||
Returns whether or not any actors were removed (a checkpoint should
|
||||
be taken).
|
||||
"""
|
||||
# TODO(simon): We don't handle nodes being added/removed. To do that,
|
||||
# we should implement some sort of control loop in master actor.
|
||||
for _, node_id_group in groupby(sorted(ray.state.node_ids())):
|
||||
for index, node_id in enumerate(node_id_group):
|
||||
proxy_name = format_actor_name(SERVE_PROXY_NAME,
|
||||
self.instance_name)
|
||||
proxy_name += "-{}-{}".format(node_id, index)
|
||||
try:
|
||||
router = ray.get_actor(proxy_name)
|
||||
except ValueError:
|
||||
logger.info(
|
||||
"Starting HTTP proxy with name '{}' on node '{}' "
|
||||
"listening on port {}".format(proxy_name, node_id,
|
||||
port))
|
||||
router = HTTPProxyActor.options(
|
||||
name=proxy_name,
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
resources={
|
||||
node_id: 0.01
|
||||
},
|
||||
).remote(
|
||||
host, port, instance_name=self.instance_name)
|
||||
self.routers.append(router)
|
||||
checkpoint_required = False
|
||||
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
|
||||
to_stop = []
|
||||
for node_id in self.routers:
|
||||
if node_id not in all_node_ids:
|
||||
logger.info(
|
||||
"Removing router on removed node '{}'.".format(node_id))
|
||||
to_stop.append(node_id)
|
||||
|
||||
def get_router(self):
|
||||
"""Returns a handle to the HTTP proxy managed by this actor."""
|
||||
for node_id in to_stop:
|
||||
router_handle = self.routers.pop(node_id)
|
||||
ray.kill(router_handle, no_restart=True)
|
||||
checkpoint_required = True
|
||||
|
||||
return checkpoint_required
|
||||
|
||||
def get_routers(self):
|
||||
"""Returns a dictionary of node ID to router actor handles."""
|
||||
return self.routers
|
||||
|
||||
def get_router_config(self):
|
||||
"""Called by the HTTP proxy on startup to fetch required state."""
|
||||
"""Called by the router on startup to fetch required state."""
|
||||
return self.routes
|
||||
|
||||
def _get_or_start_metric_exporter(self, metric_exporter_class):
|
||||
def _start_metric_exporter(self, metric_exporter_class):
|
||||
"""Get the metric exporter belonging to this serve instance.
|
||||
|
||||
If the metric exporter does not already exist, it will be started.
|
||||
@@ -210,11 +239,13 @@ class ServeController:
|
||||
|
||||
def _checkpoint(self):
|
||||
"""Checkpoint internal state and write it to the KV store."""
|
||||
assert self.write_lock.locked()
|
||||
logger.debug("Writing checkpoint")
|
||||
start = time.time()
|
||||
checkpoint = pickle.dumps(
|
||||
(self.routes, self.backends, self.traffic_policies, self.replicas,
|
||||
self.replicas_to_start, self.replicas_to_stop,
|
||||
(self.routes, list(
|
||||
self.routers.keys()), self.backends, self.traffic_policies,
|
||||
self.replicas, self.replicas_to_start, self.replicas_to_stop,
|
||||
self.backends_to_remove, self.endpoints_to_remove))
|
||||
|
||||
self.kv_store.put(CHECKPOINT_KEY, checkpoint)
|
||||
@@ -229,7 +260,7 @@ class ServeController:
|
||||
|
||||
Performs the following operations:
|
||||
1) Deserializes the internal state from the checkpoint.
|
||||
2) Pushes the latest configuration to the HTTP proxy and router
|
||||
2) Pushes the latest configuration to the routers
|
||||
in case we crashed before updating them.
|
||||
3) Starts/stops any worker replicas that are pending creation or
|
||||
deletion.
|
||||
@@ -245,6 +276,7 @@ class ServeController:
|
||||
# Load internal state from the checkpoint data.
|
||||
(
|
||||
self.routes,
|
||||
router_node_ids,
|
||||
self.backends,
|
||||
self.traffic_policies,
|
||||
self.replicas,
|
||||
@@ -254,6 +286,11 @@ class ServeController:
|
||||
self.endpoints_to_remove,
|
||||
) = pickle.loads(checkpoint_bytes)
|
||||
|
||||
for node_id in router_node_ids:
|
||||
router_name = format_actor_name(SERVE_PROXY_NAME,
|
||||
self.instance_name, node_id)
|
||||
self.routers[node_id] = ray.get_actor(router_name)
|
||||
|
||||
# Fetch actor handles for all of the backend replicas in the system.
|
||||
# All of these workers are guaranteed to already exist because they
|
||||
# would not be written to a checkpoint in self.workers until they
|
||||
@@ -270,7 +307,7 @@ class ServeController:
|
||||
for endpoint, traffic_policy in self.traffic_policies.items():
|
||||
await asyncio.gather(*[
|
||||
router.set_traffic.remote(endpoint, traffic_policy)
|
||||
for router in self.routers
|
||||
for router in self.routers.values()
|
||||
])
|
||||
|
||||
for backend_tag, replica_dict in self.workers.items():
|
||||
@@ -278,20 +315,20 @@ class ServeController:
|
||||
await asyncio.gather(*[
|
||||
router.add_new_worker.remote(backend_tag, replica_tag,
|
||||
worker)
|
||||
for router in self.routers
|
||||
for router in self.routers.values()
|
||||
])
|
||||
|
||||
for backend, info in self.backends.items():
|
||||
await asyncio.gather(*[
|
||||
router.set_backend_config.remote(backend, info.backend_config)
|
||||
for router in self.routers
|
||||
for router in self.routers.values()
|
||||
])
|
||||
await self.broadcast_backend_config(backend)
|
||||
|
||||
# Push configuration state to the HTTP proxy.
|
||||
# Push configuration state to the routers.
|
||||
await asyncio.gather(*[
|
||||
router.set_route_table.remote(self.routes)
|
||||
for router in self.routers
|
||||
for router in self.routers.values()
|
||||
])
|
||||
|
||||
# Start/stop any pending backend replicas.
|
||||
@@ -307,6 +344,16 @@ class ServeController:
|
||||
|
||||
self.write_lock.release()
|
||||
|
||||
async def run_control_loop(self):
|
||||
while True:
|
||||
async with self.write_lock:
|
||||
self._start_routers_if_needed()
|
||||
checkpoint_required = self._stop_routers_if_needed()
|
||||
if checkpoint_required:
|
||||
self._checkpoint()
|
||||
|
||||
await asyncio.sleep(CONTROL_LOOP_PERIOD_S)
|
||||
|
||||
def get_backend_configs(self):
|
||||
"""Fetched by the router on startup."""
|
||||
backend_configs = {}
|
||||
@@ -368,7 +415,7 @@ class ServeController:
|
||||
await asyncio.gather(*[
|
||||
router.add_new_worker.remote(backend_tag, replica_tag,
|
||||
worker_handle)
|
||||
for router in self.routers
|
||||
for router in self.routers.values()
|
||||
])
|
||||
|
||||
async def _start_pending_replicas(self):
|
||||
@@ -409,7 +456,7 @@ class ServeController:
|
||||
# Remove the replica from router. This call is idempotent.
|
||||
await asyncio.gather(*[
|
||||
router.remove_worker.remote(backend_tag, replica_tag)
|
||||
for router in self.routers
|
||||
for router in self.routers.values()
|
||||
])
|
||||
|
||||
# TODO(edoakes): this logic isn't ideal because there may be
|
||||
@@ -429,7 +476,7 @@ class ServeController:
|
||||
for backend_tag in self.backends_to_remove:
|
||||
await asyncio.gather(*[
|
||||
router.remove_backend.remote(backend_tag)
|
||||
for router in self.routers
|
||||
for router in self.routers.values()
|
||||
])
|
||||
self.backends_to_remove.clear()
|
||||
|
||||
@@ -441,7 +488,7 @@ class ServeController:
|
||||
for endpoint_tag in self.endpoints_to_remove:
|
||||
await asyncio.gather(*[
|
||||
router.remove_endpoint.remote(endpoint_tag)
|
||||
for router in self.routers
|
||||
for router in self.routers.values()
|
||||
])
|
||||
self.endpoints_to_remove.clear()
|
||||
|
||||
@@ -558,7 +605,7 @@ class ServeController:
|
||||
self._checkpoint()
|
||||
await asyncio.gather(*[
|
||||
router.set_traffic.remote(endpoint_name, traffic_policy)
|
||||
for router in self.routers
|
||||
for router in self.routers.values()
|
||||
])
|
||||
|
||||
async def set_traffic(self, endpoint_name, traffic_dict):
|
||||
@@ -590,14 +637,14 @@ class ServeController:
|
||||
router.set_traffic.remote(
|
||||
endpoint_name,
|
||||
self.traffic_policies[endpoint_name],
|
||||
) for router in self.routers
|
||||
) for router in self.routers.values()
|
||||
])
|
||||
|
||||
async def create_endpoint(self, endpoint, traffic_dict, route, methods):
|
||||
"""Create a new endpoint with the specified route and methods.
|
||||
|
||||
If the route is None, this is a "headless" endpoint that will not
|
||||
be added to the HTTP proxy (can only be accessed via a handle).
|
||||
be exposed over HTTP and can only be accessed via a handle.
|
||||
"""
|
||||
async with self.write_lock:
|
||||
# If this is a headless endpoint with no route, key the endpoint
|
||||
@@ -632,7 +679,7 @@ class ServeController:
|
||||
await self._set_traffic(endpoint, traffic_dict)
|
||||
await asyncio.gather(*[
|
||||
router.set_route_table.remote(self.routes)
|
||||
for router in self.routers
|
||||
for router in self.routers.values()
|
||||
])
|
||||
|
||||
async def delete_endpoint(self, endpoint):
|
||||
@@ -662,15 +709,13 @@ class ServeController:
|
||||
self.endpoints_to_remove.append(endpoint)
|
||||
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
# updates to the HTTP proxy and router to avoid inconsistent state
|
||||
# if we crash after pushing the update.
|
||||
# updates to the routers to avoid inconsistent state if we crash
|
||||
# after pushing the update.
|
||||
self._checkpoint()
|
||||
|
||||
# Update the HTTP proxy first to ensure no new requests for the
|
||||
# endpoint are sent to the router.
|
||||
await asyncio.gather(*[
|
||||
router.set_route_table.remote(self.routes)
|
||||
for router in self.routers
|
||||
for router in self.routers.values()
|
||||
])
|
||||
await self._remove_pending_endpoints()
|
||||
|
||||
@@ -698,7 +743,7 @@ class ServeController:
|
||||
# (particularly for max-batch-size).
|
||||
await asyncio.gather(*[
|
||||
router.set_backend_config.remote(backend_tag, backend_config)
|
||||
for router in self.routers
|
||||
for router in self.routers.values()
|
||||
])
|
||||
await self.broadcast_backend_config(backend_tag)
|
||||
|
||||
@@ -757,7 +802,7 @@ class ServeController:
|
||||
# (particularly for setting max_batch_size).
|
||||
await asyncio.gather(*[
|
||||
router.set_backend_config.remote(backend_tag, backend_config)
|
||||
for router in self.routers
|
||||
for router in self.routers.values()
|
||||
])
|
||||
|
||||
await self._start_pending_replicas()
|
||||
@@ -788,7 +833,7 @@ class ServeController:
|
||||
async def shutdown(self):
|
||||
"""Shuts down the serve instance completely."""
|
||||
async with self.write_lock:
|
||||
for router in self.routers:
|
||||
for router in self.routers.values():
|
||||
ray.kill(router, no_restart=True)
|
||||
ray.kill(self.metric_exporter, no_restart=True)
|
||||
for replica_dict in self.workers.values():
|
||||
|
||||
@@ -76,8 +76,8 @@ def test_controller_failure(serve_instance):
|
||||
|
||||
|
||||
def _kill_routers():
|
||||
routers = ray.get(serve.api._get_controller().get_router.remote())
|
||||
for router in routers:
|
||||
routers = ray.get(serve.api._get_controller().get_routers.remote())
|
||||
for router in routers.values():
|
||||
ray.kill(router, no_restart=False)
|
||||
|
||||
|
||||
|
||||
@@ -5,9 +5,10 @@ import pytest
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.cluster_utils import Cluster
|
||||
from ray.serve.constants import SERVE_PROXY_NAME
|
||||
from ray.serve.utils import block_until_http_ready
|
||||
from ray.cluster_utils import Cluster
|
||||
from ray.test_utils import wait_for_condition
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -24,19 +25,56 @@ def test_multiple_routers():
|
||||
assert len(node_ids) == 2
|
||||
serve.init(http_port=8005)
|
||||
|
||||
# two actors should be started
|
||||
head_http = ray.get_actor(SERVE_PROXY_NAME +
|
||||
"-{}-{}".format(node_ids[0], 0))
|
||||
ray.get_actor(SERVE_PROXY_NAME + "-{}-{}".format(node_ids[0], 1))
|
||||
def actor_name(index):
|
||||
return SERVE_PROXY_NAME + "-{}-{}".format(node_ids[0], index)
|
||||
|
||||
# wait for the actors to come up
|
||||
# Two actors should be started.
|
||||
def get_first_two_actors():
|
||||
try:
|
||||
ray.get_actor(actor_name(0))
|
||||
ray.get_actor(actor_name(1))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
wait_for_condition(get_first_two_actors)
|
||||
|
||||
# Wait for the actors to come up.
|
||||
ray.get(block_until_http_ready.remote("http://127.0.0.1:8005/-/routes"))
|
||||
|
||||
# kill the head_http server, the HTTP server should still functions
|
||||
ray.kill(head_http, no_restart=True)
|
||||
# Kill one of the servers, the HTTP server should still function.
|
||||
ray.kill(ray.get_actor(actor_name(0)), no_restart=True)
|
||||
ray.get(block_until_http_ready.remote("http://127.0.0.1:8005/-/routes"))
|
||||
|
||||
# cleanup the nodes (otherwise Ray will segfault)
|
||||
# Add a new node to the cluster. This should trigger a new router to get
|
||||
# started.
|
||||
new_node = cluster.add_node()
|
||||
|
||||
def get_third_actor():
|
||||
try:
|
||||
ray.get_actor(actor_name(2))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
wait_for_condition(get_third_actor)
|
||||
|
||||
# Remove the newly-added node from the cluster. The corresponding actor
|
||||
# should be removed as well.
|
||||
cluster.remove_node(new_node)
|
||||
|
||||
def third_actor_removed():
|
||||
try:
|
||||
ray.get_actor(actor_name(2))
|
||||
return False
|
||||
except ValueError:
|
||||
return True
|
||||
|
||||
# Check that the actor is gone and the HTTP server still functions.
|
||||
wait_for_condition(third_actor_removed)
|
||||
ray.get(block_until_http_ready.remote("http://127.0.0.1:8005/-/routes"))
|
||||
|
||||
# Clean up the nodes (otherwise Ray will segfault).
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from functools import singledispatch
|
||||
from itertools import groupby
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
@@ -104,11 +105,16 @@ def get_random_letters(length=6):
|
||||
return "".join(random.choices(string.ascii_letters, k=length))
|
||||
|
||||
|
||||
def format_actor_name(actor_name, instance_name=None):
|
||||
def format_actor_name(actor_name, instance_name=None, *modifiers):
|
||||
if instance_name is None:
|
||||
return actor_name
|
||||
name = actor_name
|
||||
else:
|
||||
return "{}:{}".format(instance_name, actor_name)
|
||||
name = "{}:{}".format(instance_name, actor_name)
|
||||
|
||||
for modifier in modifiers:
|
||||
name += "-{}".format(modifier)
|
||||
|
||||
return name
|
||||
|
||||
|
||||
@singledispatch
|
||||
@@ -216,3 +222,23 @@ def try_schedule_resources_on_nodes(
|
||||
successfully_scheduled.append(False)
|
||||
|
||||
return successfully_scheduled
|
||||
|
||||
|
||||
def get_all_node_ids():
|
||||
"""Get IDs for all nodes in the cluster.
|
||||
|
||||
Handles multiple nodes on the same IP by appending an index to the
|
||||
node_id, e.g., 'node_id-index'.
|
||||
|
||||
Returns a list of ('node_id-index', 'node_id') tuples (the latter can be
|
||||
used as a resource requirement for actor placements).
|
||||
"""
|
||||
node_ids = []
|
||||
# We need to use the node_id and index here because we could
|
||||
# have multiple virtual nodes on the same host. In that case
|
||||
# they will have the same IP and therefore node_id.
|
||||
for _, node_id_group in groupby(sorted(ray.state.node_ids())):
|
||||
for index, node_id in enumerate(node_id_group):
|
||||
node_ids.append(("{}-{}".format(node_id, index), node_id))
|
||||
|
||||
return node_ids
|
||||
|
||||
Reference in New Issue
Block a user