mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 19:32:38 +08:00
[serve] Make workers fault tolerant (#7970)
This commit is contained in:
@@ -8,9 +8,14 @@ from ray.serve.constants import SERVE_MASTER_NAME
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.http_util import Response
|
||||
from ray.serve.utils import logger
|
||||
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
# The maximum number of times to retry a request due to actor failure.
|
||||
# TODO(edoakes): this should probably be configurable.
|
||||
MAX_ACTOR_DEAD_RETRIES = 10
|
||||
|
||||
|
||||
class HTTPProxy:
|
||||
"""
|
||||
@@ -102,8 +107,9 @@ class HTTPProxy:
|
||||
await Response(self.route_table).send(scope, receive, send)
|
||||
return
|
||||
|
||||
# TODO(simon): Use werkzeug route mapper to support variable path
|
||||
if current_path not in self.route_table:
|
||||
try:
|
||||
endpoint_name, methods_allowed = self.route_table[current_path]
|
||||
except KeyError:
|
||||
error_message = (
|
||||
"Path {} not found. "
|
||||
"Please ping http://.../-/routes for routing table"
|
||||
@@ -111,8 +117,6 @@ class HTTPProxy:
|
||||
await error_sender(error_message, 404)
|
||||
return
|
||||
|
||||
endpoint_name, methods_allowed = self.route_table[current_path]
|
||||
|
||||
if scope["method"] not in methods_allowed:
|
||||
error_message = ("Methods {} not allowed. "
|
||||
"Avaiable HTTP methods are {}.").format(
|
||||
@@ -137,13 +141,24 @@ class HTTPProxy:
|
||||
absolute_slo_ms=absolute_slo_ms,
|
||||
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"))
|
||||
|
||||
try:
|
||||
result = await self.router_handle.enqueue_request.remote(
|
||||
request_metadata, scope, http_body_bytes)
|
||||
await Response(result).send(scope, receive, send)
|
||||
except Exception as e:
|
||||
error_message = "Internal Error. Traceback: {}.".format(e)
|
||||
await error_sender(error_message, 500)
|
||||
retries = 0
|
||||
while retries <= MAX_ACTOR_DEAD_RETRIES:
|
||||
try:
|
||||
result = await self.router_handle.enqueue_request.remote(
|
||||
request_metadata, scope, http_body_bytes)
|
||||
if not isinstance(result, ray.exceptions.RayActorError):
|
||||
await Response(result).send(scope, receive, send)
|
||||
break
|
||||
logger.warning("Got RayActorError:", str(result))
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
error_message = "Internal Error. Traceback: {}.".format(e)
|
||||
await error_sender(error_message, 500)
|
||||
break
|
||||
else:
|
||||
logger.debug("Maximum actor death retries exceeded")
|
||||
await error_sender(
|
||||
"Internal Error. Maximum actor death retries exceeded", 500)
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
||||
@@ -27,7 +27,7 @@ class ServeMaster:
|
||||
self.route_table = RoutingTable(kv_store_connector)
|
||||
self.backend_table = BackendTable(kv_store_connector)
|
||||
self.policy_table = TrafficPolicyTable(kv_store_connector)
|
||||
self.tag_to_actor_handles = dict()
|
||||
self.replica_tag_to_workers = dict()
|
||||
|
||||
self.router = None
|
||||
self.http_proxy = None
|
||||
@@ -131,10 +131,12 @@ class ServeMaster:
|
||||
self.backend_table.get_init_args(backend_tag)
|
||||
]
|
||||
kwargs = backend_config.get_actor_creation_args(init_args)
|
||||
kwargs[
|
||||
"max_reconstructions"] = ray.ray_constants.INFINITE_RECONSTRUCTION
|
||||
|
||||
# Start the worker.
|
||||
worker_handle = backend_actor._remote(**kwargs)
|
||||
self.tag_to_actor_handles[replica_tag] = worker_handle
|
||||
self.replica_tag_to_workers[replica_tag] = worker_handle
|
||||
|
||||
# Wait for the worker to start up.
|
||||
await worker_handle.ready.remote()
|
||||
@@ -152,8 +154,8 @@ class ServeMaster:
|
||||
backend_tag)
|
||||
|
||||
replica_tag = self.backend_table.remove_replica(backend_tag)
|
||||
assert replica_tag in self.tag_to_actor_handles
|
||||
replica_handle = self.tag_to_actor_handles.pop(replica_tag)
|
||||
assert replica_tag in self.replica_tag_to_workers
|
||||
replica_handle = self.replica_tag_to_workers.pop(replica_tag)
|
||||
|
||||
# Remove the replica from metric monitor.
|
||||
[monitor] = self.get_metric_monitor()
|
||||
@@ -166,8 +168,8 @@ class ServeMaster:
|
||||
router.remove_and_destroy_replica.remote(backend_tag,
|
||||
replica_handle))
|
||||
|
||||
def get_all_handles(self):
|
||||
return self.tag_to_actor_handles
|
||||
def get_all_worker_handles(self):
|
||||
return self.replica_tag_to_workers
|
||||
|
||||
def get_all_endpoints(self):
|
||||
return expand(
|
||||
|
||||
@@ -288,6 +288,8 @@ class Router:
|
||||
backend, buffer_queue, worker_queue, max_batch_size)
|
||||
|
||||
async def _do_query(self, backend, worker, req):
|
||||
# If the worker died, this will be a RayActorError. Just return it and
|
||||
# let the HTTP proxy handle the retry logic.
|
||||
result = await worker.handle_request.remote(req)
|
||||
await self.mark_worker_idle(backend, worker)
|
||||
return result
|
||||
|
||||
@@ -193,7 +193,7 @@ def test_killing_replicas(serve_instance):
|
||||
new_replica_tag_list = ray.get(
|
||||
master_actor._list_replicas.remote("simple:v1"))
|
||||
new_all_tag_list = list(
|
||||
ray.get(master_actor.get_all_handles.remote()).keys())
|
||||
ray.get(master_actor.get_all_worker_handles.remote()).keys())
|
||||
|
||||
# the new_replica_tag_list must be subset of all_tag_list
|
||||
assert set(new_replica_tag_list) <= set(new_all_tag_list)
|
||||
@@ -227,7 +227,7 @@ def test_not_killing_replicas(serve_instance):
|
||||
new_replica_tag_list = ray.get(
|
||||
master_actor._list_replicas.remote("bsimple:v1"))
|
||||
new_all_tag_list = list(
|
||||
ray.get(master_actor.get_all_handles.remote()).keys())
|
||||
ray.get(master_actor.get_all_worker_handles.remote()).keys())
|
||||
|
||||
# the old and new replica tag list should be identical
|
||||
# and should be subset of all_tag_list
|
||||
|
||||
@@ -1,8 +1,22 @@
|
||||
import time
|
||||
import os
|
||||
import requests
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from ray import serve
|
||||
import ray
|
||||
from ray import serve
|
||||
|
||||
|
||||
def request_with_retries(endpoint, timeout=30):
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
return requests.get(
|
||||
"http://127.0.0.1:8000" + endpoint, timeout=timeout)
|
||||
except requests.RequestException:
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def _kill_http_proxy():
|
||||
@@ -11,47 +25,135 @@ def _kill_http_proxy():
|
||||
ray.kill(http_proxy)
|
||||
|
||||
|
||||
def request_with_retries(endpoint, verify_response, timeout=30):
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
verify_response(requests.get("http://127.0.0.1:8000" + endpoint))
|
||||
break
|
||||
except requests.RequestException:
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def test_http_proxy_failure(serve_instance):
|
||||
serve.init()
|
||||
serve.create_endpoint(
|
||||
"failure_endpoint", "/failure_endpoint", methods=["GET"])
|
||||
serve.create_endpoint("proxy_failure", "/proxy_failure", methods=["GET"])
|
||||
|
||||
def function(flask_request):
|
||||
def function():
|
||||
return "hello1"
|
||||
|
||||
serve.create_backend(function, "failure:v1")
|
||||
serve.link("failure_endpoint", "failure:v1")
|
||||
serve.create_backend(function, "proxy_failure:v1")
|
||||
serve.link("proxy_failure", "proxy_failure:v1")
|
||||
|
||||
def verify_response(response):
|
||||
assert request_with_retries("/proxy_failure", timeout=0.1).text == "hello1"
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/proxy_failure", timeout=30)
|
||||
assert response.text == "hello1"
|
||||
|
||||
request_with_retries("/failure_endpoint", verify_response, timeout=0)
|
||||
|
||||
_kill_http_proxy()
|
||||
|
||||
request_with_retries("/failure_endpoint", verify_response, timeout=30)
|
||||
|
||||
_kill_http_proxy()
|
||||
|
||||
def function(flask_request):
|
||||
def function():
|
||||
return "hello2"
|
||||
|
||||
serve.create_backend(function, "failure:v2")
|
||||
serve.link("failure_endpoint", "failure:v2")
|
||||
serve.create_backend(function, "proxy_failure:v2")
|
||||
serve.link("proxy_failure", "proxy_failure:v2")
|
||||
|
||||
def verify_response(response):
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/proxy_failure", timeout=30)
|
||||
assert response.text == "hello2"
|
||||
|
||||
request_with_retries("/failure_endpoint", verify_response, timeout=30)
|
||||
|
||||
def _get_worker_handles(backend):
|
||||
handles = {}
|
||||
for tag, handle in ray.get(serve.api._get_master_actor()
|
||||
.get_all_worker_handles.remote()).items():
|
||||
if tag.startswith(backend):
|
||||
handles[tag] = handle
|
||||
|
||||
return handles
|
||||
|
||||
|
||||
# Test that a worker dying unexpectedly causes it to restart and continue
|
||||
# serving requests.
|
||||
def test_worker_restart(serve_instance):
|
||||
serve.init()
|
||||
serve.create_endpoint("worker_failure", "/worker_failure", methods=["GET"])
|
||||
|
||||
class Worker1:
|
||||
def __call__(self):
|
||||
return os.getpid()
|
||||
|
||||
serve.create_backend(Worker1, "worker_failure:v1")
|
||||
serve.link("worker_failure", "worker_failure:v1")
|
||||
|
||||
# Get the PID of the worker.
|
||||
old_pid = request_with_retries("/worker_failure", timeout=0.1).text
|
||||
|
||||
# Kill the worker.
|
||||
handles = _get_worker_handles("worker_failure:v1")
|
||||
assert len(handles) == 1
|
||||
ray.kill(list(handles.values())[0])
|
||||
|
||||
# Wait until the worker is killed and a one is started.
|
||||
start = time.time()
|
||||
while time.time() - start < 30:
|
||||
response = request_with_retries("/worker_failure", timeout=30)
|
||||
if response.text != old_pid:
|
||||
break
|
||||
else:
|
||||
assert False, "Timed out waiting for worker to die."
|
||||
|
||||
|
||||
# Test that if there are multiple replicas for a worker and one dies
|
||||
# unexpectedly, the others continue to serve requests.
|
||||
def test_worker_replica_failure(serve_instance):
|
||||
serve.http_proxy.MAX_ACTOR_DEAD_RETRIES = 0
|
||||
serve.init()
|
||||
serve.create_endpoint(
|
||||
"replica_failure", "/replica_failure", methods=["GET"])
|
||||
|
||||
class Worker:
|
||||
# Assumes that two replicas are started. Will hang forever in the
|
||||
# constructor for any workers that are restarted.
|
||||
def __init__(self, path):
|
||||
self.should_hang = False
|
||||
if not os.path.exists(path):
|
||||
with open(path, "w") as f:
|
||||
f.write("1")
|
||||
else:
|
||||
with open(path, "r") as f:
|
||||
num = int(f.read())
|
||||
|
||||
with open(path, "w") as f:
|
||||
if num == 2:
|
||||
self.should_hang = True
|
||||
else:
|
||||
f.write(str(num + 1))
|
||||
|
||||
if self.should_hang:
|
||||
while True:
|
||||
pass
|
||||
|
||||
def __call__(self):
|
||||
pass
|
||||
|
||||
temp_path = tempfile.gettempdir() + "/" + serve.utils.get_random_letters()
|
||||
serve.create_backend(Worker, "replica_failure", temp_path)
|
||||
backend_config = serve.get_backend_config("replica_failure")
|
||||
backend_config.num_replicas = 2
|
||||
serve.set_backend_config("replica_failure", backend_config)
|
||||
serve.link("replica_failure", "replica_failure")
|
||||
|
||||
# Wait until both replicas have been started.
|
||||
responses = set()
|
||||
while len(responses) == 1:
|
||||
responses.add(
|
||||
request_with_retries("/replica_failure", timeout=0.1).text)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Kill one of the replicas.
|
||||
handles = _get_worker_handles("replica_failure")
|
||||
assert len(handles) == 2
|
||||
ray.kill(list(handles.values())[0])
|
||||
|
||||
# Check that the other replica still serves requests.
|
||||
for _ in range(10):
|
||||
while True:
|
||||
try:
|
||||
# The timeout needs to be small here because the request to
|
||||
# the restarting worker will hang.
|
||||
request_with_retries("/replica_failure", timeout=0.1)
|
||||
break
|
||||
except TimeoutError:
|
||||
time.sleep(0.1)
|
||||
|
||||
Reference in New Issue
Block a user