diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index 348a01394..b818e06bb 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -13,11 +13,24 @@ py_library( py_test( name = "test_serve", size = "medium", - srcs = glob(["tests/*.py"], exclude=["tests/test_nonblocking.py"]), + srcs = glob(["tests/*.py"], + exclude=["tests/test_nonblocking.py", + "tests/test_master_crashes.py"]), tags = ["exclusive"], deps = [":serve_lib"], ) +# Runs test_api and test_failure with injected failures in the master actor. +py_test( + name = "test_master_crashes", + size = "medium", + srcs = glob(["tests/test_master_crashes.py", + "tests/test_api.py", + "tests/test_failure.py"], + exclude=["tests/test_nonblocking.py", + "tests/test_serve.py"]), +) + py_test( name = "echo_full", size = "small", diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 5d32b94b9..76fa12c31 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -10,7 +10,7 @@ from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, from ray.serve.master import ServeMaster from ray.serve.handle import RayServeHandle from ray.serve.kv_store_service import SQLiteKVStore -from ray.serve.utils import block_until_http_ready +from ray.serve.utils import block_until_http_ready, retry_actor_failures from ray.serve.exceptions import RayServeException, batch_annotation_not_found from ray.serve.backend_config import BackendConfig from ray.serve.policy import RoutePolicy @@ -139,14 +139,11 @@ def init( return SQLiteKVStore(namespace, db_path=kv_store_path) master_actor = ServeMaster.options( - detached=True, name=SERVE_MASTER_NAME).remote(kv_store_connector) - - ray.get( - master_actor.start_router.remote(queueing_policy.value, policy_kwargs)) - - ray.get(master_actor.start_metric_monitor.remote(gc_window_seconds)) - if start_server: - ray.get(master_actor.start_http_proxy.remote(http_host, http_port)) + detached=True, + name=SERVE_MASTER_NAME, + max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION, + ).remote(kv_store_connector, queueing_policy.value, policy_kwargs, + start_server, http_host, http_port, gc_window_seconds) if start_server and blocking: block_until_http_ready("http://{}:{}/-/routes".format( @@ -165,9 +162,8 @@ def create_endpoint(endpoint_name, route=None, methods=["GET"]): blocking (bool): If true, the function will wait for service to be registered before returning """ - ray.get( - master_actor.create_endpoint.remote(route, endpoint_name, - [m.upper() for m in methods])) + retry_actor_failures(master_actor.create_endpoint, route, endpoint_name, + [m.upper() for m in methods]) @_ensure_connected @@ -178,8 +174,8 @@ def set_backend_config(backend_tag, backend_config): backend_tag(str): A registered backend. backend_config(BackendConfig) : Desired backend configuration. """ - ray.get( - master_actor.set_backend_config.remote(backend_tag, backend_config)) + retry_actor_failures(master_actor.set_backend_config, backend_tag, + backend_config) @_ensure_connected @@ -189,7 +185,7 @@ def get_backend_config(backend_tag): Args: backend_tag(str): A registered backend. """ - return ray.get(master_actor.get_backend_config.remote(backend_tag)) + return retry_actor_failures(master_actor.get_backend_config, backend_tag) def _backend_accept_batch(func_or_class): @@ -240,9 +236,8 @@ def create_backend(func_or_class, if _backend_accept_batch(func_or_class): backend_config.has_accept_batch_annotation = True - ray.get( - master_actor.create_backend.remote(backend_tag, backend_config, - func_or_class, actor_init_args)) + retry_actor_failures(master_actor.create_backend, backend_tag, + backend_config, func_or_class, actor_init_args) @_ensure_connected @@ -261,9 +256,8 @@ def set_traffic(endpoint_name, traffic_policy_dictionary): traffic_policy_dictionary (dict): a dictionary maps backend names to their traffic weights. The weights must sum to 1. """ - ray.get( - master_actor.set_traffic.remote(endpoint_name, - traffic_policy_dictionary)) + retry_actor_failures(master_actor.set_traffic, endpoint_name, + traffic_policy_dictionary) @_ensure_connected @@ -286,11 +280,11 @@ def get_handle(endpoint_name, RayServeHandle """ if not missing_ok: - assert endpoint_name in ray.get( - master_actor.get_all_endpoints.remote()) + assert endpoint_name in retry_actor_failures( + master_actor.get_all_endpoints) return RayServeHandle( - ray.get(master_actor.get_router.remote())[0], + retry_actor_failures(master_actor.get_router)[0], endpoint_name, relative_slo_ms, absolute_slo_ms, @@ -309,5 +303,5 @@ def stat(percentiles=[50, 90, 95], The longest aggregation window must be shorter or equal to the gc_window_seconds. """ - [monitor] = ray.get(master_actor.get_metric_monitor.remote()) + [monitor] = retry_actor_failures(master_actor.get_metric_monitor) return ray.get(monitor.collect.remote(percentiles, agg_windows_seconds)) diff --git a/python/ray/serve/backend_config.py b/python/ray/serve/backend_config.py index 7bd6f53fa..6c5db6af2 100644 --- a/python/ray/serve/backend_config.py +++ b/python/ray/serve/backend_config.py @@ -55,9 +55,8 @@ class BackendConfig: key = "num_replicas" yield key, val - def get_actor_creation_args(self, init_args): + def get_actor_creation_args(self): ret_d = deepcopy(self.__dict__) for k in self._serve_configs: ret_d.pop(k) - ret_d["args"] = init_args return ret_d diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index d14263dbe..49f2a3d42 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -23,24 +23,14 @@ def create_backend_worker(func_or_class): assert False, "func_or_class must be function or class." class RayServeWrappedWorker(object): - def __init__(self, - backend_tag, - replica_tag, - init_args, - router_handle=None): + def __init__(self, backend_tag, replica_tag, init_args): serve.init() if is_function: _callable = func_or_class else: _callable = func_or_class(*init_args) - if router_handle is None: - master_actor = serve.api._get_master_actor() - [router_handle] = ray.get( - master_actor.get_backend_worker_config.remote()) - - self.backend = RayServeWorker(backend_tag, _callable, - router_handle, is_function) + self.backend = RayServeWorker(backend_tag, _callable, is_function) def get_metrics(self): return self.backend.get_metrics() @@ -76,10 +66,9 @@ def ensure_async(func): class RayServeWorker: """Handles requests with the provided callable.""" - def __init__(self, name, _callable, router_handle, is_function): + def __init__(self, name, _callable, is_function): self.name = name self.callable = _callable - self.router_handle = router_handle self.is_function = is_function self.error_counter = 0 diff --git a/python/ray/serve/constants.py b/python/ray/serve/constants.py index bc1faf2ad..a229c28df 100644 --- a/python/ray/serve/constants.py +++ b/python/ray/serve/constants.py @@ -1,6 +1,15 @@ #: Actor name used to register master actor SERVE_MASTER_NAME = "SERVE_MASTER_ACTOR" +#: Actor name used to register router actor +SERVE_ROUTER_NAME = "SERVE_ROUTER_ACTOR" + +#: Actor name used to register HTTP proxy actor +SERVE_PROXY_NAME = "SERVE_PROXY_ACTOR" + +#: Actor name used to register metric monitor actor +SERVE_METRIC_MONITOR_NAME = "SERVE_METRIC_MONITOR_ACTOR" + #: HTTP Address DEFAULT_HTTP_ADDRESS = "http://127.0.0.1:8000" @@ -15,6 +24,3 @@ ASYNC_CONCURRENCY = int(1e6) #: Default latency SLO DEFAULT_LATENCY_SLO_MS = 1e9 - -#: Key for storing no http route services -NO_ROUTE_KEY = "NO_ROUTE" diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index a7d9a2a33..865f83d52 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -8,7 +8,7 @@ from ray.serve.constants import SERVE_MASTER_NAME from ray.serve.context import TaskContext from ray.serve.request_params import RequestMetadata from ray.serve.http_util import Response -from ray.serve.utils import logger +from ray.serve.utils import logger, retry_actor_failures_async from urllib.parse import parse_qs @@ -29,8 +29,9 @@ class HTTPProxy: async def fetch_config_from_master(self): assert ray.is_initialized() master = ray.util.get_actor(SERVE_MASTER_NAME) - self.route_table, [self.router_handle - ] = await master.get_http_proxy_config.remote() + self.route_table, [ + self.router_handle + ] = await retry_actor_failures_async(master.get_http_proxy_config) def set_route_table(self, route_table): self.route_table = route_table diff --git a/python/ray/serve/kv_store_service.py b/python/ray/serve/kv_store_service.py index 49cd69ec1..80e1d82a8 100644 --- a/python/ray/serve/kv_store_service.py +++ b/python/ray/serve/kv_store_service.py @@ -1,12 +1,8 @@ import json import sqlite3 from abc import ABC -from typing import Union, List -from ray import cloudpickle as pickle import ray.experimental.internal_kv as ray_kv -from ray.serve.utils import logger -from ray.serve.constants import NO_ROUTE_KEY class NamespacedKVStore(ABC): @@ -167,129 +163,3 @@ class SQLiteKVStore(NamespacedKVStore): result = list( cursor.execute("SELECT key, value FROM {}".format(self.namespace))) return dict(result) - - -# Tables -class RoutingTable: - def __init__(self, kv_connector): - self.routing_table = kv_connector("routing_table") - self.methods_table = kv_connector("methods_table") - self.request_count = 0 - - def register_service(self, route: Union[str, None], service: str, - methods: List[str]): - """Create an entry in the routing table - - Args: - route: http path name. Must begin with '/'. - service: service name. This is the name http actor will push - the request to. - """ - logger.debug( - "[KV] Registering route {} to service {} with methods {}.".format( - route, service, methods)) - - # put no route services in default key - if route is None: - no_http_services = json.loads( - self.routing_table.get(NO_ROUTE_KEY, "[]")) - no_http_services.append(service) - self.routing_table.put(NO_ROUTE_KEY, json.dumps(no_http_services)) - else: - self.routing_table.put(route, service) - self.methods_table.put(route, json.dumps(methods)) - - def list_service(self, include_headless=False, include_methods=False): - """Returns the routing table. - Args: - include_headless: If True, returns a no route services (headless) - services with normal services. (Default: False) - include_methods: If True, returns a mapping include the methods - list for each route. - """ - table = self.routing_table.as_dict() - if include_methods: - methods_table = self.methods_table.as_dict() - for route, methods in methods_table.items(): - if route in table: - table[route] = (table[route], json.loads(methods)) - - if include_headless: - table[NO_ROUTE_KEY] = json.loads(table.get(NO_ROUTE_KEY, "[]")) - else: - table.pop(NO_ROUTE_KEY, None) - return table - - def get_request_count(self): - """Return the number of requests that fetched the routing table. - - This method is used for two purpose: - - 1. Make sure HTTP proxy has started and healthy. Incremented request - count means HTTP proxy is actively fetching routing table. - - 2. Make sure HTTP proxy does not have stale routing table. This number - should be incremented every HTTP_ROUTER_CHECKER_INTERVAL_S seconds. - Supervisor should check this number as indirect indicator of http - proxy's health. - """ - return self.request_count - - -class BackendTable: - def __init__(self, kv_connector): - self.backend_table = kv_connector("backend_creator") - self.replica_table = kv_connector("replica_table") - self.backend_info = kv_connector("backend_info") - self.backend_init_args = kv_connector("backend_init_args") - - def register_backend(self, backend_tag: str, backend_creator): - backend_creator_serialized = pickle.dumps(backend_creator) - self.backend_table.put(backend_tag, backend_creator_serialized) - - def save_init_args(self, backend_tag: str, arg_list): - serialized_arg_list = pickle.dumps(arg_list) - self.backend_init_args.put(backend_tag, serialized_arg_list) - - def get_init_args(self, backend_tag): - return pickle.loads(self.backend_init_args.get(backend_tag)) - - def register_info(self, backend_tag: str, backend_info_d): - self.backend_info.put(backend_tag, json.dumps(backend_info_d)) - - def get_info(self, backend_tag): - return json.loads(self.backend_info.get(backend_tag, "{}")) - - def get_backend_creator(self, backend_tag): - return pickle.loads(self.backend_table.get(backend_tag)) - - def list_backends(self): - return list(self.backend_table.as_dict().keys()) - - def list_replicas(self, backend_tag: str): - return json.loads(self.replica_table.get(backend_tag, "[]")) - - def add_replica(self, backend_tag: str, new_replica_tag: str): - replica_tags = self.list_replicas(backend_tag) - replica_tags.append(new_replica_tag) - self.replica_table.put(backend_tag, json.dumps(replica_tags)) - - def remove_replica(self, backend_tag): - replica_tags = self.list_replicas(backend_tag) - removed_replica = replica_tags.pop() - self.replica_table.put(backend_tag, json.dumps(replica_tags)) - return removed_replica - - -class TrafficPolicyTable: - def __init__(self, kv_connector): - self.traffic_policy_table = kv_connector("traffic_policy") - - def register_traffic_policy(self, service_name, policy_dict): - self.traffic_policy_table.put(service_name, json.dumps(policy_dict)) - - def list_traffic_policy(self): - return { - service: json.loads(policy) - for service, policy in self.traffic_policy_table.as_dict().items() - } diff --git a/python/ray/serve/master.py b/python/ray/serve/master.py index 3e3f5cf55..9c1fc9dd5 100644 --- a/python/ray/serve/master.py +++ b/python/ray/serve/master.py @@ -1,334 +1,540 @@ import asyncio from collections import defaultdict -from functools import wraps -import inspect +import os +import random +import time import ray +import ray.cloudpickle as pickle from ray.serve.backend_config import BackendConfig -from ray.serve.constants import ASYNC_CONCURRENCY +from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_ROUTER_NAME, + SERVE_PROXY_NAME, SERVE_METRIC_MONITOR_NAME) from ray.serve.exceptions import batch_annotation_not_found from ray.serve.http_proxy import HTTPProxyActor -from ray.serve.kv_store_service import (BackendTable, RoutingTable, - TrafficPolicyTable) from ray.serve.metric import (MetricMonitor, start_metric_monitor_loop) from ray.serve.backend_worker import create_backend_worker -from ray.serve.utils import expand, get_random_letters, logger +from ray.serve.utils import async_retryable, get_random_letters, logger import numpy as np - -def async_retryable(cls): - """Make all actor method invocations on the class retryable. - - Note: This will retry actor_handle.method_name.remote(), but it must - be invoked in an async context. - - Usage: - @ray.remote(max_reconstructions=10000) - @async_retryable - class A: - pass - """ - for name, method in inspect.getmembers(cls, predicate=inspect.isfunction): - - def decorate_with_retry(f): - @wraps(f) - async def retry_method(*args, **kwargs): - while True: - result = await f(*args, **kwargs) - if isinstance(result, ray.exceptions.RayActorError): - logger.warning( - "Actor method '{}' failed, retrying after 100ms.". - format(name)) - await asyncio.sleep(0.1) - else: - return result - - return retry_method - - method.__ray_invocation_decorator__ = decorate_with_retry - return cls +# Used for testing purposes only. If this is set, the master actor will crash +# after writing each checkpoint with the specified probability. +_CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.0 @ray.remote class ServeMaster: - """Initialize and store all actor handles. + """Responsible for managing the state of the serving system. - Note: - This actor is necessary because ray will destroy actors when the - original actor handle goes out of scope (when driver exit). Therefore - we need to initialize and store actor handles in a seperate actor. + The master actor implements fault tolerance by persisting its state in + a new checkpoint each time a state change is made. If the actor crashes, + the latest checkpoint is loaded and the state is recovered. Checkpoints + are written/read using a provided KV-store interface. + + All hard state in the system is maintained by this actor and persisted via + these checkpoints. Soft state required by other components is fetched by + those actors from this actor on startup and updates are pushed out from + this actor. + + All other actors started by the master actor are named, detached actors + so they will not fate share with the master if it crashes. + + The following guarantees are provided for state-changing calls to the + master actor: + - If the call succeeds, the change was made and will be reflected in + the system even if the master actor or other actors die unexpectedly. + - If the call fails, the change may have been made but isn't guaranteed + to have been. The client should retry in this case. Note that this + requires all implementations here to be idempotent. """ - def __init__(self, kv_store_connector): - self.kv_store_connector = kv_store_connector - self.route_table = RoutingTable(kv_store_connector) - self.backend_table = BackendTable(kv_store_connector) - self.policy_table = TrafficPolicyTable(kv_store_connector) + async def __init__(self, kv_store_connector, router_class, router_kwargs, + start_http_proxy, http_proxy_host, http_proxy_port, + metric_gc_window_s): + # Used to read/write checkpoints. + # TODO(edoakes): namespace the master actor and its checkpoints. + self.kv_store_client = kv_store_connector("serve_checkpoints") + # path -> (endpoint, methods). + self.routes = {} + # backend -> (worker_creator, init_args, backend_config). + self.backends = {} + # backend -> replica_tags. + self.replicas = defaultdict(list) + # replicas that should be started if recovering from a checkpoint. + self.replicas_to_start = defaultdict(list) + # replicas that should be stopped if recovering from a checkpoint. + self.replicas_to_stop = defaultdict(list) + # endpoint -> traffic_dict + self.traffic_policies = dict() # Dictionary of backend tag to dictionaries of replica tag to worker. + # TODO(edoakes): consider removing this and just using the names. self.workers = defaultdict(dict) + # Used to ensure that only a single state-changing operation happens + # at any given time. + self.write_lock = asyncio.Lock() + + # Cached handles to actors in the system. self.router = None self.http_proxy = None self.metric_monitor = None - def get_traffic_policy(self, endpoint_name): - return self.policy_table.list_traffic_policy()[endpoint_name] + # If starting the actor for the first time, starts up the other system + # components. If recovering, fetches their actor handles. + self._get_or_start_router(router_class, router_kwargs) + if start_http_proxy: + self._get_or_start_http_proxy(http_proxy_host, http_proxy_port) + self._get_or_start_metric_monitor(metric_gc_window_s) - def start_router(self, router_class, init_kwargs): - assert self.router is None, "Router already started." - self.router = async_retryable(router_class).options( - max_concurrency=ASYNC_CONCURRENCY, - max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION, - ).remote(**init_kwargs) + # NOTE(edoakes): unfortunately, we can't completely recover from a + # checkpoint in the constructor because we block while waiting for + # other actors to start up, and those actors fetch soft state from + # this actor. Because no other tasks will start executing until after + # the constructor finishes, if we were to run this logic in the + # constructor it could lead to deadlock between this actor and a child. + # However we do need to guarantee that we have fully recovered from a + # checkpoint before any other state-changing calls run. We address this + # by acquiring the write_lock and then posting the task to recover from + # a checkpoint to the event loop. Other state-changing calls acquire + # this lock and will be blocked until recovering from the checkpoint + # finishes. + checkpoint = self.kv_store_client.get("checkpoint") + if checkpoint is None: + logger.debug("No checkpoint found") + else: + await self.write_lock.acquire() + asyncio.get_event_loop().create_task( + self._recover_from_checkpoint(checkpoint)) + + def _get_or_start_router(self, router_class, router_kwargs): + """Get the router belonging to this serve cluster. + + If the router does not already exist, it will be started. + """ + try: + self.router = ray.util.get_actor(SERVE_ROUTER_NAME) + except ValueError: + logger.info( + "Starting router with name '{}'".format(SERVE_ROUTER_NAME)) + self.router = async_retryable(router_class).options( + detached=True, + name=SERVE_ROUTER_NAME, + max_concurrency=ASYNC_CONCURRENCY, + max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION, + ).remote(**router_kwargs) def get_router(self): - assert self.router is not None, "Router not started yet." + """Returns a handle to the router managed by this actor.""" return [self.router] - def start_http_proxy(self, host, port): - """Start the HTTP proxy on the given host:port. + def _get_or_start_http_proxy(self, host, port): + """Get the HTTP proxy belonging to this serve cluster. - On startup (or restart), the HTTP proxy will fetch its config via - get_http_proxy_config. + If the HTTP proxy does not already exist, it will be started. """ - assert self.http_proxy is None, "HTTP proxy already started." - assert self.router is not None, ( - "Router must be started before HTTP proxy.") - self.http_proxy = async_retryable(HTTPProxyActor).options( - max_concurrency=ASYNC_CONCURRENCY, - max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION, - ).remote(host, port) - - async def get_http_proxy_config(self): - route_table = self.route_table.list_service( - include_methods=True, include_headless=False) - return route_table, self.get_router() + try: + self.http_proxy = ray.util.get_actor(SERVE_PROXY_NAME) + except ValueError: + logger.info( + "Starting HTTP proxy with name '{}'".format(SERVE_PROXY_NAME)) + self.http_proxy = async_retryable(HTTPProxyActor).options( + detached=True, + name=SERVE_PROXY_NAME, + max_concurrency=ASYNC_CONCURRENCY, + max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION, + ).remote(host, port) def get_http_proxy(self): - assert self.http_proxy is not None, "HTTP proxy not started yet." + """Returns a handle to the HTTP proxy managed by this actor.""" return [self.http_proxy] - def start_metric_monitor(self, gc_window_seconds): - assert self.metric_monitor is None, "Metric monitor already started." - self.metric_monitor = MetricMonitor.remote(gc_window_seconds) - # TODO(edoakes): this should be an actor method, not a separate task. - start_metric_monitor_loop.remote(self.metric_monitor) - self.metric_monitor.add_target.remote(self.router) + def get_http_proxy_config(self): + """Called by the HTTP proxy on startup to fetch required state.""" + return self.routes, self.get_router() + + def _get_or_start_metric_monitor(self, gc_window_s): + """Get the metric monitor belonging to this serve cluster. + + If the metric monitor does not already exist, it will be started. + """ + try: + self.metric_monitor = ray.util.get_actor(SERVE_METRIC_MONITOR_NAME) + except ValueError: + logger.info("Starting metric monitor with name '{}'".format( + SERVE_METRIC_MONITOR_NAME)) + self.metric_monitor = MetricMonitor.options( + detached=True, + name=SERVE_METRIC_MONITOR_NAME).remote(gc_window_s) + # TODO(edoakes): move these into the constructor. + start_metric_monitor_loop.remote(self.metric_monitor) + self.metric_monitor.add_target.remote(self.router) def get_metric_monitor(self): - assert self.metric_monitor is not None, ( - "Metric monitor not started yet.") + """Returns a handle to the metric monitor managed by this actor.""" return [self.metric_monitor] - def _list_replicas(self, backend_tag): - return self.backend_table.list_replicas(backend_tag) + def _checkpoint(self): + """Checkpoint internal state and write it to the KV store.""" + logger.debug("Writing checkpoint") + start = time.time() + checkpoint = pickle.dumps( + (self.routes, self.backends, self.traffic_policies, self.replicas, + self.replicas_to_start, self.replicas_to_stop)) - async def scale_replicas(self, backend_tag, num_replicas): + self.kv_store_client.put("checkpoint", checkpoint) + logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start)) + + if random.random() < _CRASH_AFTER_CHECKPOINT_PROBABILITY: + logger.warning("Intentionally crashing after checkpoint") + os._exit(0) + + async def _recover_from_checkpoint(self, checkpoint_bytes): + """Recover the cluster state from the provided checkpoint. + + Performs the following operations: + 1) Deserializes the internal state from the checkpoint. + 2) Pushes the latest configuration to the HTTP proxy and router + in case we crashed before updating them. + 3) Starts/stops any worker replicas that are pending creation or + deletion. + + NOTE: this requires that self.write_lock is already acquired and will + release it before returning. + """ + assert self.write_lock.locked() + + start = time.time() + logger.info("Recovering from checkpoint") + + # Load internal state from the checkpoint data. + (self.routes, self.backends, self.traffic_policies, self.replicas, + self.replicas_to_start, + self.replicas_to_stop) = pickle.loads(checkpoint_bytes) + + # Fetch actor handles for all of the backend replicas in the system. + # All of these workers are guaranteed to already exist because they + # would not be written to a checkpoint in self.workers until they + # were created. + for backend_tag, replica_tags in self.replicas.items(): + for replica_tag in replica_tags: + self.workers[backend_tag][replica_tag] = ray.util.get_actor( + replica_tag) + + # Push configuration state to the router. + # TODO(edoakes): should we make this a pull-only model for simplicity? + for endpoint, traffic_policy in self.traffic_policies.items(): + await self.router.set_traffic.remote(endpoint, traffic_policy) + + for backend_tag, replica_dict in self.workers.items(): + for replica_tag, worker in replica_dict.items(): + await self.router.add_new_worker.remote( + backend_tag, replica_tag, worker) + + for backend, (_, _, backend_config_dict) in self.backends.items(): + await self.router.set_backend_config.remote( + backend, backend_config_dict) + + # Push configuration state to the HTTP proxy. + await self.http_proxy.set_route_table.remote(self.routes) + + # Start/stop any pending backend replicas. + await self._start_pending_replicas() + await self._stop_pending_replicas() + + logger.info( + "Recovered from checkpoint in {:.3f}s".format(time.time() - start)) + + self.write_lock.release() + + def get_backend_configs(self): + """Fetched by the router on startup.""" + backend_configs = {} + for backend, (_, _, backend_config_dict) in self.backends.items(): + backend_configs[backend] = backend_config_dict + return backend_configs + + def get_traffic_policies(self): + """Fetched by the router on startup.""" + return self.traffic_policies + + def _list_replicas(self, backend_tag): + """Used only for testing.""" + return self.replicas[backend_tag] + + def get_traffic_policy(self, endpoint): + """Fetched by serve handles.""" + return self.traffic_policies[endpoint] + + async def _start_backend_worker(self, backend_tag, replica_tag): + """Creates a backend worker and waits for it to start up. + + Assumes that the backend configuration has already been registered + in self.backends. + """ + logger.debug("Starting worker '{}' for backend '{}'.".format( + replica_tag, backend_tag)) + worker_creator, init_args, config_dict = self.backends[backend_tag] + # TODO(edoakes): just store the BackendConfig in self.backends. + backend_config = BackendConfig(**config_dict) + kwargs = backend_config.get_actor_creation_args() + + worker_handle = async_retryable(ray.remote(worker_creator)).options( + detached=True, + name=replica_tag, + max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION, + **kwargs).remote(backend_tag, replica_tag, init_args) + # TODO(edoakes): we should probably have a timeout here. + await worker_handle.ready.remote() + return worker_handle + + async def _start_pending_replicas(self): + """Starts the pending backend replicas in self.replicas_to_start. + + Starts the worker, then pushes an update to the router to add it to + the proper backend. If the worker has already been started, only + updates the router. + + Clears self.replicas_to_start. + """ + for backend_tag, replicas_to_create in self.replicas_to_start.items(): + for replica_tag in replicas_to_create: + # NOTE(edoakes): the replicas may already be created if we + # failed after creating them but before writing a checkpoint. + try: + worker_handle = ray.util.get_actor(replica_tag) + except ValueError: + worker_handle = await self._start_backend_worker( + backend_tag, replica_tag) + + self.replicas[backend_tag].append(replica_tag) + self.workers[backend_tag][replica_tag] = worker_handle + + # Register the worker with the router. + await self.router.add_new_worker.remote( + backend_tag, replica_tag, worker_handle) + + # Register the worker with the metric monitor. + self.metric_monitor.add_target.remote(worker_handle) + + self.replicas_to_start.clear() + + async def _stop_pending_replicas(self): + """Stops the pending backend replicas in self.replicas_to_stop. + + Stops workers by telling the router to remove them. + + Clears self.replicas_to_stop. + """ + for backend_tag, replicas_to_stop in self.replicas_to_stop.items(): + for replica_tag in replicas_to_stop: + # NOTE(edoakes): the replicas may already be stopped if we + # failed after stopping them but before writing a checkpoint. + try: + # Remove the replica from router. + # This will also submit __ray_terminate__ on the worker. + # NOTE(edoakes): we currently need to kill the worker from + # the router to guarantee that the router won't submit any + # more requests to it. + await self.router.remove_worker.remote( + backend_tag, replica_tag) + except ValueError: + pass + + self.replicas_to_stop.clear() + + def _scale_replicas(self, backend_tag, num_replicas): """Scale the given backend to the number of replicas. - This requires the master actor to be an async actor because we wait - synchronously for backends to start up and they may make calls into - the master actor while initializing (e.g., by calling get_handle()). + NOTE: this does not actually start or stop the replicas, but instead + adds the intention to start/stop them to self.workers_to_start and + self.workers_to_stop. The caller is responsible for then first writing + a checkpoint and then actually starting/stopping the intended replicas. + This avoids inconsistencies with starting/stopping a worker and then + crashing before writing a checkpoint. """ - assert (backend_tag in self.backend_table.list_backends() + logger.debug("Scaling backend '{}' to {} replicas".format( + backend_tag, num_replicas)) + assert (backend_tag in self.backends ), "Backend {} is not registered.".format(backend_tag) assert num_replicas >= 0, ("Number of replicas must be" " greater than or equal to 0.") - current_num_replicas = len(self._list_replicas(backend_tag)) + current_num_replicas = len(self.replicas[backend_tag]) delta_num_replicas = num_replicas - current_num_replicas if delta_num_replicas > 0: + logger.debug("Adding {} replicas to backend {}".format( + delta_num_replicas, backend_tag)) for _ in range(delta_num_replicas): - await self._start_backend_replica(backend_tag) + replica_tag = "{}#{}".format(backend_tag, get_random_letters()) + self.replicas_to_start[backend_tag].append(replica_tag) + elif delta_num_replicas < 0: + logger.debug("Removing {} replicas from backend {}".format( + -delta_num_replicas, backend_tag)) + assert len(self.replicas[backend_tag]) >= delta_num_replicas for _ in range(-delta_num_replicas): - await self._remove_backend_replica(backend_tag) + replica_tag = self.replicas[backend_tag].pop() + if len(self.replicas[backend_tag]) == 0: + del self.replicas[backend_tag] + del self.workers[backend_tag][replica_tag] + if len(self.workers[backend_tag]) == 0: + del self.workers[backend_tag] - async def get_backend_worker_config(self): - return self.get_router() - - async def _start_backend_replica(self, backend_tag): - assert (backend_tag in self.backend_table.list_backends() - ), "Backend {} is not registered.".format(backend_tag) - - replica_tag = "{}#{}".format(backend_tag, get_random_letters(length=6)) - - # Register the worker in the DB. - # TODO(edoakes): we should guarantee that if calls to the master - # succeed, the cluster state has changed and if they fail, it hasn't. - # Once we have master actor fault tolerance, this breaks that guarantee - - # because this method could fail after writing the replica to the DB. - self.backend_table.add_replica(backend_tag, replica_tag) - - # Fetch the info to start the replica from the backend table. - backend_actor = ray.remote( - self.backend_table.get_backend_creator(backend_tag)) - backend_config_dict = self.backend_table.get_info(backend_tag) - backend_config = BackendConfig(**backend_config_dict) - init_args = [ - backend_tag, replica_tag, - self.backend_table.get_init_args(backend_tag) - ] - kwargs = backend_config.get_actor_creation_args(init_args) - kwargs[ - "max_reconstructions"] = ray.ray_constants.INFINITE_RECONSTRUCTION - - # Start the worker. - worker_handle = backend_actor._remote(**kwargs) - self.workers[backend_tag][replica_tag] = worker_handle - - # Wait for the worker to start up. - await worker_handle.ready.remote() - - [router] = self.get_router() - await router.add_new_worker.remote(backend_tag, worker_handle) - - # Register the worker with the metric monitor. - self.get_metric_monitor()[0].add_target.remote(worker_handle) - - async def _remove_backend_replica(self, backend_tag): - assert (backend_tag in self.backend_table.list_backends() - ), "Backend {} is not registered.".format(backend_tag) - assert (len(self._list_replicas(backend_tag)) > - 0), "Tried to remove replica from empty backend ({}).".format( - backend_tag) - - replica_tag = self.backend_table.remove_replica(backend_tag) - assert backend_tag in self.workers - assert replica_tag in self.workers[backend_tag] - replica_handle = self.workers[backend_tag].pop(replica_tag) - if len(self.workers[backend_tag]) == 0: - del self.workers[backend_tag] - - # Remove the replica from metric monitor. - [monitor] = self.get_metric_monitor() - await monitor.remove_target.remote(replica_handle) - - # Remove the replica from router. - # This will also destroy the actor handle. - [router] = self.get_router() - await router.remove_worker.remote(backend_tag, replica_handle) + self.replicas_to_stop[backend_tag].append(replica_tag) def get_all_worker_handles(self): + """Fetched by the router on startup.""" return self.workers def get_all_endpoints(self): - return expand( - self.route_table.list_service(include_headless=True).values()) - - def get_all_routes(self): - return expand(self.route_table.list_service().keys()) - - def get_all_backends(self): - return self.backend_table.list_backends() + """Used for validation by the API client.""" + return [endpoint for endpoint, methods in self.routes.values()] async def set_traffic(self, endpoint_name, traffic_policy_dictionary): - assert endpoint_name in self.get_all_endpoints(), \ - "Attempted to assign traffic for an endpoint '{}'" \ - " that is not registered.".format(endpoint_name) + """Sets the traffic policy for the specified endpoint.""" + async with self.write_lock: + assert endpoint_name in self.get_all_endpoints(), \ + "Attempted to assign traffic for an endpoint '{}'" \ + " that is not registered.".format(endpoint_name) - assert isinstance(traffic_policy_dictionary, - dict), "Traffic policy must be dictionary" - prob = 0 - existing_backends = set(self.get_all_backends()) - for backend, weight in traffic_policy_dictionary.items(): - prob += weight - assert backend in existing_backends, \ - "Attempted to assign traffic to a backend '{}' that " \ - "is not registered.".format(backend) + assert isinstance(traffic_policy_dictionary, + dict), "Traffic policy must be dictionary" + prob = 0 + for backend, weight in traffic_policy_dictionary.items(): + prob += weight + assert backend in self.backends, \ + "Attempted to assign traffic to a backend '{}' that " \ + "is not registered.".format(backend) - assert np.isclose( - prob, 1, atol=0.02 - ), "weights must sum to 1, currently they sum to {}".format(prob) + assert np.isclose( + prob, 1, atol=0.02 + ), "weights must sum to 1, currently they sum to {}".format(prob) - self.policy_table.register_traffic_policy(endpoint_name, - traffic_policy_dictionary) - [router] = self.get_router() + self.traffic_policies[endpoint_name] = traffic_policy_dictionary - await router.set_traffic.remote(endpoint_name, - traffic_policy_dictionary) + # NOTE(edoakes): we must write a checkpoint before pushing the + # update to avoid inconsistent state if we crash after pushing the + # update. + self._checkpoint() + await self.router.set_traffic.remote(endpoint_name, + traffic_policy_dictionary) - async def create_endpoint(self, route, endpoint_name, methods): - err_prefix = "Cannot create endpoint. " - assert route not in self.get_all_routes(), \ - "{} Route '{}' is already registered.".format(err_prefix, route) - assert endpoint_name not in self.get_all_endpoints(), \ - "{} Endpoint '{}' is already registered.".format(err_prefix, - endpoint_name) - self.route_table.register_service( - route, endpoint_name, methods=methods) - [http_proxy] = self.get_http_proxy() + async def create_endpoint(self, route, endpoint, methods): + """Create a new endpoint with the specified route and methods. - await http_proxy.set_route_table.remote( - self.route_table.list_service( - include_methods=True, include_headless=False)) + If the route is None, this is a "headless" endpoint that will not + be added to the HTTP proxy (can only be accessed via a handle). + """ + async with self.write_lock: + # If this is a headless service with no route, key the endpoint + # based on its name. + # TODO(edoakes): we should probably just store routes and endpoints + # separately. + if route is None: + route = endpoint + + # TODO(edoakes): move this to client side. + err_prefix = "Cannot create endpoint." + if route in self.routes: + if self.routes[route] == (endpoint, methods): + return + else: + raise ValueError( + "{} Route '{}' is already registered.".format( + err_prefix, route)) + + if endpoint in self.get_all_endpoints(): + raise ValueError( + "{} Endpoint '{}' is already registered.".format( + err_prefix, endpoint)) + + logger.info( + "Registering route {} to endpoint {} with methods {}.".format( + route, endpoint, methods)) + + self.routes[route] = (endpoint, methods) + + # NOTE(edoakes): we must write a checkpoint before pushing the + # update to avoid inconsistent state if we crash after pushing the + # update. + self._checkpoint() + await self.http_proxy.set_route_table.remote(self.routes) async def create_backend(self, backend_tag, backend_config, func_or_class, actor_init_args): - backend_config_dict = dict(backend_config) - backend_worker = create_backend_worker(func_or_class) + """Register a new backend under the specified tag.""" + async with self.write_lock: + backend_config_dict = dict(backend_config) + backend_worker = create_backend_worker(func_or_class) - assert backend_tag not in self.get_all_backends(), \ - "Cannot create backend '{}' because a backend" \ - "with that name already exists.".format(backend_tag) + # Save creator that starts replicas, the arguments to be passed in, + # and the configuration for the backends. + self.backends[backend_tag] = (backend_worker, actor_init_args, + backend_config_dict) - # Save creator which starts replicas. - self.backend_table.register_backend(backend_tag, backend_worker) + self._scale_replicas(backend_tag, + backend_config_dict["num_replicas"]) - # Save information about configurations needed to start the replicas. - self.backend_table.register_info(backend_tag, backend_config_dict) + # NOTE(edoakes): we must write a checkpoint before starting new + # or pushing the updated config to avoid inconsistent state if we + # crash while making the change. + self._checkpoint() + await self._start_pending_replicas() + await self._stop_pending_replicas() - # Save the initial arguments needed by replicas. - self.backend_table.save_init_args(backend_tag, actor_init_args) - - # Set the backend config inside the router - # (particularly for max-batch-size). - [router] = self.get_router() - await router.set_backend_config.remote(backend_tag, - backend_config_dict) - - await self.scale_replicas(backend_tag, - backend_config_dict["num_replicas"]) + # Set the backend config inside the router + # (particularly for max-batch-size). + await self.router.set_backend_config.remote( + backend_tag, backend_config_dict) async def set_backend_config(self, backend_tag, backend_config): - assert (backend_tag in self.backend_table.list_backends() - ), "Backend {} is not registered.".format(backend_tag) - assert isinstance(backend_config, - BackendConfig), ("backend_config must be" - " of instance BackendConfig") - backend_config_dict = dict(backend_config) - old_backend_config_dict = self.backend_table.get_info(backend_tag) + """Set the config for the specified backend.""" + async with self.write_lock: + assert (backend_tag in self.backends + ), "Backend {} is not registered.".format(backend_tag) + assert isinstance(backend_config, + BackendConfig), ("backend_config must be" + " of instance BackendConfig") + backend_config_dict = dict(backend_config) + backend_worker, init_args, old_backend_config_dict = self.backends[ + backend_tag] - if (not old_backend_config_dict["has_accept_batch_annotation"] - and backend_config.max_batch_size is not None): - raise batch_annotation_not_found + if (not old_backend_config_dict["has_accept_batch_annotation"] + and backend_config.max_batch_size is not None): + raise batch_annotation_not_found - self.backend_table.register_info(backend_tag, backend_config_dict) + self.backends[backend_tag] = (backend_worker, init_args, + backend_config_dict) - # Inform the router about change in configuration - # (particularly for setting max_batch_size). - [router] = self.get_router() - await router.set_backend_config.remote(backend_tag, - backend_config_dict) + # Restart replicas if there is a change in the backend config + # related to restart_configs. + need_to_restart_replicas = any( + old_backend_config_dict[k] != backend_config_dict[k] + for k in BackendConfig.restart_on_change_fields) + if need_to_restart_replicas: + # Kill all the replicas for restarting with new configurations. + self._scale_replicas(backend_tag, 0) - # Restart replicas if there is a change in the backend config related - # to restart_configs. - need_to_restart_replicas = any( - old_backend_config_dict[k] != backend_config_dict[k] - for k in BackendConfig.restart_on_change_fields) - if need_to_restart_replicas: - # Kill all the replicas for restarting with new configurations. - await self.scale_replicas(backend_tag, 0) + # Scale the replicas with the new configuration. + self._scale_replicas(backend_tag, + backend_config_dict["num_replicas"]) - # Scale the replicas with the new configuration. - await self.scale_replicas(backend_tag, - backend_config_dict["num_replicas"]) + # NOTE(edoakes): we must write a checkpoint before pushing the + # update to avoid inconsistent state if we crash after pushing the + # update. + self._checkpoint() + + # Inform the router about change in configuration + # (particularly for setting max_batch_size). + await self.router.set_backend_config.remote( + backend_tag, backend_config_dict) + + await self._start_pending_replicas() + await self._stop_pending_replicas() def get_backend_config(self, backend_tag): - assert (backend_tag in self.backend_table.list_backends() + """Get the current config for the specified backend.""" + assert (backend_tag in self.backends ), "Backend {} is not registered.".format(backend_tag) - backend_config_dict = self.backend_table.get_info(backend_tag) - return BackendConfig(**backend_config_dict) + return BackendConfig(**self.backends[backend_tag][2]) diff --git a/python/ray/serve/metric.py b/python/ray/serve/metric.py index 9ce1ce7b2..84d89c839 100644 --- a/python/ray/serve/metric.py +++ b/python/ray/serve/metric.py @@ -24,9 +24,6 @@ class MetricMonitor: self.gc_window_seconds = gc_window_seconds self.latest_gc_time = time.time() - def is_ready(self): - return True - def add_target(self, target_handle): hex_id = target_handle._actor_id.hex() self.actor_handles[hex_id] = target_handle @@ -153,5 +150,8 @@ class MetricMonitor: @ray.remote(num_cpus=0) def start_metric_monitor_loop(monitor_handle, duration_s=5): while True: - ray.get(monitor_handle.scrape.remote()) time.sleep(duration_s) + try: + ray.get(monitor_handle.scrape.remote()) + except ray.exceptions.RayActorError: + pass diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index 9031343ac..f4333b203 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -1,6 +1,7 @@ import asyncio import copy from collections import defaultdict +import time from typing import DefaultDict, List # Note on choosing blist instead of stdlib heapq @@ -13,7 +14,7 @@ import blist import ray import ray.cloudpickle as pickle -from ray.serve.utils import logger +from ray.serve.utils import logger, retry_actor_failures class Query: @@ -138,6 +139,8 @@ class Router: self.traffic = defaultdict(dict) # backend_name -> backend_config self.backend_info = dict() + # replica tag -> worker_handle + self.replicas = dict() # -- Synchronization -- # @@ -150,15 +153,28 @@ class Router: # batching polcies. self.flush_lock = asyncio.Lock() - # Fetch the worker handles from the master actor. We use a "pull-based" - # approach instead of pushing them from the master so that the router - # can transparently recover from failure. + # Fetch the worker handles, traffic policies, and backend configs from + # the master actor. We use a "pull-based" approach instead of pushing + # them from the master so that the router can transparently recover + # from failure. ray.serve.init() master_actor = ray.serve.api._get_master_actor() - backend_dict = ray.get(master_actor.get_all_worker_handles.remote()) - for backend, replica_dict in backend_dict.items(): - for worker in replica_dict.values(): - await self.add_new_worker(backend, worker) + + traffic_policies = retry_actor_failures( + master_actor.get_traffic_policies) + for endpoint, traffic_policy in traffic_policies.items(): + await self.set_traffic(endpoint, traffic_policy) + + backend_dict = retry_actor_failures( + master_actor.get_all_worker_handles) + for backend_tag, replica_dict in backend_dict.items(): + for replica_tag, worker in replica_dict.items(): + await self.add_new_worker(backend_tag, replica_tag, worker) + + backend_configs = retry_actor_failures( + master_actor.get_backend_configs) + for backend, backend_config_dict in backend_configs.items(): + await self.set_backend_config(backend, backend_config_dict) def is_ready(self): return True @@ -199,28 +215,43 @@ class Router: result = await query.async_future return result - async def add_new_worker(self, backend, worker_handle): - logger.debug("New worker added for backend '{}'".format(backend)) - await self.mark_worker_idle(backend, worker_handle) + async def add_new_worker(self, backend_tag, replica_tag, worker_handle): + backend_replica_tag = backend_tag + ":" + replica_tag + if backend_replica_tag in self.replicas: + return + self.replicas[backend_replica_tag] = worker_handle - async def mark_worker_idle(self, backend, worker_handle): - await self.worker_queues[backend].put(worker_handle) + logger.debug("New worker added for backend '{}'".format(backend_tag)) + # await worker_handle.ready.remote() + await self.mark_worker_idle(backend_tag, backend_replica_tag) + + async def mark_worker_idle(self, backend_tag, backend_replica_tag): + if backend_replica_tag not in self.replicas: + return + + await self.worker_queues[backend_tag].put(backend_replica_tag) await self.flush() - async def remove_worker(self, backend, worker_handle): + async def remove_worker(self, backend_tag, replica_tag): + backend_replica_tag = backend_tag + ":" + replica_tag + if backend_replica_tag not in self.replicas: + return + worker_handle = self.replicas.pop(backend_replica_tag) + # We need this lock because we modify worker_queue here. async with self.flush_lock: - old_queue = self.worker_queues[backend] + old_queue = self.worker_queues[backend_tag] new_queue = asyncio.Queue() - target_id = worker_handle._actor_id while not old_queue.empty(): - worker_handle = await old_queue.get() - if worker_handle._actor_id != target_id: - await new_queue.put(worker_handle) + curr_tag = await old_queue.get() + if curr_tag != backend_replica_tag: + await new_queue.put(curr_tag) - self.worker_queues[backend] = new_queue - # TODO: consider awaiting this on a timeout or using ray.kill(). + self.worker_queues[backend_tag] = new_queue + # We need to terminate the worker here instead of from the master + # so we can guarantee that the router won't submit any more tasks + # on it. worker_handle.__ray_terminate__.remote() async def link(self, service, backend): @@ -297,11 +328,15 @@ class Router: await self._assign_query_to_worker( backend, buffer_queue, worker_queue, max_batch_size) - async def _do_query(self, backend, worker, req): + async def _do_query(self, backend, backend_replica_tag, req): # If the worker died, this will be a RayActorError. Just return it and # let the HTTP proxy handle the retry logic. + logger.debug("Sending query to replica:" + backend_replica_tag) + start = time.time() + worker = self.replicas[backend_replica_tag] result = await worker.handle_request.remote(req) - await self.mark_worker_idle(backend, worker) + await self.mark_worker_idle(backend, backend_replica_tag) + logger.debug("Got result in {:.2f}s", time.time() - start) return result async def _assign_query_to_worker(self, @@ -311,11 +346,11 @@ class Router: max_batch_size=None): while len(buffer_queue) and worker_queue.qsize(): - worker = await worker_queue.get() + backend_replica_tag = await worker_queue.get() if max_batch_size is None: # No batching request = buffer_queue.pop(0) future = asyncio.get_event_loop().create_task( - self._do_query(backend, worker, request)) + self._do_query(backend, backend_replica_tag, request)) # chaining satisfies request.async_future with future result. asyncio.futures._chain_future(future, request.async_future) else: @@ -331,7 +366,7 @@ class Router: for group in requests_group.values(): future = asyncio.get_event_loop().create_task( - self._do_query(backend, worker, group)) + self._do_query(backend, backend_replica_tag, group)) future.add_done_callback( _make_future_unwrapper( client_futures=[req.async_future for req in group], diff --git a/python/ray/serve/tests/conftest.py b/python/ray/serve/tests/conftest.py index 2f6d629bb..5e5de4e14 100644 --- a/python/ray/serve/tests/conftest.py +++ b/python/ray/serve/tests/conftest.py @@ -6,6 +6,12 @@ import pytest import ray from ray import serve +# TODO(edoakes): the failure tests currently fail with the GCS service enabled. +os.environ["RAY_GCS_SERVICE_ENABLED"] = "false" + +if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False): + serve.master._CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.5 + @pytest.fixture(scope="session") def serve_instance(): diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 5b57c91fc..011202630 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -75,24 +75,17 @@ def test_no_route(serve_instance): assert result == 1 -def test_reject_duplicate_backend_tag(serve_instance): - backend_name = "foo" - serve.create_backend(lambda foo: foo, backend_name) - with pytest.raises(AssertionError): - serve.create_backend(lambda foo: foo, backend_name) - - def test_reject_duplicate_route(serve_instance): route = "/foo" serve.create_endpoint("bar", route=route) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): serve.create_endpoint("foo", route=route) def test_reject_duplicate_endpoint(serve_instance): endpoint_name = "foo" serve.create_endpoint(endpoint_name, route="/ok") - with pytest.raises(AssertionError): + with pytest.raises(ValueError): serve.create_endpoint(endpoint_name, route="/different") @@ -102,9 +95,9 @@ def test_set_traffic_missing_data(serve_instance): serve.create_endpoint(endpoint_name) serve.create_backend(lambda: 5, backend_name) with pytest.raises(AssertionError): - serve.set_traffic(endpoint_name, {"nonexistant_backend": 1.0}) + serve.set_traffic(endpoint_name, {"nonexistent_backend": 1.0}) with pytest.raises(AssertionError): - serve.set_traffic("nonexistant_endpoint_name", {backend_name: 1.0}) + serve.set_traffic("nonexistent_endpoint_name", {backend_name: 1.0}) def test_scaling_replicas(serve_instance): diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index de0ea4fb9..a5f9b7d59 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -13,15 +13,15 @@ from ray.serve.backend_config import BackendConfig pytestmark = pytest.mark.asyncio -def setup_worker(name, func_or_class, router_handle, init_args=None): +def setup_worker(name, func_or_class, init_args=None): if init_args is None: init_args = () @ray.remote class WorkerActor: - def __init__(self, router_handle): + def __init__(self): self.worker = create_backend_worker(func_or_class)( - name, name + ":tag", init_args, router_handle=router_handle[0]) + name, name + ":tag", init_args) def ready(self): pass @@ -29,13 +29,10 @@ def setup_worker(name, func_or_class, router_handle, init_args=None): def get_metrics(self): return self.worker.get_metrics() - def run(self): - self.worker.backend.mark_idle_in_router() - async def handle_request(self, *args, **kwargs): return await self.worker.handle_request(*args, **kwargs) - worker = WorkerActor.remote([router_handle]) + worker = WorkerActor.remote() ray.get(worker.ready.remote()) return worker @@ -54,8 +51,8 @@ async def test_runner_actor(serve_instance): CONSUMER_NAME = "runner" PRODUCER_NAME = "prod" - worker = setup_worker(CONSUMER_NAME, echo, q) - await q.add_new_worker.remote(CONSUMER_NAME, worker) + worker = setup_worker(CONSUMER_NAME, echo) + await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) q.link.remote(PRODUCER_NAME, CONSUMER_NAME) @@ -79,8 +76,8 @@ async def test_ray_serve_mixin(serve_instance): def __call__(self, flask_request, i=None): return i + self.increment - worker = setup_worker(CONSUMER_NAME, MyAdder, q, init_args=(3, )) - await q.add_new_worker.remote(CONSUMER_NAME, worker) + worker = setup_worker(CONSUMER_NAME, MyAdder, init_args=(3, )) + await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) q.link.remote(PRODUCER_NAME, CONSUMER_NAME) @@ -101,8 +98,8 @@ async def test_task_runner_check_context(serve_instance): CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" - worker = setup_worker(CONSUMER_NAME, echo, q) - await q.add_new_worker.remote(CONSUMER_NAME, worker) + worker = setup_worker(CONSUMER_NAME, echo) + await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) q.link.remote(PRODUCER_NAME, CONSUMER_NAME) query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) @@ -125,8 +122,8 @@ async def test_task_runner_custom_method_single(serve_instance): CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" - worker = setup_worker(CONSUMER_NAME, NonBatcher, q) - await q.add_new_worker.remote(CONSUMER_NAME, worker) + worker = setup_worker(CONSUMER_NAME, NonBatcher) + await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) q.link.remote(PRODUCER_NAME, CONSUMER_NAME) @@ -160,7 +157,7 @@ async def test_task_runner_custom_method_batch(serve_instance): CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" - worker = setup_worker(CONSUMER_NAME, Batcher, q) + worker = setup_worker(CONSUMER_NAME, Batcher) await q.link.remote(PRODUCER_NAME, CONSUMER_NAME) await q.set_backend_config.remote( @@ -174,7 +171,7 @@ async def test_task_runner_custom_method_batch(serve_instance): futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)] futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)] - await q.add_new_worker.remote(CONSUMER_NAME, worker) + await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) gathered = await asyncio.gather(*futures) assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"} diff --git a/python/ray/serve/tests/test_failure.py b/python/ray/serve/tests/test_failure.py index ced123750..3ba2b4804 100644 --- a/python/ray/serve/tests/test_failure.py +++ b/python/ray/serve/tests/test_failure.py @@ -19,6 +19,57 @@ def request_with_retries(endpoint, timeout=30): time.sleep(0.1) +def test_master_failure(serve_instance): + serve.init() + serve.create_endpoint("master_failure", "/master_failure") + + def function(): + return "hello1" + + serve.create_backend(function, "master_failure:v1") + serve.set_traffic("master_failure", {"master_failure:v1": 1.0}) + + assert request_with_retries("/master_failure", timeout=1).text == "hello1" + + for _ in range(10): + response = request_with_retries("/master_failure", timeout=30) + assert response.text == "hello1" + + ray.kill(serve.api._get_master_actor()) + + for _ in range(10): + response = request_with_retries("/master_failure", timeout=30) + assert response.text == "hello1" + + def function(): + return "hello2" + + ray.kill(serve.api._get_master_actor()) + + serve.create_backend(function, "master_failure:v2") + serve.set_traffic("master_failure", {"master_failure:v2": 1.0}) + + for _ in range(10): + response = request_with_retries("/master_failure", timeout=30) + assert response.text == "hello2" + + def function(): + return "hello3" + + ray.kill(serve.api._get_master_actor()) + serve.create_endpoint("master_failure_2", "/master_failure_2") + ray.kill(serve.api._get_master_actor()) + serve.create_backend(function, "master_failure_2") + ray.kill(serve.api._get_master_actor()) + serve.set_traffic("master_failure_2", {"master_failure_2": 1.0}) + + for _ in range(10): + response = request_with_retries("/master_failure", timeout=30) + assert response.text == "hello2" + response = request_with_retries("/master_failure_2", timeout=30) + assert response.text == "hello3" + + def _kill_http_proxy(): [http_proxy] = ray.get( serve.api._get_master_actor().get_http_proxy.remote()) @@ -27,7 +78,7 @@ def _kill_http_proxy(): def test_http_proxy_failure(serve_instance): serve.init() - serve.create_endpoint("proxy_failure", "/proxy_failure", methods=["GET"]) + serve.create_endpoint("proxy_failure", "/proxy_failure") def function(): return "hello1" @@ -35,7 +86,7 @@ def test_http_proxy_failure(serve_instance): serve.create_backend(function, "proxy_failure:v1") serve.set_traffic("proxy_failure", {"proxy_failure:v1": 1.0}) - assert request_with_retries("/proxy_failure", timeout=0.1).text == "hello1" + assert request_with_retries("/proxy_failure", timeout=1.0).text == "hello1" for _ in range(10): response = request_with_retries("/proxy_failure", timeout=30) @@ -61,7 +112,7 @@ def _kill_router(): def test_router_failure(serve_instance): serve.init() - serve.create_endpoint("router_failure", "/router_failure", methods=["GET"]) + serve.create_endpoint("router_failure", "/router_failure") def function(): return "hello1" @@ -77,6 +128,10 @@ def test_router_failure(serve_instance): _kill_router() + for _ in range(10): + response = request_with_retries("/router_failure", timeout=30) + assert response.text == "hello1" + def function(): return "hello2" @@ -99,7 +154,7 @@ def _get_worker_handles(backend): # serving requests. def test_worker_restart(serve_instance): serve.init() - serve.create_endpoint("worker_failure", "/worker_failure", methods=["GET"]) + serve.create_endpoint("worker_failure", "/worker_failure") class Worker1: def __call__(self): @@ -109,7 +164,7 @@ def test_worker_restart(serve_instance): serve.set_traffic("worker_failure", {"worker_failure:v1": 1.0}) # Get the PID of the worker. - old_pid = request_with_retries("/worker_failure", timeout=0.1).text + old_pid = request_with_retries("/worker_failure", timeout=1).text # Kill the worker. handles = _get_worker_handles("worker_failure:v1") @@ -131,8 +186,7 @@ def test_worker_restart(serve_instance): def test_worker_replica_failure(serve_instance): serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0 serve.init() - serve.create_endpoint( - "replica_failure", "/replica_failure", methods=["GET"]) + serve.create_endpoint("replica_failure", "/replica_failure") class Worker: # Assumes that two replicas are started. Will hang forever in the @@ -169,8 +223,7 @@ def test_worker_replica_failure(serve_instance): # Wait until both replicas have been started. responses = set() while len(responses) == 1: - responses.add( - request_with_retries("/replica_failure", timeout=0.1).text) + responses.add(request_with_retries("/replica_failure", timeout=1).text) time.sleep(0.1) # Kill one of the replicas. diff --git a/python/ray/serve/tests/test_master_crashes.py b/python/ray/serve/tests/test_master_crashes.py new file mode 100644 index 000000000..5bf363235 --- /dev/null +++ b/python/ray/serve/tests/test_master_crashes.py @@ -0,0 +1,19 @@ +import os +import pytest +from pathlib import Path +import sys + +if __name__ == "__main__": + curr_dir = Path(__file__).parent + test_paths = curr_dir.rglob("test_*.py") + sorted_path = sorted(map(lambda path: str(path.absolute()), test_paths)) + serve_tests_files = list(sorted_path) + + print("Testing the following files") + for test_file in serve_tests_files: + print("->", test_file.split("/")[-1]) + + print("Setting RAY_SERVE_INTENTIONALLY_CRASH=1") + os.environ["RAY_SERVE_INTENTIONALLY_CRASH"] = "1" + + sys.exit(pytest.main(["-v", "-s"] + serve_tests_files)) diff --git a/python/ray/serve/tests/test_nonblocking.py b/python/ray/serve/tests/test_nonblocking.py index 11b0f2b4f..a0da5078d 100644 --- a/python/ray/serve/tests/test_nonblocking.py +++ b/python/ray/serve/tests/test_nonblocking.py @@ -6,15 +6,15 @@ from ray import serve def test_nonblocking(): serve.init() - serve.create_endpoint("endpoint", "/api") + serve.create_endpoint("nonblocking", "/nonblocking") def function(flask_request): return {"method": flask_request.method} - serve.create_backend(function, "echo:v1") - serve.set_traffic("endpoint", {"echo:v1": 1.0}) + serve.create_backend(function, "nonblocking:v1") + serve.set_traffic("nonblocking", {"nonblocking:v1": 1.0}) - resp = requests.get("http://127.0.0.1:8000/api").json()["method"] + resp = requests.get("http://127.0.0.1:8000/nonblocking").json()["method"] assert resp == "GET" diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index b737adf51..c6b752c3e 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -29,6 +29,9 @@ def make_task_runner_mock(): def get_all_calls(self): return self.queries + def ready(self): + pass + return TaskRunnerMock.remote() @@ -39,8 +42,9 @@ def task_runner_mock_actor(): async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor): q = RandomPolicyQueueActor.remote() - q.link.remote("svc", "backend") - q.add_new_worker.remote("backend", task_runner_mock_actor) + q.link.remote("svc", "backend-single-prod") + q.add_new_worker.remote("backend-single-prod", "replica-1", + task_runner_mock_actor) # Make sure we get the request result back result = await q.enqueue_request.remote(RequestMetadata("svc", None), 1) @@ -54,7 +58,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor): async def test_slo(serve_instance, task_runner_mock_actor): q = RandomPolicyQueueActor.remote() - await q.link.remote("svc", "backend") + await q.link.remote("svc", "backend-slo") all_request_sent = [] for i in range(10): @@ -63,7 +67,8 @@ async def test_slo(serve_instance, task_runner_mock_actor): q.enqueue_request.remote( RequestMetadata("svc", None, relative_slo_ms=slo_ms), i)) - await q.add_new_worker.remote("backend", task_runner_mock_actor) + await q.add_new_worker.remote("backend-slo", "replica-1", + task_runner_mock_actor) await asyncio.gather(*all_request_sent) @@ -78,14 +83,16 @@ async def test_slo(serve_instance, task_runner_mock_actor): async def test_alter_backend(serve_instance, task_runner_mock_actor): q = RandomPolicyQueueActor.remote() - await q.set_traffic.remote("svc", {"backend-1": 1}) - await q.add_new_worker.remote("backend-1", task_runner_mock_actor) + await q.set_traffic.remote("svc", {"backend-alter": 1}) + await q.add_new_worker.remote("backend-alter", "replica-1", + task_runner_mock_actor) await q.enqueue_request.remote(RequestMetadata("svc", None), 1) got_work = await task_runner_mock_actor.get_recent_call.remote() assert got_work.request_args[0] == 1 - await q.set_traffic.remote("svc", {"backend-2": 1}) - await q.add_new_worker.remote("backend-2", task_runner_mock_actor) + await q.set_traffic.remote("svc", {"backend-alter-2": 1}) + await q.add_new_worker.remote("backend-alter-2", "replica-1", + task_runner_mock_actor) await q.enqueue_request.remote(RequestMetadata("svc", None), 2) got_work = await task_runner_mock_actor.get_recent_call.remote() assert got_work.request_args[0] == 2 @@ -94,10 +101,13 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor): async def test_split_traffic_random(serve_instance, task_runner_mock_actor): q = RandomPolicyQueueActor.remote() - await q.set_traffic.remote("svc", {"backend-1": 0.5, "backend-2": 0.5}) + await q.set_traffic.remote("svc", { + "backend-split": 0.5, + "backend-split-2": 0.5 + }) runner_1, runner_2 = [make_task_runner_mock() for _ in range(2)] - await q.add_new_worker.remote("backend-1", runner_1) - await q.add_new_worker.remote("backend-2", runner_2) + await q.add_new_worker.remote("backend-split", "replica-1", runner_1) + await q.add_new_worker.remote("backend-split-2", "replica-1", runner_2) # assume 50% split, the probability of all 20 requests goes to a # single queue is 0.5^20 ~ 1-6 @@ -114,13 +124,13 @@ async def test_split_traffic_random(serve_instance, task_runner_mock_actor): async def test_round_robin(serve_instance, task_runner_mock_actor): q = RoundRobinPolicyQueueActor.remote() - await q.set_traffic.remote("svc", {"backend-1": 0.5, "backend-2": 0.5}) + await q.set_traffic.remote("svc", {"backend-rr": 0.5, "backend-rr-2": 0.5}) runner_1, runner_2 = [make_task_runner_mock() for _ in range(2)] # NOTE: this is the only difference between the # test_split_traffic_random and test_round_robin - await q.add_new_worker.remote("backend-1", runner_1) - await q.add_new_worker.remote("backend-2", runner_2) + await q.add_new_worker.remote("backend-rr", "replica-1", runner_1) + await q.add_new_worker.remote("backend-rr-2", "replica-1", runner_2) for _ in range(20): await q.enqueue_request.remote(RequestMetadata("svc", None), 1) @@ -135,13 +145,16 @@ async def test_round_robin(serve_instance, task_runner_mock_actor): async def test_fixed_packing(serve_instance): packing_num = 4 q = FixedPackingPolicyQueueActor.remote(packing_num=packing_num) - await q.set_traffic.remote("svc", {"backend-1": 0.5, "backend-2": 0.5}) + await q.set_traffic.remote("svc", { + "backend-fixed": 0.5, + "backend-fixed-2": 0.5 + }) runner_1, runner_2 = (make_task_runner_mock() for _ in range(2)) # both the backends will get equal number of queries # as it is packed round robin - await q.add_new_worker.remote("backend-1", runner_1) - await q.add_new_worker.remote("backend-2", runner_2) + await q.add_new_worker.remote("backend-fixed", "replica-1", runner_1) + await q.add_new_worker.remote("backend-fixed-2", "replica-1", runner_2) for backend, runner in zip(["1", "2"], [runner_1, runner_2]): for _ in range(packing_num): @@ -158,20 +171,23 @@ async def test_power_of_two_choices(serve_instance): enqueue_futures = [] # First, fill the queue for backend-1 with 3 requests - await q.set_traffic.remote("svc", {"backend-1": 1.0}) + await q.set_traffic.remote("svc", {"backend-pow2": 1.0}) for _ in range(3): future = q.enqueue_request.remote(RequestMetadata("svc", None), "1") enqueue_futures.append(future) # Then, add a new backend, this backend should be filled next - await q.set_traffic.remote("svc", {"backend-1": 0.5, "backend-2": 0.5}) + await q.set_traffic.remote("svc", { + "backend-pow2": 0.5, + "backend-pow2-2": 0.5 + }) for _ in range(2): future = q.enqueue_request.remote(RequestMetadata("svc", None), "2") enqueue_futures.append(future) runner_1, runner_2 = (make_task_runner_mock() for _ in range(2)) - await q.add_new_worker.remote("backend-1", runner_1) - await q.add_new_worker.remote("backend-2", runner_2) + await q.add_new_worker.remote("backend-pow2", "replica-1", runner_1) + await q.add_new_worker.remote("backend-pow2-2", "replica-1", runner_2) await asyncio.gather(*enqueue_futures) @@ -183,10 +199,10 @@ async def test_queue_remove_replicas(serve_instance): @ray.remote class TestRandomPolicyQueueActor(RandomPolicyQueue): def worker_queue_size(self, backend): - return self.worker_queues["backend"].qsize() + return self.worker_queues["backend-remove"].qsize() temp_actor = make_task_runner_mock() q = TestRandomPolicyQueueActor.remote() - await q.add_new_worker.remote("backend", temp_actor) - await q.remove_worker.remote("backend", temp_actor) + await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor) + await q.remove_worker.remote("backend-remove", "replica-1") assert ray.get(q.worker_queue_size.remote("backend")) == 0 diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index 8d0aa771d..9d0a70d41 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -1,3 +1,6 @@ +import asyncio +from functools import wraps +import inspect import json import logging import random @@ -6,11 +9,11 @@ import time import io import os +import ray import requests from pygments import formatters, highlight, lexers from ray.serve.context import FakeFlaskRequest, TaskContext from ray.serve.http_util import build_flask_request -import itertools import numpy as np try: @@ -18,19 +21,7 @@ try: except ImportError: pydantic = None - -def expand(l): - """ - Implements a nested flattening of a list. - Example: - >>> serve.utils.expand([1,2,[3,4,5],6]) - [1,2,3,4,5,6] - >>> serve.utils.expand(["a", ["b", "c"], "d", ["e", "f"]]) - ["a", "b", "c", "d", "e", "f"] - """ - return list( - itertools.chain.from_iterable( - [x if isinstance(x, list) else [x] for x in l])) +ACTOR_FAILURE_RETRY_TIMEOUT_S = 60 def parse_request_item(request_item): @@ -116,3 +107,69 @@ def block_until_http_ready(http_endpoint, num_retries=5, backoff_time_s=1): def get_random_letters(length=6): return "".join(random.choices(string.ascii_letters, k=length)) + + +def async_retryable(cls): + """Make all actor method invocations on the class retryable. + + Note: This will retry actor_handle.method_name.remote(), but it must + be invoked in an async context. + + Usage: + @ray.remote(max_reconstructions=10000) + @async_retryable + class A: + pass + """ + for name, method in inspect.getmembers(cls, predicate=inspect.isfunction): + + def decorate_with_retry(f): + @wraps(f) + async def retry_method(*args, **kwargs): + start = time.time() + while time.time() - start < ACTOR_FAILURE_RETRY_TIMEOUT_S: + try: + return await f(*args, **kwargs) + except ray.exceptions.RayActorError: + logger.warning( + "Actor method '{}' failed, retrying after 100ms.". + format(name)) + await asyncio.sleep(0.1) + raise RuntimeError("Timed out after {}s waiting for actor " + "method '{}' to succeed.".format( + ACTOR_FAILURE_RETRY_TIMEOUT_S, name)) + + return retry_method + + method.__ray_invocation_decorator__ = decorate_with_retry + return cls + + +def retry_actor_failures(f, *args, **kwargs): + start = time.time() + while time.time() - start < ACTOR_FAILURE_RETRY_TIMEOUT_S: + try: + return ray.get(f.remote(*args, **kwargs)) + except ray.exceptions.RayActorError: + logger.warning( + "Actor method '{}' failed, retrying after 100ms".format( + f._method_name)) + time.sleep(0.1) + raise RuntimeError("Timed out after {}s waiting for actor " + "method '{}' to succeed.".format( + ACTOR_FAILURE_RETRY_TIMEOUT_S, f._method_name)) + + +async def retry_actor_failures_async(f, *args, **kwargs): + start = time.time() + while time.time() - start < ACTOR_FAILURE_RETRY_TIMEOUT_S: + try: + return await f.remote(*args, **kwargs) + except ray.exceptions.RayActorError: + logger.warning( + "Actor method '{}' failed, retrying after 100ms".format( + f._method_name)) + await asyncio.sleep(0.1) + raise RuntimeError("Timed out after {}s waiting for actor " + "method '{}' to succeed.".format( + ACTOR_FAILURE_RETRY_TIMEOUT_S, f._method_name))