[Serve] Merge router with HTTPProxy (#9225)

This commit is contained in:
Simon Mo
2020-07-10 13:52:48 -07:00
committed by GitHub
parent 1798deae94
commit d4a5d09dab
9 changed files with 68 additions and 97 deletions
+1 -1
View File
@@ -366,7 +366,7 @@ def get_handle(endpoint_name,
master_actor.get_all_endpoints.remote())
return RayServeHandle(
ray.get(master_actor.get_router.remote())[0],
ray.get(master_actor.get_http_proxy.remote())[0],
endpoint_name,
relative_slo_ms,
absolute_slo_ms,
-3
View File
@@ -1,9 +1,6 @@
#: Actor name used to register master actor
SERVE_MASTER_NAME = "SERVE_MASTER_ACTOR"
#: Actor name used to register router actor
SERVE_ROUTER_NAME = "SERVE_ROUTER_ACTOR"
#: Actor name used to register HTTP proxy actor
SERVE_PROXY_NAME = "SERVE_PROXY_ACTOR"
+41 -24
View File
@@ -1,16 +1,16 @@
import asyncio
from urllib.parse import parse_qs
import uvicorn
import ray
from ray.exceptions import RayTaskError
from ray import serve
from ray.serve.context import TaskContext
from ray.serve.metric import MetricClient
from ray.serve.request_params import RequestMetadata
from ray.serve.http_util import Response
from ray.serve.utils import logger
from urllib.parse import parse_qs
from ray.serve.router import Router
# The maximum number of times to retry a request due to actor failure.
# TODO(edoakes): this should probably be configurable.
@@ -26,12 +26,11 @@ class HTTPProxy:
# blocks forever
"""
async def fetch_config_from_master(self):
async def fetch_config_from_master(self, instance_name=None):
assert ray.is_initialized()
master = serve.api._get_master_actor()
self.route_table, [self.router_handle
] = await master.get_http_proxy_config.remote()
self.route_table = await master.get_http_proxy_config.remote()
# The exporter is required to return results for /-/metrics endpoint.
[self.metric_exporter] = await master.get_metric_exporter.remote()
@@ -42,6 +41,9 @@ class HTTPProxy:
description="The number of requests processed",
label_names=("route", ))
self.router = Router()
await self.router.setup(instance_name)
def set_route_table(self, route_table):
self.route_table = route_table
@@ -155,24 +157,14 @@ class HTTPProxy:
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None),
)
retries = 0
while retries <= MAX_ACTOR_DEAD_RETRIES:
try:
result = await self.router_handle.enqueue_request.remote(
request_metadata, scope, http_body_bytes)
if not isinstance(result, ray.exceptions.RayActorError):
await Response(result).send(scope, receive, send)
break
logger.warning("Got RayActorError: {}".format(str(result)))
await asyncio.sleep(0.1)
except Exception as e:
error_message = "Internal Error. Traceback: {}.".format(e)
await error_sender(error_message, 500)
break
result = await self.router.enqueue_request(request_metadata, scope,
http_body_bytes)
if isinstance(result, RayTaskError):
error_message = "Task Error. Traceback: {}.".format(result)
await error_sender(error_message, 500)
else:
logger.debug("Maximum actor death retries exceeded")
await error_sender(
"Internal Error. Maximum actor death retries exceeded", 500)
await Response(result).send(scope, receive, send)
@ray.remote
@@ -180,7 +172,7 @@ class HTTPProxyActor:
async def __init__(self, host, port, instance_name=None):
serve.init(name=instance_name)
self.app = HTTPProxy()
await self.app.fetch_config_from_master()
await self.app.fetch_config_from_master(instance_name)
self.host = host
self.port = port
@@ -206,3 +198,28 @@ class HTTPProxyActor:
async def set_route_table(self, route_table):
self.app.set_route_table(route_table)
# ------ Proxy router logic ------ #
async def add_new_worker(self, backend_tag, replica_tag, worker_handle):
return await self.app.router.add_new_worker(backend_tag, replica_tag,
worker_handle)
async def set_traffic(self, endpoint, traffic_policy):
return await self.app.router.set_traffic(endpoint, traffic_policy)
async def set_backend_config(self, backend, config):
return await self.app.router.set_backend_config(backend, config)
async def remove_backend(self, backend):
return await self.app.router.remove_backend(backend)
async def remove_endpoint(self, endpoint):
return await self.app.router.remove_endpoint(endpoint)
async def remove_worker(self, backend_tag, replica_tag):
return await self.app.router.remove_worker(backend_tag, replica_tag)
async def enqueue_request(self, request_meta, *request_args,
**request_kwargs):
return await self.app.router.enqueue_request(
request_meta, *request_args, **request_kwargs)
+9 -27
View File
@@ -7,12 +7,11 @@ import time
import ray
import ray.cloudpickle as pickle
from ray.serve.backend_worker import create_backend_worker
from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_ROUTER_NAME,
SERVE_PROXY_NAME, SERVE_METRIC_SINK_NAME)
from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_PROXY_NAME,
SERVE_METRIC_SINK_NAME)
from ray.serve.http_proxy import HTTPProxyActor
from ray.serve.kv_store import RayInternalKVStore
from ray.serve.metric.exporter import MetricExporterActor
from ray.serve.router import Router
from ray.serve.exceptions import RayServeException
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
try_schedule_resources_on_nodes)
@@ -129,7 +128,6 @@ class ServeMaster:
# If starting the actor for the first time, starts up the other system
# components. If recovering, fetches their actor handles.
self._get_or_start_metric_exporter(metric_exporter_class)
self._get_or_start_router()
self._get_or_start_http_proxy(http_node_id, http_proxy_host,
http_proxy_port)
@@ -153,27 +151,6 @@ class ServeMaster:
asyncio.get_event_loop().create_task(
self._recover_from_checkpoint(checkpoint))
def _get_or_start_router(self):
"""Get the router belonging to this serve instance.
If the router does not already exist, it will be started.
"""
router_name = format_actor_name(SERVE_ROUTER_NAME, self.instance_name)
try:
self.router = ray.get_actor(router_name)
except ValueError:
logger.info("Starting router with name '{}'".format(router_name))
self.router = ray.remote(Router).options(
name=router_name,
max_concurrency=ASYNC_CONCURRENCY,
max_restarts=-1,
max_task_retries=-1,
).remote(instance_name=self.instance_name)
def get_router(self):
"""Returns a handle to the router managed by this actor."""
return [self.router]
def _get_or_start_http_proxy(self, node_id, host, port):
"""Get the HTTP proxy belonging to this serve instance.
@@ -197,13 +174,19 @@ class ServeMaster:
).remote(
host, port, instance_name=self.instance_name)
# Since router is a merged with HTTP proxy actor, the router will be
# proxied via the HTTP actor. Even though the two variable names are
# pointing to the same object, their semantic differences make the code
# more readable. (e.g. http_proxy.set_route_table, router.add_worker)
self.router = self.http_proxy
def get_http_proxy(self):
"""Returns a handle to the HTTP proxy managed by this actor."""
return [self.http_proxy]
def get_http_proxy_config(self):
"""Called by the HTTP proxy on startup to fetch required state."""
return self.routes, self.get_router()
return self.routes
def _get_or_start_metric_exporter(self, metric_exporter_class):
"""Get the metric exporter belonging to this serve instance.
@@ -766,7 +749,6 @@ class ServeMaster:
"""Shuts down the serve instance completely."""
async with self.write_lock:
ray.kill(self.http_proxy, no_restart=True)
ray.kill(self.router, no_restart=True)
ray.kill(self.metric_exporter, no_restart=True)
for replica_dict in self.workers.values():
for replica in replica_dict.values():
+1 -3
View File
@@ -87,7 +87,7 @@ def _make_future_unwrapper(client_futures: List[asyncio.Future],
class Router:
"""A router that routes request to available workers."""
async def __init__(self, instance_name=None):
async def setup(self, instance_name=None):
# Note: Several queues are used in the router
# - When a request come in, it's placed inside its corresponding
# endpoint_queue.
@@ -198,8 +198,6 @@ class Router:
self.endpoint_queues[endpoint].appendleft(query)
self.flush_endpoint_queue(endpoint)
# Note: a future change can be to directly return the ObjectRef from
# replica task submission
try:
result = await query.async_future
except RayTaskError as e:
+1 -1
View File
@@ -570,7 +570,7 @@ def test_shutdown(serve_instance):
def check_dead():
for actor_name in [
constants.SERVE_MASTER_NAME, constants.SERVE_PROXY_NAME,
constants.SERVE_ROUTER_NAME, constants.SERVE_METRIC_SINK_NAME
constants.SERVE_METRIC_SINK_NAME
]:
try:
ray.get_actor(format_actor_name(actor_name, instance_name))
@@ -50,6 +50,7 @@ async def test_runner_wraps_error():
async def test_runner_actor(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
def echo(flask_request, i=None):
return i
@@ -71,6 +72,7 @@ async def test_runner_actor(serve_instance):
async def test_ray_serve_mixin(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
CONSUMER_NAME = "runner-cls"
PRODUCER_NAME = "prod-cls"
@@ -96,6 +98,7 @@ async def test_ray_serve_mixin(serve_instance):
async def test_task_runner_check_context(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
def echo(flask_request, i=None):
# Accessing the flask_request without web context should throw.
@@ -117,6 +120,7 @@ async def test_task_runner_check_context(serve_instance):
async def test_task_runner_custom_method_single(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
class NonBatcher:
def a(self, _):
@@ -151,6 +155,7 @@ async def test_task_runner_custom_method_single(serve_instance):
async def test_task_runner_custom_method_batch(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
@serve.accept_batch
class Batcher:
@@ -216,6 +221,7 @@ async def test_task_runner_custom_method_batch(serve_instance):
async def test_task_runner_perform_batch(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
def batcher(*args, **kwargs):
return [serve.context.batch_size] * serve.context.batch_size
@@ -246,6 +252,7 @@ async def test_task_runner_perform_batch(serve_instance):
async def test_task_runner_perform_async(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
@ray.remote
class Barrier:
-38
View File
@@ -107,44 +107,6 @@ def test_http_proxy_failure(serve_instance):
assert response.text == "hello2"
def _kill_router():
[router] = ray.get(serve.api._get_master_actor().get_router.remote())
ray.kill(router, no_restart=False)
def test_router_failure(serve_instance):
serve.init()
def function():
return "hello1"
serve.create_backend("router_failure:v1", function)
serve.create_endpoint(
"router_failure", backend="router_failure:v1", route="/router_failure")
assert request_with_retries("/router_failure", timeout=5).text == "hello1"
for _ in range(10):
response = request_with_retries("/router_failure", timeout=30)
assert response.text == "hello1"
_kill_router()
for _ in range(10):
response = request_with_retries("/router_failure", timeout=30)
assert response.text == "hello1"
def function():
return "hello2"
serve.create_backend("router_failure:v2", function)
serve.set_traffic("router_failure", {"router_failure:v2": 1.0})
for _ in range(10):
response = request_with_retries("/router_failure", timeout=30)
assert response.text == "hello2"
def _get_worker_handles(backend):
master_actor = serve.api._get_master_actor()
backend_dict = ray.get(master_actor.get_all_worker_handles.remote())
+8
View File
@@ -48,6 +48,8 @@ def task_runner_mock_actor():
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote()
q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0}))
q.add_new_worker.remote("backend-single-prod", "replica-1",
task_runner_mock_actor)
@@ -64,6 +66,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
async def test_slo(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.set_traffic.remote("svc", TrafficPolicy({"backend-slo": 1.0}))
all_request_sent = []
@@ -88,6 +91,7 @@ async def test_slo(serve_instance, task_runner_mock_actor):
async def test_alter_backend(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1}))
await q.add_new_worker.remote("backend-alter", "replica-1",
@@ -106,6 +110,7 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.set_traffic.remote(
"svc", TrafficPolicy({
@@ -135,6 +140,7 @@ async def test_queue_remove_replicas(serve_instance):
temp_actor = mock_task_runner()
q = ray.remote(TestRouter).remote()
await q.setup.remote()
await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor)
await q.remove_worker.remote("backend-remove", "replica-1")
assert ray.get(q.worker_queue_size.remote("backend")) == 0
@@ -142,6 +148,7 @@ async def test_queue_remove_replicas(serve_instance):
async def test_shard_key(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote()
num_backends = 5
traffic_dict = {}
@@ -196,6 +203,7 @@ async def test_router_use_max_concurrency(serve_instance):
worker = MockWorker.remote()
q = ray.remote(VisibleRouter).remote()
await q.setup.remote()
BACKEND_NAME = "max-concurrent-test"
config = BackendConfig({"max_concurrent_queries": 1})
await q.set_traffic.remote("svc", TrafficPolicy({BACKEND_NAME: 1.0}))