From 3c44c0d3e4b5b764ab993bca11a1c58f1e43ed2c Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 10 Dec 2020 18:02:02 -0600 Subject: [PATCH] [serve] Long polling for routes in http server (#12724) --- python/ray/serve/api.py | 31 +++- python/ray/serve/backend_worker.py | 13 +- python/ray/serve/constants.py | 12 ++ python/ray/serve/controller.py | 153 ++++++++---------- python/ray/serve/endpoint_policy.py | 3 +- python/ray/serve/http_proxy.py | 40 +++-- python/ray/serve/long_poll.py | 39 ++--- python/ray/serve/router.py | 27 ++-- python/ray/serve/tests/conftest.py | 21 +-- python/ray/serve/tests/test_api.py | 50 ++---- python/ray/serve/tests/test_backend_worker.py | 2 +- python/ray/serve/tests/test_failure.py | 12 +- python/ray/serve/tests/test_long_poll.py | 36 +++-- python/ray/serve/utils.py | 6 +- 14 files changed, 232 insertions(+), 213 deletions(-) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 940b811d8..d5a60f0fc 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -43,6 +43,8 @@ class Client: self._controller_name = controller_name self._detached = detached self._shutdown = False + self._http_host, self._http_port = ray.get( + controller.get_http_config.remote()) # NOTE(simon): Used to cache client.get_handle(endpoint) call. It will # mostly grow in size, it will only shrink when user calls the @@ -149,6 +151,29 @@ class Client: self._controller.create_endpoint.remote( endpoint_name, {backend: 1.0}, route, upper_methods)) + # Block until the route table has been propagated to all HTTP proxies. + if route is not None: + + def check_ready(http_response): + return route in http_response.json() + + futures = [] + for node_id in ray.state.node_ids(): + future = block_until_http_ready.options( + num_cpus=0, resources={ + node_id: 0.01 + }).remote( + "http://{}:{}/-/routes".format(self._http_host, + self._http_port), + check_ready=check_ready, + timeout=HTTP_PROXY_TIMEOUT) + futures.append(future) + try: + ray.get(futures) + except ray.exceptions.RayTaskError: + raise TimeoutError("Route not available at HTTP proxies " + "after {HTTP_PROXY_TIMEOUT}s.") + @_ensure_connected def delete_endpoint(self, endpoint: str) -> None: """Delete the given endpoint. @@ -453,7 +478,11 @@ def start(detached: bool = False, "http://{}:{}/-/routes".format(http_host, http_port), timeout=HTTP_PROXY_TIMEOUT) futures.append(future) - ray.get(futures) + try: + ray.get(futures) + except ray.exceptions.RayTaskError: + raise TimeoutError( + "HTTP proxies not available after {HTTP_PROXY_TIMEOUT}s.") return Client(controller, controller_name, detached=detached) diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index 46ee66659..73088b558 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -15,10 +15,13 @@ from ray.serve.utils import (parse_request_item, _get_logger, chain_future, from ray.serve.exceptions import RayServeException from ray.util import metrics from ray.serve.config import BackendConfig -from ray.serve.long_poll import LongPollerAsyncClient +from ray.serve.long_poll import LongPollAsyncClient from ray.serve.router import Query -from ray.serve.constants import (DEFAULT_LATENCY_BUCKET_MS, - BACKEND_RECONFIGURE_METHOD) +from ray.serve.constants import ( + BACKEND_RECONFIGURE_METHOD, + DEFAULT_LATENCY_BUCKET_MS, + LongPollKey, +) from ray.exceptions import RayTaskError logger = _get_logger() @@ -168,8 +171,8 @@ class RayServeReplica: tag_keys=("backend", )) self.request_counter.set_default_tags({"backend": self.backend_tag}) - self.long_poll_client = LongPollerAsyncClient(controller_handle, { - "backend_configs": self._update_backend_configs, + self.long_poll_client = LongPollAsyncClient(controller_handle, { + LongPollKey.BACKEND_CONFIGS: self._update_backend_configs, }) self.error_counter = metrics.Count( diff --git a/python/ray/serve/constants.py b/python/ray/serve/constants.py index 52c76972f..db2b2008a 100644 --- a/python/ray/serve/constants.py +++ b/python/ray/serve/constants.py @@ -1,3 +1,5 @@ +from enum import auto, Enum + #: Actor name used to register controller SERVE_CONTROLLER_NAME = "SERVE_CONTROLLER_ACTOR" @@ -37,3 +39,13 @@ DEFAULT_LATENCY_BUCKET_MS = [ #: Name of backend reconfiguration method implemented by user. BACKEND_RECONFIGURE_METHOD = "reconfigure" + + +class LongPollKey(Enum): + def __repr__(self): + return f"{self.__class__.__name__}.{self.name}" + + REPLICA_HANDLES = auto() + TRAFFIC_POLICIES = auto() + BACKEND_CONFIGS = auto() + ROUTE_TABLE = auto() diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 742a5172a..7330a722d 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -13,14 +13,15 @@ import ray import ray.cloudpickle as pickle from ray.serve.autoscaling_policy import BasicAutoscalingPolicy from ray.serve.backend_worker import create_backend_replica -from ray.serve.constants import ASYNC_CONCURRENCY, SERVE_PROXY_NAME +from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_PROXY_NAME, + LongPollKey) from ray.serve.http_proxy import HTTPProxyActor from ray.serve.kv_store import RayInternalKVStore from ray.serve.exceptions import RayServeException from ray.serve.utils import (format_actor_name, get_random_letters, logger, try_schedule_resources_on_nodes, get_all_node_ids) from ray.serve.config import BackendConfig, ReplicaConfig -from ray.serve.long_poll import LongPollerHost +from ray.serve.long_poll import LongPollHost from ray.actor import ActorHandle import numpy as np @@ -145,7 +146,7 @@ class ActorStateReconciler: controller_name: str = field(init=True) detached: bool = field(init=True) - routers_cache: Dict[NodeId, ActorHandle] = field(default_factory=dict) + http_proxy_cache: Dict[NodeId, ActorHandle] = field(default_factory=dict) backend_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]] = field( default_factory=lambda: defaultdict(dict)) backend_replicas_to_start: Dict[BackendTag, List[ReplicaTag]] = field( @@ -157,8 +158,8 @@ class ActorStateReconciler: # TODO(edoakes): consider removing this and just using the names. - def router_handles(self) -> List[ActorHandle]: - return list(self.routers_cache.values()) + def http_proxy_handles(self) -> List[ActorHandle]: + return list(self.http_proxy_cache.values()) def get_replica_handles(self) -> List[ActorHandle]: return list( @@ -303,7 +304,7 @@ class ActorStateReconciler: async def _stop_pending_backend_replicas(self) -> None: """Stops the pending backend replicas in self.backend_replicas_to_stop. - Removes backend_replicas from the router, kills them, and clears + Removes backend_replicas from the http_proxy, kills them, and clears self.backend_replicas_to_stop. """ for backend_tag, replicas_list in self.backend_replicas_to_stop.items( @@ -327,26 +328,26 @@ class ActorStateReconciler: self.backend_replicas_to_stop.clear() - def _start_routers_if_needed(self, http_host: str, http_port: str, - http_middlewares: List[Any]) -> None: - """Start a router on every node if it doesn't already exist.""" + def _start_http_proxies_if_needed(self, http_host: str, http_port: str, + http_middlewares: List[Any]) -> None: + """Start an HTTP proxy on every node if it doesn't already exist.""" if http_host is None: return for node_id, node_resource in get_all_node_ids(): - if node_id in self.routers_cache: + if node_id in self.http_proxy_cache: continue - router_name = format_actor_name(SERVE_PROXY_NAME, - self.controller_name, node_id) + name = format_actor_name(SERVE_PROXY_NAME, self.controller_name, + node_id) try: - router = ray.get_actor(router_name) + proxy = ray.get_actor(name) except ValueError: - logger.info("Starting router with name '{}' on node '{}' " + logger.info("Starting HTTP proxy with name '{}' on node '{}' " "listening on '{}:{}'".format( - router_name, node_id, http_host, http_port)) - router = HTTPProxyActor.options( - name=router_name, + name, node_id, http_host, http_port)) + proxy = HTTPProxyActor.options( + name=name, lifetime="detached" if self.detached else None, max_concurrency=ASYNC_CONCURRENCY, max_restarts=-1, @@ -360,10 +361,10 @@ class ActorStateReconciler: controller_name=self.controller_name, http_middlewares=http_middlewares) - self.routers_cache[node_id] = router + self.http_proxy_cache[node_id] = proxy - def _stop_routers_if_needed(self) -> bool: - """Removes router actors from any nodes that no longer exist. + def _stop_http_proxies_if_needed(self) -> bool: + """Removes HTTP proxy actors from any nodes that no longer exist. Returns whether or not any actors were removed (a checkpoint should be taken). @@ -371,25 +372,25 @@ class ActorStateReconciler: actor_stopped = False all_node_ids = {node_id for node_id, _ in get_all_node_ids()} to_stop = [] - for node_id in self.routers_cache: + for node_id in self.http_proxy_cache: if node_id not in all_node_ids: - logger.info( - "Removing router on removed node '{}'.".format(node_id)) + logger.info("Removing HTTP proxy on removed node '{}'.".format( + node_id)) to_stop.append(node_id) for node_id in to_stop: - router_handle = self.routers_cache.pop(node_id) - ray.kill(router_handle, no_restart=True) + proxy = self.http_proxy_cache.pop(node_id) + ray.kill(proxy, no_restart=True) actor_stopped = True return actor_stopped def _recover_actor_handles(self) -> None: # Refresh the RouterCache - for node_id in self.routers_cache.keys(): - router_name = format_actor_name(SERVE_PROXY_NAME, - self.controller_name, node_id) - self.routers_cache[node_id] = ray.get_actor(router_name) + for node_id in self.http_proxy_cache.keys(): + name = format_actor_name(SERVE_PROXY_NAME, self.controller_name, + node_id) + self.http_proxy_cache[node_id] = ray.get_actor(name) # Fetch actor handles for all of the backend replicas in the system. # All of these backend_replicas are guaranteed to already exist because @@ -482,7 +483,7 @@ class ServeController: # backend -> AutoscalingPolicy self.autoscaling_policies = dict() - # Dictionary of backend_tag -> router_name -> most recent queue length. + # Dictionary of backend_tag -> proxy_name -> most recent queue length. self.backend_stats = defaultdict(lambda: defaultdict(dict)) # Used to ensure that only a single state-changing operation happens @@ -495,7 +496,7 @@ class ServeController: # If starting the actor for the first time, starts up the other system # components. If recovering, fetches their actor handles. - self.actor_reconciler._start_routers_if_needed( + self.actor_reconciler._start_http_proxies_if_needed( self.http_host, self.http_port, self.http_middlewares) # Map of awaiting results @@ -528,10 +529,11 @@ class ServeController: # can be problem at scale, e.g. updating a single backend config # will send over the entire configs. In the future, we should # optimize the logic to support subscription by key. - self.long_poll_host = LongPollerHost() + self.long_poll_host = LongPollHost() self.notify_backend_configs_changed() self.notify_replica_handles_changed() self.notify_traffic_policies_changed() + self.notify_route_table_changed() asyncio.get_event_loop().create_task(self.run_control_loop()) @@ -565,19 +567,26 @@ class ServeController: def notify_replica_handles_changed(self): self.long_poll_host.notify_changed( - "worker_handles", { + LongPollKey.REPLICA_HANDLES, { backend_tag: list(replica_dict.values()) for backend_tag, replica_dict in self.actor_reconciler.backend_replicas.items() }) def notify_traffic_policies_changed(self): - self.long_poll_host.notify_changed("traffic_policies", - self.current_state.traffic_policies) + self.long_poll_host.notify_changed( + LongPollKey.TRAFFIC_POLICIES, + self.current_state.traffic_policies, + ) def notify_backend_configs_changed(self): self.long_poll_host.notify_changed( - "backend_configs", self.current_state.get_backend_configs()) + LongPollKey.BACKEND_CONFIGS, + self.current_state.get_backend_configs()) + + def notify_route_table_changed(self): + self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE, + self.current_state.routes) async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]): """Proxy long pull client's listen request. @@ -590,13 +599,9 @@ class ServeController: return await ( self.long_poll_host.listen_for_change(keys_to_snapshot_ids)) - def get_routers(self) -> Dict[str, ActorHandle]: - """Returns a dictionary of node ID to router actor handles.""" - return self.actor_reconciler.routers_cache - - def get_router_config(self) -> Dict[str, Tuple[str, List[str]]]: - """Called by the router on startup to fetch required state.""" - return self.current_state.routes + def get_http_proxies(self) -> Dict[str, ActorHandle]: + """Returns a dictionary of node ID to http_proxy actor handles.""" + return self.actor_reconciler.http_proxy_cache def _checkpoint(self) -> None: """Checkpoint internal state and write it to the KV store.""" @@ -622,7 +627,7 @@ class ServeController: Performs the following operations: 1) Deserializes the internal state from the checkpoint. - 2) Pushes the latest configuration to the routers + 2) Pushes the latest configuration to the HTTP proxies in case we crashed before updating them. 3) Starts/stops any replicas that are pending creation or deletion. @@ -670,40 +675,26 @@ class ServeController: while True: await self.do_autoscale() async with self.write_lock: - self.actor_reconciler._start_routers_if_needed( + self.actor_reconciler._start_http_proxies_if_needed( self.http_host, self.http_port, self.http_middlewares) checkpoint_required = self.actor_reconciler.\ - _stop_routers_if_needed() + _stop_http_proxies_if_needed() if checkpoint_required: self._checkpoint() await asyncio.sleep(CONTROL_LOOP_PERIOD_S) - def get_backend_configs(self) -> Dict[str, BackendConfig]: - """Fetched by the router on startup.""" - return self.current_state.get_backend_configs() - - def get_traffic_policies(self) -> Dict[str, TrafficPolicy]: - """Fetched by the router on startup.""" - return self.current_state.traffic_policies - - def _list_replicas(self, backend_tag: BackendTag) -> List[ReplicaTag]: - """Used only for testing.""" - return list(self.actor_reconciler.backend_replicas[backend_tag].keys()) - - def get_traffic_policy(self, endpoint: str) -> TrafficPolicy: - """Fetched by serve handles.""" - return self.current_state.traffic_policies[endpoint] - - def get_all_replica_handles(self) -> Dict[str, Dict[str, ActorHandle]]: - """Fetched by the router on startup.""" + def _all_replica_handles( + self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]: + """Used for testing.""" return self.actor_reconciler.backend_replicas - def get_all_backends(self) -> Dict[str, BackendConfig]: + def get_all_backends(self) -> Dict[BackendTag, BackendConfig]: """Returns a dictionary of backend tag to backend config.""" return self.current_state.get_backend_configs() - def get_all_endpoints(self) -> Dict[str, Dict[str, Any]]: + def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]: + """Returns a dictionary of backend tag to backend config.""" return self.current_state.get_endpoints() async def _set_traffic(self, endpoint_name: str, @@ -731,7 +722,6 @@ class ServeController: # update to avoid inconsistent state if we crash after pushing the # update. self._checkpoint() - self.notify_traffic_policies_changed() return return_uuid @@ -814,10 +804,7 @@ class ServeController: # NOTE(edoakes): checkpoint is written in self._set_traffic. return_uuid = await self._set_traffic(endpoint, traffic_dict) - await asyncio.gather(*[ - router.set_route_table.remote(self.current_state.routes) - for router in self.actor_reconciler.router_handles() - ]) + self.notify_route_table_changed() return return_uuid async def delete_endpoint(self, endpoint: str) -> UUID: @@ -852,14 +839,10 @@ class ServeController: endpoint: None }) # NOTE(edoakes): we must write a checkpoint before pushing the - # updates to the routers to avoid inconsistent state if we crash + # updates to the proxies to avoid inconsistent state if we crash # after pushing the update. self._checkpoint() - - await asyncio.gather(*[ - router.set_route_table.remote(self.current_state.routes) - for router in self.actor_reconciler.router_handles() - ]) + self.notify_route_table_changed() return return_uuid async def create_backend(self, backend_tag: BackendTag, @@ -910,7 +893,7 @@ class ServeController: self.notify_replica_handles_changed() - # Set the backend config inside the router + # Set the backend config inside routers # (particularly for max_concurrent_queries). self.notify_backend_configs_changed() return return_uuid @@ -943,12 +926,12 @@ class ServeController: if backend_tag in self.autoscaling_policies: del self.autoscaling_policies[backend_tag] - # Add the intention to remove the backend from the router. + # Add the intention to remove the backend from the routers. self.actor_reconciler.backends_to_remove.append(backend_tag) return_uuid = self._create_event_with_result({backend_tag: None}) # NOTE(edoakes): we must write a checkpoint before removing the - # backend from the router to avoid inconsistent state if we crash + # backend from the routers to avoid inconsistent state if we crash # after pushing the update. self._checkpoint() await self.actor_reconciler._stop_pending_backend_replicas() @@ -986,7 +969,7 @@ class ServeController: # update. self._checkpoint() - # Inform the router about change in configuration + # Inform the routers about change in configuration # (particularly for setting max_batch_size). await self.actor_reconciler._start_pending_backend_replicas( @@ -1003,11 +986,15 @@ class ServeController: ), "Backend {} is not registered.".format(backend_tag) return self.current_state.get_backend(backend_tag).backend_config + def get_http_config(self): + """Return the HTTP proxy configuration.""" + return self.http_host, self.http_port + async def shutdown(self) -> None: """Shuts down the serve instance completely.""" async with self.write_lock: - for router in self.actor_reconciler.router_handles(): - ray.kill(router, no_restart=True) + for http_proxy in self.actor_reconciler.http_proxy_handles(): + ray.kill(http_proxy, no_restart=True) for replica in self.actor_reconciler.get_replica_handles(): ray.kill(replica, no_restart=True) self.kv_store.delete(CHECKPOINT_KEY) diff --git a/python/ray/serve/endpoint_policy.py b/python/ray/serve/endpoint_policy.py index 555c4f58a..c757f20c3 100644 --- a/python/ray/serve/endpoint_policy.py +++ b/python/ray/serve/endpoint_policy.py @@ -89,5 +89,6 @@ class RandomEndpointPolicy(EndpointPolicy): query.metadata.shard_key.encode("utf-8")) chosen_backend, shadow_backends = self._select_backends(value) - logger.debug(f"Chosen backend {chosen_backend} for query {query}") + logger.debug(f"Assigning query {query.metadata.request_id} " + f"to backend {chosen_backend}.") return [chosen_backend] + shadow_backends diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index 5dfb003c6..01c8d8e02 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -6,43 +6,40 @@ import uvicorn import ray from ray.exceptions import RayTaskError +from ray.serve.constants import LongPollKey from ray.serve.context import TaskContext from ray.util import metrics from ray.serve.utils import _get_logger, get_random_letters from ray.serve.http_util import Response +from ray.serve.long_poll import LongPollAsyncClient from ray.serve.router import Router, RequestMetadata -# The maximum number of times to retry a request due to actor failure. -# TODO(edoakes): this should probably be configurable. -MAX_ACTOR_DEAD_RETRIES = 10 - logger = _get_logger() class HTTPProxy: - """ - This class should be instantiated and ran by ASGI server. + """This class is meant to be instantiated and run by an ASGI HTTP server. >>> import uvicorn >>> uvicorn.run(HTTPProxy(kv_store_actor_handle, router_handle)) - # blocks forever """ - async def fetch_config_from_controller(self, controller_name): - assert ray.is_initialized() + def __init__(self, controller_name): controller = ray.get_actor(controller_name) - - self.route_table = await controller.get_router_config.remote() + self.router = Router(controller) + self.long_poll_client = LongPollAsyncClient(controller, { + LongPollKey.ROUTE_TABLE: self._update_route_table, + }) self.request_counter = metrics.Count( "num_http_requests", description="The number of HTTP requests processed", tag_keys=("route", )) - self.router = Router(controller) + async def setup(self): await self.router.setup_in_async_loop() - def set_route_table(self, route_table): + async def _update_route_table(self, route_table): self.route_table = route_table async def receive_http_body(self, scope, receive, send): @@ -74,8 +71,11 @@ class HTTPProxy: status_code=404).send(scope, receive, send) async def __call__(self, scope, receive, send): - # NOTE: This implements ASGI protocol specified in - # https://asgi.readthedocs.io/en/latest/specs/index.html + """Implements the ASGI protocol. + + See details at: + https://asgi.readthedocs.io/en/latest/specs/index.html. + """ error_sender = self._make_error_sender(scope, receive, send) @@ -137,12 +137,13 @@ class HTTPProxyActor: host, port, controller_name, - http_middlewares: List["starlette.middleware.Middleware"] = []): + http_middlewares: List[ + "starlette.middleware.Middleware"] = []): # noqa: F821 self.host = host self.port = port - self.app = HTTPProxy() - await self.app.fetch_config_from_controller(controller_name) + self.app = HTTPProxy(controller_name) + await self.app.setup() self.wrapped_app = self.app for middleware in http_middlewares: @@ -180,6 +181,3 @@ class HTTPProxyActor: # the main thread and uvicorn doesn't expose a way to configure it. server.install_signal_handlers = lambda: None await server.serve(sockets=[sock]) - - async def set_route_table(self, route_table): - self.app.set_route_table(route_table) diff --git a/python/ray/serve/long_poll.py b/python/ray/serve/long_poll.py index 2d747e94e..01df2cfe0 100644 --- a/python/ray/serve/long_poll.py +++ b/python/ray/serve/long_poll.py @@ -1,4 +1,5 @@ import asyncio +from inspect import iscoroutinefunction import random from collections import defaultdict from dataclasses import dataclass @@ -22,7 +23,7 @@ class UpdatedObject: UpdateStateAsyncCallable = Callable[[Any], Awaitable[None]] -class LongPollerAsyncClient: +class LongPollAsyncClient: """The asynchronous long polling client. Internally, it runs `await object_ref` in a `while True` loop. When a @@ -31,7 +32,7 @@ class LongPollerAsyncClient: the next poll. Args: - host_actor(ray.ActorHandle): handle to actor embedding LongPollerHost. + host_actor(ray.ActorHandle): handle to actor embedding LongPollHost. key_listeners(Dict[str, AsyncCallable]): a dictionary mapping keys to callbacks to be called on state update for the corresponding keys. """ @@ -40,6 +41,10 @@ class LongPollerAsyncClient: key_listeners: Dict[str, UpdateStateAsyncCallable]) -> None: self.host_actor = host_actor self.key_listeners = key_listeners + for callback in key_listeners.values(): + if not iscoroutinefunction(callback): + raise ValueError( + "Callbacks to async long poller must be 'async def'.") self.snapshot_ids: Dict[str, int] = { key: -1 @@ -56,34 +61,31 @@ class LongPollerAsyncClient: self.snapshot_ids) return object_ref - def _update(self, updates: Dict[str, UpdatedObject]): - for key, update in updates.items(): - self.object_snapshots[key] = update.object_snapshot - self.snapshot_ids[key] = update.snapshot_id - async def _do_long_poll(self): while True: try: updates: Dict[str, UpdatedObject] = await self._poll_once() - self._update(updates) - logger.debug(f"LongPollerClient received udpates: {updates}") - for key, updated_object in updates.items(): + logger.debug("LongPollClient received updates for keys: " + f"{list(updates.keys())}.") + for key, update in updates.items(): + self.object_snapshots[key] = update.object_snapshot + self.snapshot_ids[key] = update.snapshot_id # NOTE(simon): # This blocks the loop from doing another poll. Consider # use loop.create_task here or poll first then call the # callbacks. callback = self.key_listeners[key] - await callback(updated_object.object_snapshot) + await callback(update.object_snapshot) except ray.exceptions.RayActorError: # This can happen during shutdown where the controller is # intentionally killed, the client should just gracefully # exit. - logger.debug("LongPollerClient failed to connect to host. " + logger.debug("LongPollClient failed to connect to host. " "Shutting down.") break -class LongPollerHost: +class LongPollHost: """The server side object that manages long pulling requests. The desired use case is to embed this in an Ray actor. Client will be @@ -115,11 +117,10 @@ class LongPollerHost: immediately if the snapshot_ids are outdated, otherwise it will block until there's one updates. """ - # 1. Figure out which keys do we care about - watched_keys = set(self.snapshot_ids.keys()).intersection( - keys_to_snapshot_ids.keys()) - if len(watched_keys) == 0: - raise ValueError("Keys not found.") + watched_keys = keys_to_snapshot_ids.keys() + nonexistent_keys = set(watched_keys) - set(self.snapshot_ids.keys()) + if len(nonexistent_keys) > 0: + raise ValueError(f"Keys not found: {nonexistent_keys}.") # 2. If there are any outdated keys (by comparing snapshot ids) # return immediately. @@ -159,7 +160,7 @@ class LongPollerHost: def notify_changed(self, object_key: str, updated_object: Any): self.snapshot_ids[object_key] += 1 self.object_snapshots[object_key] = updated_object - logger.debug(f"LongPollerHost: {object_key} = {updated_object}") + logger.debug(f"LongPollHost: Notify change for key {object_key}.") if object_key in self.notifier_events: for event in self.notifier_events.pop(object_key): diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index 9808ad9c9..8276e6cd2 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -6,9 +6,10 @@ from typing import Any, DefaultDict, Dict, Iterable, List, Optional import ray from ray.actor import ActorHandle +from ray.serve.constants import LongPollKey from ray.serve.context import TaskContext from ray.serve.endpoint_policy import EndpointPolicy, RandomEndpointPolicy -from ray.serve.long_poll import LongPollerAsyncClient +from ray.serve.long_poll import LongPollAsyncClient from ray.serve.utils import logger from ray.util import metrics @@ -106,7 +107,8 @@ class ReplicaSet: ) >= self.max_concurrent_queries: # This replica is overloaded, try next one continue - logger.debug(f"Replica set assigned {query} to {replica}") + logger.debug(f"Assigned query {query.metadata.request_id} " + f"to replica {replica}.") ref = replica.handle_request.remote(query) self.in_flight_queries[replica].add(ref) return ref @@ -133,7 +135,8 @@ class ReplicaSet: """ assigned_ref = self._try_assign_replica(query) while assigned_ref is None: # Can't assign a replica right now. - logger.debug(f"Failed to assign a replica for query {query}") + logger.debug("Failed to assign a replica for " + f"query {query.metadata.request_id}") # Maybe there exists a free replica, we just need to refresh our # query tracker. num_finished = self._drain_completed_object_refs() @@ -141,7 +144,7 @@ class ReplicaSet: # config to be updated. if num_finished == 0: logger.debug( - f"All replicas are busy, waiting for a free replica.") + "All replicas are busy, waiting for a free replica.") await asyncio.wait( self._all_query_refs + [self.config_updated_event.wait()], return_when=asyncio.FIRST_COMPLETED) @@ -176,14 +179,14 @@ class Router: async def setup_in_async_loop(self): # NOTE(simon): Instead of performing initialization in __init__, - # We separated the init of LongPollerAsyncClient to this method because - # __init__ might be called in sync context. LongPollerAsyncClient + # We separated the init of LongPollAsyncClient to this method because + # __init__ might be called in sync context. LongPollAsyncClient # requires async context. - self.long_pull_client = LongPollerAsyncClient( + self.long_poll_client = LongPollAsyncClient( self.controller, { - "traffic_policies": self._update_traffic_policies, - "worker_handles": self._update_worker_handles, - "backend_configs": self._update_backend_configs, + LongPollKey.TRAFFIC_POLICIES: self._update_traffic_policies, + LongPollKey.REPLICA_HANDLES: self._update_replica_handles, + LongPollKey.BACKEND_CONFIGS: self._update_backend_configs, }) async def _update_traffic_policies(self, traffic_policies): @@ -194,8 +197,8 @@ class Router: event = self._pending_endpoints.pop(endpoint) event.set() - async def _update_worker_handles(self, worker_handles): - for backend_tag, replica_handles in worker_handles.items(): + async def _update_replica_handles(self, replica_handles): + for backend_tag, replica_handles in replica_handles.items(): self.backend_replicas[backend_tag].update_worker_replicas( replica_handles) diff --git a/python/ray/serve/tests/conftest.py b/python/ray/serve/tests/conftest.py index 740d82271..76a67c190 100644 --- a/python/ray/serve/tests/conftest.py +++ b/python/ray/serve/tests/conftest.py @@ -7,6 +7,7 @@ import pytest import ray from ray import serve from ray.serve.config import BackendConfig +from ray.serve.constants import LongPollKey if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False) == 1: serve.controller._CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.5 @@ -42,22 +43,22 @@ def mock_controller_with_name(): @ray.remote(num_cpus=0) class MockControllerActor: def __init__(self): - from ray.serve.long_poll import LongPollerHost - self.host = LongPollerHost() + from ray.serve.long_poll import LongPollHost + self.host = LongPollHost() self.backend_replicas = defaultdict(list) self.backend_configs = dict() self.clear() def clear(self): - self.host.notify_changed("worker_handles", {}) - self.host.notify_changed("traffic_policies", {}) - self.host.notify_changed("backend_configs", {}) + self.host.notify_changed(LongPollKey.REPLICA_HANDLES, {}) + self.host.notify_changed(LongPollKey.TRAFFIC_POLICIES, {}) + self.host.notify_changed(LongPollKey.BACKEND_CONFIGS, {}) async def listen_for_change(self, snapshot_ids): return await self.host.listen_for_change(snapshot_ids) def set_traffic(self, endpoint, traffic_policy): - self.host.notify_changed("traffic_policies", + self.host.notify_changed(LongPollKey.TRAFFIC_POLICIES, {endpoint: traffic_policy}) def add_new_replica(self, @@ -68,15 +69,17 @@ def mock_controller_with_name(): self.backend_configs[backend_tag] = backend_config self.host.notify_changed( - "worker_handles", + LongPollKey.REPLICA_HANDLES, self.backend_replicas, ) - self.host.notify_changed("backend_configs", self.backend_configs) + self.host.notify_changed(LongPollKey.BACKEND_CONFIGS, + self.backend_configs) def update_backend(self, backend_tag: str, backend_config: BackendConfig): self.backend_configs[backend_tag] = backend_config - self.host.notify_changed("backend_configs", self.backend_configs) + self.host.notify_changed(LongPollKey.BACKEND_CONFIGS, + self.backend_configs) name = f"MockController{random.randint(0,10e4)}" yield name, MockControllerActor.options(name=name).remote() diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 5d2bc76c0..4ad1cb554 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -25,22 +25,6 @@ def test_e2e(serve_instance): client.create_endpoint( "endpoint", backend="echo:v1", route="/api", methods=["GET", "POST"]) - retry_count = 5 - timeout_sleep = 0.5 - while True: - try: - resp = requests.get( - "http://127.0.0.1:8000/-/routes", timeout=0.5).json() - assert resp == {"/api": ["endpoint", ["GET", "POST"]]} - break - except Exception as e: - time.sleep(timeout_sleep) - timeout_sleep *= 2 - retry_count -= 1 - if retry_count == 0: - assert False, ("Route table hasn't been updated after 3 tries." - "The latest error was {}").format(e) - resp = requests.get("http://127.0.0.1:8000/api").json()["method"] assert resp == "GET" @@ -63,7 +47,7 @@ def test_backend_user_config(serve_instance): config = BackendConfig(num_replicas=2, user_config={"count": 123, "b": 2}) client.create_backend("counter", Counter, config=config) - client.create_endpoint("counter", backend="counter", route="/counter") + client.create_endpoint("counter", backend="counter") handle = client.get_handle("counter") def check(val, num_replicas): @@ -183,7 +167,7 @@ def test_reject_duplicate_endpoint_and_route(serve_instance): def test_no_http(serve_instance): client = serve.start(http_host=None) - assert len(ray.get(client._controller.get_routers.remote())) == 0 + assert len(ray.get(client._controller.get_http_proxies.remote())) == 0 def hello(*args): return "hello" @@ -223,11 +207,6 @@ def test_scaling_replicas(serve_instance): client.create_endpoint("counter", backend="counter:v1", route="/increment") - # Keep checking the routing table until /increment is populated - while "/increment" not in requests.get( - "http://127.0.0.1:8000/-/routes").json(): - time.sleep(0.2) - counter_result = [] for _ in range(10): resp = requests.get("http://127.0.0.1:8000/increment").json() @@ -267,11 +246,6 @@ def test_batching(serve_instance): client.create_endpoint( "counter1", backend="counter:v11", route="/increment2") - # Keep checking the routing table until /increment is populated - while "/increment2" not in requests.get( - "http://127.0.0.1:8000/-/routes").json(): - time.sleep(0.2) - future_list = [] handle = client.get_handle("counter1") for _ in range(20): @@ -299,8 +273,7 @@ def test_batching_exception(serve_instance): # Set the max batch size. config = BackendConfig(max_batch_size=5) client.create_backend("exception:v1", NoListReturned, config=config) - client.create_endpoint( - "exception-test", backend="exception:v1", route="/noListReturned") + client.create_endpoint("exception-test", backend="exception:v1") handle = client.get_handle("exception-test") with pytest.raises(ray.exceptions.RayTaskError): @@ -323,16 +296,16 @@ def test_updating_config(serve_instance): client.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple") controller = client._controller - old_replica_tag_list = ray.get( - controller._list_replicas.remote("bsimple:v1")) + old_replica_tag_list = list( + ray.get(controller._all_replica_handles.remote())["bsimple:v1"].keys()) update_config = BackendConfig(max_batch_size=5) client.update_backend_config("bsimple:v1", update_config) - new_replica_tag_list = ray.get( - controller._list_replicas.remote("bsimple:v1")) + new_replica_tag_list = list( + ray.get(controller._all_replica_handles.remote())["bsimple:v1"].keys()) new_all_tag_list = [] for worker_dict in ray.get( - controller.get_all_replica_handles.remote()).values(): + controller._all_replica_handles.remote()).values(): new_all_tag_list.extend(list(worker_dict.keys())) # the old and new replica tag list should be identical @@ -648,7 +621,7 @@ def test_create_infeasible_error(serve_instance): "MagicMLResource": 100 }}) - # Even each replica might be feasible, the total might not be. + # Even though each replica might be feasible, the total might not be. current_cpus = int(ray.nodes()[0]["Resources"]["CPU"]) num_replicas = current_cpus + 20 config = BackendConfig(num_replicas=num_replicas) @@ -661,10 +634,6 @@ def test_create_infeasible_error(serve_instance): }}, config=config) - # No replica should be created! - replicas = ray.get(client._controller._list_replicas.remote("f1")) - assert len(replicas) == 0 - def test_shutdown(): def f(): @@ -797,6 +766,7 @@ def test_serve_metrics(serve_instance): client.create_backend("metrics", batcher) client.create_endpoint("metrics", backend="metrics", route="/metrics") + # send 10 concurrent requests url = "http://127.0.0.1:8000/metrics" ray.get([block_until_http_ready.remote(url) for _ in range(10)]) diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index 4233740fa..1b03c0835 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -48,7 +48,7 @@ def setup_worker(name, async def add_servable_to_router(servable, router, controller_name, **kwargs): worker = setup_worker( "backend", servable, controller_name=controller_name, **kwargs) - await router._update_worker_handles.remote({"backend": [worker]}) + await router._update_replica_handles.remote({"backend": [worker]}) await router._update_traffic_policies.remote({ "endpoint": TrafficPolicy({ "backend": 1.0 diff --git a/python/ray/serve/tests/test_failure.py b/python/ray/serve/tests/test_failure.py index 99a05ca39..6312e56f2 100644 --- a/python/ray/serve/tests/test_failure.py +++ b/python/ray/serve/tests/test_failure.py @@ -76,10 +76,10 @@ def test_controller_failure(serve_instance): assert response.text == "hello3" -def _kill_routers(client): - routers = ray.get(client._controller.get_routers.remote()) - for router in routers.values(): - ray.kill(router, no_restart=False) +def _kill_http_proxies(client): + http_proxies = ray.get(client._controller.get_http_proxies.remote()) + for http_proxy in http_proxies.values(): + ray.kill(http_proxy, no_restart=False) def test_http_proxy_failure(serve_instance): @@ -98,7 +98,7 @@ def test_http_proxy_failure(serve_instance): response = request_with_retries("/proxy_failure", timeout=30) assert response.text == "hello1" - _kill_routers(client) + _kill_http_proxies(client) def function(_): return "hello2" @@ -113,7 +113,7 @@ def test_http_proxy_failure(serve_instance): def _get_worker_handles(client, backend): controller = client._controller - backend_dict = ray.get(controller.get_all_replica_handles.remote()) + backend_dict = ray.get(controller._all_replica_handles.remote()) return list(backend_dict[backend].values()) diff --git a/python/ray/serve/tests/test_long_poll.py b/python/ray/serve/tests/test_long_poll.py index 040219527..5916eadcb 100644 --- a/python/ray/serve/tests/test_long_poll.py +++ b/python/ray/serve/tests/test_long_poll.py @@ -1,5 +1,4 @@ import sys -import functools import time import asyncio import os @@ -8,12 +7,12 @@ from typing import Dict import pytest import ray -from ray.serve.long_poll import (LongPollerAsyncClient, LongPollerHost, +from ray.serve.long_poll import (LongPollAsyncClient, LongPollHost, UpdatedObject) def test_host_standalone(serve_instance): - host = ray.remote(LongPollerHost).remote() + host = ray.remote(LongPollHost).remote() # Write two values ray.get(host.notify_changed.remote("key_1", 999)) @@ -44,10 +43,10 @@ def test_long_poll_restarts(serve_instance): max_restarts=-1, max_task_retries=-1, ) - class RestartableLongPollerHost: + class RestartableLongPollHost: def __init__(self) -> None: print("actor started") - self.host = LongPollerHost() + self.host = LongPollHost() self.host.notify_changed("timer", time.time()) self.should_exit = False @@ -63,7 +62,7 @@ def test_long_poll_restarts(serve_instance): print("actor exit") os._exit(1) - host = RestartableLongPollerHost.remote() + host = RestartableLongPollHost.remote() updated_values = ray.get(host.listen_for_change.remote({"timer": -1})) timer: UpdatedObject = updated_values["timer"] @@ -81,22 +80,31 @@ def test_long_poll_restarts(serve_instance): @pytest.mark.asyncio async def test_async_client(serve_instance): - host = ray.remote(LongPollerHost).remote() + host = ray.remote(LongPollHost).remote() # Write two values ray.get(host.notify_changed.remote("key_1", 100)) ray.get(host.notify_changed.remote("key_2", 999)) + # Check that construction fails with a sync callback. + def callback(result, key): + pass + + with pytest.raises(ValueError): + client = LongPollAsyncClient(host, {"key": callback}) + callback_results = dict() - async def callback(result, key): - callback_results[key] = result + async def key_1_callback(result): + callback_results["key_1"] = result - client = LongPollerAsyncClient( - host, { - "key_1": functools.partial(callback, key="key_1"), - "key_2": functools.partial(callback, key="key_2") - }) + async def key_2_callback(result): + callback_results["key_2"] = result + + client = LongPollAsyncClient(host, { + "key_1": key_1_callback, + "key_2": key_2_callback, + }) while len(client.object_snapshots) == 0: # Yield the loop for client to get the result diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index e9c95925d..efa6f1b6b 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -144,6 +144,7 @@ class ServeEncoder(json.JSONEncoder): @ray.remote(num_cpus=0) def block_until_http_ready(http_endpoint, backoff_time_s=1, + check_ready=None, timeout=HTTP_PROXY_TIMEOUT): http_is_ready = False start_time = time.time() @@ -152,7 +153,10 @@ def block_until_http_ready(http_endpoint, try: resp = requests.get(http_endpoint) assert resp.status_code == 200 - http_is_ready = True + if check_ready is None: + http_is_ready = True + else: + http_is_ready = check_ready(resp) except Exception: pass