mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 12:10:40 +08:00
[serve] Master actor fault tolerance (#8116)
This commit is contained in:
+14
-1
@@ -13,11 +13,24 @@ py_library(
|
||||
py_test(
|
||||
name = "test_serve",
|
||||
size = "medium",
|
||||
srcs = glob(["tests/*.py"], exclude=["tests/test_nonblocking.py"]),
|
||||
srcs = glob(["tests/*.py"],
|
||||
exclude=["tests/test_nonblocking.py",
|
||||
"tests/test_master_crashes.py"]),
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
# Runs test_api and test_failure with injected failures in the master actor.
|
||||
py_test(
|
||||
name = "test_master_crashes",
|
||||
size = "medium",
|
||||
srcs = glob(["tests/test_master_crashes.py",
|
||||
"tests/test_api.py",
|
||||
"tests/test_failure.py"],
|
||||
exclude=["tests/test_nonblocking.py",
|
||||
"tests/test_serve.py"]),
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "echo_full",
|
||||
size = "small",
|
||||
|
||||
+19
-25
@@ -10,7 +10,7 @@ from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
|
||||
from ray.serve.master import ServeMaster
|
||||
from ray.serve.handle import RayServeHandle
|
||||
from ray.serve.kv_store_service import SQLiteKVStore
|
||||
from ray.serve.utils import block_until_http_ready
|
||||
from ray.serve.utils import block_until_http_ready, retry_actor_failures
|
||||
from ray.serve.exceptions import RayServeException, batch_annotation_not_found
|
||||
from ray.serve.backend_config import BackendConfig
|
||||
from ray.serve.policy import RoutePolicy
|
||||
@@ -139,14 +139,11 @@ def init(
|
||||
return SQLiteKVStore(namespace, db_path=kv_store_path)
|
||||
|
||||
master_actor = ServeMaster.options(
|
||||
detached=True, name=SERVE_MASTER_NAME).remote(kv_store_connector)
|
||||
|
||||
ray.get(
|
||||
master_actor.start_router.remote(queueing_policy.value, policy_kwargs))
|
||||
|
||||
ray.get(master_actor.start_metric_monitor.remote(gc_window_seconds))
|
||||
if start_server:
|
||||
ray.get(master_actor.start_http_proxy.remote(http_host, http_port))
|
||||
detached=True,
|
||||
name=SERVE_MASTER_NAME,
|
||||
max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION,
|
||||
).remote(kv_store_connector, queueing_policy.value, policy_kwargs,
|
||||
start_server, http_host, http_port, gc_window_seconds)
|
||||
|
||||
if start_server and blocking:
|
||||
block_until_http_ready("http://{}:{}/-/routes".format(
|
||||
@@ -165,9 +162,8 @@ def create_endpoint(endpoint_name, route=None, methods=["GET"]):
|
||||
blocking (bool): If true, the function will wait for service to be
|
||||
registered before returning
|
||||
"""
|
||||
ray.get(
|
||||
master_actor.create_endpoint.remote(route, endpoint_name,
|
||||
[m.upper() for m in methods]))
|
||||
retry_actor_failures(master_actor.create_endpoint, route, endpoint_name,
|
||||
[m.upper() for m in methods])
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
@@ -178,8 +174,8 @@ def set_backend_config(backend_tag, backend_config):
|
||||
backend_tag(str): A registered backend.
|
||||
backend_config(BackendConfig) : Desired backend configuration.
|
||||
"""
|
||||
ray.get(
|
||||
master_actor.set_backend_config.remote(backend_tag, backend_config))
|
||||
retry_actor_failures(master_actor.set_backend_config, backend_tag,
|
||||
backend_config)
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
@@ -189,7 +185,7 @@ def get_backend_config(backend_tag):
|
||||
Args:
|
||||
backend_tag(str): A registered backend.
|
||||
"""
|
||||
return ray.get(master_actor.get_backend_config.remote(backend_tag))
|
||||
return retry_actor_failures(master_actor.get_backend_config, backend_tag)
|
||||
|
||||
|
||||
def _backend_accept_batch(func_or_class):
|
||||
@@ -240,9 +236,8 @@ def create_backend(func_or_class,
|
||||
if _backend_accept_batch(func_or_class):
|
||||
backend_config.has_accept_batch_annotation = True
|
||||
|
||||
ray.get(
|
||||
master_actor.create_backend.remote(backend_tag, backend_config,
|
||||
func_or_class, actor_init_args))
|
||||
retry_actor_failures(master_actor.create_backend, backend_tag,
|
||||
backend_config, func_or_class, actor_init_args)
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
@@ -261,9 +256,8 @@ def set_traffic(endpoint_name, traffic_policy_dictionary):
|
||||
traffic_policy_dictionary (dict): a dictionary maps backend names
|
||||
to their traffic weights. The weights must sum to 1.
|
||||
"""
|
||||
ray.get(
|
||||
master_actor.set_traffic.remote(endpoint_name,
|
||||
traffic_policy_dictionary))
|
||||
retry_actor_failures(master_actor.set_traffic, endpoint_name,
|
||||
traffic_policy_dictionary)
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
@@ -286,11 +280,11 @@ def get_handle(endpoint_name,
|
||||
RayServeHandle
|
||||
"""
|
||||
if not missing_ok:
|
||||
assert endpoint_name in ray.get(
|
||||
master_actor.get_all_endpoints.remote())
|
||||
assert endpoint_name in retry_actor_failures(
|
||||
master_actor.get_all_endpoints)
|
||||
|
||||
return RayServeHandle(
|
||||
ray.get(master_actor.get_router.remote())[0],
|
||||
retry_actor_failures(master_actor.get_router)[0],
|
||||
endpoint_name,
|
||||
relative_slo_ms,
|
||||
absolute_slo_ms,
|
||||
@@ -309,5 +303,5 @@ def stat(percentiles=[50, 90, 95],
|
||||
The longest aggregation window must be shorter or equal to the
|
||||
gc_window_seconds.
|
||||
"""
|
||||
[monitor] = ray.get(master_actor.get_metric_monitor.remote())
|
||||
[monitor] = retry_actor_failures(master_actor.get_metric_monitor)
|
||||
return ray.get(monitor.collect.remote(percentiles, agg_windows_seconds))
|
||||
|
||||
@@ -55,9 +55,8 @@ class BackendConfig:
|
||||
key = "num_replicas"
|
||||
yield key, val
|
||||
|
||||
def get_actor_creation_args(self, init_args):
|
||||
def get_actor_creation_args(self):
|
||||
ret_d = deepcopy(self.__dict__)
|
||||
for k in self._serve_configs:
|
||||
ret_d.pop(k)
|
||||
ret_d["args"] = init_args
|
||||
return ret_d
|
||||
|
||||
@@ -23,24 +23,14 @@ def create_backend_worker(func_or_class):
|
||||
assert False, "func_or_class must be function or class."
|
||||
|
||||
class RayServeWrappedWorker(object):
|
||||
def __init__(self,
|
||||
backend_tag,
|
||||
replica_tag,
|
||||
init_args,
|
||||
router_handle=None):
|
||||
def __init__(self, backend_tag, replica_tag, init_args):
|
||||
serve.init()
|
||||
if is_function:
|
||||
_callable = func_or_class
|
||||
else:
|
||||
_callable = func_or_class(*init_args)
|
||||
|
||||
if router_handle is None:
|
||||
master_actor = serve.api._get_master_actor()
|
||||
[router_handle] = ray.get(
|
||||
master_actor.get_backend_worker_config.remote())
|
||||
|
||||
self.backend = RayServeWorker(backend_tag, _callable,
|
||||
router_handle, is_function)
|
||||
self.backend = RayServeWorker(backend_tag, _callable, is_function)
|
||||
|
||||
def get_metrics(self):
|
||||
return self.backend.get_metrics()
|
||||
@@ -76,10 +66,9 @@ def ensure_async(func):
|
||||
class RayServeWorker:
|
||||
"""Handles requests with the provided callable."""
|
||||
|
||||
def __init__(self, name, _callable, router_handle, is_function):
|
||||
def __init__(self, name, _callable, is_function):
|
||||
self.name = name
|
||||
self.callable = _callable
|
||||
self.router_handle = router_handle
|
||||
self.is_function = is_function
|
||||
|
||||
self.error_counter = 0
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
#: Actor name used to register master actor
|
||||
SERVE_MASTER_NAME = "SERVE_MASTER_ACTOR"
|
||||
|
||||
#: Actor name used to register router actor
|
||||
SERVE_ROUTER_NAME = "SERVE_ROUTER_ACTOR"
|
||||
|
||||
#: Actor name used to register HTTP proxy actor
|
||||
SERVE_PROXY_NAME = "SERVE_PROXY_ACTOR"
|
||||
|
||||
#: Actor name used to register metric monitor actor
|
||||
SERVE_METRIC_MONITOR_NAME = "SERVE_METRIC_MONITOR_ACTOR"
|
||||
|
||||
#: HTTP Address
|
||||
DEFAULT_HTTP_ADDRESS = "http://127.0.0.1:8000"
|
||||
|
||||
@@ -15,6 +24,3 @@ ASYNC_CONCURRENCY = int(1e6)
|
||||
|
||||
#: Default latency SLO
|
||||
DEFAULT_LATENCY_SLO_MS = 1e9
|
||||
|
||||
#: Key for storing no http route services
|
||||
NO_ROUTE_KEY = "NO_ROUTE"
|
||||
|
||||
@@ -8,7 +8,7 @@ from ray.serve.constants import SERVE_MASTER_NAME
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.http_util import Response
|
||||
from ray.serve.utils import logger
|
||||
from ray.serve.utils import logger, retry_actor_failures_async
|
||||
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
@@ -29,8 +29,9 @@ class HTTPProxy:
|
||||
async def fetch_config_from_master(self):
|
||||
assert ray.is_initialized()
|
||||
master = ray.util.get_actor(SERVE_MASTER_NAME)
|
||||
self.route_table, [self.router_handle
|
||||
] = await master.get_http_proxy_config.remote()
|
||||
self.route_table, [
|
||||
self.router_handle
|
||||
] = await retry_actor_failures_async(master.get_http_proxy_config)
|
||||
|
||||
def set_route_table(self, route_table):
|
||||
self.route_table = route_table
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from abc import ABC
|
||||
from typing import Union, List
|
||||
|
||||
from ray import cloudpickle as pickle
|
||||
import ray.experimental.internal_kv as ray_kv
|
||||
from ray.serve.utils import logger
|
||||
from ray.serve.constants import NO_ROUTE_KEY
|
||||
|
||||
|
||||
class NamespacedKVStore(ABC):
|
||||
@@ -167,129 +163,3 @@ class SQLiteKVStore(NamespacedKVStore):
|
||||
result = list(
|
||||
cursor.execute("SELECT key, value FROM {}".format(self.namespace)))
|
||||
return dict(result)
|
||||
|
||||
|
||||
# Tables
|
||||
class RoutingTable:
|
||||
def __init__(self, kv_connector):
|
||||
self.routing_table = kv_connector("routing_table")
|
||||
self.methods_table = kv_connector("methods_table")
|
||||
self.request_count = 0
|
||||
|
||||
def register_service(self, route: Union[str, None], service: str,
|
||||
methods: List[str]):
|
||||
"""Create an entry in the routing table
|
||||
|
||||
Args:
|
||||
route: http path name. Must begin with '/'.
|
||||
service: service name. This is the name http actor will push
|
||||
the request to.
|
||||
"""
|
||||
logger.debug(
|
||||
"[KV] Registering route {} to service {} with methods {}.".format(
|
||||
route, service, methods))
|
||||
|
||||
# put no route services in default key
|
||||
if route is None:
|
||||
no_http_services = json.loads(
|
||||
self.routing_table.get(NO_ROUTE_KEY, "[]"))
|
||||
no_http_services.append(service)
|
||||
self.routing_table.put(NO_ROUTE_KEY, json.dumps(no_http_services))
|
||||
else:
|
||||
self.routing_table.put(route, service)
|
||||
self.methods_table.put(route, json.dumps(methods))
|
||||
|
||||
def list_service(self, include_headless=False, include_methods=False):
|
||||
"""Returns the routing table.
|
||||
Args:
|
||||
include_headless: If True, returns a no route services (headless)
|
||||
services with normal services. (Default: False)
|
||||
include_methods: If True, returns a mapping include the methods
|
||||
list for each route.
|
||||
"""
|
||||
table = self.routing_table.as_dict()
|
||||
if include_methods:
|
||||
methods_table = self.methods_table.as_dict()
|
||||
for route, methods in methods_table.items():
|
||||
if route in table:
|
||||
table[route] = (table[route], json.loads(methods))
|
||||
|
||||
if include_headless:
|
||||
table[NO_ROUTE_KEY] = json.loads(table.get(NO_ROUTE_KEY, "[]"))
|
||||
else:
|
||||
table.pop(NO_ROUTE_KEY, None)
|
||||
return table
|
||||
|
||||
def get_request_count(self):
|
||||
"""Return the number of requests that fetched the routing table.
|
||||
|
||||
This method is used for two purpose:
|
||||
|
||||
1. Make sure HTTP proxy has started and healthy. Incremented request
|
||||
count means HTTP proxy is actively fetching routing table.
|
||||
|
||||
2. Make sure HTTP proxy does not have stale routing table. This number
|
||||
should be incremented every HTTP_ROUTER_CHECKER_INTERVAL_S seconds.
|
||||
Supervisor should check this number as indirect indicator of http
|
||||
proxy's health.
|
||||
"""
|
||||
return self.request_count
|
||||
|
||||
|
||||
class BackendTable:
|
||||
def __init__(self, kv_connector):
|
||||
self.backend_table = kv_connector("backend_creator")
|
||||
self.replica_table = kv_connector("replica_table")
|
||||
self.backend_info = kv_connector("backend_info")
|
||||
self.backend_init_args = kv_connector("backend_init_args")
|
||||
|
||||
def register_backend(self, backend_tag: str, backend_creator):
|
||||
backend_creator_serialized = pickle.dumps(backend_creator)
|
||||
self.backend_table.put(backend_tag, backend_creator_serialized)
|
||||
|
||||
def save_init_args(self, backend_tag: str, arg_list):
|
||||
serialized_arg_list = pickle.dumps(arg_list)
|
||||
self.backend_init_args.put(backend_tag, serialized_arg_list)
|
||||
|
||||
def get_init_args(self, backend_tag):
|
||||
return pickle.loads(self.backend_init_args.get(backend_tag))
|
||||
|
||||
def register_info(self, backend_tag: str, backend_info_d):
|
||||
self.backend_info.put(backend_tag, json.dumps(backend_info_d))
|
||||
|
||||
def get_info(self, backend_tag):
|
||||
return json.loads(self.backend_info.get(backend_tag, "{}"))
|
||||
|
||||
def get_backend_creator(self, backend_tag):
|
||||
return pickle.loads(self.backend_table.get(backend_tag))
|
||||
|
||||
def list_backends(self):
|
||||
return list(self.backend_table.as_dict().keys())
|
||||
|
||||
def list_replicas(self, backend_tag: str):
|
||||
return json.loads(self.replica_table.get(backend_tag, "[]"))
|
||||
|
||||
def add_replica(self, backend_tag: str, new_replica_tag: str):
|
||||
replica_tags = self.list_replicas(backend_tag)
|
||||
replica_tags.append(new_replica_tag)
|
||||
self.replica_table.put(backend_tag, json.dumps(replica_tags))
|
||||
|
||||
def remove_replica(self, backend_tag):
|
||||
replica_tags = self.list_replicas(backend_tag)
|
||||
removed_replica = replica_tags.pop()
|
||||
self.replica_table.put(backend_tag, json.dumps(replica_tags))
|
||||
return removed_replica
|
||||
|
||||
|
||||
class TrafficPolicyTable:
|
||||
def __init__(self, kv_connector):
|
||||
self.traffic_policy_table = kv_connector("traffic_policy")
|
||||
|
||||
def register_traffic_policy(self, service_name, policy_dict):
|
||||
self.traffic_policy_table.put(service_name, json.dumps(policy_dict))
|
||||
|
||||
def list_traffic_policy(self):
|
||||
return {
|
||||
service: json.loads(policy)
|
||||
for service, policy in self.traffic_policy_table.as_dict().items()
|
||||
}
|
||||
|
||||
+454
-248
@@ -1,334 +1,540 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import ray
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.serve.backend_config import BackendConfig
|
||||
from ray.serve.constants import ASYNC_CONCURRENCY
|
||||
from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_ROUTER_NAME,
|
||||
SERVE_PROXY_NAME, SERVE_METRIC_MONITOR_NAME)
|
||||
from ray.serve.exceptions import batch_annotation_not_found
|
||||
from ray.serve.http_proxy import HTTPProxyActor
|
||||
from ray.serve.kv_store_service import (BackendTable, RoutingTable,
|
||||
TrafficPolicyTable)
|
||||
from ray.serve.metric import (MetricMonitor, start_metric_monitor_loop)
|
||||
from ray.serve.backend_worker import create_backend_worker
|
||||
from ray.serve.utils import expand, get_random_letters, logger
|
||||
from ray.serve.utils import async_retryable, get_random_letters, logger
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def async_retryable(cls):
|
||||
"""Make all actor method invocations on the class retryable.
|
||||
|
||||
Note: This will retry actor_handle.method_name.remote(), but it must
|
||||
be invoked in an async context.
|
||||
|
||||
Usage:
|
||||
@ray.remote(max_reconstructions=10000)
|
||||
@async_retryable
|
||||
class A:
|
||||
pass
|
||||
"""
|
||||
for name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
|
||||
|
||||
def decorate_with_retry(f):
|
||||
@wraps(f)
|
||||
async def retry_method(*args, **kwargs):
|
||||
while True:
|
||||
result = await f(*args, **kwargs)
|
||||
if isinstance(result, ray.exceptions.RayActorError):
|
||||
logger.warning(
|
||||
"Actor method '{}' failed, retrying after 100ms.".
|
||||
format(name))
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
return result
|
||||
|
||||
return retry_method
|
||||
|
||||
method.__ray_invocation_decorator__ = decorate_with_retry
|
||||
return cls
|
||||
# Used for testing purposes only. If this is set, the master actor will crash
|
||||
# after writing each checkpoint with the specified probability.
|
||||
_CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.0
|
||||
|
||||
|
||||
@ray.remote
|
||||
class ServeMaster:
|
||||
"""Initialize and store all actor handles.
|
||||
"""Responsible for managing the state of the serving system.
|
||||
|
||||
Note:
|
||||
This actor is necessary because ray will destroy actors when the
|
||||
original actor handle goes out of scope (when driver exit). Therefore
|
||||
we need to initialize and store actor handles in a seperate actor.
|
||||
The master actor implements fault tolerance by persisting its state in
|
||||
a new checkpoint each time a state change is made. If the actor crashes,
|
||||
the latest checkpoint is loaded and the state is recovered. Checkpoints
|
||||
are written/read using a provided KV-store interface.
|
||||
|
||||
All hard state in the system is maintained by this actor and persisted via
|
||||
these checkpoints. Soft state required by other components is fetched by
|
||||
those actors from this actor on startup and updates are pushed out from
|
||||
this actor.
|
||||
|
||||
All other actors started by the master actor are named, detached actors
|
||||
so they will not fate share with the master if it crashes.
|
||||
|
||||
The following guarantees are provided for state-changing calls to the
|
||||
master actor:
|
||||
- If the call succeeds, the change was made and will be reflected in
|
||||
the system even if the master actor or other actors die unexpectedly.
|
||||
- If the call fails, the change may have been made but isn't guaranteed
|
||||
to have been. The client should retry in this case. Note that this
|
||||
requires all implementations here to be idempotent.
|
||||
"""
|
||||
|
||||
def __init__(self, kv_store_connector):
|
||||
self.kv_store_connector = kv_store_connector
|
||||
self.route_table = RoutingTable(kv_store_connector)
|
||||
self.backend_table = BackendTable(kv_store_connector)
|
||||
self.policy_table = TrafficPolicyTable(kv_store_connector)
|
||||
async def __init__(self, kv_store_connector, router_class, router_kwargs,
|
||||
start_http_proxy, http_proxy_host, http_proxy_port,
|
||||
metric_gc_window_s):
|
||||
# Used to read/write checkpoints.
|
||||
# TODO(edoakes): namespace the master actor and its checkpoints.
|
||||
self.kv_store_client = kv_store_connector("serve_checkpoints")
|
||||
# path -> (endpoint, methods).
|
||||
self.routes = {}
|
||||
# backend -> (worker_creator, init_args, backend_config).
|
||||
self.backends = {}
|
||||
# backend -> replica_tags.
|
||||
self.replicas = defaultdict(list)
|
||||
# replicas that should be started if recovering from a checkpoint.
|
||||
self.replicas_to_start = defaultdict(list)
|
||||
# replicas that should be stopped if recovering from a checkpoint.
|
||||
self.replicas_to_stop = defaultdict(list)
|
||||
# endpoint -> traffic_dict
|
||||
self.traffic_policies = dict()
|
||||
# Dictionary of backend tag to dictionaries of replica tag to worker.
|
||||
# TODO(edoakes): consider removing this and just using the names.
|
||||
self.workers = defaultdict(dict)
|
||||
|
||||
# Used to ensure that only a single state-changing operation happens
|
||||
# at any given time.
|
||||
self.write_lock = asyncio.Lock()
|
||||
|
||||
# Cached handles to actors in the system.
|
||||
self.router = None
|
||||
self.http_proxy = None
|
||||
self.metric_monitor = None
|
||||
|
||||
def get_traffic_policy(self, endpoint_name):
|
||||
return self.policy_table.list_traffic_policy()[endpoint_name]
|
||||
# If starting the actor for the first time, starts up the other system
|
||||
# components. If recovering, fetches their actor handles.
|
||||
self._get_or_start_router(router_class, router_kwargs)
|
||||
if start_http_proxy:
|
||||
self._get_or_start_http_proxy(http_proxy_host, http_proxy_port)
|
||||
self._get_or_start_metric_monitor(metric_gc_window_s)
|
||||
|
||||
def start_router(self, router_class, init_kwargs):
|
||||
assert self.router is None, "Router already started."
|
||||
self.router = async_retryable(router_class).options(
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION,
|
||||
).remote(**init_kwargs)
|
||||
# NOTE(edoakes): unfortunately, we can't completely recover from a
|
||||
# checkpoint in the constructor because we block while waiting for
|
||||
# other actors to start up, and those actors fetch soft state from
|
||||
# this actor. Because no other tasks will start executing until after
|
||||
# the constructor finishes, if we were to run this logic in the
|
||||
# constructor it could lead to deadlock between this actor and a child.
|
||||
# However we do need to guarantee that we have fully recovered from a
|
||||
# checkpoint before any other state-changing calls run. We address this
|
||||
# by acquiring the write_lock and then posting the task to recover from
|
||||
# a checkpoint to the event loop. Other state-changing calls acquire
|
||||
# this lock and will be blocked until recovering from the checkpoint
|
||||
# finishes.
|
||||
checkpoint = self.kv_store_client.get("checkpoint")
|
||||
if checkpoint is None:
|
||||
logger.debug("No checkpoint found")
|
||||
else:
|
||||
await self.write_lock.acquire()
|
||||
asyncio.get_event_loop().create_task(
|
||||
self._recover_from_checkpoint(checkpoint))
|
||||
|
||||
def _get_or_start_router(self, router_class, router_kwargs):
|
||||
"""Get the router belonging to this serve cluster.
|
||||
|
||||
If the router does not already exist, it will be started.
|
||||
"""
|
||||
try:
|
||||
self.router = ray.util.get_actor(SERVE_ROUTER_NAME)
|
||||
except ValueError:
|
||||
logger.info(
|
||||
"Starting router with name '{}'".format(SERVE_ROUTER_NAME))
|
||||
self.router = async_retryable(router_class).options(
|
||||
detached=True,
|
||||
name=SERVE_ROUTER_NAME,
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION,
|
||||
).remote(**router_kwargs)
|
||||
|
||||
def get_router(self):
|
||||
assert self.router is not None, "Router not started yet."
|
||||
"""Returns a handle to the router managed by this actor."""
|
||||
return [self.router]
|
||||
|
||||
def start_http_proxy(self, host, port):
|
||||
"""Start the HTTP proxy on the given host:port.
|
||||
def _get_or_start_http_proxy(self, host, port):
|
||||
"""Get the HTTP proxy belonging to this serve cluster.
|
||||
|
||||
On startup (or restart), the HTTP proxy will fetch its config via
|
||||
get_http_proxy_config.
|
||||
If the HTTP proxy does not already exist, it will be started.
|
||||
"""
|
||||
assert self.http_proxy is None, "HTTP proxy already started."
|
||||
assert self.router is not None, (
|
||||
"Router must be started before HTTP proxy.")
|
||||
self.http_proxy = async_retryable(HTTPProxyActor).options(
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION,
|
||||
).remote(host, port)
|
||||
|
||||
async def get_http_proxy_config(self):
|
||||
route_table = self.route_table.list_service(
|
||||
include_methods=True, include_headless=False)
|
||||
return route_table, self.get_router()
|
||||
try:
|
||||
self.http_proxy = ray.util.get_actor(SERVE_PROXY_NAME)
|
||||
except ValueError:
|
||||
logger.info(
|
||||
"Starting HTTP proxy with name '{}'".format(SERVE_PROXY_NAME))
|
||||
self.http_proxy = async_retryable(HTTPProxyActor).options(
|
||||
detached=True,
|
||||
name=SERVE_PROXY_NAME,
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION,
|
||||
).remote(host, port)
|
||||
|
||||
def get_http_proxy(self):
|
||||
assert self.http_proxy is not None, "HTTP proxy not started yet."
|
||||
"""Returns a handle to the HTTP proxy managed by this actor."""
|
||||
return [self.http_proxy]
|
||||
|
||||
def start_metric_monitor(self, gc_window_seconds):
|
||||
assert self.metric_monitor is None, "Metric monitor already started."
|
||||
self.metric_monitor = MetricMonitor.remote(gc_window_seconds)
|
||||
# TODO(edoakes): this should be an actor method, not a separate task.
|
||||
start_metric_monitor_loop.remote(self.metric_monitor)
|
||||
self.metric_monitor.add_target.remote(self.router)
|
||||
def get_http_proxy_config(self):
|
||||
"""Called by the HTTP proxy on startup to fetch required state."""
|
||||
return self.routes, self.get_router()
|
||||
|
||||
def _get_or_start_metric_monitor(self, gc_window_s):
|
||||
"""Get the metric monitor belonging to this serve cluster.
|
||||
|
||||
If the metric monitor does not already exist, it will be started.
|
||||
"""
|
||||
try:
|
||||
self.metric_monitor = ray.util.get_actor(SERVE_METRIC_MONITOR_NAME)
|
||||
except ValueError:
|
||||
logger.info("Starting metric monitor with name '{}'".format(
|
||||
SERVE_METRIC_MONITOR_NAME))
|
||||
self.metric_monitor = MetricMonitor.options(
|
||||
detached=True,
|
||||
name=SERVE_METRIC_MONITOR_NAME).remote(gc_window_s)
|
||||
# TODO(edoakes): move these into the constructor.
|
||||
start_metric_monitor_loop.remote(self.metric_monitor)
|
||||
self.metric_monitor.add_target.remote(self.router)
|
||||
|
||||
def get_metric_monitor(self):
|
||||
assert self.metric_monitor is not None, (
|
||||
"Metric monitor not started yet.")
|
||||
"""Returns a handle to the metric monitor managed by this actor."""
|
||||
return [self.metric_monitor]
|
||||
|
||||
def _list_replicas(self, backend_tag):
|
||||
return self.backend_table.list_replicas(backend_tag)
|
||||
def _checkpoint(self):
|
||||
"""Checkpoint internal state and write it to the KV store."""
|
||||
logger.debug("Writing checkpoint")
|
||||
start = time.time()
|
||||
checkpoint = pickle.dumps(
|
||||
(self.routes, self.backends, self.traffic_policies, self.replicas,
|
||||
self.replicas_to_start, self.replicas_to_stop))
|
||||
|
||||
async def scale_replicas(self, backend_tag, num_replicas):
|
||||
self.kv_store_client.put("checkpoint", checkpoint)
|
||||
logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start))
|
||||
|
||||
if random.random() < _CRASH_AFTER_CHECKPOINT_PROBABILITY:
|
||||
logger.warning("Intentionally crashing after checkpoint")
|
||||
os._exit(0)
|
||||
|
||||
async def _recover_from_checkpoint(self, checkpoint_bytes):
|
||||
"""Recover the cluster state from the provided checkpoint.
|
||||
|
||||
Performs the following operations:
|
||||
1) Deserializes the internal state from the checkpoint.
|
||||
2) Pushes the latest configuration to the HTTP proxy and router
|
||||
in case we crashed before updating them.
|
||||
3) Starts/stops any worker replicas that are pending creation or
|
||||
deletion.
|
||||
|
||||
NOTE: this requires that self.write_lock is already acquired and will
|
||||
release it before returning.
|
||||
"""
|
||||
assert self.write_lock.locked()
|
||||
|
||||
start = time.time()
|
||||
logger.info("Recovering from checkpoint")
|
||||
|
||||
# Load internal state from the checkpoint data.
|
||||
(self.routes, self.backends, self.traffic_policies, self.replicas,
|
||||
self.replicas_to_start,
|
||||
self.replicas_to_stop) = pickle.loads(checkpoint_bytes)
|
||||
|
||||
# Fetch actor handles for all of the backend replicas in the system.
|
||||
# All of these workers are guaranteed to already exist because they
|
||||
# would not be written to a checkpoint in self.workers until they
|
||||
# were created.
|
||||
for backend_tag, replica_tags in self.replicas.items():
|
||||
for replica_tag in replica_tags:
|
||||
self.workers[backend_tag][replica_tag] = ray.util.get_actor(
|
||||
replica_tag)
|
||||
|
||||
# Push configuration state to the router.
|
||||
# TODO(edoakes): should we make this a pull-only model for simplicity?
|
||||
for endpoint, traffic_policy in self.traffic_policies.items():
|
||||
await self.router.set_traffic.remote(endpoint, traffic_policy)
|
||||
|
||||
for backend_tag, replica_dict in self.workers.items():
|
||||
for replica_tag, worker in replica_dict.items():
|
||||
await self.router.add_new_worker.remote(
|
||||
backend_tag, replica_tag, worker)
|
||||
|
||||
for backend, (_, _, backend_config_dict) in self.backends.items():
|
||||
await self.router.set_backend_config.remote(
|
||||
backend, backend_config_dict)
|
||||
|
||||
# Push configuration state to the HTTP proxy.
|
||||
await self.http_proxy.set_route_table.remote(self.routes)
|
||||
|
||||
# Start/stop any pending backend replicas.
|
||||
await self._start_pending_replicas()
|
||||
await self._stop_pending_replicas()
|
||||
|
||||
logger.info(
|
||||
"Recovered from checkpoint in {:.3f}s".format(time.time() - start))
|
||||
|
||||
self.write_lock.release()
|
||||
|
||||
def get_backend_configs(self):
|
||||
"""Fetched by the router on startup."""
|
||||
backend_configs = {}
|
||||
for backend, (_, _, backend_config_dict) in self.backends.items():
|
||||
backend_configs[backend] = backend_config_dict
|
||||
return backend_configs
|
||||
|
||||
def get_traffic_policies(self):
|
||||
"""Fetched by the router on startup."""
|
||||
return self.traffic_policies
|
||||
|
||||
def _list_replicas(self, backend_tag):
|
||||
"""Used only for testing."""
|
||||
return self.replicas[backend_tag]
|
||||
|
||||
def get_traffic_policy(self, endpoint):
|
||||
"""Fetched by serve handles."""
|
||||
return self.traffic_policies[endpoint]
|
||||
|
||||
async def _start_backend_worker(self, backend_tag, replica_tag):
|
||||
"""Creates a backend worker and waits for it to start up.
|
||||
|
||||
Assumes that the backend configuration has already been registered
|
||||
in self.backends.
|
||||
"""
|
||||
logger.debug("Starting worker '{}' for backend '{}'.".format(
|
||||
replica_tag, backend_tag))
|
||||
worker_creator, init_args, config_dict = self.backends[backend_tag]
|
||||
# TODO(edoakes): just store the BackendConfig in self.backends.
|
||||
backend_config = BackendConfig(**config_dict)
|
||||
kwargs = backend_config.get_actor_creation_args()
|
||||
|
||||
worker_handle = async_retryable(ray.remote(worker_creator)).options(
|
||||
detached=True,
|
||||
name=replica_tag,
|
||||
max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION,
|
||||
**kwargs).remote(backend_tag, replica_tag, init_args)
|
||||
# TODO(edoakes): we should probably have a timeout here.
|
||||
await worker_handle.ready.remote()
|
||||
return worker_handle
|
||||
|
||||
async def _start_pending_replicas(self):
|
||||
"""Starts the pending backend replicas in self.replicas_to_start.
|
||||
|
||||
Starts the worker, then pushes an update to the router to add it to
|
||||
the proper backend. If the worker has already been started, only
|
||||
updates the router.
|
||||
|
||||
Clears self.replicas_to_start.
|
||||
"""
|
||||
for backend_tag, replicas_to_create in self.replicas_to_start.items():
|
||||
for replica_tag in replicas_to_create:
|
||||
# NOTE(edoakes): the replicas may already be created if we
|
||||
# failed after creating them but before writing a checkpoint.
|
||||
try:
|
||||
worker_handle = ray.util.get_actor(replica_tag)
|
||||
except ValueError:
|
||||
worker_handle = await self._start_backend_worker(
|
||||
backend_tag, replica_tag)
|
||||
|
||||
self.replicas[backend_tag].append(replica_tag)
|
||||
self.workers[backend_tag][replica_tag] = worker_handle
|
||||
|
||||
# Register the worker with the router.
|
||||
await self.router.add_new_worker.remote(
|
||||
backend_tag, replica_tag, worker_handle)
|
||||
|
||||
# Register the worker with the metric monitor.
|
||||
self.metric_monitor.add_target.remote(worker_handle)
|
||||
|
||||
self.replicas_to_start.clear()
|
||||
|
||||
async def _stop_pending_replicas(self):
|
||||
"""Stops the pending backend replicas in self.replicas_to_stop.
|
||||
|
||||
Stops workers by telling the router to remove them.
|
||||
|
||||
Clears self.replicas_to_stop.
|
||||
"""
|
||||
for backend_tag, replicas_to_stop in self.replicas_to_stop.items():
|
||||
for replica_tag in replicas_to_stop:
|
||||
# NOTE(edoakes): the replicas may already be stopped if we
|
||||
# failed after stopping them but before writing a checkpoint.
|
||||
try:
|
||||
# Remove the replica from router.
|
||||
# This will also submit __ray_terminate__ on the worker.
|
||||
# NOTE(edoakes): we currently need to kill the worker from
|
||||
# the router to guarantee that the router won't submit any
|
||||
# more requests to it.
|
||||
await self.router.remove_worker.remote(
|
||||
backend_tag, replica_tag)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self.replicas_to_stop.clear()
|
||||
|
||||
def _scale_replicas(self, backend_tag, num_replicas):
|
||||
"""Scale the given backend to the number of replicas.
|
||||
|
||||
This requires the master actor to be an async actor because we wait
|
||||
synchronously for backends to start up and they may make calls into
|
||||
the master actor while initializing (e.g., by calling get_handle()).
|
||||
NOTE: this does not actually start or stop the replicas, but instead
|
||||
adds the intention to start/stop them to self.workers_to_start and
|
||||
self.workers_to_stop. The caller is responsible for then first writing
|
||||
a checkpoint and then actually starting/stopping the intended replicas.
|
||||
This avoids inconsistencies with starting/stopping a worker and then
|
||||
crashing before writing a checkpoint.
|
||||
"""
|
||||
assert (backend_tag in self.backend_table.list_backends()
|
||||
logger.debug("Scaling backend '{}' to {} replicas".format(
|
||||
backend_tag, num_replicas))
|
||||
assert (backend_tag in self.backends
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert num_replicas >= 0, ("Number of replicas must be"
|
||||
" greater than or equal to 0.")
|
||||
|
||||
current_num_replicas = len(self._list_replicas(backend_tag))
|
||||
current_num_replicas = len(self.replicas[backend_tag])
|
||||
delta_num_replicas = num_replicas - current_num_replicas
|
||||
|
||||
if delta_num_replicas > 0:
|
||||
logger.debug("Adding {} replicas to backend {}".format(
|
||||
delta_num_replicas, backend_tag))
|
||||
for _ in range(delta_num_replicas):
|
||||
await self._start_backend_replica(backend_tag)
|
||||
replica_tag = "{}#{}".format(backend_tag, get_random_letters())
|
||||
self.replicas_to_start[backend_tag].append(replica_tag)
|
||||
|
||||
elif delta_num_replicas < 0:
|
||||
logger.debug("Removing {} replicas from backend {}".format(
|
||||
-delta_num_replicas, backend_tag))
|
||||
assert len(self.replicas[backend_tag]) >= delta_num_replicas
|
||||
for _ in range(-delta_num_replicas):
|
||||
await self._remove_backend_replica(backend_tag)
|
||||
replica_tag = self.replicas[backend_tag].pop()
|
||||
if len(self.replicas[backend_tag]) == 0:
|
||||
del self.replicas[backend_tag]
|
||||
del self.workers[backend_tag][replica_tag]
|
||||
if len(self.workers[backend_tag]) == 0:
|
||||
del self.workers[backend_tag]
|
||||
|
||||
async def get_backend_worker_config(self):
|
||||
return self.get_router()
|
||||
|
||||
async def _start_backend_replica(self, backend_tag):
|
||||
assert (backend_tag in self.backend_table.list_backends()
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
|
||||
replica_tag = "{}#{}".format(backend_tag, get_random_letters(length=6))
|
||||
|
||||
# Register the worker in the DB.
|
||||
# TODO(edoakes): we should guarantee that if calls to the master
|
||||
# succeed, the cluster state has changed and if they fail, it hasn't.
|
||||
# Once we have master actor fault tolerance, this breaks that guarantee
|
||||
|
||||
# because this method could fail after writing the replica to the DB.
|
||||
self.backend_table.add_replica(backend_tag, replica_tag)
|
||||
|
||||
# Fetch the info to start the replica from the backend table.
|
||||
backend_actor = ray.remote(
|
||||
self.backend_table.get_backend_creator(backend_tag))
|
||||
backend_config_dict = self.backend_table.get_info(backend_tag)
|
||||
backend_config = BackendConfig(**backend_config_dict)
|
||||
init_args = [
|
||||
backend_tag, replica_tag,
|
||||
self.backend_table.get_init_args(backend_tag)
|
||||
]
|
||||
kwargs = backend_config.get_actor_creation_args(init_args)
|
||||
kwargs[
|
||||
"max_reconstructions"] = ray.ray_constants.INFINITE_RECONSTRUCTION
|
||||
|
||||
# Start the worker.
|
||||
worker_handle = backend_actor._remote(**kwargs)
|
||||
self.workers[backend_tag][replica_tag] = worker_handle
|
||||
|
||||
# Wait for the worker to start up.
|
||||
await worker_handle.ready.remote()
|
||||
|
||||
[router] = self.get_router()
|
||||
await router.add_new_worker.remote(backend_tag, worker_handle)
|
||||
|
||||
# Register the worker with the metric monitor.
|
||||
self.get_metric_monitor()[0].add_target.remote(worker_handle)
|
||||
|
||||
async def _remove_backend_replica(self, backend_tag):
|
||||
assert (backend_tag in self.backend_table.list_backends()
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert (len(self._list_replicas(backend_tag)) >
|
||||
0), "Tried to remove replica from empty backend ({}).".format(
|
||||
backend_tag)
|
||||
|
||||
replica_tag = self.backend_table.remove_replica(backend_tag)
|
||||
assert backend_tag in self.workers
|
||||
assert replica_tag in self.workers[backend_tag]
|
||||
replica_handle = self.workers[backend_tag].pop(replica_tag)
|
||||
if len(self.workers[backend_tag]) == 0:
|
||||
del self.workers[backend_tag]
|
||||
|
||||
# Remove the replica from metric monitor.
|
||||
[monitor] = self.get_metric_monitor()
|
||||
await monitor.remove_target.remote(replica_handle)
|
||||
|
||||
# Remove the replica from router.
|
||||
# This will also destroy the actor handle.
|
||||
[router] = self.get_router()
|
||||
await router.remove_worker.remote(backend_tag, replica_handle)
|
||||
self.replicas_to_stop[backend_tag].append(replica_tag)
|
||||
|
||||
def get_all_worker_handles(self):
|
||||
"""Fetched by the router on startup."""
|
||||
return self.workers
|
||||
|
||||
def get_all_endpoints(self):
|
||||
return expand(
|
||||
self.route_table.list_service(include_headless=True).values())
|
||||
|
||||
def get_all_routes(self):
|
||||
return expand(self.route_table.list_service().keys())
|
||||
|
||||
def get_all_backends(self):
|
||||
return self.backend_table.list_backends()
|
||||
"""Used for validation by the API client."""
|
||||
return [endpoint for endpoint, methods in self.routes.values()]
|
||||
|
||||
async def set_traffic(self, endpoint_name, traffic_policy_dictionary):
|
||||
assert endpoint_name in self.get_all_endpoints(), \
|
||||
"Attempted to assign traffic for an endpoint '{}'" \
|
||||
" that is not registered.".format(endpoint_name)
|
||||
"""Sets the traffic policy for the specified endpoint."""
|
||||
async with self.write_lock:
|
||||
assert endpoint_name in self.get_all_endpoints(), \
|
||||
"Attempted to assign traffic for an endpoint '{}'" \
|
||||
" that is not registered.".format(endpoint_name)
|
||||
|
||||
assert isinstance(traffic_policy_dictionary,
|
||||
dict), "Traffic policy must be dictionary"
|
||||
prob = 0
|
||||
existing_backends = set(self.get_all_backends())
|
||||
for backend, weight in traffic_policy_dictionary.items():
|
||||
prob += weight
|
||||
assert backend in existing_backends, \
|
||||
"Attempted to assign traffic to a backend '{}' that " \
|
||||
"is not registered.".format(backend)
|
||||
assert isinstance(traffic_policy_dictionary,
|
||||
dict), "Traffic policy must be dictionary"
|
||||
prob = 0
|
||||
for backend, weight in traffic_policy_dictionary.items():
|
||||
prob += weight
|
||||
assert backend in self.backends, \
|
||||
"Attempted to assign traffic to a backend '{}' that " \
|
||||
"is not registered.".format(backend)
|
||||
|
||||
assert np.isclose(
|
||||
prob, 1, atol=0.02
|
||||
), "weights must sum to 1, currently they sum to {}".format(prob)
|
||||
assert np.isclose(
|
||||
prob, 1, atol=0.02
|
||||
), "weights must sum to 1, currently they sum to {}".format(prob)
|
||||
|
||||
self.policy_table.register_traffic_policy(endpoint_name,
|
||||
traffic_policy_dictionary)
|
||||
[router] = self.get_router()
|
||||
self.traffic_policies[endpoint_name] = traffic_policy_dictionary
|
||||
|
||||
await router.set_traffic.remote(endpoint_name,
|
||||
traffic_policy_dictionary)
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
# update to avoid inconsistent state if we crash after pushing the
|
||||
# update.
|
||||
self._checkpoint()
|
||||
await self.router.set_traffic.remote(endpoint_name,
|
||||
traffic_policy_dictionary)
|
||||
|
||||
async def create_endpoint(self, route, endpoint_name, methods):
|
||||
err_prefix = "Cannot create endpoint. "
|
||||
assert route not in self.get_all_routes(), \
|
||||
"{} Route '{}' is already registered.".format(err_prefix, route)
|
||||
assert endpoint_name not in self.get_all_endpoints(), \
|
||||
"{} Endpoint '{}' is already registered.".format(err_prefix,
|
||||
endpoint_name)
|
||||
self.route_table.register_service(
|
||||
route, endpoint_name, methods=methods)
|
||||
[http_proxy] = self.get_http_proxy()
|
||||
async def create_endpoint(self, route, endpoint, methods):
|
||||
"""Create a new endpoint with the specified route and methods.
|
||||
|
||||
await http_proxy.set_route_table.remote(
|
||||
self.route_table.list_service(
|
||||
include_methods=True, include_headless=False))
|
||||
If the route is None, this is a "headless" endpoint that will not
|
||||
be added to the HTTP proxy (can only be accessed via a handle).
|
||||
"""
|
||||
async with self.write_lock:
|
||||
# If this is a headless service with no route, key the endpoint
|
||||
# based on its name.
|
||||
# TODO(edoakes): we should probably just store routes and endpoints
|
||||
# separately.
|
||||
if route is None:
|
||||
route = endpoint
|
||||
|
||||
# TODO(edoakes): move this to client side.
|
||||
err_prefix = "Cannot create endpoint."
|
||||
if route in self.routes:
|
||||
if self.routes[route] == (endpoint, methods):
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
"{} Route '{}' is already registered.".format(
|
||||
err_prefix, route))
|
||||
|
||||
if endpoint in self.get_all_endpoints():
|
||||
raise ValueError(
|
||||
"{} Endpoint '{}' is already registered.".format(
|
||||
err_prefix, endpoint))
|
||||
|
||||
logger.info(
|
||||
"Registering route {} to endpoint {} with methods {}.".format(
|
||||
route, endpoint, methods))
|
||||
|
||||
self.routes[route] = (endpoint, methods)
|
||||
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
# update to avoid inconsistent state if we crash after pushing the
|
||||
# update.
|
||||
self._checkpoint()
|
||||
await self.http_proxy.set_route_table.remote(self.routes)
|
||||
|
||||
async def create_backend(self, backend_tag, backend_config, func_or_class,
|
||||
actor_init_args):
|
||||
backend_config_dict = dict(backend_config)
|
||||
backend_worker = create_backend_worker(func_or_class)
|
||||
"""Register a new backend under the specified tag."""
|
||||
async with self.write_lock:
|
||||
backend_config_dict = dict(backend_config)
|
||||
backend_worker = create_backend_worker(func_or_class)
|
||||
|
||||
assert backend_tag not in self.get_all_backends(), \
|
||||
"Cannot create backend '{}' because a backend" \
|
||||
"with that name already exists.".format(backend_tag)
|
||||
# Save creator that starts replicas, the arguments to be passed in,
|
||||
# and the configuration for the backends.
|
||||
self.backends[backend_tag] = (backend_worker, actor_init_args,
|
||||
backend_config_dict)
|
||||
|
||||
# Save creator which starts replicas.
|
||||
self.backend_table.register_backend(backend_tag, backend_worker)
|
||||
self._scale_replicas(backend_tag,
|
||||
backend_config_dict["num_replicas"])
|
||||
|
||||
# Save information about configurations needed to start the replicas.
|
||||
self.backend_table.register_info(backend_tag, backend_config_dict)
|
||||
# NOTE(edoakes): we must write a checkpoint before starting new
|
||||
# or pushing the updated config to avoid inconsistent state if we
|
||||
# crash while making the change.
|
||||
self._checkpoint()
|
||||
await self._start_pending_replicas()
|
||||
await self._stop_pending_replicas()
|
||||
|
||||
# Save the initial arguments needed by replicas.
|
||||
self.backend_table.save_init_args(backend_tag, actor_init_args)
|
||||
|
||||
# Set the backend config inside the router
|
||||
# (particularly for max-batch-size).
|
||||
[router] = self.get_router()
|
||||
await router.set_backend_config.remote(backend_tag,
|
||||
backend_config_dict)
|
||||
|
||||
await self.scale_replicas(backend_tag,
|
||||
backend_config_dict["num_replicas"])
|
||||
# Set the backend config inside the router
|
||||
# (particularly for max-batch-size).
|
||||
await self.router.set_backend_config.remote(
|
||||
backend_tag, backend_config_dict)
|
||||
|
||||
async def set_backend_config(self, backend_tag, backend_config):
|
||||
assert (backend_tag in self.backend_table.list_backends()
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert isinstance(backend_config,
|
||||
BackendConfig), ("backend_config must be"
|
||||
" of instance BackendConfig")
|
||||
backend_config_dict = dict(backend_config)
|
||||
old_backend_config_dict = self.backend_table.get_info(backend_tag)
|
||||
"""Set the config for the specified backend."""
|
||||
async with self.write_lock:
|
||||
assert (backend_tag in self.backends
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert isinstance(backend_config,
|
||||
BackendConfig), ("backend_config must be"
|
||||
" of instance BackendConfig")
|
||||
backend_config_dict = dict(backend_config)
|
||||
backend_worker, init_args, old_backend_config_dict = self.backends[
|
||||
backend_tag]
|
||||
|
||||
if (not old_backend_config_dict["has_accept_batch_annotation"]
|
||||
and backend_config.max_batch_size is not None):
|
||||
raise batch_annotation_not_found
|
||||
if (not old_backend_config_dict["has_accept_batch_annotation"]
|
||||
and backend_config.max_batch_size is not None):
|
||||
raise batch_annotation_not_found
|
||||
|
||||
self.backend_table.register_info(backend_tag, backend_config_dict)
|
||||
self.backends[backend_tag] = (backend_worker, init_args,
|
||||
backend_config_dict)
|
||||
|
||||
# Inform the router about change in configuration
|
||||
# (particularly for setting max_batch_size).
|
||||
[router] = self.get_router()
|
||||
await router.set_backend_config.remote(backend_tag,
|
||||
backend_config_dict)
|
||||
# Restart replicas if there is a change in the backend config
|
||||
# related to restart_configs.
|
||||
need_to_restart_replicas = any(
|
||||
old_backend_config_dict[k] != backend_config_dict[k]
|
||||
for k in BackendConfig.restart_on_change_fields)
|
||||
if need_to_restart_replicas:
|
||||
# Kill all the replicas for restarting with new configurations.
|
||||
self._scale_replicas(backend_tag, 0)
|
||||
|
||||
# Restart replicas if there is a change in the backend config related
|
||||
# to restart_configs.
|
||||
need_to_restart_replicas = any(
|
||||
old_backend_config_dict[k] != backend_config_dict[k]
|
||||
for k in BackendConfig.restart_on_change_fields)
|
||||
if need_to_restart_replicas:
|
||||
# Kill all the replicas for restarting with new configurations.
|
||||
await self.scale_replicas(backend_tag, 0)
|
||||
# Scale the replicas with the new configuration.
|
||||
self._scale_replicas(backend_tag,
|
||||
backend_config_dict["num_replicas"])
|
||||
|
||||
# Scale the replicas with the new configuration.
|
||||
await self.scale_replicas(backend_tag,
|
||||
backend_config_dict["num_replicas"])
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
# update to avoid inconsistent state if we crash after pushing the
|
||||
# update.
|
||||
self._checkpoint()
|
||||
|
||||
# Inform the router about change in configuration
|
||||
# (particularly for setting max_batch_size).
|
||||
await self.router.set_backend_config.remote(
|
||||
backend_tag, backend_config_dict)
|
||||
|
||||
await self._start_pending_replicas()
|
||||
await self._stop_pending_replicas()
|
||||
|
||||
def get_backend_config(self, backend_tag):
|
||||
assert (backend_tag in self.backend_table.list_backends()
|
||||
"""Get the current config for the specified backend."""
|
||||
assert (backend_tag in self.backends
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
backend_config_dict = self.backend_table.get_info(backend_tag)
|
||||
return BackendConfig(**backend_config_dict)
|
||||
return BackendConfig(**self.backends[backend_tag][2])
|
||||
|
||||
@@ -24,9 +24,6 @@ class MetricMonitor:
|
||||
self.gc_window_seconds = gc_window_seconds
|
||||
self.latest_gc_time = time.time()
|
||||
|
||||
def is_ready(self):
|
||||
return True
|
||||
|
||||
def add_target(self, target_handle):
|
||||
hex_id = target_handle._actor_id.hex()
|
||||
self.actor_handles[hex_id] = target_handle
|
||||
@@ -153,5 +150,8 @@ class MetricMonitor:
|
||||
@ray.remote(num_cpus=0)
|
||||
def start_metric_monitor_loop(monitor_handle, duration_s=5):
|
||||
while True:
|
||||
ray.get(monitor_handle.scrape.remote())
|
||||
time.sleep(duration_s)
|
||||
try:
|
||||
ray.get(monitor_handle.scrape.remote())
|
||||
except ray.exceptions.RayActorError:
|
||||
pass
|
||||
|
||||
+61
-26
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
import time
|
||||
from typing import DefaultDict, List
|
||||
|
||||
# Note on choosing blist instead of stdlib heapq
|
||||
@@ -13,7 +14,7 @@ import blist
|
||||
|
||||
import ray
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.serve.utils import logger
|
||||
from ray.serve.utils import logger, retry_actor_failures
|
||||
|
||||
|
||||
class Query:
|
||||
@@ -138,6 +139,8 @@ class Router:
|
||||
self.traffic = defaultdict(dict)
|
||||
# backend_name -> backend_config
|
||||
self.backend_info = dict()
|
||||
# replica tag -> worker_handle
|
||||
self.replicas = dict()
|
||||
|
||||
# -- Synchronization -- #
|
||||
|
||||
@@ -150,15 +153,28 @@ class Router:
|
||||
# batching polcies.
|
||||
self.flush_lock = asyncio.Lock()
|
||||
|
||||
# Fetch the worker handles from the master actor. We use a "pull-based"
|
||||
# approach instead of pushing them from the master so that the router
|
||||
# can transparently recover from failure.
|
||||
# Fetch the worker handles, traffic policies, and backend configs from
|
||||
# the master actor. We use a "pull-based" approach instead of pushing
|
||||
# them from the master so that the router can transparently recover
|
||||
# from failure.
|
||||
ray.serve.init()
|
||||
master_actor = ray.serve.api._get_master_actor()
|
||||
backend_dict = ray.get(master_actor.get_all_worker_handles.remote())
|
||||
for backend, replica_dict in backend_dict.items():
|
||||
for worker in replica_dict.values():
|
||||
await self.add_new_worker(backend, worker)
|
||||
|
||||
traffic_policies = retry_actor_failures(
|
||||
master_actor.get_traffic_policies)
|
||||
for endpoint, traffic_policy in traffic_policies.items():
|
||||
await self.set_traffic(endpoint, traffic_policy)
|
||||
|
||||
backend_dict = retry_actor_failures(
|
||||
master_actor.get_all_worker_handles)
|
||||
for backend_tag, replica_dict in backend_dict.items():
|
||||
for replica_tag, worker in replica_dict.items():
|
||||
await self.add_new_worker(backend_tag, replica_tag, worker)
|
||||
|
||||
backend_configs = retry_actor_failures(
|
||||
master_actor.get_backend_configs)
|
||||
for backend, backend_config_dict in backend_configs.items():
|
||||
await self.set_backend_config(backend, backend_config_dict)
|
||||
|
||||
def is_ready(self):
|
||||
return True
|
||||
@@ -199,28 +215,43 @@ class Router:
|
||||
result = await query.async_future
|
||||
return result
|
||||
|
||||
async def add_new_worker(self, backend, worker_handle):
|
||||
logger.debug("New worker added for backend '{}'".format(backend))
|
||||
await self.mark_worker_idle(backend, worker_handle)
|
||||
async def add_new_worker(self, backend_tag, replica_tag, worker_handle):
|
||||
backend_replica_tag = backend_tag + ":" + replica_tag
|
||||
if backend_replica_tag in self.replicas:
|
||||
return
|
||||
self.replicas[backend_replica_tag] = worker_handle
|
||||
|
||||
async def mark_worker_idle(self, backend, worker_handle):
|
||||
await self.worker_queues[backend].put(worker_handle)
|
||||
logger.debug("New worker added for backend '{}'".format(backend_tag))
|
||||
# await worker_handle.ready.remote()
|
||||
await self.mark_worker_idle(backend_tag, backend_replica_tag)
|
||||
|
||||
async def mark_worker_idle(self, backend_tag, backend_replica_tag):
|
||||
if backend_replica_tag not in self.replicas:
|
||||
return
|
||||
|
||||
await self.worker_queues[backend_tag].put(backend_replica_tag)
|
||||
await self.flush()
|
||||
|
||||
async def remove_worker(self, backend, worker_handle):
|
||||
async def remove_worker(self, backend_tag, replica_tag):
|
||||
backend_replica_tag = backend_tag + ":" + replica_tag
|
||||
if backend_replica_tag not in self.replicas:
|
||||
return
|
||||
worker_handle = self.replicas.pop(backend_replica_tag)
|
||||
|
||||
# We need this lock because we modify worker_queue here.
|
||||
async with self.flush_lock:
|
||||
old_queue = self.worker_queues[backend]
|
||||
old_queue = self.worker_queues[backend_tag]
|
||||
new_queue = asyncio.Queue()
|
||||
target_id = worker_handle._actor_id
|
||||
|
||||
while not old_queue.empty():
|
||||
worker_handle = await old_queue.get()
|
||||
if worker_handle._actor_id != target_id:
|
||||
await new_queue.put(worker_handle)
|
||||
curr_tag = await old_queue.get()
|
||||
if curr_tag != backend_replica_tag:
|
||||
await new_queue.put(curr_tag)
|
||||
|
||||
self.worker_queues[backend] = new_queue
|
||||
# TODO: consider awaiting this on a timeout or using ray.kill().
|
||||
self.worker_queues[backend_tag] = new_queue
|
||||
# We need to terminate the worker here instead of from the master
|
||||
# so we can guarantee that the router won't submit any more tasks
|
||||
# on it.
|
||||
worker_handle.__ray_terminate__.remote()
|
||||
|
||||
async def link(self, service, backend):
|
||||
@@ -297,11 +328,15 @@ class Router:
|
||||
await self._assign_query_to_worker(
|
||||
backend, buffer_queue, worker_queue, max_batch_size)
|
||||
|
||||
async def _do_query(self, backend, worker, req):
|
||||
async def _do_query(self, backend, backend_replica_tag, req):
|
||||
# If the worker died, this will be a RayActorError. Just return it and
|
||||
# let the HTTP proxy handle the retry logic.
|
||||
logger.debug("Sending query to replica:" + backend_replica_tag)
|
||||
start = time.time()
|
||||
worker = self.replicas[backend_replica_tag]
|
||||
result = await worker.handle_request.remote(req)
|
||||
await self.mark_worker_idle(backend, worker)
|
||||
await self.mark_worker_idle(backend, backend_replica_tag)
|
||||
logger.debug("Got result in {:.2f}s", time.time() - start)
|
||||
return result
|
||||
|
||||
async def _assign_query_to_worker(self,
|
||||
@@ -311,11 +346,11 @@ class Router:
|
||||
max_batch_size=None):
|
||||
|
||||
while len(buffer_queue) and worker_queue.qsize():
|
||||
worker = await worker_queue.get()
|
||||
backend_replica_tag = await worker_queue.get()
|
||||
if max_batch_size is None: # No batching
|
||||
request = buffer_queue.pop(0)
|
||||
future = asyncio.get_event_loop().create_task(
|
||||
self._do_query(backend, worker, request))
|
||||
self._do_query(backend, backend_replica_tag, request))
|
||||
# chaining satisfies request.async_future with future result.
|
||||
asyncio.futures._chain_future(future, request.async_future)
|
||||
else:
|
||||
@@ -331,7 +366,7 @@ class Router:
|
||||
|
||||
for group in requests_group.values():
|
||||
future = asyncio.get_event_loop().create_task(
|
||||
self._do_query(backend, worker, group))
|
||||
self._do_query(backend, backend_replica_tag, group))
|
||||
future.add_done_callback(
|
||||
_make_future_unwrapper(
|
||||
client_futures=[req.async_future for req in group],
|
||||
|
||||
@@ -6,6 +6,12 @@ import pytest
|
||||
import ray
|
||||
from ray import serve
|
||||
|
||||
# TODO(edoakes): the failure tests currently fail with the GCS service enabled.
|
||||
os.environ["RAY_GCS_SERVICE_ENABLED"] = "false"
|
||||
|
||||
if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False):
|
||||
serve.master._CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.5
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def serve_instance():
|
||||
|
||||
@@ -75,24 +75,17 @@ def test_no_route(serve_instance):
|
||||
assert result == 1
|
||||
|
||||
|
||||
def test_reject_duplicate_backend_tag(serve_instance):
|
||||
backend_name = "foo"
|
||||
serve.create_backend(lambda foo: foo, backend_name)
|
||||
with pytest.raises(AssertionError):
|
||||
serve.create_backend(lambda foo: foo, backend_name)
|
||||
|
||||
|
||||
def test_reject_duplicate_route(serve_instance):
|
||||
route = "/foo"
|
||||
serve.create_endpoint("bar", route=route)
|
||||
with pytest.raises(AssertionError):
|
||||
with pytest.raises(ValueError):
|
||||
serve.create_endpoint("foo", route=route)
|
||||
|
||||
|
||||
def test_reject_duplicate_endpoint(serve_instance):
|
||||
endpoint_name = "foo"
|
||||
serve.create_endpoint(endpoint_name, route="/ok")
|
||||
with pytest.raises(AssertionError):
|
||||
with pytest.raises(ValueError):
|
||||
serve.create_endpoint(endpoint_name, route="/different")
|
||||
|
||||
|
||||
@@ -102,9 +95,9 @@ def test_set_traffic_missing_data(serve_instance):
|
||||
serve.create_endpoint(endpoint_name)
|
||||
serve.create_backend(lambda: 5, backend_name)
|
||||
with pytest.raises(AssertionError):
|
||||
serve.set_traffic(endpoint_name, {"nonexistant_backend": 1.0})
|
||||
serve.set_traffic(endpoint_name, {"nonexistent_backend": 1.0})
|
||||
with pytest.raises(AssertionError):
|
||||
serve.set_traffic("nonexistant_endpoint_name", {backend_name: 1.0})
|
||||
serve.set_traffic("nonexistent_endpoint_name", {backend_name: 1.0})
|
||||
|
||||
|
||||
def test_scaling_replicas(serve_instance):
|
||||
|
||||
@@ -13,15 +13,15 @@ from ray.serve.backend_config import BackendConfig
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def setup_worker(name, func_or_class, router_handle, init_args=None):
|
||||
def setup_worker(name, func_or_class, init_args=None):
|
||||
if init_args is None:
|
||||
init_args = ()
|
||||
|
||||
@ray.remote
|
||||
class WorkerActor:
|
||||
def __init__(self, router_handle):
|
||||
def __init__(self):
|
||||
self.worker = create_backend_worker(func_or_class)(
|
||||
name, name + ":tag", init_args, router_handle=router_handle[0])
|
||||
name, name + ":tag", init_args)
|
||||
|
||||
def ready(self):
|
||||
pass
|
||||
@@ -29,13 +29,10 @@ def setup_worker(name, func_or_class, router_handle, init_args=None):
|
||||
def get_metrics(self):
|
||||
return self.worker.get_metrics()
|
||||
|
||||
def run(self):
|
||||
self.worker.backend.mark_idle_in_router()
|
||||
|
||||
async def handle_request(self, *args, **kwargs):
|
||||
return await self.worker.handle_request(*args, **kwargs)
|
||||
|
||||
worker = WorkerActor.remote([router_handle])
|
||||
worker = WorkerActor.remote()
|
||||
ray.get(worker.ready.remote())
|
||||
return worker
|
||||
|
||||
@@ -54,8 +51,8 @@ async def test_runner_actor(serve_instance):
|
||||
CONSUMER_NAME = "runner"
|
||||
PRODUCER_NAME = "prod"
|
||||
|
||||
worker = setup_worker(CONSUMER_NAME, echo, q)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, worker)
|
||||
worker = setup_worker(CONSUMER_NAME, echo)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
|
||||
@@ -79,8 +76,8 @@ async def test_ray_serve_mixin(serve_instance):
|
||||
def __call__(self, flask_request, i=None):
|
||||
return i + self.increment
|
||||
|
||||
worker = setup_worker(CONSUMER_NAME, MyAdder, q, init_args=(3, ))
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, worker)
|
||||
worker = setup_worker(CONSUMER_NAME, MyAdder, init_args=(3, ))
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
|
||||
@@ -101,8 +98,8 @@ async def test_task_runner_check_context(serve_instance):
|
||||
CONSUMER_NAME = "runner"
|
||||
PRODUCER_NAME = "producer"
|
||||
|
||||
worker = setup_worker(CONSUMER_NAME, echo, q)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, worker)
|
||||
worker = setup_worker(CONSUMER_NAME, echo)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
|
||||
@@ -125,8 +122,8 @@ async def test_task_runner_custom_method_single(serve_instance):
|
||||
CONSUMER_NAME = "runner"
|
||||
PRODUCER_NAME = "producer"
|
||||
|
||||
worker = setup_worker(CONSUMER_NAME, NonBatcher, q)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, worker)
|
||||
worker = setup_worker(CONSUMER_NAME, NonBatcher)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
|
||||
@@ -160,7 +157,7 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||
CONSUMER_NAME = "runner"
|
||||
PRODUCER_NAME = "producer"
|
||||
|
||||
worker = setup_worker(CONSUMER_NAME, Batcher, q)
|
||||
worker = setup_worker(CONSUMER_NAME, Batcher)
|
||||
|
||||
await q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
await q.set_backend_config.remote(
|
||||
@@ -174,7 +171,7 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||
futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
|
||||
futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]
|
||||
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, worker)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
|
||||
gathered = await asyncio.gather(*futures)
|
||||
assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
|
||||
|
||||
@@ -19,6 +19,57 @@ def request_with_retries(endpoint, timeout=30):
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def test_master_failure(serve_instance):
|
||||
serve.init()
|
||||
serve.create_endpoint("master_failure", "/master_failure")
|
||||
|
||||
def function():
|
||||
return "hello1"
|
||||
|
||||
serve.create_backend(function, "master_failure:v1")
|
||||
serve.set_traffic("master_failure", {"master_failure:v1": 1.0})
|
||||
|
||||
assert request_with_retries("/master_failure", timeout=1).text == "hello1"
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/master_failure", timeout=30)
|
||||
assert response.text == "hello1"
|
||||
|
||||
ray.kill(serve.api._get_master_actor())
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/master_failure", timeout=30)
|
||||
assert response.text == "hello1"
|
||||
|
||||
def function():
|
||||
return "hello2"
|
||||
|
||||
ray.kill(serve.api._get_master_actor())
|
||||
|
||||
serve.create_backend(function, "master_failure:v2")
|
||||
serve.set_traffic("master_failure", {"master_failure:v2": 1.0})
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/master_failure", timeout=30)
|
||||
assert response.text == "hello2"
|
||||
|
||||
def function():
|
||||
return "hello3"
|
||||
|
||||
ray.kill(serve.api._get_master_actor())
|
||||
serve.create_endpoint("master_failure_2", "/master_failure_2")
|
||||
ray.kill(serve.api._get_master_actor())
|
||||
serve.create_backend(function, "master_failure_2")
|
||||
ray.kill(serve.api._get_master_actor())
|
||||
serve.set_traffic("master_failure_2", {"master_failure_2": 1.0})
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/master_failure", timeout=30)
|
||||
assert response.text == "hello2"
|
||||
response = request_with_retries("/master_failure_2", timeout=30)
|
||||
assert response.text == "hello3"
|
||||
|
||||
|
||||
def _kill_http_proxy():
|
||||
[http_proxy] = ray.get(
|
||||
serve.api._get_master_actor().get_http_proxy.remote())
|
||||
@@ -27,7 +78,7 @@ def _kill_http_proxy():
|
||||
|
||||
def test_http_proxy_failure(serve_instance):
|
||||
serve.init()
|
||||
serve.create_endpoint("proxy_failure", "/proxy_failure", methods=["GET"])
|
||||
serve.create_endpoint("proxy_failure", "/proxy_failure")
|
||||
|
||||
def function():
|
||||
return "hello1"
|
||||
@@ -35,7 +86,7 @@ def test_http_proxy_failure(serve_instance):
|
||||
serve.create_backend(function, "proxy_failure:v1")
|
||||
serve.set_traffic("proxy_failure", {"proxy_failure:v1": 1.0})
|
||||
|
||||
assert request_with_retries("/proxy_failure", timeout=0.1).text == "hello1"
|
||||
assert request_with_retries("/proxy_failure", timeout=1.0).text == "hello1"
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/proxy_failure", timeout=30)
|
||||
@@ -61,7 +112,7 @@ def _kill_router():
|
||||
|
||||
def test_router_failure(serve_instance):
|
||||
serve.init()
|
||||
serve.create_endpoint("router_failure", "/router_failure", methods=["GET"])
|
||||
serve.create_endpoint("router_failure", "/router_failure")
|
||||
|
||||
def function():
|
||||
return "hello1"
|
||||
@@ -77,6 +128,10 @@ def test_router_failure(serve_instance):
|
||||
|
||||
_kill_router()
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/router_failure", timeout=30)
|
||||
assert response.text == "hello1"
|
||||
|
||||
def function():
|
||||
return "hello2"
|
||||
|
||||
@@ -99,7 +154,7 @@ def _get_worker_handles(backend):
|
||||
# serving requests.
|
||||
def test_worker_restart(serve_instance):
|
||||
serve.init()
|
||||
serve.create_endpoint("worker_failure", "/worker_failure", methods=["GET"])
|
||||
serve.create_endpoint("worker_failure", "/worker_failure")
|
||||
|
||||
class Worker1:
|
||||
def __call__(self):
|
||||
@@ -109,7 +164,7 @@ def test_worker_restart(serve_instance):
|
||||
serve.set_traffic("worker_failure", {"worker_failure:v1": 1.0})
|
||||
|
||||
# Get the PID of the worker.
|
||||
old_pid = request_with_retries("/worker_failure", timeout=0.1).text
|
||||
old_pid = request_with_retries("/worker_failure", timeout=1).text
|
||||
|
||||
# Kill the worker.
|
||||
handles = _get_worker_handles("worker_failure:v1")
|
||||
@@ -131,8 +186,7 @@ def test_worker_restart(serve_instance):
|
||||
def test_worker_replica_failure(serve_instance):
|
||||
serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0
|
||||
serve.init()
|
||||
serve.create_endpoint(
|
||||
"replica_failure", "/replica_failure", methods=["GET"])
|
||||
serve.create_endpoint("replica_failure", "/replica_failure")
|
||||
|
||||
class Worker:
|
||||
# Assumes that two replicas are started. Will hang forever in the
|
||||
@@ -169,8 +223,7 @@ def test_worker_replica_failure(serve_instance):
|
||||
# Wait until both replicas have been started.
|
||||
responses = set()
|
||||
while len(responses) == 1:
|
||||
responses.add(
|
||||
request_with_retries("/replica_failure", timeout=0.1).text)
|
||||
responses.add(request_with_retries("/replica_failure", timeout=1).text)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Kill one of the replicas.
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
curr_dir = Path(__file__).parent
|
||||
test_paths = curr_dir.rglob("test_*.py")
|
||||
sorted_path = sorted(map(lambda path: str(path.absolute()), test_paths))
|
||||
serve_tests_files = list(sorted_path)
|
||||
|
||||
print("Testing the following files")
|
||||
for test_file in serve_tests_files:
|
||||
print("->", test_file.split("/")[-1])
|
||||
|
||||
print("Setting RAY_SERVE_INTENTIONALLY_CRASH=1")
|
||||
os.environ["RAY_SERVE_INTENTIONALLY_CRASH"] = "1"
|
||||
|
||||
sys.exit(pytest.main(["-v", "-s"] + serve_tests_files))
|
||||
@@ -6,15 +6,15 @@ from ray import serve
|
||||
|
||||
def test_nonblocking():
|
||||
serve.init()
|
||||
serve.create_endpoint("endpoint", "/api")
|
||||
serve.create_endpoint("nonblocking", "/nonblocking")
|
||||
|
||||
def function(flask_request):
|
||||
return {"method": flask_request.method}
|
||||
|
||||
serve.create_backend(function, "echo:v1")
|
||||
serve.set_traffic("endpoint", {"echo:v1": 1.0})
|
||||
serve.create_backend(function, "nonblocking:v1")
|
||||
serve.set_traffic("nonblocking", {"nonblocking:v1": 1.0})
|
||||
|
||||
resp = requests.get("http://127.0.0.1:8000/api").json()["method"]
|
||||
resp = requests.get("http://127.0.0.1:8000/nonblocking").json()["method"]
|
||||
assert resp == "GET"
|
||||
|
||||
|
||||
|
||||
@@ -29,6 +29,9 @@ def make_task_runner_mock():
|
||||
def get_all_calls(self):
|
||||
return self.queries
|
||||
|
||||
def ready(self):
|
||||
pass
|
||||
|
||||
return TaskRunnerMock.remote()
|
||||
|
||||
|
||||
@@ -39,8 +42,9 @@ def task_runner_mock_actor():
|
||||
|
||||
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
q = RandomPolicyQueueActor.remote()
|
||||
q.link.remote("svc", "backend")
|
||||
q.add_new_worker.remote("backend", task_runner_mock_actor)
|
||||
q.link.remote("svc", "backend-single-prod")
|
||||
q.add_new_worker.remote("backend-single-prod", "replica-1",
|
||||
task_runner_mock_actor)
|
||||
|
||||
# Make sure we get the request result back
|
||||
result = await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||
@@ -54,7 +58,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
|
||||
async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
q = RandomPolicyQueueActor.remote()
|
||||
await q.link.remote("svc", "backend")
|
||||
await q.link.remote("svc", "backend-slo")
|
||||
|
||||
all_request_sent = []
|
||||
for i in range(10):
|
||||
@@ -63,7 +67,8 @@ async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
q.enqueue_request.remote(
|
||||
RequestMetadata("svc", None, relative_slo_ms=slo_ms), i))
|
||||
|
||||
await q.add_new_worker.remote("backend", task_runner_mock_actor)
|
||||
await q.add_new_worker.remote("backend-slo", "replica-1",
|
||||
task_runner_mock_actor)
|
||||
|
||||
await asyncio.gather(*all_request_sent)
|
||||
|
||||
@@ -78,14 +83,16 @@ async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
q = RandomPolicyQueueActor.remote()
|
||||
|
||||
await q.set_traffic.remote("svc", {"backend-1": 1})
|
||||
await q.add_new_worker.remote("backend-1", task_runner_mock_actor)
|
||||
await q.set_traffic.remote("svc", {"backend-alter": 1})
|
||||
await q.add_new_worker.remote("backend-alter", "replica-1",
|
||||
task_runner_mock_actor)
|
||||
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
||||
assert got_work.request_args[0] == 1
|
||||
|
||||
await q.set_traffic.remote("svc", {"backend-2": 1})
|
||||
await q.add_new_worker.remote("backend-2", task_runner_mock_actor)
|
||||
await q.set_traffic.remote("svc", {"backend-alter-2": 1})
|
||||
await q.add_new_worker.remote("backend-alter-2", "replica-1",
|
||||
task_runner_mock_actor)
|
||||
await q.enqueue_request.remote(RequestMetadata("svc", None), 2)
|
||||
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
||||
assert got_work.request_args[0] == 2
|
||||
@@ -94,10 +101,13 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
||||
q = RandomPolicyQueueActor.remote()
|
||||
|
||||
await q.set_traffic.remote("svc", {"backend-1": 0.5, "backend-2": 0.5})
|
||||
await q.set_traffic.remote("svc", {
|
||||
"backend-split": 0.5,
|
||||
"backend-split-2": 0.5
|
||||
})
|
||||
runner_1, runner_2 = [make_task_runner_mock() for _ in range(2)]
|
||||
await q.add_new_worker.remote("backend-1", runner_1)
|
||||
await q.add_new_worker.remote("backend-2", runner_2)
|
||||
await q.add_new_worker.remote("backend-split", "replica-1", runner_1)
|
||||
await q.add_new_worker.remote("backend-split-2", "replica-1", runner_2)
|
||||
|
||||
# assume 50% split, the probability of all 20 requests goes to a
|
||||
# single queue is 0.5^20 ~ 1-6
|
||||
@@ -114,13 +124,13 @@ async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
||||
async def test_round_robin(serve_instance, task_runner_mock_actor):
|
||||
q = RoundRobinPolicyQueueActor.remote()
|
||||
|
||||
await q.set_traffic.remote("svc", {"backend-1": 0.5, "backend-2": 0.5})
|
||||
await q.set_traffic.remote("svc", {"backend-rr": 0.5, "backend-rr-2": 0.5})
|
||||
runner_1, runner_2 = [make_task_runner_mock() for _ in range(2)]
|
||||
|
||||
# NOTE: this is the only difference between the
|
||||
# test_split_traffic_random and test_round_robin
|
||||
await q.add_new_worker.remote("backend-1", runner_1)
|
||||
await q.add_new_worker.remote("backend-2", runner_2)
|
||||
await q.add_new_worker.remote("backend-rr", "replica-1", runner_1)
|
||||
await q.add_new_worker.remote("backend-rr-2", "replica-1", runner_2)
|
||||
|
||||
for _ in range(20):
|
||||
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||
@@ -135,13 +145,16 @@ async def test_round_robin(serve_instance, task_runner_mock_actor):
|
||||
async def test_fixed_packing(serve_instance):
|
||||
packing_num = 4
|
||||
q = FixedPackingPolicyQueueActor.remote(packing_num=packing_num)
|
||||
await q.set_traffic.remote("svc", {"backend-1": 0.5, "backend-2": 0.5})
|
||||
await q.set_traffic.remote("svc", {
|
||||
"backend-fixed": 0.5,
|
||||
"backend-fixed-2": 0.5
|
||||
})
|
||||
|
||||
runner_1, runner_2 = (make_task_runner_mock() for _ in range(2))
|
||||
# both the backends will get equal number of queries
|
||||
# as it is packed round robin
|
||||
await q.add_new_worker.remote("backend-1", runner_1)
|
||||
await q.add_new_worker.remote("backend-2", runner_2)
|
||||
await q.add_new_worker.remote("backend-fixed", "replica-1", runner_1)
|
||||
await q.add_new_worker.remote("backend-fixed-2", "replica-1", runner_2)
|
||||
|
||||
for backend, runner in zip(["1", "2"], [runner_1, runner_2]):
|
||||
for _ in range(packing_num):
|
||||
@@ -158,20 +171,23 @@ async def test_power_of_two_choices(serve_instance):
|
||||
enqueue_futures = []
|
||||
|
||||
# First, fill the queue for backend-1 with 3 requests
|
||||
await q.set_traffic.remote("svc", {"backend-1": 1.0})
|
||||
await q.set_traffic.remote("svc", {"backend-pow2": 1.0})
|
||||
for _ in range(3):
|
||||
future = q.enqueue_request.remote(RequestMetadata("svc", None), "1")
|
||||
enqueue_futures.append(future)
|
||||
|
||||
# Then, add a new backend, this backend should be filled next
|
||||
await q.set_traffic.remote("svc", {"backend-1": 0.5, "backend-2": 0.5})
|
||||
await q.set_traffic.remote("svc", {
|
||||
"backend-pow2": 0.5,
|
||||
"backend-pow2-2": 0.5
|
||||
})
|
||||
for _ in range(2):
|
||||
future = q.enqueue_request.remote(RequestMetadata("svc", None), "2")
|
||||
enqueue_futures.append(future)
|
||||
|
||||
runner_1, runner_2 = (make_task_runner_mock() for _ in range(2))
|
||||
await q.add_new_worker.remote("backend-1", runner_1)
|
||||
await q.add_new_worker.remote("backend-2", runner_2)
|
||||
await q.add_new_worker.remote("backend-pow2", "replica-1", runner_1)
|
||||
await q.add_new_worker.remote("backend-pow2-2", "replica-1", runner_2)
|
||||
|
||||
await asyncio.gather(*enqueue_futures)
|
||||
|
||||
@@ -183,10 +199,10 @@ async def test_queue_remove_replicas(serve_instance):
|
||||
@ray.remote
|
||||
class TestRandomPolicyQueueActor(RandomPolicyQueue):
|
||||
def worker_queue_size(self, backend):
|
||||
return self.worker_queues["backend"].qsize()
|
||||
return self.worker_queues["backend-remove"].qsize()
|
||||
|
||||
temp_actor = make_task_runner_mock()
|
||||
q = TestRandomPolicyQueueActor.remote()
|
||||
await q.add_new_worker.remote("backend", temp_actor)
|
||||
await q.remove_worker.remote("backend", temp_actor)
|
||||
await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor)
|
||||
await q.remove_worker.remote("backend-remove", "replica-1")
|
||||
assert ray.get(q.worker_queue_size.remote("backend")) == 0
|
||||
|
||||
+71
-14
@@ -1,3 +1,6 @@
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
@@ -6,11 +9,11 @@ import time
|
||||
import io
|
||||
import os
|
||||
|
||||
import ray
|
||||
import requests
|
||||
from pygments import formatters, highlight, lexers
|
||||
from ray.serve.context import FakeFlaskRequest, TaskContext
|
||||
from ray.serve.http_util import build_flask_request
|
||||
import itertools
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
@@ -18,19 +21,7 @@ try:
|
||||
except ImportError:
|
||||
pydantic = None
|
||||
|
||||
|
||||
def expand(l):
|
||||
"""
|
||||
Implements a nested flattening of a list.
|
||||
Example:
|
||||
>>> serve.utils.expand([1,2,[3,4,5],6])
|
||||
[1,2,3,4,5,6]
|
||||
>>> serve.utils.expand(["a", ["b", "c"], "d", ["e", "f"]])
|
||||
["a", "b", "c", "d", "e", "f"]
|
||||
"""
|
||||
return list(
|
||||
itertools.chain.from_iterable(
|
||||
[x if isinstance(x, list) else [x] for x in l]))
|
||||
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
|
||||
|
||||
|
||||
def parse_request_item(request_item):
|
||||
@@ -116,3 +107,69 @@ def block_until_http_ready(http_endpoint, num_retries=5, backoff_time_s=1):
|
||||
|
||||
def get_random_letters(length=6):
|
||||
return "".join(random.choices(string.ascii_letters, k=length))
|
||||
|
||||
|
||||
def async_retryable(cls):
|
||||
"""Make all actor method invocations on the class retryable.
|
||||
|
||||
Note: This will retry actor_handle.method_name.remote(), but it must
|
||||
be invoked in an async context.
|
||||
|
||||
Usage:
|
||||
@ray.remote(max_reconstructions=10000)
|
||||
@async_retryable
|
||||
class A:
|
||||
pass
|
||||
"""
|
||||
for name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
|
||||
|
||||
def decorate_with_retry(f):
|
||||
@wraps(f)
|
||||
async def retry_method(*args, **kwargs):
|
||||
start = time.time()
|
||||
while time.time() - start < ACTOR_FAILURE_RETRY_TIMEOUT_S:
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
except ray.exceptions.RayActorError:
|
||||
logger.warning(
|
||||
"Actor method '{}' failed, retrying after 100ms.".
|
||||
format(name))
|
||||
await asyncio.sleep(0.1)
|
||||
raise RuntimeError("Timed out after {}s waiting for actor "
|
||||
"method '{}' to succeed.".format(
|
||||
ACTOR_FAILURE_RETRY_TIMEOUT_S, name))
|
||||
|
||||
return retry_method
|
||||
|
||||
method.__ray_invocation_decorator__ = decorate_with_retry
|
||||
return cls
|
||||
|
||||
|
||||
def retry_actor_failures(f, *args, **kwargs):
|
||||
start = time.time()
|
||||
while time.time() - start < ACTOR_FAILURE_RETRY_TIMEOUT_S:
|
||||
try:
|
||||
return ray.get(f.remote(*args, **kwargs))
|
||||
except ray.exceptions.RayActorError:
|
||||
logger.warning(
|
||||
"Actor method '{}' failed, retrying after 100ms".format(
|
||||
f._method_name))
|
||||
time.sleep(0.1)
|
||||
raise RuntimeError("Timed out after {}s waiting for actor "
|
||||
"method '{}' to succeed.".format(
|
||||
ACTOR_FAILURE_RETRY_TIMEOUT_S, f._method_name))
|
||||
|
||||
|
||||
async def retry_actor_failures_async(f, *args, **kwargs):
|
||||
start = time.time()
|
||||
while time.time() - start < ACTOR_FAILURE_RETRY_TIMEOUT_S:
|
||||
try:
|
||||
return await f.remote(*args, **kwargs)
|
||||
except ray.exceptions.RayActorError:
|
||||
logger.warning(
|
||||
"Actor method '{}' failed, retrying after 100ms".format(
|
||||
f._method_name))
|
||||
await asyncio.sleep(0.1)
|
||||
raise RuntimeError("Timed out after {}s waiting for actor "
|
||||
"method '{}' to succeed.".format(
|
||||
ACTOR_FAILURE_RETRY_TIMEOUT_S, f._method_name))
|
||||
|
||||
Reference in New Issue
Block a user