From 2cb9cfb2b6b1d2b5a14b543310761d64b76b8508 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Sun, 12 Apr 2020 11:48:08 -0500 Subject: [PATCH] [serve] Make workers fault tolerant (#7970) --- python/ray/serve/http_proxy.py | 37 ++++-- python/ray/serve/master.py | 14 ++- python/ray/serve/router.py | 2 + python/ray/serve/tests/test_api.py | 4 +- python/ray/serve/tests/test_failure.py | 164 ++++++++++++++++++++----- 5 files changed, 171 insertions(+), 50 deletions(-) diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index 0974a9be8..b99832319 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -8,9 +8,14 @@ from ray.serve.constants import SERVE_MASTER_NAME from ray.serve.context import TaskContext from ray.serve.request_params import RequestMetadata from ray.serve.http_util import Response +from ray.serve.utils import logger from urllib.parse import parse_qs +# The maximum number of times to retry a request due to actor failure. +# TODO(edoakes): this should probably be configurable. +MAX_ACTOR_DEAD_RETRIES = 10 + class HTTPProxy: """ @@ -102,8 +107,9 @@ class HTTPProxy: await Response(self.route_table).send(scope, receive, send) return - # TODO(simon): Use werkzeug route mapper to support variable path - if current_path not in self.route_table: + try: + endpoint_name, methods_allowed = self.route_table[current_path] + except KeyError: error_message = ( "Path {} not found. " "Please ping http://.../-/routes for routing table" @@ -111,8 +117,6 @@ class HTTPProxy: await error_sender(error_message, 404) return - endpoint_name, methods_allowed = self.route_table[current_path] - if scope["method"] not in methods_allowed: error_message = ("Methods {} not allowed. " "Avaiable HTTP methods are {}.").format( @@ -137,13 +141,24 @@ class HTTPProxy: absolute_slo_ms=absolute_slo_ms, call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__")) - try: - result = await self.router_handle.enqueue_request.remote( - request_metadata, scope, http_body_bytes) - await Response(result).send(scope, receive, send) - except Exception as e: - error_message = "Internal Error. Traceback: {}.".format(e) - await error_sender(error_message, 500) + retries = 0 + while retries <= MAX_ACTOR_DEAD_RETRIES: + try: + result = await self.router_handle.enqueue_request.remote( + request_metadata, scope, http_body_bytes) + if not isinstance(result, ray.exceptions.RayActorError): + await Response(result).send(scope, receive, send) + break + logger.warning("Got RayActorError:", str(result)) + await asyncio.sleep(0.1) + except Exception as e: + error_message = "Internal Error. Traceback: {}.".format(e) + await error_sender(error_message, 500) + break + else: + logger.debug("Maximum actor death retries exceeded") + await error_sender( + "Internal Error. Maximum actor death retries exceeded", 500) @ray.remote diff --git a/python/ray/serve/master.py b/python/ray/serve/master.py index 1004fb2fb..a5ace3a50 100644 --- a/python/ray/serve/master.py +++ b/python/ray/serve/master.py @@ -27,7 +27,7 @@ class ServeMaster: self.route_table = RoutingTable(kv_store_connector) self.backend_table = BackendTable(kv_store_connector) self.policy_table = TrafficPolicyTable(kv_store_connector) - self.tag_to_actor_handles = dict() + self.replica_tag_to_workers = dict() self.router = None self.http_proxy = None @@ -131,10 +131,12 @@ class ServeMaster: self.backend_table.get_init_args(backend_tag) ] kwargs = backend_config.get_actor_creation_args(init_args) + kwargs[ + "max_reconstructions"] = ray.ray_constants.INFINITE_RECONSTRUCTION # Start the worker. worker_handle = backend_actor._remote(**kwargs) - self.tag_to_actor_handles[replica_tag] = worker_handle + self.replica_tag_to_workers[replica_tag] = worker_handle # Wait for the worker to start up. await worker_handle.ready.remote() @@ -152,8 +154,8 @@ class ServeMaster: backend_tag) replica_tag = self.backend_table.remove_replica(backend_tag) - assert replica_tag in self.tag_to_actor_handles - replica_handle = self.tag_to_actor_handles.pop(replica_tag) + assert replica_tag in self.replica_tag_to_workers + replica_handle = self.replica_tag_to_workers.pop(replica_tag) # Remove the replica from metric monitor. [monitor] = self.get_metric_monitor() @@ -166,8 +168,8 @@ class ServeMaster: router.remove_and_destroy_replica.remote(backend_tag, replica_handle)) - def get_all_handles(self): - return self.tag_to_actor_handles + def get_all_worker_handles(self): + return self.replica_tag_to_workers def get_all_endpoints(self): return expand( diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index 0136c5838..2f1519f8e 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -288,6 +288,8 @@ class Router: backend, buffer_queue, worker_queue, max_batch_size) async def _do_query(self, backend, worker, req): + # If the worker died, this will be a RayActorError. Just return it and + # let the HTTP proxy handle the retry logic. result = await worker.handle_request.remote(req) await self.mark_worker_idle(backend, worker) return result diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 5cf201524..ec6f75062 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -193,7 +193,7 @@ def test_killing_replicas(serve_instance): new_replica_tag_list = ray.get( master_actor._list_replicas.remote("simple:v1")) new_all_tag_list = list( - ray.get(master_actor.get_all_handles.remote()).keys()) + ray.get(master_actor.get_all_worker_handles.remote()).keys()) # the new_replica_tag_list must be subset of all_tag_list assert set(new_replica_tag_list) <= set(new_all_tag_list) @@ -227,7 +227,7 @@ def test_not_killing_replicas(serve_instance): new_replica_tag_list = ray.get( master_actor._list_replicas.remote("bsimple:v1")) new_all_tag_list = list( - ray.get(master_actor.get_all_handles.remote()).keys()) + ray.get(master_actor.get_all_worker_handles.remote()).keys()) # the old and new replica tag list should be identical # and should be subset of all_tag_list diff --git a/python/ray/serve/tests/test_failure.py b/python/ray/serve/tests/test_failure.py index 1e26a625d..f9b40f98f 100644 --- a/python/ray/serve/tests/test_failure.py +++ b/python/ray/serve/tests/test_failure.py @@ -1,8 +1,22 @@ -import time +import os import requests +import tempfile +import time -from ray import serve import ray +from ray import serve + + +def request_with_retries(endpoint, timeout=30): + start = time.time() + while True: + try: + return requests.get( + "http://127.0.0.1:8000" + endpoint, timeout=timeout) + except requests.RequestException: + if time.time() - start > timeout: + raise TimeoutError + time.sleep(0.1) def _kill_http_proxy(): @@ -11,47 +25,135 @@ def _kill_http_proxy(): ray.kill(http_proxy) -def request_with_retries(endpoint, verify_response, timeout=30): - start = time.time() - while True: - try: - verify_response(requests.get("http://127.0.0.1:8000" + endpoint)) - break - except requests.RequestException: - if time.time() - start > timeout: - raise TimeoutError - time.sleep(0.1) - - def test_http_proxy_failure(serve_instance): serve.init() - serve.create_endpoint( - "failure_endpoint", "/failure_endpoint", methods=["GET"]) + serve.create_endpoint("proxy_failure", "/proxy_failure", methods=["GET"]) - def function(flask_request): + def function(): return "hello1" - serve.create_backend(function, "failure:v1") - serve.link("failure_endpoint", "failure:v1") + serve.create_backend(function, "proxy_failure:v1") + serve.link("proxy_failure", "proxy_failure:v1") - def verify_response(response): + assert request_with_retries("/proxy_failure", timeout=0.1).text == "hello1" + + for _ in range(10): + response = request_with_retries("/proxy_failure", timeout=30) assert response.text == "hello1" - request_with_retries("/failure_endpoint", verify_response, timeout=0) - _kill_http_proxy() - request_with_retries("/failure_endpoint", verify_response, timeout=30) - - _kill_http_proxy() - - def function(flask_request): + def function(): return "hello2" - serve.create_backend(function, "failure:v2") - serve.link("failure_endpoint", "failure:v2") + serve.create_backend(function, "proxy_failure:v2") + serve.link("proxy_failure", "proxy_failure:v2") - def verify_response(response): + for _ in range(10): + response = request_with_retries("/proxy_failure", timeout=30) assert response.text == "hello2" - request_with_retries("/failure_endpoint", verify_response, timeout=30) + +def _get_worker_handles(backend): + handles = {} + for tag, handle in ray.get(serve.api._get_master_actor() + .get_all_worker_handles.remote()).items(): + if tag.startswith(backend): + handles[tag] = handle + + return handles + + +# Test that a worker dying unexpectedly causes it to restart and continue +# serving requests. +def test_worker_restart(serve_instance): + serve.init() + serve.create_endpoint("worker_failure", "/worker_failure", methods=["GET"]) + + class Worker1: + def __call__(self): + return os.getpid() + + serve.create_backend(Worker1, "worker_failure:v1") + serve.link("worker_failure", "worker_failure:v1") + + # Get the PID of the worker. + old_pid = request_with_retries("/worker_failure", timeout=0.1).text + + # Kill the worker. + handles = _get_worker_handles("worker_failure:v1") + assert len(handles) == 1 + ray.kill(list(handles.values())[0]) + + # Wait until the worker is killed and a one is started. + start = time.time() + while time.time() - start < 30: + response = request_with_retries("/worker_failure", timeout=30) + if response.text != old_pid: + break + else: + assert False, "Timed out waiting for worker to die." + + +# Test that if there are multiple replicas for a worker and one dies +# unexpectedly, the others continue to serve requests. +def test_worker_replica_failure(serve_instance): + serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0 + serve.init() + serve.create_endpoint( + "replica_failure", "/replica_failure", methods=["GET"]) + + class Worker: + # Assumes that two replicas are started. Will hang forever in the + # constructor for any workers that are restarted. + def __init__(self, path): + self.should_hang = False + if not os.path.exists(path): + with open(path, "w") as f: + f.write("1") + else: + with open(path, "r") as f: + num = int(f.read()) + + with open(path, "w") as f: + if num == 2: + self.should_hang = True + else: + f.write(str(num + 1)) + + if self.should_hang: + while True: + pass + + def __call__(self): + pass + + temp_path = tempfile.gettempdir() + "/" + serve.utils.get_random_letters() + serve.create_backend(Worker, "replica_failure", temp_path) + backend_config = serve.get_backend_config("replica_failure") + backend_config.num_replicas = 2 + serve.set_backend_config("replica_failure", backend_config) + serve.link("replica_failure", "replica_failure") + + # Wait until both replicas have been started. + responses = set() + while len(responses) == 1: + responses.add( + request_with_retries("/replica_failure", timeout=0.1).text) + time.sleep(0.1) + + # Kill one of the replicas. + handles = _get_worker_handles("replica_failure") + assert len(handles) == 2 + ray.kill(list(handles.values())[0]) + + # Check that the other replica still serves requests. + for _ in range(10): + while True: + try: + # The timeout needs to be small here because the request to + # the restarting worker will hang. + request_with_retries("/replica_failure", timeout=0.1) + break + except TimeoutError: + time.sleep(0.1)