[serve] Make workers fault tolerant (#7970)

This commit is contained in:
Edward Oakes
2020-04-12 11:48:08 -05:00
committed by GitHub
parent 98bfcd53bc
commit 2cb9cfb2b6
5 changed files with 171 additions and 50 deletions
+26 -11
View File
@@ -8,9 +8,14 @@ from ray.serve.constants import SERVE_MASTER_NAME
from ray.serve.context import TaskContext
from ray.serve.request_params import RequestMetadata
from ray.serve.http_util import Response
from ray.serve.utils import logger
from urllib.parse import parse_qs
# The maximum number of times to retry a request due to actor failure.
# TODO(edoakes): this should probably be configurable.
MAX_ACTOR_DEAD_RETRIES = 10
class HTTPProxy:
"""
@@ -102,8 +107,9 @@ class HTTPProxy:
await Response(self.route_table).send(scope, receive, send)
return
# TODO(simon): Use werkzeug route mapper to support variable path
if current_path not in self.route_table:
try:
endpoint_name, methods_allowed = self.route_table[current_path]
except KeyError:
error_message = (
"Path {} not found. "
"Please ping http://.../-/routes for routing table"
@@ -111,8 +117,6 @@ class HTTPProxy:
await error_sender(error_message, 404)
return
endpoint_name, methods_allowed = self.route_table[current_path]
if scope["method"] not in methods_allowed:
error_message = ("Methods {} not allowed. "
"Avaiable HTTP methods are {}.").format(
@@ -137,13 +141,24 @@ class HTTPProxy:
absolute_slo_ms=absolute_slo_ms,
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"))
try:
result = await self.router_handle.enqueue_request.remote(
request_metadata, scope, http_body_bytes)
await Response(result).send(scope, receive, send)
except Exception as e:
error_message = "Internal Error. Traceback: {}.".format(e)
await error_sender(error_message, 500)
retries = 0
while retries <= MAX_ACTOR_DEAD_RETRIES:
try:
result = await self.router_handle.enqueue_request.remote(
request_metadata, scope, http_body_bytes)
if not isinstance(result, ray.exceptions.RayActorError):
await Response(result).send(scope, receive, send)
break
logger.warning("Got RayActorError:", str(result))
await asyncio.sleep(0.1)
except Exception as e:
error_message = "Internal Error. Traceback: {}.".format(e)
await error_sender(error_message, 500)
break
else:
logger.debug("Maximum actor death retries exceeded")
await error_sender(
"Internal Error. Maximum actor death retries exceeded", 500)
@ray.remote
+8 -6
View File
@@ -27,7 +27,7 @@ class ServeMaster:
self.route_table = RoutingTable(kv_store_connector)
self.backend_table = BackendTable(kv_store_connector)
self.policy_table = TrafficPolicyTable(kv_store_connector)
self.tag_to_actor_handles = dict()
self.replica_tag_to_workers = dict()
self.router = None
self.http_proxy = None
@@ -131,10 +131,12 @@ class ServeMaster:
self.backend_table.get_init_args(backend_tag)
]
kwargs = backend_config.get_actor_creation_args(init_args)
kwargs[
"max_reconstructions"] = ray.ray_constants.INFINITE_RECONSTRUCTION
# Start the worker.
worker_handle = backend_actor._remote(**kwargs)
self.tag_to_actor_handles[replica_tag] = worker_handle
self.replica_tag_to_workers[replica_tag] = worker_handle
# Wait for the worker to start up.
await worker_handle.ready.remote()
@@ -152,8 +154,8 @@ class ServeMaster:
backend_tag)
replica_tag = self.backend_table.remove_replica(backend_tag)
assert replica_tag in self.tag_to_actor_handles
replica_handle = self.tag_to_actor_handles.pop(replica_tag)
assert replica_tag in self.replica_tag_to_workers
replica_handle = self.replica_tag_to_workers.pop(replica_tag)
# Remove the replica from metric monitor.
[monitor] = self.get_metric_monitor()
@@ -166,8 +168,8 @@ class ServeMaster:
router.remove_and_destroy_replica.remote(backend_tag,
replica_handle))
def get_all_handles(self):
return self.tag_to_actor_handles
def get_all_worker_handles(self):
return self.replica_tag_to_workers
def get_all_endpoints(self):
return expand(
+2
View File
@@ -288,6 +288,8 @@ class Router:
backend, buffer_queue, worker_queue, max_batch_size)
async def _do_query(self, backend, worker, req):
# If the worker died, this will be a RayActorError. Just return it and
# let the HTTP proxy handle the retry logic.
result = await worker.handle_request.remote(req)
await self.mark_worker_idle(backend, worker)
return result
+2 -2
View File
@@ -193,7 +193,7 @@ def test_killing_replicas(serve_instance):
new_replica_tag_list = ray.get(
master_actor._list_replicas.remote("simple:v1"))
new_all_tag_list = list(
ray.get(master_actor.get_all_handles.remote()).keys())
ray.get(master_actor.get_all_worker_handles.remote()).keys())
# the new_replica_tag_list must be subset of all_tag_list
assert set(new_replica_tag_list) <= set(new_all_tag_list)
@@ -227,7 +227,7 @@ def test_not_killing_replicas(serve_instance):
new_replica_tag_list = ray.get(
master_actor._list_replicas.remote("bsimple:v1"))
new_all_tag_list = list(
ray.get(master_actor.get_all_handles.remote()).keys())
ray.get(master_actor.get_all_worker_handles.remote()).keys())
# the old and new replica tag list should be identical
# and should be subset of all_tag_list
+133 -31
View File
@@ -1,8 +1,22 @@
import time
import os
import requests
import tempfile
import time
from ray import serve
import ray
from ray import serve
def request_with_retries(endpoint, timeout=30):
start = time.time()
while True:
try:
return requests.get(
"http://127.0.0.1:8000" + endpoint, timeout=timeout)
except requests.RequestException:
if time.time() - start > timeout:
raise TimeoutError
time.sleep(0.1)
def _kill_http_proxy():
@@ -11,47 +25,135 @@ def _kill_http_proxy():
ray.kill(http_proxy)
def request_with_retries(endpoint, verify_response, timeout=30):
start = time.time()
while True:
try:
verify_response(requests.get("http://127.0.0.1:8000" + endpoint))
break
except requests.RequestException:
if time.time() - start > timeout:
raise TimeoutError
time.sleep(0.1)
def test_http_proxy_failure(serve_instance):
serve.init()
serve.create_endpoint(
"failure_endpoint", "/failure_endpoint", methods=["GET"])
serve.create_endpoint("proxy_failure", "/proxy_failure", methods=["GET"])
def function(flask_request):
def function():
return "hello1"
serve.create_backend(function, "failure:v1")
serve.link("failure_endpoint", "failure:v1")
serve.create_backend(function, "proxy_failure:v1")
serve.link("proxy_failure", "proxy_failure:v1")
def verify_response(response):
assert request_with_retries("/proxy_failure", timeout=0.1).text == "hello1"
for _ in range(10):
response = request_with_retries("/proxy_failure", timeout=30)
assert response.text == "hello1"
request_with_retries("/failure_endpoint", verify_response, timeout=0)
_kill_http_proxy()
request_with_retries("/failure_endpoint", verify_response, timeout=30)
_kill_http_proxy()
def function(flask_request):
def function():
return "hello2"
serve.create_backend(function, "failure:v2")
serve.link("failure_endpoint", "failure:v2")
serve.create_backend(function, "proxy_failure:v2")
serve.link("proxy_failure", "proxy_failure:v2")
def verify_response(response):
for _ in range(10):
response = request_with_retries("/proxy_failure", timeout=30)
assert response.text == "hello2"
request_with_retries("/failure_endpoint", verify_response, timeout=30)
def _get_worker_handles(backend):
handles = {}
for tag, handle in ray.get(serve.api._get_master_actor()
.get_all_worker_handles.remote()).items():
if tag.startswith(backend):
handles[tag] = handle
return handles
# Test that a worker dying unexpectedly causes it to restart and continue
# serving requests.
def test_worker_restart(serve_instance):
serve.init()
serve.create_endpoint("worker_failure", "/worker_failure", methods=["GET"])
class Worker1:
def __call__(self):
return os.getpid()
serve.create_backend(Worker1, "worker_failure:v1")
serve.link("worker_failure", "worker_failure:v1")
# Get the PID of the worker.
old_pid = request_with_retries("/worker_failure", timeout=0.1).text
# Kill the worker.
handles = _get_worker_handles("worker_failure:v1")
assert len(handles) == 1
ray.kill(list(handles.values())[0])
# Wait until the worker is killed and a one is started.
start = time.time()
while time.time() - start < 30:
response = request_with_retries("/worker_failure", timeout=30)
if response.text != old_pid:
break
else:
assert False, "Timed out waiting for worker to die."
# Test that if there are multiple replicas for a worker and one dies
# unexpectedly, the others continue to serve requests.
def test_worker_replica_failure(serve_instance):
serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0
serve.init()
serve.create_endpoint(
"replica_failure", "/replica_failure", methods=["GET"])
class Worker:
# Assumes that two replicas are started. Will hang forever in the
# constructor for any workers that are restarted.
def __init__(self, path):
self.should_hang = False
if not os.path.exists(path):
with open(path, "w") as f:
f.write("1")
else:
with open(path, "r") as f:
num = int(f.read())
with open(path, "w") as f:
if num == 2:
self.should_hang = True
else:
f.write(str(num + 1))
if self.should_hang:
while True:
pass
def __call__(self):
pass
temp_path = tempfile.gettempdir() + "/" + serve.utils.get_random_letters()
serve.create_backend(Worker, "replica_failure", temp_path)
backend_config = serve.get_backend_config("replica_failure")
backend_config.num_replicas = 2
serve.set_backend_config("replica_failure", backend_config)
serve.link("replica_failure", "replica_failure")
# Wait until both replicas have been started.
responses = set()
while len(responses) == 1:
responses.add(
request_with_retries("/replica_failure", timeout=0.1).text)
time.sleep(0.1)
# Kill one of the replicas.
handles = _get_worker_handles("replica_failure")
assert len(handles) == 2
ray.kill(list(handles.values())[0])
# Check that the other replica still serves requests.
for _ in range(10):
while True:
try:
# The timeout needs to be small here because the request to
# the restarting worker will hang.
request_with_retries("/replica_failure", timeout=0.1)
break
except TimeoutError:
time.sleep(0.1)