[serve] Long polling for routes in http server (#12724)

This commit is contained in:
Edward Oakes
2020-12-10 18:02:02 -06:00
committed by GitHub
parent 006856b9a1
commit 3c44c0d3e4
14 changed files with 232 additions and 213 deletions
+30 -1
View File
@@ -43,6 +43,8 @@ class Client:
self._controller_name = controller_name
self._detached = detached
self._shutdown = False
self._http_host, self._http_port = ray.get(
controller.get_http_config.remote())
# NOTE(simon): Used to cache client.get_handle(endpoint) call. It will
# mostly grow in size, it will only shrink when user calls the
@@ -149,6 +151,29 @@ class Client:
self._controller.create_endpoint.remote(
endpoint_name, {backend: 1.0}, route, upper_methods))
# Block until the route table has been propagated to all HTTP proxies.
if route is not None:
def check_ready(http_response):
return route in http_response.json()
futures = []
for node_id in ray.state.node_ids():
future = block_until_http_ready.options(
num_cpus=0, resources={
node_id: 0.01
}).remote(
"http://{}:{}/-/routes".format(self._http_host,
self._http_port),
check_ready=check_ready,
timeout=HTTP_PROXY_TIMEOUT)
futures.append(future)
try:
ray.get(futures)
except ray.exceptions.RayTaskError:
raise TimeoutError("Route not available at HTTP proxies "
"after {HTTP_PROXY_TIMEOUT}s.")
@_ensure_connected
def delete_endpoint(self, endpoint: str) -> None:
"""Delete the given endpoint.
@@ -453,7 +478,11 @@ def start(detached: bool = False,
"http://{}:{}/-/routes".format(http_host, http_port),
timeout=HTTP_PROXY_TIMEOUT)
futures.append(future)
ray.get(futures)
try:
ray.get(futures)
except ray.exceptions.RayTaskError:
raise TimeoutError(
"HTTP proxies not available after {HTTP_PROXY_TIMEOUT}s.")
return Client(controller, controller_name, detached=detached)
+8 -5
View File
@@ -15,10 +15,13 @@ from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
from ray.serve.exceptions import RayServeException
from ray.util import metrics
from ray.serve.config import BackendConfig
from ray.serve.long_poll import LongPollerAsyncClient
from ray.serve.long_poll import LongPollAsyncClient
from ray.serve.router import Query
from ray.serve.constants import (DEFAULT_LATENCY_BUCKET_MS,
BACKEND_RECONFIGURE_METHOD)
from ray.serve.constants import (
BACKEND_RECONFIGURE_METHOD,
DEFAULT_LATENCY_BUCKET_MS,
LongPollKey,
)
from ray.exceptions import RayTaskError
logger = _get_logger()
@@ -168,8 +171,8 @@ class RayServeReplica:
tag_keys=("backend", ))
self.request_counter.set_default_tags({"backend": self.backend_tag})
self.long_poll_client = LongPollerAsyncClient(controller_handle, {
"backend_configs": self._update_backend_configs,
self.long_poll_client = LongPollAsyncClient(controller_handle, {
LongPollKey.BACKEND_CONFIGS: self._update_backend_configs,
})
self.error_counter = metrics.Count(
+12
View File
@@ -1,3 +1,5 @@
from enum import auto, Enum
#: Actor name used to register controller
SERVE_CONTROLLER_NAME = "SERVE_CONTROLLER_ACTOR"
@@ -37,3 +39,13 @@ DEFAULT_LATENCY_BUCKET_MS = [
#: Name of backend reconfiguration method implemented by user.
BACKEND_RECONFIGURE_METHOD = "reconfigure"
class LongPollKey(Enum):
def __repr__(self):
return f"{self.__class__.__name__}.{self.name}"
REPLICA_HANDLES = auto()
TRAFFIC_POLICIES = auto()
BACKEND_CONFIGS = auto()
ROUTE_TABLE = auto()
+70 -83
View File
@@ -13,14 +13,15 @@ import ray
import ray.cloudpickle as pickle
from ray.serve.autoscaling_policy import BasicAutoscalingPolicy
from ray.serve.backend_worker import create_backend_replica
from ray.serve.constants import ASYNC_CONCURRENCY, SERVE_PROXY_NAME
from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_PROXY_NAME,
LongPollKey)
from ray.serve.http_proxy import HTTPProxyActor
from ray.serve.kv_store import RayInternalKVStore
from ray.serve.exceptions import RayServeException
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
try_schedule_resources_on_nodes, get_all_node_ids)
from ray.serve.config import BackendConfig, ReplicaConfig
from ray.serve.long_poll import LongPollerHost
from ray.serve.long_poll import LongPollHost
from ray.actor import ActorHandle
import numpy as np
@@ -145,7 +146,7 @@ class ActorStateReconciler:
controller_name: str = field(init=True)
detached: bool = field(init=True)
routers_cache: Dict[NodeId, ActorHandle] = field(default_factory=dict)
http_proxy_cache: Dict[NodeId, ActorHandle] = field(default_factory=dict)
backend_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]] = field(
default_factory=lambda: defaultdict(dict))
backend_replicas_to_start: Dict[BackendTag, List[ReplicaTag]] = field(
@@ -157,8 +158,8 @@ class ActorStateReconciler:
# TODO(edoakes): consider removing this and just using the names.
def router_handles(self) -> List[ActorHandle]:
return list(self.routers_cache.values())
def http_proxy_handles(self) -> List[ActorHandle]:
return list(self.http_proxy_cache.values())
def get_replica_handles(self) -> List[ActorHandle]:
return list(
@@ -303,7 +304,7 @@ class ActorStateReconciler:
async def _stop_pending_backend_replicas(self) -> None:
"""Stops the pending backend replicas in self.backend_replicas_to_stop.
Removes backend_replicas from the router, kills them, and clears
Removes backend_replicas from the http_proxy, kills them, and clears
self.backend_replicas_to_stop.
"""
for backend_tag, replicas_list in self.backend_replicas_to_stop.items(
@@ -327,26 +328,26 @@ class ActorStateReconciler:
self.backend_replicas_to_stop.clear()
def _start_routers_if_needed(self, http_host: str, http_port: str,
http_middlewares: List[Any]) -> None:
"""Start a router on every node if it doesn't already exist."""
def _start_http_proxies_if_needed(self, http_host: str, http_port: str,
http_middlewares: List[Any]) -> None:
"""Start an HTTP proxy on every node if it doesn't already exist."""
if http_host is None:
return
for node_id, node_resource in get_all_node_ids():
if node_id in self.routers_cache:
if node_id in self.http_proxy_cache:
continue
router_name = format_actor_name(SERVE_PROXY_NAME,
self.controller_name, node_id)
name = format_actor_name(SERVE_PROXY_NAME, self.controller_name,
node_id)
try:
router = ray.get_actor(router_name)
proxy = ray.get_actor(name)
except ValueError:
logger.info("Starting router with name '{}' on node '{}' "
logger.info("Starting HTTP proxy with name '{}' on node '{}' "
"listening on '{}:{}'".format(
router_name, node_id, http_host, http_port))
router = HTTPProxyActor.options(
name=router_name,
name, node_id, http_host, http_port))
proxy = HTTPProxyActor.options(
name=name,
lifetime="detached" if self.detached else None,
max_concurrency=ASYNC_CONCURRENCY,
max_restarts=-1,
@@ -360,10 +361,10 @@ class ActorStateReconciler:
controller_name=self.controller_name,
http_middlewares=http_middlewares)
self.routers_cache[node_id] = router
self.http_proxy_cache[node_id] = proxy
def _stop_routers_if_needed(self) -> bool:
"""Removes router actors from any nodes that no longer exist.
def _stop_http_proxies_if_needed(self) -> bool:
"""Removes HTTP proxy actors from any nodes that no longer exist.
Returns whether or not any actors were removed (a checkpoint should
be taken).
@@ -371,25 +372,25 @@ class ActorStateReconciler:
actor_stopped = False
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
to_stop = []
for node_id in self.routers_cache:
for node_id in self.http_proxy_cache:
if node_id not in all_node_ids:
logger.info(
"Removing router on removed node '{}'.".format(node_id))
logger.info("Removing HTTP proxy on removed node '{}'.".format(
node_id))
to_stop.append(node_id)
for node_id in to_stop:
router_handle = self.routers_cache.pop(node_id)
ray.kill(router_handle, no_restart=True)
proxy = self.http_proxy_cache.pop(node_id)
ray.kill(proxy, no_restart=True)
actor_stopped = True
return actor_stopped
def _recover_actor_handles(self) -> None:
# Refresh the RouterCache
for node_id in self.routers_cache.keys():
router_name = format_actor_name(SERVE_PROXY_NAME,
self.controller_name, node_id)
self.routers_cache[node_id] = ray.get_actor(router_name)
for node_id in self.http_proxy_cache.keys():
name = format_actor_name(SERVE_PROXY_NAME, self.controller_name,
node_id)
self.http_proxy_cache[node_id] = ray.get_actor(name)
# Fetch actor handles for all of the backend replicas in the system.
# All of these backend_replicas are guaranteed to already exist because
@@ -482,7 +483,7 @@ class ServeController:
# backend -> AutoscalingPolicy
self.autoscaling_policies = dict()
# Dictionary of backend_tag -> router_name -> most recent queue length.
# Dictionary of backend_tag -> proxy_name -> most recent queue length.
self.backend_stats = defaultdict(lambda: defaultdict(dict))
# Used to ensure that only a single state-changing operation happens
@@ -495,7 +496,7 @@ class ServeController:
# If starting the actor for the first time, starts up the other system
# components. If recovering, fetches their actor handles.
self.actor_reconciler._start_routers_if_needed(
self.actor_reconciler._start_http_proxies_if_needed(
self.http_host, self.http_port, self.http_middlewares)
# Map of awaiting results
@@ -528,10 +529,11 @@ class ServeController:
# can be problem at scale, e.g. updating a single backend config
# will send over the entire configs. In the future, we should
# optimize the logic to support subscription by key.
self.long_poll_host = LongPollerHost()
self.long_poll_host = LongPollHost()
self.notify_backend_configs_changed()
self.notify_replica_handles_changed()
self.notify_traffic_policies_changed()
self.notify_route_table_changed()
asyncio.get_event_loop().create_task(self.run_control_loop())
@@ -565,19 +567,26 @@ class ServeController:
def notify_replica_handles_changed(self):
self.long_poll_host.notify_changed(
"worker_handles", {
LongPollKey.REPLICA_HANDLES, {
backend_tag: list(replica_dict.values())
for backend_tag, replica_dict in
self.actor_reconciler.backend_replicas.items()
})
def notify_traffic_policies_changed(self):
self.long_poll_host.notify_changed("traffic_policies",
self.current_state.traffic_policies)
self.long_poll_host.notify_changed(
LongPollKey.TRAFFIC_POLICIES,
self.current_state.traffic_policies,
)
def notify_backend_configs_changed(self):
self.long_poll_host.notify_changed(
"backend_configs", self.current_state.get_backend_configs())
LongPollKey.BACKEND_CONFIGS,
self.current_state.get_backend_configs())
def notify_route_table_changed(self):
self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE,
self.current_state.routes)
async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
"""Proxy long pull client's listen request.
@@ -590,13 +599,9 @@ class ServeController:
return await (
self.long_poll_host.listen_for_change(keys_to_snapshot_ids))
def get_routers(self) -> Dict[str, ActorHandle]:
"""Returns a dictionary of node ID to router actor handles."""
return self.actor_reconciler.routers_cache
def get_router_config(self) -> Dict[str, Tuple[str, List[str]]]:
"""Called by the router on startup to fetch required state."""
return self.current_state.routes
def get_http_proxies(self) -> Dict[str, ActorHandle]:
"""Returns a dictionary of node ID to http_proxy actor handles."""
return self.actor_reconciler.http_proxy_cache
def _checkpoint(self) -> None:
"""Checkpoint internal state and write it to the KV store."""
@@ -622,7 +627,7 @@ class ServeController:
Performs the following operations:
1) Deserializes the internal state from the checkpoint.
2) Pushes the latest configuration to the routers
2) Pushes the latest configuration to the HTTP proxies
in case we crashed before updating them.
3) Starts/stops any replicas that are pending creation or
deletion.
@@ -670,40 +675,26 @@ class ServeController:
while True:
await self.do_autoscale()
async with self.write_lock:
self.actor_reconciler._start_routers_if_needed(
self.actor_reconciler._start_http_proxies_if_needed(
self.http_host, self.http_port, self.http_middlewares)
checkpoint_required = self.actor_reconciler.\
_stop_routers_if_needed()
_stop_http_proxies_if_needed()
if checkpoint_required:
self._checkpoint()
await asyncio.sleep(CONTROL_LOOP_PERIOD_S)
def get_backend_configs(self) -> Dict[str, BackendConfig]:
"""Fetched by the router on startup."""
return self.current_state.get_backend_configs()
def get_traffic_policies(self) -> Dict[str, TrafficPolicy]:
"""Fetched by the router on startup."""
return self.current_state.traffic_policies
def _list_replicas(self, backend_tag: BackendTag) -> List[ReplicaTag]:
"""Used only for testing."""
return list(self.actor_reconciler.backend_replicas[backend_tag].keys())
def get_traffic_policy(self, endpoint: str) -> TrafficPolicy:
"""Fetched by serve handles."""
return self.current_state.traffic_policies[endpoint]
def get_all_replica_handles(self) -> Dict[str, Dict[str, ActorHandle]]:
"""Fetched by the router on startup."""
def _all_replica_handles(
self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]:
"""Used for testing."""
return self.actor_reconciler.backend_replicas
def get_all_backends(self) -> Dict[str, BackendConfig]:
def get_all_backends(self) -> Dict[BackendTag, BackendConfig]:
"""Returns a dictionary of backend tag to backend config."""
return self.current_state.get_backend_configs()
def get_all_endpoints(self) -> Dict[str, Dict[str, Any]]:
def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]:
"""Returns a dictionary of backend tag to backend config."""
return self.current_state.get_endpoints()
async def _set_traffic(self, endpoint_name: str,
@@ -731,7 +722,6 @@ class ServeController:
# update to avoid inconsistent state if we crash after pushing the
# update.
self._checkpoint()
self.notify_traffic_policies_changed()
return return_uuid
@@ -814,10 +804,7 @@ class ServeController:
# NOTE(edoakes): checkpoint is written in self._set_traffic.
return_uuid = await self._set_traffic(endpoint, traffic_dict)
await asyncio.gather(*[
router.set_route_table.remote(self.current_state.routes)
for router in self.actor_reconciler.router_handles()
])
self.notify_route_table_changed()
return return_uuid
async def delete_endpoint(self, endpoint: str) -> UUID:
@@ -852,14 +839,10 @@ class ServeController:
endpoint: None
})
# NOTE(edoakes): we must write a checkpoint before pushing the
# updates to the routers to avoid inconsistent state if we crash
# updates to the proxies to avoid inconsistent state if we crash
# after pushing the update.
self._checkpoint()
await asyncio.gather(*[
router.set_route_table.remote(self.current_state.routes)
for router in self.actor_reconciler.router_handles()
])
self.notify_route_table_changed()
return return_uuid
async def create_backend(self, backend_tag: BackendTag,
@@ -910,7 +893,7 @@ class ServeController:
self.notify_replica_handles_changed()
# Set the backend config inside the router
# Set the backend config inside routers
# (particularly for max_concurrent_queries).
self.notify_backend_configs_changed()
return return_uuid
@@ -943,12 +926,12 @@ class ServeController:
if backend_tag in self.autoscaling_policies:
del self.autoscaling_policies[backend_tag]
# Add the intention to remove the backend from the router.
# Add the intention to remove the backend from the routers.
self.actor_reconciler.backends_to_remove.append(backend_tag)
return_uuid = self._create_event_with_result({backend_tag: None})
# NOTE(edoakes): we must write a checkpoint before removing the
# backend from the router to avoid inconsistent state if we crash
# backend from the routers to avoid inconsistent state if we crash
# after pushing the update.
self._checkpoint()
await self.actor_reconciler._stop_pending_backend_replicas()
@@ -986,7 +969,7 @@ class ServeController:
# update.
self._checkpoint()
# Inform the router about change in configuration
# Inform the routers about change in configuration
# (particularly for setting max_batch_size).
await self.actor_reconciler._start_pending_backend_replicas(
@@ -1003,11 +986,15 @@ class ServeController:
), "Backend {} is not registered.".format(backend_tag)
return self.current_state.get_backend(backend_tag).backend_config
def get_http_config(self):
"""Return the HTTP proxy configuration."""
return self.http_host, self.http_port
async def shutdown(self) -> None:
"""Shuts down the serve instance completely."""
async with self.write_lock:
for router in self.actor_reconciler.router_handles():
ray.kill(router, no_restart=True)
for http_proxy in self.actor_reconciler.http_proxy_handles():
ray.kill(http_proxy, no_restart=True)
for replica in self.actor_reconciler.get_replica_handles():
ray.kill(replica, no_restart=True)
self.kv_store.delete(CHECKPOINT_KEY)
+2 -1
View File
@@ -89,5 +89,6 @@ class RandomEndpointPolicy(EndpointPolicy):
query.metadata.shard_key.encode("utf-8"))
chosen_backend, shadow_backends = self._select_backends(value)
logger.debug(f"Chosen backend {chosen_backend} for query {query}")
logger.debug(f"Assigning query {query.metadata.request_id} "
f"to backend {chosen_backend}.")
return [chosen_backend] + shadow_backends
+19 -21
View File
@@ -6,43 +6,40 @@ import uvicorn
import ray
from ray.exceptions import RayTaskError
from ray.serve.constants import LongPollKey
from ray.serve.context import TaskContext
from ray.util import metrics
from ray.serve.utils import _get_logger, get_random_letters
from ray.serve.http_util import Response
from ray.serve.long_poll import LongPollAsyncClient
from ray.serve.router import Router, RequestMetadata
# The maximum number of times to retry a request due to actor failure.
# TODO(edoakes): this should probably be configurable.
MAX_ACTOR_DEAD_RETRIES = 10
logger = _get_logger()
class HTTPProxy:
"""
This class should be instantiated and ran by ASGI server.
"""This class is meant to be instantiated and run by an ASGI HTTP server.
>>> import uvicorn
>>> uvicorn.run(HTTPProxy(kv_store_actor_handle, router_handle))
# blocks forever
"""
async def fetch_config_from_controller(self, controller_name):
assert ray.is_initialized()
def __init__(self, controller_name):
controller = ray.get_actor(controller_name)
self.route_table = await controller.get_router_config.remote()
self.router = Router(controller)
self.long_poll_client = LongPollAsyncClient(controller, {
LongPollKey.ROUTE_TABLE: self._update_route_table,
})
self.request_counter = metrics.Count(
"num_http_requests",
description="The number of HTTP requests processed",
tag_keys=("route", ))
self.router = Router(controller)
async def setup(self):
await self.router.setup_in_async_loop()
def set_route_table(self, route_table):
async def _update_route_table(self, route_table):
self.route_table = route_table
async def receive_http_body(self, scope, receive, send):
@@ -74,8 +71,11 @@ class HTTPProxy:
status_code=404).send(scope, receive, send)
async def __call__(self, scope, receive, send):
# NOTE: This implements ASGI protocol specified in
# https://asgi.readthedocs.io/en/latest/specs/index.html
"""Implements the ASGI protocol.
See details at:
https://asgi.readthedocs.io/en/latest/specs/index.html.
"""
error_sender = self._make_error_sender(scope, receive, send)
@@ -137,12 +137,13 @@ class HTTPProxyActor:
host,
port,
controller_name,
http_middlewares: List["starlette.middleware.Middleware"] = []):
http_middlewares: List[
"starlette.middleware.Middleware"] = []): # noqa: F821
self.host = host
self.port = port
self.app = HTTPProxy()
await self.app.fetch_config_from_controller(controller_name)
self.app = HTTPProxy(controller_name)
await self.app.setup()
self.wrapped_app = self.app
for middleware in http_middlewares:
@@ -180,6 +181,3 @@ class HTTPProxyActor:
# the main thread and uvicorn doesn't expose a way to configure it.
server.install_signal_handlers = lambda: None
await server.serve(sockets=[sock])
async def set_route_table(self, route_table):
self.app.set_route_table(route_table)
+20 -19
View File
@@ -1,4 +1,5 @@
import asyncio
from inspect import iscoroutinefunction
import random
from collections import defaultdict
from dataclasses import dataclass
@@ -22,7 +23,7 @@ class UpdatedObject:
UpdateStateAsyncCallable = Callable[[Any], Awaitable[None]]
class LongPollerAsyncClient:
class LongPollAsyncClient:
"""The asynchronous long polling client.
Internally, it runs `await object_ref` in a `while True` loop. When a
@@ -31,7 +32,7 @@ class LongPollerAsyncClient:
the next poll.
Args:
host_actor(ray.ActorHandle): handle to actor embedding LongPollerHost.
host_actor(ray.ActorHandle): handle to actor embedding LongPollHost.
key_listeners(Dict[str, AsyncCallable]): a dictionary mapping keys to
callbacks to be called on state update for the corresponding keys.
"""
@@ -40,6 +41,10 @@ class LongPollerAsyncClient:
key_listeners: Dict[str, UpdateStateAsyncCallable]) -> None:
self.host_actor = host_actor
self.key_listeners = key_listeners
for callback in key_listeners.values():
if not iscoroutinefunction(callback):
raise ValueError(
"Callbacks to async long poller must be 'async def'.")
self.snapshot_ids: Dict[str, int] = {
key: -1
@@ -56,34 +61,31 @@ class LongPollerAsyncClient:
self.snapshot_ids)
return object_ref
def _update(self, updates: Dict[str, UpdatedObject]):
for key, update in updates.items():
self.object_snapshots[key] = update.object_snapshot
self.snapshot_ids[key] = update.snapshot_id
async def _do_long_poll(self):
while True:
try:
updates: Dict[str, UpdatedObject] = await self._poll_once()
self._update(updates)
logger.debug(f"LongPollerClient received udpates: {updates}")
for key, updated_object in updates.items():
logger.debug("LongPollClient received updates for keys: "
f"{list(updates.keys())}.")
for key, update in updates.items():
self.object_snapshots[key] = update.object_snapshot
self.snapshot_ids[key] = update.snapshot_id
# NOTE(simon):
# This blocks the loop from doing another poll. Consider
# use loop.create_task here or poll first then call the
# callbacks.
callback = self.key_listeners[key]
await callback(updated_object.object_snapshot)
await callback(update.object_snapshot)
except ray.exceptions.RayActorError:
# This can happen during shutdown where the controller is
# intentionally killed, the client should just gracefully
# exit.
logger.debug("LongPollerClient failed to connect to host. "
logger.debug("LongPollClient failed to connect to host. "
"Shutting down.")
break
class LongPollerHost:
class LongPollHost:
"""The server side object that manages long pulling requests.
The desired use case is to embed this in an Ray actor. Client will be
@@ -115,11 +117,10 @@ class LongPollerHost:
immediately if the snapshot_ids are outdated, otherwise it will block
until there's one updates.
"""
# 1. Figure out which keys do we care about
watched_keys = set(self.snapshot_ids.keys()).intersection(
keys_to_snapshot_ids.keys())
if len(watched_keys) == 0:
raise ValueError("Keys not found.")
watched_keys = keys_to_snapshot_ids.keys()
nonexistent_keys = set(watched_keys) - set(self.snapshot_ids.keys())
if len(nonexistent_keys) > 0:
raise ValueError(f"Keys not found: {nonexistent_keys}.")
# 2. If there are any outdated keys (by comparing snapshot ids)
# return immediately.
@@ -159,7 +160,7 @@ class LongPollerHost:
def notify_changed(self, object_key: str, updated_object: Any):
self.snapshot_ids[object_key] += 1
self.object_snapshots[object_key] = updated_object
logger.debug(f"LongPollerHost: {object_key} = {updated_object}")
logger.debug(f"LongPollHost: Notify change for key {object_key}.")
if object_key in self.notifier_events:
for event in self.notifier_events.pop(object_key):
+15 -12
View File
@@ -6,9 +6,10 @@ from typing import Any, DefaultDict, Dict, Iterable, List, Optional
import ray
from ray.actor import ActorHandle
from ray.serve.constants import LongPollKey
from ray.serve.context import TaskContext
from ray.serve.endpoint_policy import EndpointPolicy, RandomEndpointPolicy
from ray.serve.long_poll import LongPollerAsyncClient
from ray.serve.long_poll import LongPollAsyncClient
from ray.serve.utils import logger
from ray.util import metrics
@@ -106,7 +107,8 @@ class ReplicaSet:
) >= self.max_concurrent_queries:
# This replica is overloaded, try next one
continue
logger.debug(f"Replica set assigned {query} to {replica}")
logger.debug(f"Assigned query {query.metadata.request_id} "
f"to replica {replica}.")
ref = replica.handle_request.remote(query)
self.in_flight_queries[replica].add(ref)
return ref
@@ -133,7 +135,8 @@ class ReplicaSet:
"""
assigned_ref = self._try_assign_replica(query)
while assigned_ref is None: # Can't assign a replica right now.
logger.debug(f"Failed to assign a replica for query {query}")
logger.debug("Failed to assign a replica for "
f"query {query.metadata.request_id}")
# Maybe there exists a free replica, we just need to refresh our
# query tracker.
num_finished = self._drain_completed_object_refs()
@@ -141,7 +144,7 @@ class ReplicaSet:
# config to be updated.
if num_finished == 0:
logger.debug(
f"All replicas are busy, waiting for a free replica.")
"All replicas are busy, waiting for a free replica.")
await asyncio.wait(
self._all_query_refs + [self.config_updated_event.wait()],
return_when=asyncio.FIRST_COMPLETED)
@@ -176,14 +179,14 @@ class Router:
async def setup_in_async_loop(self):
# NOTE(simon): Instead of performing initialization in __init__,
# We separated the init of LongPollerAsyncClient to this method because
# __init__ might be called in sync context. LongPollerAsyncClient
# We separated the init of LongPollAsyncClient to this method because
# __init__ might be called in sync context. LongPollAsyncClient
# requires async context.
self.long_pull_client = LongPollerAsyncClient(
self.long_poll_client = LongPollAsyncClient(
self.controller, {
"traffic_policies": self._update_traffic_policies,
"worker_handles": self._update_worker_handles,
"backend_configs": self._update_backend_configs,
LongPollKey.TRAFFIC_POLICIES: self._update_traffic_policies,
LongPollKey.REPLICA_HANDLES: self._update_replica_handles,
LongPollKey.BACKEND_CONFIGS: self._update_backend_configs,
})
async def _update_traffic_policies(self, traffic_policies):
@@ -194,8 +197,8 @@ class Router:
event = self._pending_endpoints.pop(endpoint)
event.set()
async def _update_worker_handles(self, worker_handles):
for backend_tag, replica_handles in worker_handles.items():
async def _update_replica_handles(self, replica_handles):
for backend_tag, replica_handles in replica_handles.items():
self.backend_replicas[backend_tag].update_worker_replicas(
replica_handles)
+12 -9
View File
@@ -7,6 +7,7 @@ import pytest
import ray
from ray import serve
from ray.serve.config import BackendConfig
from ray.serve.constants import LongPollKey
if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False) == 1:
serve.controller._CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.5
@@ -42,22 +43,22 @@ def mock_controller_with_name():
@ray.remote(num_cpus=0)
class MockControllerActor:
def __init__(self):
from ray.serve.long_poll import LongPollerHost
self.host = LongPollerHost()
from ray.serve.long_poll import LongPollHost
self.host = LongPollHost()
self.backend_replicas = defaultdict(list)
self.backend_configs = dict()
self.clear()
def clear(self):
self.host.notify_changed("worker_handles", {})
self.host.notify_changed("traffic_policies", {})
self.host.notify_changed("backend_configs", {})
self.host.notify_changed(LongPollKey.REPLICA_HANDLES, {})
self.host.notify_changed(LongPollKey.TRAFFIC_POLICIES, {})
self.host.notify_changed(LongPollKey.BACKEND_CONFIGS, {})
async def listen_for_change(self, snapshot_ids):
return await self.host.listen_for_change(snapshot_ids)
def set_traffic(self, endpoint, traffic_policy):
self.host.notify_changed("traffic_policies",
self.host.notify_changed(LongPollKey.TRAFFIC_POLICIES,
{endpoint: traffic_policy})
def add_new_replica(self,
@@ -68,15 +69,17 @@ def mock_controller_with_name():
self.backend_configs[backend_tag] = backend_config
self.host.notify_changed(
"worker_handles",
LongPollKey.REPLICA_HANDLES,
self.backend_replicas,
)
self.host.notify_changed("backend_configs", self.backend_configs)
self.host.notify_changed(LongPollKey.BACKEND_CONFIGS,
self.backend_configs)
def update_backend(self, backend_tag: str,
backend_config: BackendConfig):
self.backend_configs[backend_tag] = backend_config
self.host.notify_changed("backend_configs", self.backend_configs)
self.host.notify_changed(LongPollKey.BACKEND_CONFIGS,
self.backend_configs)
name = f"MockController{random.randint(0,10e4)}"
yield name, MockControllerActor.options(name=name).remote()
+10 -40
View File
@@ -25,22 +25,6 @@ def test_e2e(serve_instance):
client.create_endpoint(
"endpoint", backend="echo:v1", route="/api", methods=["GET", "POST"])
retry_count = 5
timeout_sleep = 0.5
while True:
try:
resp = requests.get(
"http://127.0.0.1:8000/-/routes", timeout=0.5).json()
assert resp == {"/api": ["endpoint", ["GET", "POST"]]}
break
except Exception as e:
time.sleep(timeout_sleep)
timeout_sleep *= 2
retry_count -= 1
if retry_count == 0:
assert False, ("Route table hasn't been updated after 3 tries."
"The latest error was {}").format(e)
resp = requests.get("http://127.0.0.1:8000/api").json()["method"]
assert resp == "GET"
@@ -63,7 +47,7 @@ def test_backend_user_config(serve_instance):
config = BackendConfig(num_replicas=2, user_config={"count": 123, "b": 2})
client.create_backend("counter", Counter, config=config)
client.create_endpoint("counter", backend="counter", route="/counter")
client.create_endpoint("counter", backend="counter")
handle = client.get_handle("counter")
def check(val, num_replicas):
@@ -183,7 +167,7 @@ def test_reject_duplicate_endpoint_and_route(serve_instance):
def test_no_http(serve_instance):
client = serve.start(http_host=None)
assert len(ray.get(client._controller.get_routers.remote())) == 0
assert len(ray.get(client._controller.get_http_proxies.remote())) == 0
def hello(*args):
return "hello"
@@ -223,11 +207,6 @@ def test_scaling_replicas(serve_instance):
client.create_endpoint("counter", backend="counter:v1", route="/increment")
# Keep checking the routing table until /increment is populated
while "/increment" not in requests.get(
"http://127.0.0.1:8000/-/routes").json():
time.sleep(0.2)
counter_result = []
for _ in range(10):
resp = requests.get("http://127.0.0.1:8000/increment").json()
@@ -267,11 +246,6 @@ def test_batching(serve_instance):
client.create_endpoint(
"counter1", backend="counter:v11", route="/increment2")
# Keep checking the routing table until /increment is populated
while "/increment2" not in requests.get(
"http://127.0.0.1:8000/-/routes").json():
time.sleep(0.2)
future_list = []
handle = client.get_handle("counter1")
for _ in range(20):
@@ -299,8 +273,7 @@ def test_batching_exception(serve_instance):
# Set the max batch size.
config = BackendConfig(max_batch_size=5)
client.create_backend("exception:v1", NoListReturned, config=config)
client.create_endpoint(
"exception-test", backend="exception:v1", route="/noListReturned")
client.create_endpoint("exception-test", backend="exception:v1")
handle = client.get_handle("exception-test")
with pytest.raises(ray.exceptions.RayTaskError):
@@ -323,16 +296,16 @@ def test_updating_config(serve_instance):
client.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")
controller = client._controller
old_replica_tag_list = ray.get(
controller._list_replicas.remote("bsimple:v1"))
old_replica_tag_list = list(
ray.get(controller._all_replica_handles.remote())["bsimple:v1"].keys())
update_config = BackendConfig(max_batch_size=5)
client.update_backend_config("bsimple:v1", update_config)
new_replica_tag_list = ray.get(
controller._list_replicas.remote("bsimple:v1"))
new_replica_tag_list = list(
ray.get(controller._all_replica_handles.remote())["bsimple:v1"].keys())
new_all_tag_list = []
for worker_dict in ray.get(
controller.get_all_replica_handles.remote()).values():
controller._all_replica_handles.remote()).values():
new_all_tag_list.extend(list(worker_dict.keys()))
# the old and new replica tag list should be identical
@@ -648,7 +621,7 @@ def test_create_infeasible_error(serve_instance):
"MagicMLResource": 100
}})
# Even each replica might be feasible, the total might not be.
# Even though each replica might be feasible, the total might not be.
current_cpus = int(ray.nodes()[0]["Resources"]["CPU"])
num_replicas = current_cpus + 20
config = BackendConfig(num_replicas=num_replicas)
@@ -661,10 +634,6 @@ def test_create_infeasible_error(serve_instance):
}},
config=config)
# No replica should be created!
replicas = ray.get(client._controller._list_replicas.remote("f1"))
assert len(replicas) == 0
def test_shutdown():
def f():
@@ -797,6 +766,7 @@ def test_serve_metrics(serve_instance):
client.create_backend("metrics", batcher)
client.create_endpoint("metrics", backend="metrics", route="/metrics")
# send 10 concurrent requests
url = "http://127.0.0.1:8000/metrics"
ray.get([block_until_http_ready.remote(url) for _ in range(10)])
@@ -48,7 +48,7 @@ def setup_worker(name,
async def add_servable_to_router(servable, router, controller_name, **kwargs):
worker = setup_worker(
"backend", servable, controller_name=controller_name, **kwargs)
await router._update_worker_handles.remote({"backend": [worker]})
await router._update_replica_handles.remote({"backend": [worker]})
await router._update_traffic_policies.remote({
"endpoint": TrafficPolicy({
"backend": 1.0
+6 -6
View File
@@ -76,10 +76,10 @@ def test_controller_failure(serve_instance):
assert response.text == "hello3"
def _kill_routers(client):
routers = ray.get(client._controller.get_routers.remote())
for router in routers.values():
ray.kill(router, no_restart=False)
def _kill_http_proxies(client):
http_proxies = ray.get(client._controller.get_http_proxies.remote())
for http_proxy in http_proxies.values():
ray.kill(http_proxy, no_restart=False)
def test_http_proxy_failure(serve_instance):
@@ -98,7 +98,7 @@ def test_http_proxy_failure(serve_instance):
response = request_with_retries("/proxy_failure", timeout=30)
assert response.text == "hello1"
_kill_routers(client)
_kill_http_proxies(client)
def function(_):
return "hello2"
@@ -113,7 +113,7 @@ def test_http_proxy_failure(serve_instance):
def _get_worker_handles(client, backend):
controller = client._controller
backend_dict = ray.get(controller.get_all_replica_handles.remote())
backend_dict = ray.get(controller._all_replica_handles.remote())
return list(backend_dict[backend].values())
+22 -14
View File
@@ -1,5 +1,4 @@
import sys
import functools
import time
import asyncio
import os
@@ -8,12 +7,12 @@ from typing import Dict
import pytest
import ray
from ray.serve.long_poll import (LongPollerAsyncClient, LongPollerHost,
from ray.serve.long_poll import (LongPollAsyncClient, LongPollHost,
UpdatedObject)
def test_host_standalone(serve_instance):
host = ray.remote(LongPollerHost).remote()
host = ray.remote(LongPollHost).remote()
# Write two values
ray.get(host.notify_changed.remote("key_1", 999))
@@ -44,10 +43,10 @@ def test_long_poll_restarts(serve_instance):
max_restarts=-1,
max_task_retries=-1,
)
class RestartableLongPollerHost:
class RestartableLongPollHost:
def __init__(self) -> None:
print("actor started")
self.host = LongPollerHost()
self.host = LongPollHost()
self.host.notify_changed("timer", time.time())
self.should_exit = False
@@ -63,7 +62,7 @@ def test_long_poll_restarts(serve_instance):
print("actor exit")
os._exit(1)
host = RestartableLongPollerHost.remote()
host = RestartableLongPollHost.remote()
updated_values = ray.get(host.listen_for_change.remote({"timer": -1}))
timer: UpdatedObject = updated_values["timer"]
@@ -81,22 +80,31 @@ def test_long_poll_restarts(serve_instance):
@pytest.mark.asyncio
async def test_async_client(serve_instance):
host = ray.remote(LongPollerHost).remote()
host = ray.remote(LongPollHost).remote()
# Write two values
ray.get(host.notify_changed.remote("key_1", 100))
ray.get(host.notify_changed.remote("key_2", 999))
# Check that construction fails with a sync callback.
def callback(result, key):
pass
with pytest.raises(ValueError):
client = LongPollAsyncClient(host, {"key": callback})
callback_results = dict()
async def callback(result, key):
callback_results[key] = result
async def key_1_callback(result):
callback_results["key_1"] = result
client = LongPollerAsyncClient(
host, {
"key_1": functools.partial(callback, key="key_1"),
"key_2": functools.partial(callback, key="key_2")
})
async def key_2_callback(result):
callback_results["key_2"] = result
client = LongPollAsyncClient(host, {
"key_1": key_1_callback,
"key_2": key_2_callback,
})
while len(client.object_snapshots) == 0:
# Yield the loop for client to get the result
+5 -1
View File
@@ -144,6 +144,7 @@ class ServeEncoder(json.JSONEncoder):
@ray.remote(num_cpus=0)
def block_until_http_ready(http_endpoint,
backoff_time_s=1,
check_ready=None,
timeout=HTTP_PROXY_TIMEOUT):
http_is_ready = False
start_time = time.time()
@@ -152,7 +153,10 @@ def block_until_http_ready(http_endpoint,
try:
resp = requests.get(http_endpoint)
assert resp.status_code == 200
http_is_ready = True
if check_ready is None:
http_is_ready = True
else:
http_is_ready = check_ready(resp)
except Exception:
pass