mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:08:32 +08:00
[Serve] Merge router with HTTPProxy (#9225)
This commit is contained in:
@@ -366,7 +366,7 @@ def get_handle(endpoint_name,
|
||||
master_actor.get_all_endpoints.remote())
|
||||
|
||||
return RayServeHandle(
|
||||
ray.get(master_actor.get_router.remote())[0],
|
||||
ray.get(master_actor.get_http_proxy.remote())[0],
|
||||
endpoint_name,
|
||||
relative_slo_ms,
|
||||
absolute_slo_ms,
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
#: Actor name used to register master actor
|
||||
SERVE_MASTER_NAME = "SERVE_MASTER_ACTOR"
|
||||
|
||||
#: Actor name used to register router actor
|
||||
SERVE_ROUTER_NAME = "SERVE_ROUTER_ACTOR"
|
||||
|
||||
#: Actor name used to register HTTP proxy actor
|
||||
SERVE_PROXY_NAME = "SERVE_PROXY_ACTOR"
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import asyncio
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import uvicorn
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray import serve
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.metric import MetricClient
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.http_util import Response
|
||||
from ray.serve.utils import logger
|
||||
|
||||
from urllib.parse import parse_qs
|
||||
from ray.serve.router import Router
|
||||
|
||||
# The maximum number of times to retry a request due to actor failure.
|
||||
# TODO(edoakes): this should probably be configurable.
|
||||
@@ -26,12 +26,11 @@ class HTTPProxy:
|
||||
# blocks forever
|
||||
"""
|
||||
|
||||
async def fetch_config_from_master(self):
|
||||
async def fetch_config_from_master(self, instance_name=None):
|
||||
assert ray.is_initialized()
|
||||
master = serve.api._get_master_actor()
|
||||
|
||||
self.route_table, [self.router_handle
|
||||
] = await master.get_http_proxy_config.remote()
|
||||
self.route_table = await master.get_http_proxy_config.remote()
|
||||
|
||||
# The exporter is required to return results for /-/metrics endpoint.
|
||||
[self.metric_exporter] = await master.get_metric_exporter.remote()
|
||||
@@ -42,6 +41,9 @@ class HTTPProxy:
|
||||
description="The number of requests processed",
|
||||
label_names=("route", ))
|
||||
|
||||
self.router = Router()
|
||||
await self.router.setup(instance_name)
|
||||
|
||||
def set_route_table(self, route_table):
|
||||
self.route_table = route_table
|
||||
|
||||
@@ -155,24 +157,14 @@ class HTTPProxy:
|
||||
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None),
|
||||
)
|
||||
|
||||
retries = 0
|
||||
while retries <= MAX_ACTOR_DEAD_RETRIES:
|
||||
try:
|
||||
result = await self.router_handle.enqueue_request.remote(
|
||||
request_metadata, scope, http_body_bytes)
|
||||
if not isinstance(result, ray.exceptions.RayActorError):
|
||||
await Response(result).send(scope, receive, send)
|
||||
break
|
||||
logger.warning("Got RayActorError: {}".format(str(result)))
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
error_message = "Internal Error. Traceback: {}.".format(e)
|
||||
await error_sender(error_message, 500)
|
||||
break
|
||||
result = await self.router.enqueue_request(request_metadata, scope,
|
||||
http_body_bytes)
|
||||
|
||||
if isinstance(result, RayTaskError):
|
||||
error_message = "Task Error. Traceback: {}.".format(result)
|
||||
await error_sender(error_message, 500)
|
||||
else:
|
||||
logger.debug("Maximum actor death retries exceeded")
|
||||
await error_sender(
|
||||
"Internal Error. Maximum actor death retries exceeded", 500)
|
||||
await Response(result).send(scope, receive, send)
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -180,7 +172,7 @@ class HTTPProxyActor:
|
||||
async def __init__(self, host, port, instance_name=None):
|
||||
serve.init(name=instance_name)
|
||||
self.app = HTTPProxy()
|
||||
await self.app.fetch_config_from_master()
|
||||
await self.app.fetch_config_from_master(instance_name)
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
@@ -206,3 +198,28 @@ class HTTPProxyActor:
|
||||
|
||||
async def set_route_table(self, route_table):
|
||||
self.app.set_route_table(route_table)
|
||||
|
||||
# ------ Proxy router logic ------ #
|
||||
async def add_new_worker(self, backend_tag, replica_tag, worker_handle):
|
||||
return await self.app.router.add_new_worker(backend_tag, replica_tag,
|
||||
worker_handle)
|
||||
|
||||
async def set_traffic(self, endpoint, traffic_policy):
|
||||
return await self.app.router.set_traffic(endpoint, traffic_policy)
|
||||
|
||||
async def set_backend_config(self, backend, config):
|
||||
return await self.app.router.set_backend_config(backend, config)
|
||||
|
||||
async def remove_backend(self, backend):
|
||||
return await self.app.router.remove_backend(backend)
|
||||
|
||||
async def remove_endpoint(self, endpoint):
|
||||
return await self.app.router.remove_endpoint(endpoint)
|
||||
|
||||
async def remove_worker(self, backend_tag, replica_tag):
|
||||
return await self.app.router.remove_worker(backend_tag, replica_tag)
|
||||
|
||||
async def enqueue_request(self, request_meta, *request_args,
|
||||
**request_kwargs):
|
||||
return await self.app.router.enqueue_request(
|
||||
request_meta, *request_args, **request_kwargs)
|
||||
|
||||
@@ -7,12 +7,11 @@ import time
|
||||
import ray
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.serve.backend_worker import create_backend_worker
|
||||
from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_ROUTER_NAME,
|
||||
SERVE_PROXY_NAME, SERVE_METRIC_SINK_NAME)
|
||||
from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_PROXY_NAME,
|
||||
SERVE_METRIC_SINK_NAME)
|
||||
from ray.serve.http_proxy import HTTPProxyActor
|
||||
from ray.serve.kv_store import RayInternalKVStore
|
||||
from ray.serve.metric.exporter import MetricExporterActor
|
||||
from ray.serve.router import Router
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
|
||||
try_schedule_resources_on_nodes)
|
||||
@@ -129,7 +128,6 @@ class ServeMaster:
|
||||
# If starting the actor for the first time, starts up the other system
|
||||
# components. If recovering, fetches their actor handles.
|
||||
self._get_or_start_metric_exporter(metric_exporter_class)
|
||||
self._get_or_start_router()
|
||||
self._get_or_start_http_proxy(http_node_id, http_proxy_host,
|
||||
http_proxy_port)
|
||||
|
||||
@@ -153,27 +151,6 @@ class ServeMaster:
|
||||
asyncio.get_event_loop().create_task(
|
||||
self._recover_from_checkpoint(checkpoint))
|
||||
|
||||
def _get_or_start_router(self):
|
||||
"""Get the router belonging to this serve instance.
|
||||
|
||||
If the router does not already exist, it will be started.
|
||||
"""
|
||||
router_name = format_actor_name(SERVE_ROUTER_NAME, self.instance_name)
|
||||
try:
|
||||
self.router = ray.get_actor(router_name)
|
||||
except ValueError:
|
||||
logger.info("Starting router with name '{}'".format(router_name))
|
||||
self.router = ray.remote(Router).options(
|
||||
name=router_name,
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
).remote(instance_name=self.instance_name)
|
||||
|
||||
def get_router(self):
|
||||
"""Returns a handle to the router managed by this actor."""
|
||||
return [self.router]
|
||||
|
||||
def _get_or_start_http_proxy(self, node_id, host, port):
|
||||
"""Get the HTTP proxy belonging to this serve instance.
|
||||
|
||||
@@ -197,13 +174,19 @@ class ServeMaster:
|
||||
).remote(
|
||||
host, port, instance_name=self.instance_name)
|
||||
|
||||
# Since router is a merged with HTTP proxy actor, the router will be
|
||||
# proxied via the HTTP actor. Even though the two variable names are
|
||||
# pointing to the same object, their semantic differences make the code
|
||||
# more readable. (e.g. http_proxy.set_route_table, router.add_worker)
|
||||
self.router = self.http_proxy
|
||||
|
||||
def get_http_proxy(self):
|
||||
"""Returns a handle to the HTTP proxy managed by this actor."""
|
||||
return [self.http_proxy]
|
||||
|
||||
def get_http_proxy_config(self):
|
||||
"""Called by the HTTP proxy on startup to fetch required state."""
|
||||
return self.routes, self.get_router()
|
||||
return self.routes
|
||||
|
||||
def _get_or_start_metric_exporter(self, metric_exporter_class):
|
||||
"""Get the metric exporter belonging to this serve instance.
|
||||
@@ -766,7 +749,6 @@ class ServeMaster:
|
||||
"""Shuts down the serve instance completely."""
|
||||
async with self.write_lock:
|
||||
ray.kill(self.http_proxy, no_restart=True)
|
||||
ray.kill(self.router, no_restart=True)
|
||||
ray.kill(self.metric_exporter, no_restart=True)
|
||||
for replica_dict in self.workers.values():
|
||||
for replica in replica_dict.values():
|
||||
|
||||
@@ -87,7 +87,7 @@ def _make_future_unwrapper(client_futures: List[asyncio.Future],
|
||||
class Router:
|
||||
"""A router that routes request to available workers."""
|
||||
|
||||
async def __init__(self, instance_name=None):
|
||||
async def setup(self, instance_name=None):
|
||||
# Note: Several queues are used in the router
|
||||
# - When a request come in, it's placed inside its corresponding
|
||||
# endpoint_queue.
|
||||
@@ -198,8 +198,6 @@ class Router:
|
||||
self.endpoint_queues[endpoint].appendleft(query)
|
||||
self.flush_endpoint_queue(endpoint)
|
||||
|
||||
# Note: a future change can be to directly return the ObjectRef from
|
||||
# replica task submission
|
||||
try:
|
||||
result = await query.async_future
|
||||
except RayTaskError as e:
|
||||
|
||||
@@ -570,7 +570,7 @@ def test_shutdown(serve_instance):
|
||||
def check_dead():
|
||||
for actor_name in [
|
||||
constants.SERVE_MASTER_NAME, constants.SERVE_PROXY_NAME,
|
||||
constants.SERVE_ROUTER_NAME, constants.SERVE_METRIC_SINK_NAME
|
||||
constants.SERVE_METRIC_SINK_NAME
|
||||
]:
|
||||
try:
|
||||
ray.get_actor(format_actor_name(actor_name, instance_name))
|
||||
|
||||
@@ -50,6 +50,7 @@ async def test_runner_wraps_error():
|
||||
|
||||
async def test_runner_actor(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
|
||||
def echo(flask_request, i=None):
|
||||
return i
|
||||
@@ -71,6 +72,7 @@ async def test_runner_actor(serve_instance):
|
||||
|
||||
async def test_ray_serve_mixin(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
|
||||
CONSUMER_NAME = "runner-cls"
|
||||
PRODUCER_NAME = "prod-cls"
|
||||
@@ -96,6 +98,7 @@ async def test_ray_serve_mixin(serve_instance):
|
||||
|
||||
async def test_task_runner_check_context(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
|
||||
def echo(flask_request, i=None):
|
||||
# Accessing the flask_request without web context should throw.
|
||||
@@ -117,6 +120,7 @@ async def test_task_runner_check_context(serve_instance):
|
||||
|
||||
async def test_task_runner_custom_method_single(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
|
||||
class NonBatcher:
|
||||
def a(self, _):
|
||||
@@ -151,6 +155,7 @@ async def test_task_runner_custom_method_single(serve_instance):
|
||||
|
||||
async def test_task_runner_custom_method_batch(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
|
||||
@serve.accept_batch
|
||||
class Batcher:
|
||||
@@ -216,6 +221,7 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||
|
||||
async def test_task_runner_perform_batch(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
|
||||
def batcher(*args, **kwargs):
|
||||
return [serve.context.batch_size] * serve.context.batch_size
|
||||
@@ -246,6 +252,7 @@ async def test_task_runner_perform_batch(serve_instance):
|
||||
|
||||
async def test_task_runner_perform_async(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
|
||||
@ray.remote
|
||||
class Barrier:
|
||||
|
||||
@@ -107,44 +107,6 @@ def test_http_proxy_failure(serve_instance):
|
||||
assert response.text == "hello2"
|
||||
|
||||
|
||||
def _kill_router():
|
||||
[router] = ray.get(serve.api._get_master_actor().get_router.remote())
|
||||
ray.kill(router, no_restart=False)
|
||||
|
||||
|
||||
def test_router_failure(serve_instance):
|
||||
serve.init()
|
||||
|
||||
def function():
|
||||
return "hello1"
|
||||
|
||||
serve.create_backend("router_failure:v1", function)
|
||||
serve.create_endpoint(
|
||||
"router_failure", backend="router_failure:v1", route="/router_failure")
|
||||
|
||||
assert request_with_retries("/router_failure", timeout=5).text == "hello1"
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/router_failure", timeout=30)
|
||||
assert response.text == "hello1"
|
||||
|
||||
_kill_router()
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/router_failure", timeout=30)
|
||||
assert response.text == "hello1"
|
||||
|
||||
def function():
|
||||
return "hello2"
|
||||
|
||||
serve.create_backend("router_failure:v2", function)
|
||||
serve.set_traffic("router_failure", {"router_failure:v2": 1.0})
|
||||
|
||||
for _ in range(10):
|
||||
response = request_with_retries("/router_failure", timeout=30)
|
||||
assert response.text == "hello2"
|
||||
|
||||
|
||||
def _get_worker_handles(backend):
|
||||
master_actor = serve.api._get_master_actor()
|
||||
backend_dict = ray.get(master_actor.get_all_worker_handles.remote())
|
||||
|
||||
@@ -48,6 +48,8 @@ def task_runner_mock_actor():
|
||||
|
||||
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
|
||||
q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0}))
|
||||
q.add_new_worker.remote("backend-single-prod", "replica-1",
|
||||
task_runner_mock_actor)
|
||||
@@ -64,6 +66,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
|
||||
async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({"backend-slo": 1.0}))
|
||||
|
||||
all_request_sent = []
|
||||
@@ -88,6 +91,7 @@ async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
|
||||
async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1}))
|
||||
await q.add_new_worker.remote("backend-alter", "replica-1",
|
||||
@@ -106,6 +110,7 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
|
||||
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
|
||||
await q.set_traffic.remote(
|
||||
"svc", TrafficPolicy({
|
||||
@@ -135,6 +140,7 @@ async def test_queue_remove_replicas(serve_instance):
|
||||
|
||||
temp_actor = mock_task_runner()
|
||||
q = ray.remote(TestRouter).remote()
|
||||
await q.setup.remote()
|
||||
await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor)
|
||||
await q.remove_worker.remote("backend-remove", "replica-1")
|
||||
assert ray.get(q.worker_queue_size.remote("backend")) == 0
|
||||
@@ -142,6 +148,7 @@ async def test_queue_remove_replicas(serve_instance):
|
||||
|
||||
async def test_shard_key(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
|
||||
num_backends = 5
|
||||
traffic_dict = {}
|
||||
@@ -196,6 +203,7 @@ async def test_router_use_max_concurrency(serve_instance):
|
||||
|
||||
worker = MockWorker.remote()
|
||||
q = ray.remote(VisibleRouter).remote()
|
||||
await q.setup.remote()
|
||||
BACKEND_NAME = "max-concurrent-test"
|
||||
config = BackendConfig({"max_concurrent_queries": 1})
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({BACKEND_NAME: 1.0}))
|
||||
|
||||
Reference in New Issue
Block a user