[serve] Master actor fault tolerance (#8116)

This commit is contained in:
Edward Oakes
2020-04-28 15:52:29 -05:00
committed by GitHub
parent ebdccde030
commit 7c0200c93b
18 changed files with 789 additions and 535 deletions
+14 -1
View File
@@ -13,11 +13,24 @@ py_library(
py_test(
name = "test_serve",
size = "medium",
srcs = glob(["tests/*.py"], exclude=["tests/test_nonblocking.py"]),
srcs = glob(["tests/*.py"],
exclude=["tests/test_nonblocking.py",
"tests/test_master_crashes.py"]),
tags = ["exclusive"],
deps = [":serve_lib"],
)
# Runs test_api and test_failure with injected failures in the master actor.
py_test(
name = "test_master_crashes",
size = "medium",
srcs = glob(["tests/test_master_crashes.py",
"tests/test_api.py",
"tests/test_failure.py"],
exclude=["tests/test_nonblocking.py",
"tests/test_serve.py"]),
)
py_test(
name = "echo_full",
size = "small",
+19 -25
View File
@@ -10,7 +10,7 @@ from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
from ray.serve.master import ServeMaster
from ray.serve.handle import RayServeHandle
from ray.serve.kv_store_service import SQLiteKVStore
from ray.serve.utils import block_until_http_ready
from ray.serve.utils import block_until_http_ready, retry_actor_failures
from ray.serve.exceptions import RayServeException, batch_annotation_not_found
from ray.serve.backend_config import BackendConfig
from ray.serve.policy import RoutePolicy
@@ -139,14 +139,11 @@ def init(
return SQLiteKVStore(namespace, db_path=kv_store_path)
master_actor = ServeMaster.options(
detached=True, name=SERVE_MASTER_NAME).remote(kv_store_connector)
ray.get(
master_actor.start_router.remote(queueing_policy.value, policy_kwargs))
ray.get(master_actor.start_metric_monitor.remote(gc_window_seconds))
if start_server:
ray.get(master_actor.start_http_proxy.remote(http_host, http_port))
detached=True,
name=SERVE_MASTER_NAME,
max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION,
).remote(kv_store_connector, queueing_policy.value, policy_kwargs,
start_server, http_host, http_port, gc_window_seconds)
if start_server and blocking:
block_until_http_ready("http://{}:{}/-/routes".format(
@@ -165,9 +162,8 @@ def create_endpoint(endpoint_name, route=None, methods=["GET"]):
blocking (bool): If true, the function will wait for service to be
registered before returning
"""
ray.get(
master_actor.create_endpoint.remote(route, endpoint_name,
[m.upper() for m in methods]))
retry_actor_failures(master_actor.create_endpoint, route, endpoint_name,
[m.upper() for m in methods])
@_ensure_connected
@@ -178,8 +174,8 @@ def set_backend_config(backend_tag, backend_config):
backend_tag(str): A registered backend.
backend_config(BackendConfig) : Desired backend configuration.
"""
ray.get(
master_actor.set_backend_config.remote(backend_tag, backend_config))
retry_actor_failures(master_actor.set_backend_config, backend_tag,
backend_config)
@_ensure_connected
@@ -189,7 +185,7 @@ def get_backend_config(backend_tag):
Args:
backend_tag(str): A registered backend.
"""
return ray.get(master_actor.get_backend_config.remote(backend_tag))
return retry_actor_failures(master_actor.get_backend_config, backend_tag)
def _backend_accept_batch(func_or_class):
@@ -240,9 +236,8 @@ def create_backend(func_or_class,
if _backend_accept_batch(func_or_class):
backend_config.has_accept_batch_annotation = True
ray.get(
master_actor.create_backend.remote(backend_tag, backend_config,
func_or_class, actor_init_args))
retry_actor_failures(master_actor.create_backend, backend_tag,
backend_config, func_or_class, actor_init_args)
@_ensure_connected
@@ -261,9 +256,8 @@ def set_traffic(endpoint_name, traffic_policy_dictionary):
traffic_policy_dictionary (dict): a dictionary maps backend names
to their traffic weights. The weights must sum to 1.
"""
ray.get(
master_actor.set_traffic.remote(endpoint_name,
traffic_policy_dictionary))
retry_actor_failures(master_actor.set_traffic, endpoint_name,
traffic_policy_dictionary)
@_ensure_connected
@@ -286,11 +280,11 @@ def get_handle(endpoint_name,
RayServeHandle
"""
if not missing_ok:
assert endpoint_name in ray.get(
master_actor.get_all_endpoints.remote())
assert endpoint_name in retry_actor_failures(
master_actor.get_all_endpoints)
return RayServeHandle(
ray.get(master_actor.get_router.remote())[0],
retry_actor_failures(master_actor.get_router)[0],
endpoint_name,
relative_slo_ms,
absolute_slo_ms,
@@ -309,5 +303,5 @@ def stat(percentiles=[50, 90, 95],
The longest aggregation window must be shorter or equal to the
gc_window_seconds.
"""
[monitor] = ray.get(master_actor.get_metric_monitor.remote())
[monitor] = retry_actor_failures(master_actor.get_metric_monitor)
return ray.get(monitor.collect.remote(percentiles, agg_windows_seconds))
+1 -2
View File
@@ -55,9 +55,8 @@ class BackendConfig:
key = "num_replicas"
yield key, val
def get_actor_creation_args(self, init_args):
def get_actor_creation_args(self):
ret_d = deepcopy(self.__dict__)
for k in self._serve_configs:
ret_d.pop(k)
ret_d["args"] = init_args
return ret_d
+3 -14
View File
@@ -23,24 +23,14 @@ def create_backend_worker(func_or_class):
assert False, "func_or_class must be function or class."
class RayServeWrappedWorker(object):
def __init__(self,
backend_tag,
replica_tag,
init_args,
router_handle=None):
def __init__(self, backend_tag, replica_tag, init_args):
serve.init()
if is_function:
_callable = func_or_class
else:
_callable = func_or_class(*init_args)
if router_handle is None:
master_actor = serve.api._get_master_actor()
[router_handle] = ray.get(
master_actor.get_backend_worker_config.remote())
self.backend = RayServeWorker(backend_tag, _callable,
router_handle, is_function)
self.backend = RayServeWorker(backend_tag, _callable, is_function)
def get_metrics(self):
return self.backend.get_metrics()
@@ -76,10 +66,9 @@ def ensure_async(func):
class RayServeWorker:
"""Handles requests with the provided callable."""
def __init__(self, name, _callable, router_handle, is_function):
def __init__(self, name, _callable, is_function):
self.name = name
self.callable = _callable
self.router_handle = router_handle
self.is_function = is_function
self.error_counter = 0
+9 -3
View File
@@ -1,6 +1,15 @@
#: Actor name used to register master actor
SERVE_MASTER_NAME = "SERVE_MASTER_ACTOR"
#: Actor name used to register router actor
SERVE_ROUTER_NAME = "SERVE_ROUTER_ACTOR"
#: Actor name used to register HTTP proxy actor
SERVE_PROXY_NAME = "SERVE_PROXY_ACTOR"
#: Actor name used to register metric monitor actor
SERVE_METRIC_MONITOR_NAME = "SERVE_METRIC_MONITOR_ACTOR"
#: HTTP Address
DEFAULT_HTTP_ADDRESS = "http://127.0.0.1:8000"
@@ -15,6 +24,3 @@ ASYNC_CONCURRENCY = int(1e6)
#: Default latency SLO
DEFAULT_LATENCY_SLO_MS = 1e9
#: Key for storing no http route services
NO_ROUTE_KEY = "NO_ROUTE"
+4 -3
View File
@@ -8,7 +8,7 @@ from ray.serve.constants import SERVE_MASTER_NAME
from ray.serve.context import TaskContext
from ray.serve.request_params import RequestMetadata
from ray.serve.http_util import Response
from ray.serve.utils import logger
from ray.serve.utils import logger, retry_actor_failures_async
from urllib.parse import parse_qs
@@ -29,8 +29,9 @@ class HTTPProxy:
async def fetch_config_from_master(self):
assert ray.is_initialized()
master = ray.util.get_actor(SERVE_MASTER_NAME)
self.route_table, [self.router_handle
] = await master.get_http_proxy_config.remote()
self.route_table, [
self.router_handle
] = await retry_actor_failures_async(master.get_http_proxy_config)
def set_route_table(self, route_table):
self.route_table = route_table
-130
View File
@@ -1,12 +1,8 @@
import json
import sqlite3
from abc import ABC
from typing import Union, List
from ray import cloudpickle as pickle
import ray.experimental.internal_kv as ray_kv
from ray.serve.utils import logger
from ray.serve.constants import NO_ROUTE_KEY
class NamespacedKVStore(ABC):
@@ -167,129 +163,3 @@ class SQLiteKVStore(NamespacedKVStore):
result = list(
cursor.execute("SELECT key, value FROM {}".format(self.namespace)))
return dict(result)
# Tables
class RoutingTable:
def __init__(self, kv_connector):
self.routing_table = kv_connector("routing_table")
self.methods_table = kv_connector("methods_table")
self.request_count = 0
def register_service(self, route: Union[str, None], service: str,
methods: List[str]):
"""Create an entry in the routing table
Args:
route: http path name. Must begin with '/'.
service: service name. This is the name http actor will push
the request to.
"""
logger.debug(
"[KV] Registering route {} to service {} with methods {}.".format(
route, service, methods))
# put no route services in default key
if route is None:
no_http_services = json.loads(
self.routing_table.get(NO_ROUTE_KEY, "[]"))
no_http_services.append(service)
self.routing_table.put(NO_ROUTE_KEY, json.dumps(no_http_services))
else:
self.routing_table.put(route, service)
self.methods_table.put(route, json.dumps(methods))
def list_service(self, include_headless=False, include_methods=False):
"""Returns the routing table.
Args:
include_headless: If True, returns a no route services (headless)
services with normal services. (Default: False)
include_methods: If True, returns a mapping include the methods
list for each route.
"""
table = self.routing_table.as_dict()
if include_methods:
methods_table = self.methods_table.as_dict()
for route, methods in methods_table.items():
if route in table:
table[route] = (table[route], json.loads(methods))
if include_headless:
table[NO_ROUTE_KEY] = json.loads(table.get(NO_ROUTE_KEY, "[]"))
else:
table.pop(NO_ROUTE_KEY, None)
return table
def get_request_count(self):
"""Return the number of requests that fetched the routing table.
This method is used for two purpose:
1. Make sure HTTP proxy has started and healthy. Incremented request
count means HTTP proxy is actively fetching routing table.
2. Make sure HTTP proxy does not have stale routing table. This number
should be incremented every HTTP_ROUTER_CHECKER_INTERVAL_S seconds.
Supervisor should check this number as indirect indicator of http
proxy's health.
"""
return self.request_count
class BackendTable:
def __init__(self, kv_connector):
self.backend_table = kv_connector("backend_creator")
self.replica_table = kv_connector("replica_table")
self.backend_info = kv_connector("backend_info")
self.backend_init_args = kv_connector("backend_init_args")
def register_backend(self, backend_tag: str, backend_creator):
backend_creator_serialized = pickle.dumps(backend_creator)
self.backend_table.put(backend_tag, backend_creator_serialized)
def save_init_args(self, backend_tag: str, arg_list):
serialized_arg_list = pickle.dumps(arg_list)
self.backend_init_args.put(backend_tag, serialized_arg_list)
def get_init_args(self, backend_tag):
return pickle.loads(self.backend_init_args.get(backend_tag))
def register_info(self, backend_tag: str, backend_info_d):
self.backend_info.put(backend_tag, json.dumps(backend_info_d))
def get_info(self, backend_tag):
return json.loads(self.backend_info.get(backend_tag, "{}"))
def get_backend_creator(self, backend_tag):
return pickle.loads(self.backend_table.get(backend_tag))
def list_backends(self):
return list(self.backend_table.as_dict().keys())
def list_replicas(self, backend_tag: str):
return json.loads(self.replica_table.get(backend_tag, "[]"))
def add_replica(self, backend_tag: str, new_replica_tag: str):
replica_tags = self.list_replicas(backend_tag)
replica_tags.append(new_replica_tag)
self.replica_table.put(backend_tag, json.dumps(replica_tags))
def remove_replica(self, backend_tag):
replica_tags = self.list_replicas(backend_tag)
removed_replica = replica_tags.pop()
self.replica_table.put(backend_tag, json.dumps(replica_tags))
return removed_replica
class TrafficPolicyTable:
def __init__(self, kv_connector):
self.traffic_policy_table = kv_connector("traffic_policy")
def register_traffic_policy(self, service_name, policy_dict):
self.traffic_policy_table.put(service_name, json.dumps(policy_dict))
def list_traffic_policy(self):
return {
service: json.loads(policy)
for service, policy in self.traffic_policy_table.as_dict().items()
}
+454 -248
View File
@@ -1,334 +1,540 @@
import asyncio
from collections import defaultdict
from functools import wraps
import inspect
import os
import random
import time
import ray
import ray.cloudpickle as pickle
from ray.serve.backend_config import BackendConfig
from ray.serve.constants import ASYNC_CONCURRENCY
from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_ROUTER_NAME,
SERVE_PROXY_NAME, SERVE_METRIC_MONITOR_NAME)
from ray.serve.exceptions import batch_annotation_not_found
from ray.serve.http_proxy import HTTPProxyActor
from ray.serve.kv_store_service import (BackendTable, RoutingTable,
TrafficPolicyTable)
from ray.serve.metric import (MetricMonitor, start_metric_monitor_loop)
from ray.serve.backend_worker import create_backend_worker
from ray.serve.utils import expand, get_random_letters, logger
from ray.serve.utils import async_retryable, get_random_letters, logger
import numpy as np
def async_retryable(cls):
"""Make all actor method invocations on the class retryable.
Note: This will retry actor_handle.method_name.remote(), but it must
be invoked in an async context.
Usage:
@ray.remote(max_reconstructions=10000)
@async_retryable
class A:
pass
"""
for name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
def decorate_with_retry(f):
@wraps(f)
async def retry_method(*args, **kwargs):
while True:
result = await f(*args, **kwargs)
if isinstance(result, ray.exceptions.RayActorError):
logger.warning(
"Actor method '{}' failed, retrying after 100ms.".
format(name))
await asyncio.sleep(0.1)
else:
return result
return retry_method
method.__ray_invocation_decorator__ = decorate_with_retry
return cls
# Used for testing purposes only. If this is set, the master actor will crash
# after writing each checkpoint with the specified probability.
_CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.0
@ray.remote
class ServeMaster:
"""Initialize and store all actor handles.
"""Responsible for managing the state of the serving system.
Note:
This actor is necessary because ray will destroy actors when the
original actor handle goes out of scope (when driver exit). Therefore
we need to initialize and store actor handles in a seperate actor.
The master actor implements fault tolerance by persisting its state in
a new checkpoint each time a state change is made. If the actor crashes,
the latest checkpoint is loaded and the state is recovered. Checkpoints
are written/read using a provided KV-store interface.
All hard state in the system is maintained by this actor and persisted via
these checkpoints. Soft state required by other components is fetched by
those actors from this actor on startup and updates are pushed out from
this actor.
All other actors started by the master actor are named, detached actors
so they will not fate share with the master if it crashes.
The following guarantees are provided for state-changing calls to the
master actor:
- If the call succeeds, the change was made and will be reflected in
the system even if the master actor or other actors die unexpectedly.
- If the call fails, the change may have been made but isn't guaranteed
to have been. The client should retry in this case. Note that this
requires all implementations here to be idempotent.
"""
def __init__(self, kv_store_connector):
self.kv_store_connector = kv_store_connector
self.route_table = RoutingTable(kv_store_connector)
self.backend_table = BackendTable(kv_store_connector)
self.policy_table = TrafficPolicyTable(kv_store_connector)
async def __init__(self, kv_store_connector, router_class, router_kwargs,
start_http_proxy, http_proxy_host, http_proxy_port,
metric_gc_window_s):
# Used to read/write checkpoints.
# TODO(edoakes): namespace the master actor and its checkpoints.
self.kv_store_client = kv_store_connector("serve_checkpoints")
# path -> (endpoint, methods).
self.routes = {}
# backend -> (worker_creator, init_args, backend_config).
self.backends = {}
# backend -> replica_tags.
self.replicas = defaultdict(list)
# replicas that should be started if recovering from a checkpoint.
self.replicas_to_start = defaultdict(list)
# replicas that should be stopped if recovering from a checkpoint.
self.replicas_to_stop = defaultdict(list)
# endpoint -> traffic_dict
self.traffic_policies = dict()
# Dictionary of backend tag to dictionaries of replica tag to worker.
# TODO(edoakes): consider removing this and just using the names.
self.workers = defaultdict(dict)
# Used to ensure that only a single state-changing operation happens
# at any given time.
self.write_lock = asyncio.Lock()
# Cached handles to actors in the system.
self.router = None
self.http_proxy = None
self.metric_monitor = None
def get_traffic_policy(self, endpoint_name):
return self.policy_table.list_traffic_policy()[endpoint_name]
# If starting the actor for the first time, starts up the other system
# components. If recovering, fetches their actor handles.
self._get_or_start_router(router_class, router_kwargs)
if start_http_proxy:
self._get_or_start_http_proxy(http_proxy_host, http_proxy_port)
self._get_or_start_metric_monitor(metric_gc_window_s)
def start_router(self, router_class, init_kwargs):
assert self.router is None, "Router already started."
self.router = async_retryable(router_class).options(
max_concurrency=ASYNC_CONCURRENCY,
max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION,
).remote(**init_kwargs)
# NOTE(edoakes): unfortunately, we can't completely recover from a
# checkpoint in the constructor because we block while waiting for
# other actors to start up, and those actors fetch soft state from
# this actor. Because no other tasks will start executing until after
# the constructor finishes, if we were to run this logic in the
# constructor it could lead to deadlock between this actor and a child.
# However we do need to guarantee that we have fully recovered from a
# checkpoint before any other state-changing calls run. We address this
# by acquiring the write_lock and then posting the task to recover from
# a checkpoint to the event loop. Other state-changing calls acquire
# this lock and will be blocked until recovering from the checkpoint
# finishes.
checkpoint = self.kv_store_client.get("checkpoint")
if checkpoint is None:
logger.debug("No checkpoint found")
else:
await self.write_lock.acquire()
asyncio.get_event_loop().create_task(
self._recover_from_checkpoint(checkpoint))
def _get_or_start_router(self, router_class, router_kwargs):
"""Get the router belonging to this serve cluster.
If the router does not already exist, it will be started.
"""
try:
self.router = ray.util.get_actor(SERVE_ROUTER_NAME)
except ValueError:
logger.info(
"Starting router with name '{}'".format(SERVE_ROUTER_NAME))
self.router = async_retryable(router_class).options(
detached=True,
name=SERVE_ROUTER_NAME,
max_concurrency=ASYNC_CONCURRENCY,
max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION,
).remote(**router_kwargs)
def get_router(self):
assert self.router is not None, "Router not started yet."
"""Returns a handle to the router managed by this actor."""
return [self.router]
def start_http_proxy(self, host, port):
"""Start the HTTP proxy on the given host:port.
def _get_or_start_http_proxy(self, host, port):
"""Get the HTTP proxy belonging to this serve cluster.
On startup (or restart), the HTTP proxy will fetch its config via
get_http_proxy_config.
If the HTTP proxy does not already exist, it will be started.
"""
assert self.http_proxy is None, "HTTP proxy already started."
assert self.router is not None, (
"Router must be started before HTTP proxy.")
self.http_proxy = async_retryable(HTTPProxyActor).options(
max_concurrency=ASYNC_CONCURRENCY,
max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION,
).remote(host, port)
async def get_http_proxy_config(self):
route_table = self.route_table.list_service(
include_methods=True, include_headless=False)
return route_table, self.get_router()
try:
self.http_proxy = ray.util.get_actor(SERVE_PROXY_NAME)
except ValueError:
logger.info(
"Starting HTTP proxy with name '{}'".format(SERVE_PROXY_NAME))
self.http_proxy = async_retryable(HTTPProxyActor).options(
detached=True,
name=SERVE_PROXY_NAME,
max_concurrency=ASYNC_CONCURRENCY,
max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION,
).remote(host, port)
def get_http_proxy(self):
assert self.http_proxy is not None, "HTTP proxy not started yet."
"""Returns a handle to the HTTP proxy managed by this actor."""
return [self.http_proxy]
def start_metric_monitor(self, gc_window_seconds):
assert self.metric_monitor is None, "Metric monitor already started."
self.metric_monitor = MetricMonitor.remote(gc_window_seconds)
# TODO(edoakes): this should be an actor method, not a separate task.
start_metric_monitor_loop.remote(self.metric_monitor)
self.metric_monitor.add_target.remote(self.router)
def get_http_proxy_config(self):
"""Called by the HTTP proxy on startup to fetch required state."""
return self.routes, self.get_router()
def _get_or_start_metric_monitor(self, gc_window_s):
"""Get the metric monitor belonging to this serve cluster.
If the metric monitor does not already exist, it will be started.
"""
try:
self.metric_monitor = ray.util.get_actor(SERVE_METRIC_MONITOR_NAME)
except ValueError:
logger.info("Starting metric monitor with name '{}'".format(
SERVE_METRIC_MONITOR_NAME))
self.metric_monitor = MetricMonitor.options(
detached=True,
name=SERVE_METRIC_MONITOR_NAME).remote(gc_window_s)
# TODO(edoakes): move these into the constructor.
start_metric_monitor_loop.remote(self.metric_monitor)
self.metric_monitor.add_target.remote(self.router)
def get_metric_monitor(self):
assert self.metric_monitor is not None, (
"Metric monitor not started yet.")
"""Returns a handle to the metric monitor managed by this actor."""
return [self.metric_monitor]
def _list_replicas(self, backend_tag):
return self.backend_table.list_replicas(backend_tag)
def _checkpoint(self):
"""Checkpoint internal state and write it to the KV store."""
logger.debug("Writing checkpoint")
start = time.time()
checkpoint = pickle.dumps(
(self.routes, self.backends, self.traffic_policies, self.replicas,
self.replicas_to_start, self.replicas_to_stop))
async def scale_replicas(self, backend_tag, num_replicas):
self.kv_store_client.put("checkpoint", checkpoint)
logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start))
if random.random() < _CRASH_AFTER_CHECKPOINT_PROBABILITY:
logger.warning("Intentionally crashing after checkpoint")
os._exit(0)
async def _recover_from_checkpoint(self, checkpoint_bytes):
"""Recover the cluster state from the provided checkpoint.
Performs the following operations:
1) Deserializes the internal state from the checkpoint.
2) Pushes the latest configuration to the HTTP proxy and router
in case we crashed before updating them.
3) Starts/stops any worker replicas that are pending creation or
deletion.
NOTE: this requires that self.write_lock is already acquired and will
release it before returning.
"""
assert self.write_lock.locked()
start = time.time()
logger.info("Recovering from checkpoint")
# Load internal state from the checkpoint data.
(self.routes, self.backends, self.traffic_policies, self.replicas,
self.replicas_to_start,
self.replicas_to_stop) = pickle.loads(checkpoint_bytes)
# Fetch actor handles for all of the backend replicas in the system.
# All of these workers are guaranteed to already exist because they
# would not be written to a checkpoint in self.workers until they
# were created.
for backend_tag, replica_tags in self.replicas.items():
for replica_tag in replica_tags:
self.workers[backend_tag][replica_tag] = ray.util.get_actor(
replica_tag)
# Push configuration state to the router.
# TODO(edoakes): should we make this a pull-only model for simplicity?
for endpoint, traffic_policy in self.traffic_policies.items():
await self.router.set_traffic.remote(endpoint, traffic_policy)
for backend_tag, replica_dict in self.workers.items():
for replica_tag, worker in replica_dict.items():
await self.router.add_new_worker.remote(
backend_tag, replica_tag, worker)
for backend, (_, _, backend_config_dict) in self.backends.items():
await self.router.set_backend_config.remote(
backend, backend_config_dict)
# Push configuration state to the HTTP proxy.
await self.http_proxy.set_route_table.remote(self.routes)
# Start/stop any pending backend replicas.
await self._start_pending_replicas()
await self._stop_pending_replicas()
logger.info(
"Recovered from checkpoint in {:.3f}s".format(time.time() - start))
self.write_lock.release()
def get_backend_configs(self):
"""Fetched by the router on startup."""
backend_configs = {}
for backend, (_, _, backend_config_dict) in self.backends.items():
backend_configs[backend] = backend_config_dict
return backend_configs
def get_traffic_policies(self):
"""Fetched by the router on startup."""
return self.traffic_policies
def _list_replicas(self, backend_tag):
"""Used only for testing."""
return self.replicas[backend_tag]
def get_traffic_policy(self, endpoint):
"""Fetched by serve handles."""
return self.traffic_policies[endpoint]
async def _start_backend_worker(self, backend_tag, replica_tag):
"""Creates a backend worker and waits for it to start up.
Assumes that the backend configuration has already been registered
in self.backends.
"""
logger.debug("Starting worker '{}' for backend '{}'.".format(
replica_tag, backend_tag))
worker_creator, init_args, config_dict = self.backends[backend_tag]
# TODO(edoakes): just store the BackendConfig in self.backends.
backend_config = BackendConfig(**config_dict)
kwargs = backend_config.get_actor_creation_args()
worker_handle = async_retryable(ray.remote(worker_creator)).options(
detached=True,
name=replica_tag,
max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION,
**kwargs).remote(backend_tag, replica_tag, init_args)
# TODO(edoakes): we should probably have a timeout here.
await worker_handle.ready.remote()
return worker_handle
async def _start_pending_replicas(self):
"""Starts the pending backend replicas in self.replicas_to_start.
Starts the worker, then pushes an update to the router to add it to
the proper backend. If the worker has already been started, only
updates the router.
Clears self.replicas_to_start.
"""
for backend_tag, replicas_to_create in self.replicas_to_start.items():
for replica_tag in replicas_to_create:
# NOTE(edoakes): the replicas may already be created if we
# failed after creating them but before writing a checkpoint.
try:
worker_handle = ray.util.get_actor(replica_tag)
except ValueError:
worker_handle = await self._start_backend_worker(
backend_tag, replica_tag)
self.replicas[backend_tag].append(replica_tag)
self.workers[backend_tag][replica_tag] = worker_handle
# Register the worker with the router.
await self.router.add_new_worker.remote(
backend_tag, replica_tag, worker_handle)
# Register the worker with the metric monitor.
self.metric_monitor.add_target.remote(worker_handle)
self.replicas_to_start.clear()
async def _stop_pending_replicas(self):
"""Stops the pending backend replicas in self.replicas_to_stop.
Stops workers by telling the router to remove them.
Clears self.replicas_to_stop.
"""
for backend_tag, replicas_to_stop in self.replicas_to_stop.items():
for replica_tag in replicas_to_stop:
# NOTE(edoakes): the replicas may already be stopped if we
# failed after stopping them but before writing a checkpoint.
try:
# Remove the replica from router.
# This will also submit __ray_terminate__ on the worker.
# NOTE(edoakes): we currently need to kill the worker from
# the router to guarantee that the router won't submit any
# more requests to it.
await self.router.remove_worker.remote(
backend_tag, replica_tag)
except ValueError:
pass
self.replicas_to_stop.clear()
def _scale_replicas(self, backend_tag, num_replicas):
"""Scale the given backend to the number of replicas.
This requires the master actor to be an async actor because we wait
synchronously for backends to start up and they may make calls into
the master actor while initializing (e.g., by calling get_handle()).
NOTE: this does not actually start or stop the replicas, but instead
adds the intention to start/stop them to self.workers_to_start and
self.workers_to_stop. The caller is responsible for then first writing
a checkpoint and then actually starting/stopping the intended replicas.
This avoids inconsistencies with starting/stopping a worker and then
crashing before writing a checkpoint.
"""
assert (backend_tag in self.backend_table.list_backends()
logger.debug("Scaling backend '{}' to {} replicas".format(
backend_tag, num_replicas))
assert (backend_tag in self.backends
), "Backend {} is not registered.".format(backend_tag)
assert num_replicas >= 0, ("Number of replicas must be"
" greater than or equal to 0.")
current_num_replicas = len(self._list_replicas(backend_tag))
current_num_replicas = len(self.replicas[backend_tag])
delta_num_replicas = num_replicas - current_num_replicas
if delta_num_replicas > 0:
logger.debug("Adding {} replicas to backend {}".format(
delta_num_replicas, backend_tag))
for _ in range(delta_num_replicas):
await self._start_backend_replica(backend_tag)
replica_tag = "{}#{}".format(backend_tag, get_random_letters())
self.replicas_to_start[backend_tag].append(replica_tag)
elif delta_num_replicas < 0:
logger.debug("Removing {} replicas from backend {}".format(
-delta_num_replicas, backend_tag))
assert len(self.replicas[backend_tag]) >= delta_num_replicas
for _ in range(-delta_num_replicas):
await self._remove_backend_replica(backend_tag)
replica_tag = self.replicas[backend_tag].pop()
if len(self.replicas[backend_tag]) == 0:
del self.replicas[backend_tag]
del self.workers[backend_tag][replica_tag]
if len(self.workers[backend_tag]) == 0:
del self.workers[backend_tag]
async def get_backend_worker_config(self):
return self.get_router()
async def _start_backend_replica(self, backend_tag):
assert (backend_tag in self.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
replica_tag = "{}#{}".format(backend_tag, get_random_letters(length=6))
# Register the worker in the DB.
# TODO(edoakes): we should guarantee that if calls to the master
# succeed, the cluster state has changed and if they fail, it hasn't.
# Once we have master actor fault tolerance, this breaks that guarantee
# because this method could fail after writing the replica to the DB.
self.backend_table.add_replica(backend_tag, replica_tag)
# Fetch the info to start the replica from the backend table.
backend_actor = ray.remote(
self.backend_table.get_backend_creator(backend_tag))
backend_config_dict = self.backend_table.get_info(backend_tag)
backend_config = BackendConfig(**backend_config_dict)
init_args = [
backend_tag, replica_tag,
self.backend_table.get_init_args(backend_tag)
]
kwargs = backend_config.get_actor_creation_args(init_args)
kwargs[
"max_reconstructions"] = ray.ray_constants.INFINITE_RECONSTRUCTION
# Start the worker.
worker_handle = backend_actor._remote(**kwargs)
self.workers[backend_tag][replica_tag] = worker_handle
# Wait for the worker to start up.
await worker_handle.ready.remote()
[router] = self.get_router()
await router.add_new_worker.remote(backend_tag, worker_handle)
# Register the worker with the metric monitor.
self.get_metric_monitor()[0].add_target.remote(worker_handle)
async def _remove_backend_replica(self, backend_tag):
assert (backend_tag in self.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
assert (len(self._list_replicas(backend_tag)) >
0), "Tried to remove replica from empty backend ({}).".format(
backend_tag)
replica_tag = self.backend_table.remove_replica(backend_tag)
assert backend_tag in self.workers
assert replica_tag in self.workers[backend_tag]
replica_handle = self.workers[backend_tag].pop(replica_tag)
if len(self.workers[backend_tag]) == 0:
del self.workers[backend_tag]
# Remove the replica from metric monitor.
[monitor] = self.get_metric_monitor()
await monitor.remove_target.remote(replica_handle)
# Remove the replica from router.
# This will also destroy the actor handle.
[router] = self.get_router()
await router.remove_worker.remote(backend_tag, replica_handle)
self.replicas_to_stop[backend_tag].append(replica_tag)
def get_all_worker_handles(self):
"""Fetched by the router on startup."""
return self.workers
def get_all_endpoints(self):
return expand(
self.route_table.list_service(include_headless=True).values())
def get_all_routes(self):
return expand(self.route_table.list_service().keys())
def get_all_backends(self):
return self.backend_table.list_backends()
"""Used for validation by the API client."""
return [endpoint for endpoint, methods in self.routes.values()]
async def set_traffic(self, endpoint_name, traffic_policy_dictionary):
assert endpoint_name in self.get_all_endpoints(), \
"Attempted to assign traffic for an endpoint '{}'" \
" that is not registered.".format(endpoint_name)
"""Sets the traffic policy for the specified endpoint."""
async with self.write_lock:
assert endpoint_name in self.get_all_endpoints(), \
"Attempted to assign traffic for an endpoint '{}'" \
" that is not registered.".format(endpoint_name)
assert isinstance(traffic_policy_dictionary,
dict), "Traffic policy must be dictionary"
prob = 0
existing_backends = set(self.get_all_backends())
for backend, weight in traffic_policy_dictionary.items():
prob += weight
assert backend in existing_backends, \
"Attempted to assign traffic to a backend '{}' that " \
"is not registered.".format(backend)
assert isinstance(traffic_policy_dictionary,
dict), "Traffic policy must be dictionary"
prob = 0
for backend, weight in traffic_policy_dictionary.items():
prob += weight
assert backend in self.backends, \
"Attempted to assign traffic to a backend '{}' that " \
"is not registered.".format(backend)
assert np.isclose(
prob, 1, atol=0.02
), "weights must sum to 1, currently they sum to {}".format(prob)
assert np.isclose(
prob, 1, atol=0.02
), "weights must sum to 1, currently they sum to {}".format(prob)
self.policy_table.register_traffic_policy(endpoint_name,
traffic_policy_dictionary)
[router] = self.get_router()
self.traffic_policies[endpoint_name] = traffic_policy_dictionary
await router.set_traffic.remote(endpoint_name,
traffic_policy_dictionary)
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
# update.
self._checkpoint()
await self.router.set_traffic.remote(endpoint_name,
traffic_policy_dictionary)
async def create_endpoint(self, route, endpoint_name, methods):
err_prefix = "Cannot create endpoint. "
assert route not in self.get_all_routes(), \
"{} Route '{}' is already registered.".format(err_prefix, route)
assert endpoint_name not in self.get_all_endpoints(), \
"{} Endpoint '{}' is already registered.".format(err_prefix,
endpoint_name)
self.route_table.register_service(
route, endpoint_name, methods=methods)
[http_proxy] = self.get_http_proxy()
async def create_endpoint(self, route, endpoint, methods):
"""Create a new endpoint with the specified route and methods.
await http_proxy.set_route_table.remote(
self.route_table.list_service(
include_methods=True, include_headless=False))
If the route is None, this is a "headless" endpoint that will not
be added to the HTTP proxy (can only be accessed via a handle).
"""
async with self.write_lock:
# If this is a headless service with no route, key the endpoint
# based on its name.
# TODO(edoakes): we should probably just store routes and endpoints
# separately.
if route is None:
route = endpoint
# TODO(edoakes): move this to client side.
err_prefix = "Cannot create endpoint."
if route in self.routes:
if self.routes[route] == (endpoint, methods):
return
else:
raise ValueError(
"{} Route '{}' is already registered.".format(
err_prefix, route))
if endpoint in self.get_all_endpoints():
raise ValueError(
"{} Endpoint '{}' is already registered.".format(
err_prefix, endpoint))
logger.info(
"Registering route {} to endpoint {} with methods {}.".format(
route, endpoint, methods))
self.routes[route] = (endpoint, methods)
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
# update.
self._checkpoint()
await self.http_proxy.set_route_table.remote(self.routes)
async def create_backend(self, backend_tag, backend_config, func_or_class,
actor_init_args):
backend_config_dict = dict(backend_config)
backend_worker = create_backend_worker(func_or_class)
"""Register a new backend under the specified tag."""
async with self.write_lock:
backend_config_dict = dict(backend_config)
backend_worker = create_backend_worker(func_or_class)
assert backend_tag not in self.get_all_backends(), \
"Cannot create backend '{}' because a backend" \
"with that name already exists.".format(backend_tag)
# Save creator that starts replicas, the arguments to be passed in,
# and the configuration for the backends.
self.backends[backend_tag] = (backend_worker, actor_init_args,
backend_config_dict)
# Save creator which starts replicas.
self.backend_table.register_backend(backend_tag, backend_worker)
self._scale_replicas(backend_tag,
backend_config_dict["num_replicas"])
# Save information about configurations needed to start the replicas.
self.backend_table.register_info(backend_tag, backend_config_dict)
# NOTE(edoakes): we must write a checkpoint before starting new
# or pushing the updated config to avoid inconsistent state if we
# crash while making the change.
self._checkpoint()
await self._start_pending_replicas()
await self._stop_pending_replicas()
# Save the initial arguments needed by replicas.
self.backend_table.save_init_args(backend_tag, actor_init_args)
# Set the backend config inside the router
# (particularly for max-batch-size).
[router] = self.get_router()
await router.set_backend_config.remote(backend_tag,
backend_config_dict)
await self.scale_replicas(backend_tag,
backend_config_dict["num_replicas"])
# Set the backend config inside the router
# (particularly for max-batch-size).
await self.router.set_backend_config.remote(
backend_tag, backend_config_dict)
async def set_backend_config(self, backend_tag, backend_config):
assert (backend_tag in self.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
assert isinstance(backend_config,
BackendConfig), ("backend_config must be"
" of instance BackendConfig")
backend_config_dict = dict(backend_config)
old_backend_config_dict = self.backend_table.get_info(backend_tag)
"""Set the config for the specified backend."""
async with self.write_lock:
assert (backend_tag in self.backends
), "Backend {} is not registered.".format(backend_tag)
assert isinstance(backend_config,
BackendConfig), ("backend_config must be"
" of instance BackendConfig")
backend_config_dict = dict(backend_config)
backend_worker, init_args, old_backend_config_dict = self.backends[
backend_tag]
if (not old_backend_config_dict["has_accept_batch_annotation"]
and backend_config.max_batch_size is not None):
raise batch_annotation_not_found
if (not old_backend_config_dict["has_accept_batch_annotation"]
and backend_config.max_batch_size is not None):
raise batch_annotation_not_found
self.backend_table.register_info(backend_tag, backend_config_dict)
self.backends[backend_tag] = (backend_worker, init_args,
backend_config_dict)
# Inform the router about change in configuration
# (particularly for setting max_batch_size).
[router] = self.get_router()
await router.set_backend_config.remote(backend_tag,
backend_config_dict)
# Restart replicas if there is a change in the backend config
# related to restart_configs.
need_to_restart_replicas = any(
old_backend_config_dict[k] != backend_config_dict[k]
for k in BackendConfig.restart_on_change_fields)
if need_to_restart_replicas:
# Kill all the replicas for restarting with new configurations.
self._scale_replicas(backend_tag, 0)
# Restart replicas if there is a change in the backend config related
# to restart_configs.
need_to_restart_replicas = any(
old_backend_config_dict[k] != backend_config_dict[k]
for k in BackendConfig.restart_on_change_fields)
if need_to_restart_replicas:
# Kill all the replicas for restarting with new configurations.
await self.scale_replicas(backend_tag, 0)
# Scale the replicas with the new configuration.
self._scale_replicas(backend_tag,
backend_config_dict["num_replicas"])
# Scale the replicas with the new configuration.
await self.scale_replicas(backend_tag,
backend_config_dict["num_replicas"])
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
# update.
self._checkpoint()
# Inform the router about change in configuration
# (particularly for setting max_batch_size).
await self.router.set_backend_config.remote(
backend_tag, backend_config_dict)
await self._start_pending_replicas()
await self._stop_pending_replicas()
def get_backend_config(self, backend_tag):
assert (backend_tag in self.backend_table.list_backends()
"""Get the current config for the specified backend."""
assert (backend_tag in self.backends
), "Backend {} is not registered.".format(backend_tag)
backend_config_dict = self.backend_table.get_info(backend_tag)
return BackendConfig(**backend_config_dict)
return BackendConfig(**self.backends[backend_tag][2])
+4 -4
View File
@@ -24,9 +24,6 @@ class MetricMonitor:
self.gc_window_seconds = gc_window_seconds
self.latest_gc_time = time.time()
def is_ready(self):
return True
def add_target(self, target_handle):
hex_id = target_handle._actor_id.hex()
self.actor_handles[hex_id] = target_handle
@@ -153,5 +150,8 @@ class MetricMonitor:
@ray.remote(num_cpus=0)
def start_metric_monitor_loop(monitor_handle, duration_s=5):
while True:
ray.get(monitor_handle.scrape.remote())
time.sleep(duration_s)
try:
ray.get(monitor_handle.scrape.remote())
except ray.exceptions.RayActorError:
pass
+61 -26
View File
@@ -1,6 +1,7 @@
import asyncio
import copy
from collections import defaultdict
import time
from typing import DefaultDict, List
# Note on choosing blist instead of stdlib heapq
@@ -13,7 +14,7 @@ import blist
import ray
import ray.cloudpickle as pickle
from ray.serve.utils import logger
from ray.serve.utils import logger, retry_actor_failures
class Query:
@@ -138,6 +139,8 @@ class Router:
self.traffic = defaultdict(dict)
# backend_name -> backend_config
self.backend_info = dict()
# replica tag -> worker_handle
self.replicas = dict()
# -- Synchronization -- #
@@ -150,15 +153,28 @@ class Router:
# batching polcies.
self.flush_lock = asyncio.Lock()
# Fetch the worker handles from the master actor. We use a "pull-based"
# approach instead of pushing them from the master so that the router
# can transparently recover from failure.
# Fetch the worker handles, traffic policies, and backend configs from
# the master actor. We use a "pull-based" approach instead of pushing
# them from the master so that the router can transparently recover
# from failure.
ray.serve.init()
master_actor = ray.serve.api._get_master_actor()
backend_dict = ray.get(master_actor.get_all_worker_handles.remote())
for backend, replica_dict in backend_dict.items():
for worker in replica_dict.values():
await self.add_new_worker(backend, worker)
traffic_policies = retry_actor_failures(
master_actor.get_traffic_policies)
for endpoint, traffic_policy in traffic_policies.items():
await self.set_traffic(endpoint, traffic_policy)
backend_dict = retry_actor_failures(
master_actor.get_all_worker_handles)
for backend_tag, replica_dict in backend_dict.items():
for replica_tag, worker in replica_dict.items():
await self.add_new_worker(backend_tag, replica_tag, worker)
backend_configs = retry_actor_failures(
master_actor.get_backend_configs)
for backend, backend_config_dict in backend_configs.items():
await self.set_backend_config(backend, backend_config_dict)
def is_ready(self):
return True
@@ -199,28 +215,43 @@ class Router:
result = await query.async_future
return result
async def add_new_worker(self, backend, worker_handle):
logger.debug("New worker added for backend '{}'".format(backend))
await self.mark_worker_idle(backend, worker_handle)
async def add_new_worker(self, backend_tag, replica_tag, worker_handle):
backend_replica_tag = backend_tag + ":" + replica_tag
if backend_replica_tag in self.replicas:
return
self.replicas[backend_replica_tag] = worker_handle
async def mark_worker_idle(self, backend, worker_handle):
await self.worker_queues[backend].put(worker_handle)
logger.debug("New worker added for backend '{}'".format(backend_tag))
# await worker_handle.ready.remote()
await self.mark_worker_idle(backend_tag, backend_replica_tag)
async def mark_worker_idle(self, backend_tag, backend_replica_tag):
if backend_replica_tag not in self.replicas:
return
await self.worker_queues[backend_tag].put(backend_replica_tag)
await self.flush()
async def remove_worker(self, backend, worker_handle):
async def remove_worker(self, backend_tag, replica_tag):
backend_replica_tag = backend_tag + ":" + replica_tag
if backend_replica_tag not in self.replicas:
return
worker_handle = self.replicas.pop(backend_replica_tag)
# We need this lock because we modify worker_queue here.
async with self.flush_lock:
old_queue = self.worker_queues[backend]
old_queue = self.worker_queues[backend_tag]
new_queue = asyncio.Queue()
target_id = worker_handle._actor_id
while not old_queue.empty():
worker_handle = await old_queue.get()
if worker_handle._actor_id != target_id:
await new_queue.put(worker_handle)
curr_tag = await old_queue.get()
if curr_tag != backend_replica_tag:
await new_queue.put(curr_tag)
self.worker_queues[backend] = new_queue
# TODO: consider awaiting this on a timeout or using ray.kill().
self.worker_queues[backend_tag] = new_queue
# We need to terminate the worker here instead of from the master
# so we can guarantee that the router won't submit any more tasks
# on it.
worker_handle.__ray_terminate__.remote()
async def link(self, service, backend):
@@ -297,11 +328,15 @@ class Router:
await self._assign_query_to_worker(
backend, buffer_queue, worker_queue, max_batch_size)
async def _do_query(self, backend, worker, req):
async def _do_query(self, backend, backend_replica_tag, req):
# If the worker died, this will be a RayActorError. Just return it and
# let the HTTP proxy handle the retry logic.
logger.debug("Sending query to replica:" + backend_replica_tag)
start = time.time()
worker = self.replicas[backend_replica_tag]
result = await worker.handle_request.remote(req)
await self.mark_worker_idle(backend, worker)
await self.mark_worker_idle(backend, backend_replica_tag)
logger.debug("Got result in {:.2f}s", time.time() - start)
return result
async def _assign_query_to_worker(self,
@@ -311,11 +346,11 @@ class Router:
max_batch_size=None):
while len(buffer_queue) and worker_queue.qsize():
worker = await worker_queue.get()
backend_replica_tag = await worker_queue.get()
if max_batch_size is None: # No batching
request = buffer_queue.pop(0)
future = asyncio.get_event_loop().create_task(
self._do_query(backend, worker, request))
self._do_query(backend, backend_replica_tag, request))
# chaining satisfies request.async_future with future result.
asyncio.futures._chain_future(future, request.async_future)
else:
@@ -331,7 +366,7 @@ class Router:
for group in requests_group.values():
future = asyncio.get_event_loop().create_task(
self._do_query(backend, worker, group))
self._do_query(backend, backend_replica_tag, group))
future.add_done_callback(
_make_future_unwrapper(
client_futures=[req.async_future for req in group],
+6
View File
@@ -6,6 +6,12 @@ import pytest
import ray
from ray import serve
# TODO(edoakes): the failure tests currently fail with the GCS service enabled.
os.environ["RAY_GCS_SERVICE_ENABLED"] = "false"
if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False):
serve.master._CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.5
@pytest.fixture(scope="session")
def serve_instance():
+4 -11
View File
@@ -75,24 +75,17 @@ def test_no_route(serve_instance):
assert result == 1
def test_reject_duplicate_backend_tag(serve_instance):
backend_name = "foo"
serve.create_backend(lambda foo: foo, backend_name)
with pytest.raises(AssertionError):
serve.create_backend(lambda foo: foo, backend_name)
def test_reject_duplicate_route(serve_instance):
route = "/foo"
serve.create_endpoint("bar", route=route)
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
serve.create_endpoint("foo", route=route)
def test_reject_duplicate_endpoint(serve_instance):
endpoint_name = "foo"
serve.create_endpoint(endpoint_name, route="/ok")
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
serve.create_endpoint(endpoint_name, route="/different")
@@ -102,9 +95,9 @@ def test_set_traffic_missing_data(serve_instance):
serve.create_endpoint(endpoint_name)
serve.create_backend(lambda: 5, backend_name)
with pytest.raises(AssertionError):
serve.set_traffic(endpoint_name, {"nonexistant_backend": 1.0})
serve.set_traffic(endpoint_name, {"nonexistent_backend": 1.0})
with pytest.raises(AssertionError):
serve.set_traffic("nonexistant_endpoint_name", {backend_name: 1.0})
serve.set_traffic("nonexistent_endpoint_name", {backend_name: 1.0})
def test_scaling_replicas(serve_instance):
+14 -17
View File
@@ -13,15 +13,15 @@ from ray.serve.backend_config import BackendConfig
pytestmark = pytest.mark.asyncio
def setup_worker(name, func_or_class, router_handle, init_args=None):
def setup_worker(name, func_or_class, init_args=None):
if init_args is None:
init_args = ()
@ray.remote
class WorkerActor:
def __init__(self, router_handle):
def __init__(self):
self.worker = create_backend_worker(func_or_class)(
name, name + ":tag", init_args, router_handle=router_handle[0])
name, name + ":tag", init_args)
def ready(self):
pass
@@ -29,13 +29,10 @@ def setup_worker(name, func_or_class, router_handle, init_args=None):
def get_metrics(self):
return self.worker.get_metrics()
def run(self):
self.worker.backend.mark_idle_in_router()
async def handle_request(self, *args, **kwargs):
return await self.worker.handle_request(*args, **kwargs)
worker = WorkerActor.remote([router_handle])
worker = WorkerActor.remote()
ray.get(worker.ready.remote())
return worker
@@ -54,8 +51,8 @@ async def test_runner_actor(serve_instance):
CONSUMER_NAME = "runner"
PRODUCER_NAME = "prod"
worker = setup_worker(CONSUMER_NAME, echo, q)
await q.add_new_worker.remote(CONSUMER_NAME, worker)
worker = setup_worker(CONSUMER_NAME, echo)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
@@ -79,8 +76,8 @@ async def test_ray_serve_mixin(serve_instance):
def __call__(self, flask_request, i=None):
return i + self.increment
worker = setup_worker(CONSUMER_NAME, MyAdder, q, init_args=(3, ))
await q.add_new_worker.remote(CONSUMER_NAME, worker)
worker = setup_worker(CONSUMER_NAME, MyAdder, init_args=(3, ))
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
@@ -101,8 +98,8 @@ async def test_task_runner_check_context(serve_instance):
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
worker = setup_worker(CONSUMER_NAME, echo, q)
await q.add_new_worker.remote(CONSUMER_NAME, worker)
worker = setup_worker(CONSUMER_NAME, echo)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
@@ -125,8 +122,8 @@ async def test_task_runner_custom_method_single(serve_instance):
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
worker = setup_worker(CONSUMER_NAME, NonBatcher, q)
await q.add_new_worker.remote(CONSUMER_NAME, worker)
worker = setup_worker(CONSUMER_NAME, NonBatcher)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
@@ -160,7 +157,7 @@ async def test_task_runner_custom_method_batch(serve_instance):
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
worker = setup_worker(CONSUMER_NAME, Batcher, q)
worker = setup_worker(CONSUMER_NAME, Batcher)
await q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
await q.set_backend_config.remote(
@@ -174,7 +171,7 @@ async def test_task_runner_custom_method_batch(serve_instance):
futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]
await q.add_new_worker.remote(CONSUMER_NAME, worker)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
gathered = await asyncio.gather(*futures)
assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
+62 -9
View File
@@ -19,6 +19,57 @@ def request_with_retries(endpoint, timeout=30):
time.sleep(0.1)
def test_master_failure(serve_instance):
serve.init()
serve.create_endpoint("master_failure", "/master_failure")
def function():
return "hello1"
serve.create_backend(function, "master_failure:v1")
serve.set_traffic("master_failure", {"master_failure:v1": 1.0})
assert request_with_retries("/master_failure", timeout=1).text == "hello1"
for _ in range(10):
response = request_with_retries("/master_failure", timeout=30)
assert response.text == "hello1"
ray.kill(serve.api._get_master_actor())
for _ in range(10):
response = request_with_retries("/master_failure", timeout=30)
assert response.text == "hello1"
def function():
return "hello2"
ray.kill(serve.api._get_master_actor())
serve.create_backend(function, "master_failure:v2")
serve.set_traffic("master_failure", {"master_failure:v2": 1.0})
for _ in range(10):
response = request_with_retries("/master_failure", timeout=30)
assert response.text == "hello2"
def function():
return "hello3"
ray.kill(serve.api._get_master_actor())
serve.create_endpoint("master_failure_2", "/master_failure_2")
ray.kill(serve.api._get_master_actor())
serve.create_backend(function, "master_failure_2")
ray.kill(serve.api._get_master_actor())
serve.set_traffic("master_failure_2", {"master_failure_2": 1.0})
for _ in range(10):
response = request_with_retries("/master_failure", timeout=30)
assert response.text == "hello2"
response = request_with_retries("/master_failure_2", timeout=30)
assert response.text == "hello3"
def _kill_http_proxy():
[http_proxy] = ray.get(
serve.api._get_master_actor().get_http_proxy.remote())
@@ -27,7 +78,7 @@ def _kill_http_proxy():
def test_http_proxy_failure(serve_instance):
serve.init()
serve.create_endpoint("proxy_failure", "/proxy_failure", methods=["GET"])
serve.create_endpoint("proxy_failure", "/proxy_failure")
def function():
return "hello1"
@@ -35,7 +86,7 @@ def test_http_proxy_failure(serve_instance):
serve.create_backend(function, "proxy_failure:v1")
serve.set_traffic("proxy_failure", {"proxy_failure:v1": 1.0})
assert request_with_retries("/proxy_failure", timeout=0.1).text == "hello1"
assert request_with_retries("/proxy_failure", timeout=1.0).text == "hello1"
for _ in range(10):
response = request_with_retries("/proxy_failure", timeout=30)
@@ -61,7 +112,7 @@ def _kill_router():
def test_router_failure(serve_instance):
serve.init()
serve.create_endpoint("router_failure", "/router_failure", methods=["GET"])
serve.create_endpoint("router_failure", "/router_failure")
def function():
return "hello1"
@@ -77,6 +128,10 @@ def test_router_failure(serve_instance):
_kill_router()
for _ in range(10):
response = request_with_retries("/router_failure", timeout=30)
assert response.text == "hello1"
def function():
return "hello2"
@@ -99,7 +154,7 @@ def _get_worker_handles(backend):
# serving requests.
def test_worker_restart(serve_instance):
serve.init()
serve.create_endpoint("worker_failure", "/worker_failure", methods=["GET"])
serve.create_endpoint("worker_failure", "/worker_failure")
class Worker1:
def __call__(self):
@@ -109,7 +164,7 @@ def test_worker_restart(serve_instance):
serve.set_traffic("worker_failure", {"worker_failure:v1": 1.0})
# Get the PID of the worker.
old_pid = request_with_retries("/worker_failure", timeout=0.1).text
old_pid = request_with_retries("/worker_failure", timeout=1).text
# Kill the worker.
handles = _get_worker_handles("worker_failure:v1")
@@ -131,8 +186,7 @@ def test_worker_restart(serve_instance):
def test_worker_replica_failure(serve_instance):
serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0
serve.init()
serve.create_endpoint(
"replica_failure", "/replica_failure", methods=["GET"])
serve.create_endpoint("replica_failure", "/replica_failure")
class Worker:
# Assumes that two replicas are started. Will hang forever in the
@@ -169,8 +223,7 @@ def test_worker_replica_failure(serve_instance):
# Wait until both replicas have been started.
responses = set()
while len(responses) == 1:
responses.add(
request_with_retries("/replica_failure", timeout=0.1).text)
responses.add(request_with_retries("/replica_failure", timeout=1).text)
time.sleep(0.1)
# Kill one of the replicas.
@@ -0,0 +1,19 @@
import os
import pytest
from pathlib import Path
import sys
if __name__ == "__main__":
curr_dir = Path(__file__).parent
test_paths = curr_dir.rglob("test_*.py")
sorted_path = sorted(map(lambda path: str(path.absolute()), test_paths))
serve_tests_files = list(sorted_path)
print("Testing the following files")
for test_file in serve_tests_files:
print("->", test_file.split("/")[-1])
print("Setting RAY_SERVE_INTENTIONALLY_CRASH=1")
os.environ["RAY_SERVE_INTENTIONALLY_CRASH"] = "1"
sys.exit(pytest.main(["-v", "-s"] + serve_tests_files))
+4 -4
View File
@@ -6,15 +6,15 @@ from ray import serve
def test_nonblocking():
serve.init()
serve.create_endpoint("endpoint", "/api")
serve.create_endpoint("nonblocking", "/nonblocking")
def function(flask_request):
return {"method": flask_request.method}
serve.create_backend(function, "echo:v1")
serve.set_traffic("endpoint", {"echo:v1": 1.0})
serve.create_backend(function, "nonblocking:v1")
serve.set_traffic("nonblocking", {"nonblocking:v1": 1.0})
resp = requests.get("http://127.0.0.1:8000/api").json()["method"]
resp = requests.get("http://127.0.0.1:8000/nonblocking").json()["method"]
assert resp == "GET"
+40 -24
View File
@@ -29,6 +29,9 @@ def make_task_runner_mock():
def get_all_calls(self):
return self.queries
def ready(self):
pass
return TaskRunnerMock.remote()
@@ -39,8 +42,9 @@ def task_runner_mock_actor():
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
q = RandomPolicyQueueActor.remote()
q.link.remote("svc", "backend")
q.add_new_worker.remote("backend", task_runner_mock_actor)
q.link.remote("svc", "backend-single-prod")
q.add_new_worker.remote("backend-single-prod", "replica-1",
task_runner_mock_actor)
# Make sure we get the request result back
result = await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
@@ -54,7 +58,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
async def test_slo(serve_instance, task_runner_mock_actor):
q = RandomPolicyQueueActor.remote()
await q.link.remote("svc", "backend")
await q.link.remote("svc", "backend-slo")
all_request_sent = []
for i in range(10):
@@ -63,7 +67,8 @@ async def test_slo(serve_instance, task_runner_mock_actor):
q.enqueue_request.remote(
RequestMetadata("svc", None, relative_slo_ms=slo_ms), i))
await q.add_new_worker.remote("backend", task_runner_mock_actor)
await q.add_new_worker.remote("backend-slo", "replica-1",
task_runner_mock_actor)
await asyncio.gather(*all_request_sent)
@@ -78,14 +83,16 @@ async def test_slo(serve_instance, task_runner_mock_actor):
async def test_alter_backend(serve_instance, task_runner_mock_actor):
q = RandomPolicyQueueActor.remote()
await q.set_traffic.remote("svc", {"backend-1": 1})
await q.add_new_worker.remote("backend-1", task_runner_mock_actor)
await q.set_traffic.remote("svc", {"backend-alter": 1})
await q.add_new_worker.remote("backend-alter", "replica-1",
task_runner_mock_actor)
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
got_work = await task_runner_mock_actor.get_recent_call.remote()
assert got_work.request_args[0] == 1
await q.set_traffic.remote("svc", {"backend-2": 1})
await q.add_new_worker.remote("backend-2", task_runner_mock_actor)
await q.set_traffic.remote("svc", {"backend-alter-2": 1})
await q.add_new_worker.remote("backend-alter-2", "replica-1",
task_runner_mock_actor)
await q.enqueue_request.remote(RequestMetadata("svc", None), 2)
got_work = await task_runner_mock_actor.get_recent_call.remote()
assert got_work.request_args[0] == 2
@@ -94,10 +101,13 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
q = RandomPolicyQueueActor.remote()
await q.set_traffic.remote("svc", {"backend-1": 0.5, "backend-2": 0.5})
await q.set_traffic.remote("svc", {
"backend-split": 0.5,
"backend-split-2": 0.5
})
runner_1, runner_2 = [make_task_runner_mock() for _ in range(2)]
await q.add_new_worker.remote("backend-1", runner_1)
await q.add_new_worker.remote("backend-2", runner_2)
await q.add_new_worker.remote("backend-split", "replica-1", runner_1)
await q.add_new_worker.remote("backend-split-2", "replica-1", runner_2)
# assume 50% split, the probability of all 20 requests goes to a
# single queue is 0.5^20 ~ 1-6
@@ -114,13 +124,13 @@ async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
async def test_round_robin(serve_instance, task_runner_mock_actor):
q = RoundRobinPolicyQueueActor.remote()
await q.set_traffic.remote("svc", {"backend-1": 0.5, "backend-2": 0.5})
await q.set_traffic.remote("svc", {"backend-rr": 0.5, "backend-rr-2": 0.5})
runner_1, runner_2 = [make_task_runner_mock() for _ in range(2)]
# NOTE: this is the only difference between the
# test_split_traffic_random and test_round_robin
await q.add_new_worker.remote("backend-1", runner_1)
await q.add_new_worker.remote("backend-2", runner_2)
await q.add_new_worker.remote("backend-rr", "replica-1", runner_1)
await q.add_new_worker.remote("backend-rr-2", "replica-1", runner_2)
for _ in range(20):
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
@@ -135,13 +145,16 @@ async def test_round_robin(serve_instance, task_runner_mock_actor):
async def test_fixed_packing(serve_instance):
packing_num = 4
q = FixedPackingPolicyQueueActor.remote(packing_num=packing_num)
await q.set_traffic.remote("svc", {"backend-1": 0.5, "backend-2": 0.5})
await q.set_traffic.remote("svc", {
"backend-fixed": 0.5,
"backend-fixed-2": 0.5
})
runner_1, runner_2 = (make_task_runner_mock() for _ in range(2))
# both the backends will get equal number of queries
# as it is packed round robin
await q.add_new_worker.remote("backend-1", runner_1)
await q.add_new_worker.remote("backend-2", runner_2)
await q.add_new_worker.remote("backend-fixed", "replica-1", runner_1)
await q.add_new_worker.remote("backend-fixed-2", "replica-1", runner_2)
for backend, runner in zip(["1", "2"], [runner_1, runner_2]):
for _ in range(packing_num):
@@ -158,20 +171,23 @@ async def test_power_of_two_choices(serve_instance):
enqueue_futures = []
# First, fill the queue for backend-1 with 3 requests
await q.set_traffic.remote("svc", {"backend-1": 1.0})
await q.set_traffic.remote("svc", {"backend-pow2": 1.0})
for _ in range(3):
future = q.enqueue_request.remote(RequestMetadata("svc", None), "1")
enqueue_futures.append(future)
# Then, add a new backend, this backend should be filled next
await q.set_traffic.remote("svc", {"backend-1": 0.5, "backend-2": 0.5})
await q.set_traffic.remote("svc", {
"backend-pow2": 0.5,
"backend-pow2-2": 0.5
})
for _ in range(2):
future = q.enqueue_request.remote(RequestMetadata("svc", None), "2")
enqueue_futures.append(future)
runner_1, runner_2 = (make_task_runner_mock() for _ in range(2))
await q.add_new_worker.remote("backend-1", runner_1)
await q.add_new_worker.remote("backend-2", runner_2)
await q.add_new_worker.remote("backend-pow2", "replica-1", runner_1)
await q.add_new_worker.remote("backend-pow2-2", "replica-1", runner_2)
await asyncio.gather(*enqueue_futures)
@@ -183,10 +199,10 @@ async def test_queue_remove_replicas(serve_instance):
@ray.remote
class TestRandomPolicyQueueActor(RandomPolicyQueue):
def worker_queue_size(self, backend):
return self.worker_queues["backend"].qsize()
return self.worker_queues["backend-remove"].qsize()
temp_actor = make_task_runner_mock()
q = TestRandomPolicyQueueActor.remote()
await q.add_new_worker.remote("backend", temp_actor)
await q.remove_worker.remote("backend", temp_actor)
await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor)
await q.remove_worker.remote("backend-remove", "replica-1")
assert ray.get(q.worker_queue_size.remote("backend")) == 0
+71 -14
View File
@@ -1,3 +1,6 @@
import asyncio
from functools import wraps
import inspect
import json
import logging
import random
@@ -6,11 +9,11 @@ import time
import io
import os
import ray
import requests
from pygments import formatters, highlight, lexers
from ray.serve.context import FakeFlaskRequest, TaskContext
from ray.serve.http_util import build_flask_request
import itertools
import numpy as np
try:
@@ -18,19 +21,7 @@ try:
except ImportError:
pydantic = None
def expand(l):
"""
Implements a nested flattening of a list.
Example:
>>> serve.utils.expand([1,2,[3,4,5],6])
[1,2,3,4,5,6]
>>> serve.utils.expand(["a", ["b", "c"], "d", ["e", "f"]])
["a", "b", "c", "d", "e", "f"]
"""
return list(
itertools.chain.from_iterable(
[x if isinstance(x, list) else [x] for x in l]))
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
def parse_request_item(request_item):
@@ -116,3 +107,69 @@ def block_until_http_ready(http_endpoint, num_retries=5, backoff_time_s=1):
def get_random_letters(length=6):
return "".join(random.choices(string.ascii_letters, k=length))
def async_retryable(cls):
"""Make all actor method invocations on the class retryable.
Note: This will retry actor_handle.method_name.remote(), but it must
be invoked in an async context.
Usage:
@ray.remote(max_reconstructions=10000)
@async_retryable
class A:
pass
"""
for name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
def decorate_with_retry(f):
@wraps(f)
async def retry_method(*args, **kwargs):
start = time.time()
while time.time() - start < ACTOR_FAILURE_RETRY_TIMEOUT_S:
try:
return await f(*args, **kwargs)
except ray.exceptions.RayActorError:
logger.warning(
"Actor method '{}' failed, retrying after 100ms.".
format(name))
await asyncio.sleep(0.1)
raise RuntimeError("Timed out after {}s waiting for actor "
"method '{}' to succeed.".format(
ACTOR_FAILURE_RETRY_TIMEOUT_S, name))
return retry_method
method.__ray_invocation_decorator__ = decorate_with_retry
return cls
def retry_actor_failures(f, *args, **kwargs):
start = time.time()
while time.time() - start < ACTOR_FAILURE_RETRY_TIMEOUT_S:
try:
return ray.get(f.remote(*args, **kwargs))
except ray.exceptions.RayActorError:
logger.warning(
"Actor method '{}' failed, retrying after 100ms".format(
f._method_name))
time.sleep(0.1)
raise RuntimeError("Timed out after {}s waiting for actor "
"method '{}' to succeed.".format(
ACTOR_FAILURE_RETRY_TIMEOUT_S, f._method_name))
async def retry_actor_failures_async(f, *args, **kwargs):
start = time.time()
while time.time() - start < ACTOR_FAILURE_RETRY_TIMEOUT_S:
try:
return await f.remote(*args, **kwargs)
except ray.exceptions.RayActorError:
logger.warning(
"Actor method '{}' failed, retrying after 100ms".format(
f._method_name))
await asyncio.sleep(0.1)
raise RuntimeError("Timed out after {}s waiting for actor "
"method '{}' to succeed.".format(
ACTOR_FAILURE_RETRY_TIMEOUT_S, f._method_name))