mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 00:52:45 +08:00
[serve] Long polling for routes in http server (#12724)
This commit is contained in:
+30
-1
@@ -43,6 +43,8 @@ class Client:
|
||||
self._controller_name = controller_name
|
||||
self._detached = detached
|
||||
self._shutdown = False
|
||||
self._http_host, self._http_port = ray.get(
|
||||
controller.get_http_config.remote())
|
||||
|
||||
# NOTE(simon): Used to cache client.get_handle(endpoint) call. It will
|
||||
# mostly grow in size, it will only shrink when user calls the
|
||||
@@ -149,6 +151,29 @@ class Client:
|
||||
self._controller.create_endpoint.remote(
|
||||
endpoint_name, {backend: 1.0}, route, upper_methods))
|
||||
|
||||
# Block until the route table has been propagated to all HTTP proxies.
|
||||
if route is not None:
|
||||
|
||||
def check_ready(http_response):
|
||||
return route in http_response.json()
|
||||
|
||||
futures = []
|
||||
for node_id in ray.state.node_ids():
|
||||
future = block_until_http_ready.options(
|
||||
num_cpus=0, resources={
|
||||
node_id: 0.01
|
||||
}).remote(
|
||||
"http://{}:{}/-/routes".format(self._http_host,
|
||||
self._http_port),
|
||||
check_ready=check_ready,
|
||||
timeout=HTTP_PROXY_TIMEOUT)
|
||||
futures.append(future)
|
||||
try:
|
||||
ray.get(futures)
|
||||
except ray.exceptions.RayTaskError:
|
||||
raise TimeoutError("Route not available at HTTP proxies "
|
||||
"after {HTTP_PROXY_TIMEOUT}s.")
|
||||
|
||||
@_ensure_connected
|
||||
def delete_endpoint(self, endpoint: str) -> None:
|
||||
"""Delete the given endpoint.
|
||||
@@ -453,7 +478,11 @@ def start(detached: bool = False,
|
||||
"http://{}:{}/-/routes".format(http_host, http_port),
|
||||
timeout=HTTP_PROXY_TIMEOUT)
|
||||
futures.append(future)
|
||||
ray.get(futures)
|
||||
try:
|
||||
ray.get(futures)
|
||||
except ray.exceptions.RayTaskError:
|
||||
raise TimeoutError(
|
||||
"HTTP proxies not available after {HTTP_PROXY_TIMEOUT}s.")
|
||||
|
||||
return Client(controller, controller_name, detached=detached)
|
||||
|
||||
|
||||
@@ -15,10 +15,13 @@ from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.util import metrics
|
||||
from ray.serve.config import BackendConfig
|
||||
from ray.serve.long_poll import LongPollerAsyncClient
|
||||
from ray.serve.long_poll import LongPollAsyncClient
|
||||
from ray.serve.router import Query
|
||||
from ray.serve.constants import (DEFAULT_LATENCY_BUCKET_MS,
|
||||
BACKEND_RECONFIGURE_METHOD)
|
||||
from ray.serve.constants import (
|
||||
BACKEND_RECONFIGURE_METHOD,
|
||||
DEFAULT_LATENCY_BUCKET_MS,
|
||||
LongPollKey,
|
||||
)
|
||||
from ray.exceptions import RayTaskError
|
||||
|
||||
logger = _get_logger()
|
||||
@@ -168,8 +171,8 @@ class RayServeReplica:
|
||||
tag_keys=("backend", ))
|
||||
self.request_counter.set_default_tags({"backend": self.backend_tag})
|
||||
|
||||
self.long_poll_client = LongPollerAsyncClient(controller_handle, {
|
||||
"backend_configs": self._update_backend_configs,
|
||||
self.long_poll_client = LongPollAsyncClient(controller_handle, {
|
||||
LongPollKey.BACKEND_CONFIGS: self._update_backend_configs,
|
||||
})
|
||||
|
||||
self.error_counter = metrics.Count(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from enum import auto, Enum
|
||||
|
||||
#: Actor name used to register controller
|
||||
SERVE_CONTROLLER_NAME = "SERVE_CONTROLLER_ACTOR"
|
||||
|
||||
@@ -37,3 +39,13 @@ DEFAULT_LATENCY_BUCKET_MS = [
|
||||
|
||||
#: Name of backend reconfiguration method implemented by user.
|
||||
BACKEND_RECONFIGURE_METHOD = "reconfigure"
|
||||
|
||||
|
||||
class LongPollKey(Enum):
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
REPLICA_HANDLES = auto()
|
||||
TRAFFIC_POLICIES = auto()
|
||||
BACKEND_CONFIGS = auto()
|
||||
ROUTE_TABLE = auto()
|
||||
|
||||
@@ -13,14 +13,15 @@ import ray
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.serve.autoscaling_policy import BasicAutoscalingPolicy
|
||||
from ray.serve.backend_worker import create_backend_replica
|
||||
from ray.serve.constants import ASYNC_CONCURRENCY, SERVE_PROXY_NAME
|
||||
from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_PROXY_NAME,
|
||||
LongPollKey)
|
||||
from ray.serve.http_proxy import HTTPProxyActor
|
||||
from ray.serve.kv_store import RayInternalKVStore
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
|
||||
try_schedule_resources_on_nodes, get_all_node_ids)
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig
|
||||
from ray.serve.long_poll import LongPollerHost
|
||||
from ray.serve.long_poll import LongPollHost
|
||||
from ray.actor import ActorHandle
|
||||
|
||||
import numpy as np
|
||||
@@ -145,7 +146,7 @@ class ActorStateReconciler:
|
||||
controller_name: str = field(init=True)
|
||||
detached: bool = field(init=True)
|
||||
|
||||
routers_cache: Dict[NodeId, ActorHandle] = field(default_factory=dict)
|
||||
http_proxy_cache: Dict[NodeId, ActorHandle] = field(default_factory=dict)
|
||||
backend_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]] = field(
|
||||
default_factory=lambda: defaultdict(dict))
|
||||
backend_replicas_to_start: Dict[BackendTag, List[ReplicaTag]] = field(
|
||||
@@ -157,8 +158,8 @@ class ActorStateReconciler:
|
||||
|
||||
# TODO(edoakes): consider removing this and just using the names.
|
||||
|
||||
def router_handles(self) -> List[ActorHandle]:
|
||||
return list(self.routers_cache.values())
|
||||
def http_proxy_handles(self) -> List[ActorHandle]:
|
||||
return list(self.http_proxy_cache.values())
|
||||
|
||||
def get_replica_handles(self) -> List[ActorHandle]:
|
||||
return list(
|
||||
@@ -303,7 +304,7 @@ class ActorStateReconciler:
|
||||
async def _stop_pending_backend_replicas(self) -> None:
|
||||
"""Stops the pending backend replicas in self.backend_replicas_to_stop.
|
||||
|
||||
Removes backend_replicas from the router, kills them, and clears
|
||||
Removes backend_replicas from the http_proxy, kills them, and clears
|
||||
self.backend_replicas_to_stop.
|
||||
"""
|
||||
for backend_tag, replicas_list in self.backend_replicas_to_stop.items(
|
||||
@@ -327,26 +328,26 @@ class ActorStateReconciler:
|
||||
|
||||
self.backend_replicas_to_stop.clear()
|
||||
|
||||
def _start_routers_if_needed(self, http_host: str, http_port: str,
|
||||
http_middlewares: List[Any]) -> None:
|
||||
"""Start a router on every node if it doesn't already exist."""
|
||||
def _start_http_proxies_if_needed(self, http_host: str, http_port: str,
|
||||
http_middlewares: List[Any]) -> None:
|
||||
"""Start an HTTP proxy on every node if it doesn't already exist."""
|
||||
if http_host is None:
|
||||
return
|
||||
|
||||
for node_id, node_resource in get_all_node_ids():
|
||||
if node_id in self.routers_cache:
|
||||
if node_id in self.http_proxy_cache:
|
||||
continue
|
||||
|
||||
router_name = format_actor_name(SERVE_PROXY_NAME,
|
||||
self.controller_name, node_id)
|
||||
name = format_actor_name(SERVE_PROXY_NAME, self.controller_name,
|
||||
node_id)
|
||||
try:
|
||||
router = ray.get_actor(router_name)
|
||||
proxy = ray.get_actor(name)
|
||||
except ValueError:
|
||||
logger.info("Starting router with name '{}' on node '{}' "
|
||||
logger.info("Starting HTTP proxy with name '{}' on node '{}' "
|
||||
"listening on '{}:{}'".format(
|
||||
router_name, node_id, http_host, http_port))
|
||||
router = HTTPProxyActor.options(
|
||||
name=router_name,
|
||||
name, node_id, http_host, http_port))
|
||||
proxy = HTTPProxyActor.options(
|
||||
name=name,
|
||||
lifetime="detached" if self.detached else None,
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_restarts=-1,
|
||||
@@ -360,10 +361,10 @@ class ActorStateReconciler:
|
||||
controller_name=self.controller_name,
|
||||
http_middlewares=http_middlewares)
|
||||
|
||||
self.routers_cache[node_id] = router
|
||||
self.http_proxy_cache[node_id] = proxy
|
||||
|
||||
def _stop_routers_if_needed(self) -> bool:
|
||||
"""Removes router actors from any nodes that no longer exist.
|
||||
def _stop_http_proxies_if_needed(self) -> bool:
|
||||
"""Removes HTTP proxy actors from any nodes that no longer exist.
|
||||
|
||||
Returns whether or not any actors were removed (a checkpoint should
|
||||
be taken).
|
||||
@@ -371,25 +372,25 @@ class ActorStateReconciler:
|
||||
actor_stopped = False
|
||||
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
|
||||
to_stop = []
|
||||
for node_id in self.routers_cache:
|
||||
for node_id in self.http_proxy_cache:
|
||||
if node_id not in all_node_ids:
|
||||
logger.info(
|
||||
"Removing router on removed node '{}'.".format(node_id))
|
||||
logger.info("Removing HTTP proxy on removed node '{}'.".format(
|
||||
node_id))
|
||||
to_stop.append(node_id)
|
||||
|
||||
for node_id in to_stop:
|
||||
router_handle = self.routers_cache.pop(node_id)
|
||||
ray.kill(router_handle, no_restart=True)
|
||||
proxy = self.http_proxy_cache.pop(node_id)
|
||||
ray.kill(proxy, no_restart=True)
|
||||
actor_stopped = True
|
||||
|
||||
return actor_stopped
|
||||
|
||||
def _recover_actor_handles(self) -> None:
|
||||
# Refresh the RouterCache
|
||||
for node_id in self.routers_cache.keys():
|
||||
router_name = format_actor_name(SERVE_PROXY_NAME,
|
||||
self.controller_name, node_id)
|
||||
self.routers_cache[node_id] = ray.get_actor(router_name)
|
||||
for node_id in self.http_proxy_cache.keys():
|
||||
name = format_actor_name(SERVE_PROXY_NAME, self.controller_name,
|
||||
node_id)
|
||||
self.http_proxy_cache[node_id] = ray.get_actor(name)
|
||||
|
||||
# Fetch actor handles for all of the backend replicas in the system.
|
||||
# All of these backend_replicas are guaranteed to already exist because
|
||||
@@ -482,7 +483,7 @@ class ServeController:
|
||||
# backend -> AutoscalingPolicy
|
||||
self.autoscaling_policies = dict()
|
||||
|
||||
# Dictionary of backend_tag -> router_name -> most recent queue length.
|
||||
# Dictionary of backend_tag -> proxy_name -> most recent queue length.
|
||||
self.backend_stats = defaultdict(lambda: defaultdict(dict))
|
||||
|
||||
# Used to ensure that only a single state-changing operation happens
|
||||
@@ -495,7 +496,7 @@ class ServeController:
|
||||
|
||||
# If starting the actor for the first time, starts up the other system
|
||||
# components. If recovering, fetches their actor handles.
|
||||
self.actor_reconciler._start_routers_if_needed(
|
||||
self.actor_reconciler._start_http_proxies_if_needed(
|
||||
self.http_host, self.http_port, self.http_middlewares)
|
||||
|
||||
# Map of awaiting results
|
||||
@@ -528,10 +529,11 @@ class ServeController:
|
||||
# can be problem at scale, e.g. updating a single backend config
|
||||
# will send over the entire configs. In the future, we should
|
||||
# optimize the logic to support subscription by key.
|
||||
self.long_poll_host = LongPollerHost()
|
||||
self.long_poll_host = LongPollHost()
|
||||
self.notify_backend_configs_changed()
|
||||
self.notify_replica_handles_changed()
|
||||
self.notify_traffic_policies_changed()
|
||||
self.notify_route_table_changed()
|
||||
|
||||
asyncio.get_event_loop().create_task(self.run_control_loop())
|
||||
|
||||
@@ -565,19 +567,26 @@ class ServeController:
|
||||
|
||||
def notify_replica_handles_changed(self):
|
||||
self.long_poll_host.notify_changed(
|
||||
"worker_handles", {
|
||||
LongPollKey.REPLICA_HANDLES, {
|
||||
backend_tag: list(replica_dict.values())
|
||||
for backend_tag, replica_dict in
|
||||
self.actor_reconciler.backend_replicas.items()
|
||||
})
|
||||
|
||||
def notify_traffic_policies_changed(self):
|
||||
self.long_poll_host.notify_changed("traffic_policies",
|
||||
self.current_state.traffic_policies)
|
||||
self.long_poll_host.notify_changed(
|
||||
LongPollKey.TRAFFIC_POLICIES,
|
||||
self.current_state.traffic_policies,
|
||||
)
|
||||
|
||||
def notify_backend_configs_changed(self):
|
||||
self.long_poll_host.notify_changed(
|
||||
"backend_configs", self.current_state.get_backend_configs())
|
||||
LongPollKey.BACKEND_CONFIGS,
|
||||
self.current_state.get_backend_configs())
|
||||
|
||||
def notify_route_table_changed(self):
|
||||
self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE,
|
||||
self.current_state.routes)
|
||||
|
||||
async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
|
||||
"""Proxy long pull client's listen request.
|
||||
@@ -590,13 +599,9 @@ class ServeController:
|
||||
return await (
|
||||
self.long_poll_host.listen_for_change(keys_to_snapshot_ids))
|
||||
|
||||
def get_routers(self) -> Dict[str, ActorHandle]:
|
||||
"""Returns a dictionary of node ID to router actor handles."""
|
||||
return self.actor_reconciler.routers_cache
|
||||
|
||||
def get_router_config(self) -> Dict[str, Tuple[str, List[str]]]:
|
||||
"""Called by the router on startup to fetch required state."""
|
||||
return self.current_state.routes
|
||||
def get_http_proxies(self) -> Dict[str, ActorHandle]:
|
||||
"""Returns a dictionary of node ID to http_proxy actor handles."""
|
||||
return self.actor_reconciler.http_proxy_cache
|
||||
|
||||
def _checkpoint(self) -> None:
|
||||
"""Checkpoint internal state and write it to the KV store."""
|
||||
@@ -622,7 +627,7 @@ class ServeController:
|
||||
|
||||
Performs the following operations:
|
||||
1) Deserializes the internal state from the checkpoint.
|
||||
2) Pushes the latest configuration to the routers
|
||||
2) Pushes the latest configuration to the HTTP proxies
|
||||
in case we crashed before updating them.
|
||||
3) Starts/stops any replicas that are pending creation or
|
||||
deletion.
|
||||
@@ -670,40 +675,26 @@ class ServeController:
|
||||
while True:
|
||||
await self.do_autoscale()
|
||||
async with self.write_lock:
|
||||
self.actor_reconciler._start_routers_if_needed(
|
||||
self.actor_reconciler._start_http_proxies_if_needed(
|
||||
self.http_host, self.http_port, self.http_middlewares)
|
||||
checkpoint_required = self.actor_reconciler.\
|
||||
_stop_routers_if_needed()
|
||||
_stop_http_proxies_if_needed()
|
||||
if checkpoint_required:
|
||||
self._checkpoint()
|
||||
|
||||
await asyncio.sleep(CONTROL_LOOP_PERIOD_S)
|
||||
|
||||
def get_backend_configs(self) -> Dict[str, BackendConfig]:
|
||||
"""Fetched by the router on startup."""
|
||||
return self.current_state.get_backend_configs()
|
||||
|
||||
def get_traffic_policies(self) -> Dict[str, TrafficPolicy]:
|
||||
"""Fetched by the router on startup."""
|
||||
return self.current_state.traffic_policies
|
||||
|
||||
def _list_replicas(self, backend_tag: BackendTag) -> List[ReplicaTag]:
|
||||
"""Used only for testing."""
|
||||
return list(self.actor_reconciler.backend_replicas[backend_tag].keys())
|
||||
|
||||
def get_traffic_policy(self, endpoint: str) -> TrafficPolicy:
|
||||
"""Fetched by serve handles."""
|
||||
return self.current_state.traffic_policies[endpoint]
|
||||
|
||||
def get_all_replica_handles(self) -> Dict[str, Dict[str, ActorHandle]]:
|
||||
"""Fetched by the router on startup."""
|
||||
def _all_replica_handles(
|
||||
self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]:
|
||||
"""Used for testing."""
|
||||
return self.actor_reconciler.backend_replicas
|
||||
|
||||
def get_all_backends(self) -> Dict[str, BackendConfig]:
|
||||
def get_all_backends(self) -> Dict[BackendTag, BackendConfig]:
|
||||
"""Returns a dictionary of backend tag to backend config."""
|
||||
return self.current_state.get_backend_configs()
|
||||
|
||||
def get_all_endpoints(self) -> Dict[str, Dict[str, Any]]:
|
||||
def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]:
|
||||
"""Returns a dictionary of backend tag to backend config."""
|
||||
return self.current_state.get_endpoints()
|
||||
|
||||
async def _set_traffic(self, endpoint_name: str,
|
||||
@@ -731,7 +722,6 @@ class ServeController:
|
||||
# update to avoid inconsistent state if we crash after pushing the
|
||||
# update.
|
||||
self._checkpoint()
|
||||
|
||||
self.notify_traffic_policies_changed()
|
||||
return return_uuid
|
||||
|
||||
@@ -814,10 +804,7 @@ class ServeController:
|
||||
|
||||
# NOTE(edoakes): checkpoint is written in self._set_traffic.
|
||||
return_uuid = await self._set_traffic(endpoint, traffic_dict)
|
||||
await asyncio.gather(*[
|
||||
router.set_route_table.remote(self.current_state.routes)
|
||||
for router in self.actor_reconciler.router_handles()
|
||||
])
|
||||
self.notify_route_table_changed()
|
||||
return return_uuid
|
||||
|
||||
async def delete_endpoint(self, endpoint: str) -> UUID:
|
||||
@@ -852,14 +839,10 @@ class ServeController:
|
||||
endpoint: None
|
||||
})
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
# updates to the routers to avoid inconsistent state if we crash
|
||||
# updates to the proxies to avoid inconsistent state if we crash
|
||||
# after pushing the update.
|
||||
self._checkpoint()
|
||||
|
||||
await asyncio.gather(*[
|
||||
router.set_route_table.remote(self.current_state.routes)
|
||||
for router in self.actor_reconciler.router_handles()
|
||||
])
|
||||
self.notify_route_table_changed()
|
||||
return return_uuid
|
||||
|
||||
async def create_backend(self, backend_tag: BackendTag,
|
||||
@@ -910,7 +893,7 @@ class ServeController:
|
||||
|
||||
self.notify_replica_handles_changed()
|
||||
|
||||
# Set the backend config inside the router
|
||||
# Set the backend config inside routers
|
||||
# (particularly for max_concurrent_queries).
|
||||
self.notify_backend_configs_changed()
|
||||
return return_uuid
|
||||
@@ -943,12 +926,12 @@ class ServeController:
|
||||
if backend_tag in self.autoscaling_policies:
|
||||
del self.autoscaling_policies[backend_tag]
|
||||
|
||||
# Add the intention to remove the backend from the router.
|
||||
# Add the intention to remove the backend from the routers.
|
||||
self.actor_reconciler.backends_to_remove.append(backend_tag)
|
||||
|
||||
return_uuid = self._create_event_with_result({backend_tag: None})
|
||||
# NOTE(edoakes): we must write a checkpoint before removing the
|
||||
# backend from the router to avoid inconsistent state if we crash
|
||||
# backend from the routers to avoid inconsistent state if we crash
|
||||
# after pushing the update.
|
||||
self._checkpoint()
|
||||
await self.actor_reconciler._stop_pending_backend_replicas()
|
||||
@@ -986,7 +969,7 @@ class ServeController:
|
||||
# update.
|
||||
self._checkpoint()
|
||||
|
||||
# Inform the router about change in configuration
|
||||
# Inform the routers about change in configuration
|
||||
# (particularly for setting max_batch_size).
|
||||
|
||||
await self.actor_reconciler._start_pending_backend_replicas(
|
||||
@@ -1003,11 +986,15 @@ class ServeController:
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
return self.current_state.get_backend(backend_tag).backend_config
|
||||
|
||||
def get_http_config(self):
|
||||
"""Return the HTTP proxy configuration."""
|
||||
return self.http_host, self.http_port
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shuts down the serve instance completely."""
|
||||
async with self.write_lock:
|
||||
for router in self.actor_reconciler.router_handles():
|
||||
ray.kill(router, no_restart=True)
|
||||
for http_proxy in self.actor_reconciler.http_proxy_handles():
|
||||
ray.kill(http_proxy, no_restart=True)
|
||||
for replica in self.actor_reconciler.get_replica_handles():
|
||||
ray.kill(replica, no_restart=True)
|
||||
self.kv_store.delete(CHECKPOINT_KEY)
|
||||
|
||||
@@ -89,5 +89,6 @@ class RandomEndpointPolicy(EndpointPolicy):
|
||||
query.metadata.shard_key.encode("utf-8"))
|
||||
|
||||
chosen_backend, shadow_backends = self._select_backends(value)
|
||||
logger.debug(f"Chosen backend {chosen_backend} for query {query}")
|
||||
logger.debug(f"Assigning query {query.metadata.request_id} "
|
||||
f"to backend {chosen_backend}.")
|
||||
return [chosen_backend] + shadow_backends
|
||||
|
||||
@@ -6,43 +6,40 @@ import uvicorn
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray.serve.constants import LongPollKey
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.util import metrics
|
||||
from ray.serve.utils import _get_logger, get_random_letters
|
||||
from ray.serve.http_util import Response
|
||||
from ray.serve.long_poll import LongPollAsyncClient
|
||||
from ray.serve.router import Router, RequestMetadata
|
||||
|
||||
# The maximum number of times to retry a request due to actor failure.
|
||||
# TODO(edoakes): this should probably be configurable.
|
||||
MAX_ACTOR_DEAD_RETRIES = 10
|
||||
|
||||
logger = _get_logger()
|
||||
|
||||
|
||||
class HTTPProxy:
|
||||
"""
|
||||
This class should be instantiated and ran by ASGI server.
|
||||
"""This class is meant to be instantiated and run by an ASGI HTTP server.
|
||||
|
||||
>>> import uvicorn
|
||||
>>> uvicorn.run(HTTPProxy(kv_store_actor_handle, router_handle))
|
||||
# blocks forever
|
||||
"""
|
||||
|
||||
async def fetch_config_from_controller(self, controller_name):
|
||||
assert ray.is_initialized()
|
||||
def __init__(self, controller_name):
|
||||
controller = ray.get_actor(controller_name)
|
||||
|
||||
self.route_table = await controller.get_router_config.remote()
|
||||
self.router = Router(controller)
|
||||
self.long_poll_client = LongPollAsyncClient(controller, {
|
||||
LongPollKey.ROUTE_TABLE: self._update_route_table,
|
||||
})
|
||||
|
||||
self.request_counter = metrics.Count(
|
||||
"num_http_requests",
|
||||
description="The number of HTTP requests processed",
|
||||
tag_keys=("route", ))
|
||||
|
||||
self.router = Router(controller)
|
||||
async def setup(self):
|
||||
await self.router.setup_in_async_loop()
|
||||
|
||||
def set_route_table(self, route_table):
|
||||
async def _update_route_table(self, route_table):
|
||||
self.route_table = route_table
|
||||
|
||||
async def receive_http_body(self, scope, receive, send):
|
||||
@@ -74,8 +71,11 @@ class HTTPProxy:
|
||||
status_code=404).send(scope, receive, send)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
# NOTE: This implements ASGI protocol specified in
|
||||
# https://asgi.readthedocs.io/en/latest/specs/index.html
|
||||
"""Implements the ASGI protocol.
|
||||
|
||||
See details at:
|
||||
https://asgi.readthedocs.io/en/latest/specs/index.html.
|
||||
"""
|
||||
|
||||
error_sender = self._make_error_sender(scope, receive, send)
|
||||
|
||||
@@ -137,12 +137,13 @@ class HTTPProxyActor:
|
||||
host,
|
||||
port,
|
||||
controller_name,
|
||||
http_middlewares: List["starlette.middleware.Middleware"] = []):
|
||||
http_middlewares: List[
|
||||
"starlette.middleware.Middleware"] = []): # noqa: F821
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self.app = HTTPProxy()
|
||||
await self.app.fetch_config_from_controller(controller_name)
|
||||
self.app = HTTPProxy(controller_name)
|
||||
await self.app.setup()
|
||||
|
||||
self.wrapped_app = self.app
|
||||
for middleware in http_middlewares:
|
||||
@@ -180,6 +181,3 @@ class HTTPProxyActor:
|
||||
# the main thread and uvicorn doesn't expose a way to configure it.
|
||||
server.install_signal_handlers = lambda: None
|
||||
await server.serve(sockets=[sock])
|
||||
|
||||
async def set_route_table(self, route_table):
|
||||
self.app.set_route_table(route_table)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from inspect import iscoroutinefunction
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
@@ -22,7 +23,7 @@ class UpdatedObject:
|
||||
UpdateStateAsyncCallable = Callable[[Any], Awaitable[None]]
|
||||
|
||||
|
||||
class LongPollerAsyncClient:
|
||||
class LongPollAsyncClient:
|
||||
"""The asynchronous long polling client.
|
||||
|
||||
Internally, it runs `await object_ref` in a `while True` loop. When a
|
||||
@@ -31,7 +32,7 @@ class LongPollerAsyncClient:
|
||||
the next poll.
|
||||
|
||||
Args:
|
||||
host_actor(ray.ActorHandle): handle to actor embedding LongPollerHost.
|
||||
host_actor(ray.ActorHandle): handle to actor embedding LongPollHost.
|
||||
key_listeners(Dict[str, AsyncCallable]): a dictionary mapping keys to
|
||||
callbacks to be called on state update for the corresponding keys.
|
||||
"""
|
||||
@@ -40,6 +41,10 @@ class LongPollerAsyncClient:
|
||||
key_listeners: Dict[str, UpdateStateAsyncCallable]) -> None:
|
||||
self.host_actor = host_actor
|
||||
self.key_listeners = key_listeners
|
||||
for callback in key_listeners.values():
|
||||
if not iscoroutinefunction(callback):
|
||||
raise ValueError(
|
||||
"Callbacks to async long poller must be 'async def'.")
|
||||
|
||||
self.snapshot_ids: Dict[str, int] = {
|
||||
key: -1
|
||||
@@ -56,34 +61,31 @@ class LongPollerAsyncClient:
|
||||
self.snapshot_ids)
|
||||
return object_ref
|
||||
|
||||
def _update(self, updates: Dict[str, UpdatedObject]):
|
||||
for key, update in updates.items():
|
||||
self.object_snapshots[key] = update.object_snapshot
|
||||
self.snapshot_ids[key] = update.snapshot_id
|
||||
|
||||
async def _do_long_poll(self):
|
||||
while True:
|
||||
try:
|
||||
updates: Dict[str, UpdatedObject] = await self._poll_once()
|
||||
self._update(updates)
|
||||
logger.debug(f"LongPollerClient received udpates: {updates}")
|
||||
for key, updated_object in updates.items():
|
||||
logger.debug("LongPollClient received updates for keys: "
|
||||
f"{list(updates.keys())}.")
|
||||
for key, update in updates.items():
|
||||
self.object_snapshots[key] = update.object_snapshot
|
||||
self.snapshot_ids[key] = update.snapshot_id
|
||||
# NOTE(simon):
|
||||
# This blocks the loop from doing another poll. Consider
|
||||
# use loop.create_task here or poll first then call the
|
||||
# callbacks.
|
||||
callback = self.key_listeners[key]
|
||||
await callback(updated_object.object_snapshot)
|
||||
await callback(update.object_snapshot)
|
||||
except ray.exceptions.RayActorError:
|
||||
# This can happen during shutdown where the controller is
|
||||
# intentionally killed, the client should just gracefully
|
||||
# exit.
|
||||
logger.debug("LongPollerClient failed to connect to host. "
|
||||
logger.debug("LongPollClient failed to connect to host. "
|
||||
"Shutting down.")
|
||||
break
|
||||
|
||||
|
||||
class LongPollerHost:
|
||||
class LongPollHost:
|
||||
"""The server side object that manages long pulling requests.
|
||||
|
||||
The desired use case is to embed this in an Ray actor. Client will be
|
||||
@@ -115,11 +117,10 @@ class LongPollerHost:
|
||||
immediately if the snapshot_ids are outdated, otherwise it will block
|
||||
until there's one updates.
|
||||
"""
|
||||
# 1. Figure out which keys do we care about
|
||||
watched_keys = set(self.snapshot_ids.keys()).intersection(
|
||||
keys_to_snapshot_ids.keys())
|
||||
if len(watched_keys) == 0:
|
||||
raise ValueError("Keys not found.")
|
||||
watched_keys = keys_to_snapshot_ids.keys()
|
||||
nonexistent_keys = set(watched_keys) - set(self.snapshot_ids.keys())
|
||||
if len(nonexistent_keys) > 0:
|
||||
raise ValueError(f"Keys not found: {nonexistent_keys}.")
|
||||
|
||||
# 2. If there are any outdated keys (by comparing snapshot ids)
|
||||
# return immediately.
|
||||
@@ -159,7 +160,7 @@ class LongPollerHost:
|
||||
def notify_changed(self, object_key: str, updated_object: Any):
|
||||
self.snapshot_ids[object_key] += 1
|
||||
self.object_snapshots[object_key] = updated_object
|
||||
logger.debug(f"LongPollerHost: {object_key} = {updated_object}")
|
||||
logger.debug(f"LongPollHost: Notify change for key {object_key}.")
|
||||
|
||||
if object_key in self.notifier_events:
|
||||
for event in self.notifier_events.pop(object_key):
|
||||
|
||||
+15
-12
@@ -6,9 +6,10 @@ from typing import Any, DefaultDict, Dict, Iterable, List, Optional
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.serve.constants import LongPollKey
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.endpoint_policy import EndpointPolicy, RandomEndpointPolicy
|
||||
from ray.serve.long_poll import LongPollerAsyncClient
|
||||
from ray.serve.long_poll import LongPollAsyncClient
|
||||
from ray.serve.utils import logger
|
||||
from ray.util import metrics
|
||||
|
||||
@@ -106,7 +107,8 @@ class ReplicaSet:
|
||||
) >= self.max_concurrent_queries:
|
||||
# This replica is overloaded, try next one
|
||||
continue
|
||||
logger.debug(f"Replica set assigned {query} to {replica}")
|
||||
logger.debug(f"Assigned query {query.metadata.request_id} "
|
||||
f"to replica {replica}.")
|
||||
ref = replica.handle_request.remote(query)
|
||||
self.in_flight_queries[replica].add(ref)
|
||||
return ref
|
||||
@@ -133,7 +135,8 @@ class ReplicaSet:
|
||||
"""
|
||||
assigned_ref = self._try_assign_replica(query)
|
||||
while assigned_ref is None: # Can't assign a replica right now.
|
||||
logger.debug(f"Failed to assign a replica for query {query}")
|
||||
logger.debug("Failed to assign a replica for "
|
||||
f"query {query.metadata.request_id}")
|
||||
# Maybe there exists a free replica, we just need to refresh our
|
||||
# query tracker.
|
||||
num_finished = self._drain_completed_object_refs()
|
||||
@@ -141,7 +144,7 @@ class ReplicaSet:
|
||||
# config to be updated.
|
||||
if num_finished == 0:
|
||||
logger.debug(
|
||||
f"All replicas are busy, waiting for a free replica.")
|
||||
"All replicas are busy, waiting for a free replica.")
|
||||
await asyncio.wait(
|
||||
self._all_query_refs + [self.config_updated_event.wait()],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
@@ -176,14 +179,14 @@ class Router:
|
||||
|
||||
async def setup_in_async_loop(self):
|
||||
# NOTE(simon): Instead of performing initialization in __init__,
|
||||
# We separated the init of LongPollerAsyncClient to this method because
|
||||
# __init__ might be called in sync context. LongPollerAsyncClient
|
||||
# We separated the init of LongPollAsyncClient to this method because
|
||||
# __init__ might be called in sync context. LongPollAsyncClient
|
||||
# requires async context.
|
||||
self.long_pull_client = LongPollerAsyncClient(
|
||||
self.long_poll_client = LongPollAsyncClient(
|
||||
self.controller, {
|
||||
"traffic_policies": self._update_traffic_policies,
|
||||
"worker_handles": self._update_worker_handles,
|
||||
"backend_configs": self._update_backend_configs,
|
||||
LongPollKey.TRAFFIC_POLICIES: self._update_traffic_policies,
|
||||
LongPollKey.REPLICA_HANDLES: self._update_replica_handles,
|
||||
LongPollKey.BACKEND_CONFIGS: self._update_backend_configs,
|
||||
})
|
||||
|
||||
async def _update_traffic_policies(self, traffic_policies):
|
||||
@@ -194,8 +197,8 @@ class Router:
|
||||
event = self._pending_endpoints.pop(endpoint)
|
||||
event.set()
|
||||
|
||||
async def _update_worker_handles(self, worker_handles):
|
||||
for backend_tag, replica_handles in worker_handles.items():
|
||||
async def _update_replica_handles(self, replica_handles):
|
||||
for backend_tag, replica_handles in replica_handles.items():
|
||||
self.backend_replicas[backend_tag].update_worker_replicas(
|
||||
replica_handles)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve.config import BackendConfig
|
||||
from ray.serve.constants import LongPollKey
|
||||
|
||||
if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False) == 1:
|
||||
serve.controller._CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.5
|
||||
@@ -42,22 +43,22 @@ def mock_controller_with_name():
|
||||
@ray.remote(num_cpus=0)
|
||||
class MockControllerActor:
|
||||
def __init__(self):
|
||||
from ray.serve.long_poll import LongPollerHost
|
||||
self.host = LongPollerHost()
|
||||
from ray.serve.long_poll import LongPollHost
|
||||
self.host = LongPollHost()
|
||||
self.backend_replicas = defaultdict(list)
|
||||
self.backend_configs = dict()
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self.host.notify_changed("worker_handles", {})
|
||||
self.host.notify_changed("traffic_policies", {})
|
||||
self.host.notify_changed("backend_configs", {})
|
||||
self.host.notify_changed(LongPollKey.REPLICA_HANDLES, {})
|
||||
self.host.notify_changed(LongPollKey.TRAFFIC_POLICIES, {})
|
||||
self.host.notify_changed(LongPollKey.BACKEND_CONFIGS, {})
|
||||
|
||||
async def listen_for_change(self, snapshot_ids):
|
||||
return await self.host.listen_for_change(snapshot_ids)
|
||||
|
||||
def set_traffic(self, endpoint, traffic_policy):
|
||||
self.host.notify_changed("traffic_policies",
|
||||
self.host.notify_changed(LongPollKey.TRAFFIC_POLICIES,
|
||||
{endpoint: traffic_policy})
|
||||
|
||||
def add_new_replica(self,
|
||||
@@ -68,15 +69,17 @@ def mock_controller_with_name():
|
||||
self.backend_configs[backend_tag] = backend_config
|
||||
|
||||
self.host.notify_changed(
|
||||
"worker_handles",
|
||||
LongPollKey.REPLICA_HANDLES,
|
||||
self.backend_replicas,
|
||||
)
|
||||
self.host.notify_changed("backend_configs", self.backend_configs)
|
||||
self.host.notify_changed(LongPollKey.BACKEND_CONFIGS,
|
||||
self.backend_configs)
|
||||
|
||||
def update_backend(self, backend_tag: str,
|
||||
backend_config: BackendConfig):
|
||||
self.backend_configs[backend_tag] = backend_config
|
||||
self.host.notify_changed("backend_configs", self.backend_configs)
|
||||
self.host.notify_changed(LongPollKey.BACKEND_CONFIGS,
|
||||
self.backend_configs)
|
||||
|
||||
name = f"MockController{random.randint(0,10e4)}"
|
||||
yield name, MockControllerActor.options(name=name).remote()
|
||||
|
||||
@@ -25,22 +25,6 @@ def test_e2e(serve_instance):
|
||||
client.create_endpoint(
|
||||
"endpoint", backend="echo:v1", route="/api", methods=["GET", "POST"])
|
||||
|
||||
retry_count = 5
|
||||
timeout_sleep = 0.5
|
||||
while True:
|
||||
try:
|
||||
resp = requests.get(
|
||||
"http://127.0.0.1:8000/-/routes", timeout=0.5).json()
|
||||
assert resp == {"/api": ["endpoint", ["GET", "POST"]]}
|
||||
break
|
||||
except Exception as e:
|
||||
time.sleep(timeout_sleep)
|
||||
timeout_sleep *= 2
|
||||
retry_count -= 1
|
||||
if retry_count == 0:
|
||||
assert False, ("Route table hasn't been updated after 3 tries."
|
||||
"The latest error was {}").format(e)
|
||||
|
||||
resp = requests.get("http://127.0.0.1:8000/api").json()["method"]
|
||||
assert resp == "GET"
|
||||
|
||||
@@ -63,7 +47,7 @@ def test_backend_user_config(serve_instance):
|
||||
|
||||
config = BackendConfig(num_replicas=2, user_config={"count": 123, "b": 2})
|
||||
client.create_backend("counter", Counter, config=config)
|
||||
client.create_endpoint("counter", backend="counter", route="/counter")
|
||||
client.create_endpoint("counter", backend="counter")
|
||||
handle = client.get_handle("counter")
|
||||
|
||||
def check(val, num_replicas):
|
||||
@@ -183,7 +167,7 @@ def test_reject_duplicate_endpoint_and_route(serve_instance):
|
||||
def test_no_http(serve_instance):
|
||||
client = serve.start(http_host=None)
|
||||
|
||||
assert len(ray.get(client._controller.get_routers.remote())) == 0
|
||||
assert len(ray.get(client._controller.get_http_proxies.remote())) == 0
|
||||
|
||||
def hello(*args):
|
||||
return "hello"
|
||||
@@ -223,11 +207,6 @@ def test_scaling_replicas(serve_instance):
|
||||
|
||||
client.create_endpoint("counter", backend="counter:v1", route="/increment")
|
||||
|
||||
# Keep checking the routing table until /increment is populated
|
||||
while "/increment" not in requests.get(
|
||||
"http://127.0.0.1:8000/-/routes").json():
|
||||
time.sleep(0.2)
|
||||
|
||||
counter_result = []
|
||||
for _ in range(10):
|
||||
resp = requests.get("http://127.0.0.1:8000/increment").json()
|
||||
@@ -267,11 +246,6 @@ def test_batching(serve_instance):
|
||||
client.create_endpoint(
|
||||
"counter1", backend="counter:v11", route="/increment2")
|
||||
|
||||
# Keep checking the routing table until /increment is populated
|
||||
while "/increment2" not in requests.get(
|
||||
"http://127.0.0.1:8000/-/routes").json():
|
||||
time.sleep(0.2)
|
||||
|
||||
future_list = []
|
||||
handle = client.get_handle("counter1")
|
||||
for _ in range(20):
|
||||
@@ -299,8 +273,7 @@ def test_batching_exception(serve_instance):
|
||||
# Set the max batch size.
|
||||
config = BackendConfig(max_batch_size=5)
|
||||
client.create_backend("exception:v1", NoListReturned, config=config)
|
||||
client.create_endpoint(
|
||||
"exception-test", backend="exception:v1", route="/noListReturned")
|
||||
client.create_endpoint("exception-test", backend="exception:v1")
|
||||
|
||||
handle = client.get_handle("exception-test")
|
||||
with pytest.raises(ray.exceptions.RayTaskError):
|
||||
@@ -323,16 +296,16 @@ def test_updating_config(serve_instance):
|
||||
client.create_endpoint("bsimple", backend="bsimple:v1", route="/bsimple")
|
||||
|
||||
controller = client._controller
|
||||
old_replica_tag_list = ray.get(
|
||||
controller._list_replicas.remote("bsimple:v1"))
|
||||
old_replica_tag_list = list(
|
||||
ray.get(controller._all_replica_handles.remote())["bsimple:v1"].keys())
|
||||
|
||||
update_config = BackendConfig(max_batch_size=5)
|
||||
client.update_backend_config("bsimple:v1", update_config)
|
||||
new_replica_tag_list = ray.get(
|
||||
controller._list_replicas.remote("bsimple:v1"))
|
||||
new_replica_tag_list = list(
|
||||
ray.get(controller._all_replica_handles.remote())["bsimple:v1"].keys())
|
||||
new_all_tag_list = []
|
||||
for worker_dict in ray.get(
|
||||
controller.get_all_replica_handles.remote()).values():
|
||||
controller._all_replica_handles.remote()).values():
|
||||
new_all_tag_list.extend(list(worker_dict.keys()))
|
||||
|
||||
# the old and new replica tag list should be identical
|
||||
@@ -648,7 +621,7 @@ def test_create_infeasible_error(serve_instance):
|
||||
"MagicMLResource": 100
|
||||
}})
|
||||
|
||||
# Even each replica might be feasible, the total might not be.
|
||||
# Even though each replica might be feasible, the total might not be.
|
||||
current_cpus = int(ray.nodes()[0]["Resources"]["CPU"])
|
||||
num_replicas = current_cpus + 20
|
||||
config = BackendConfig(num_replicas=num_replicas)
|
||||
@@ -661,10 +634,6 @@ def test_create_infeasible_error(serve_instance):
|
||||
}},
|
||||
config=config)
|
||||
|
||||
# No replica should be created!
|
||||
replicas = ray.get(client._controller._list_replicas.remote("f1"))
|
||||
assert len(replicas) == 0
|
||||
|
||||
|
||||
def test_shutdown():
|
||||
def f():
|
||||
@@ -797,6 +766,7 @@ def test_serve_metrics(serve_instance):
|
||||
|
||||
client.create_backend("metrics", batcher)
|
||||
client.create_endpoint("metrics", backend="metrics", route="/metrics")
|
||||
|
||||
# send 10 concurrent requests
|
||||
url = "http://127.0.0.1:8000/metrics"
|
||||
ray.get([block_until_http_ready.remote(url) for _ in range(10)])
|
||||
|
||||
@@ -48,7 +48,7 @@ def setup_worker(name,
|
||||
async def add_servable_to_router(servable, router, controller_name, **kwargs):
|
||||
worker = setup_worker(
|
||||
"backend", servable, controller_name=controller_name, **kwargs)
|
||||
await router._update_worker_handles.remote({"backend": [worker]})
|
||||
await router._update_replica_handles.remote({"backend": [worker]})
|
||||
await router._update_traffic_policies.remote({
|
||||
"endpoint": TrafficPolicy({
|
||||
"backend": 1.0
|
||||
|
||||
@@ -76,10 +76,10 @@ def test_controller_failure(serve_instance):
|
||||
assert response.text == "hello3"
|
||||
|
||||
|
||||
def _kill_routers(client):
|
||||
routers = ray.get(client._controller.get_routers.remote())
|
||||
for router in routers.values():
|
||||
ray.kill(router, no_restart=False)
|
||||
def _kill_http_proxies(client):
|
||||
http_proxies = ray.get(client._controller.get_http_proxies.remote())
|
||||
for http_proxy in http_proxies.values():
|
||||
ray.kill(http_proxy, no_restart=False)
|
||||
|
||||
|
||||
def test_http_proxy_failure(serve_instance):
|
||||
@@ -98,7 +98,7 @@ def test_http_proxy_failure(serve_instance):
|
||||
response = request_with_retries("/proxy_failure", timeout=30)
|
||||
assert response.text == "hello1"
|
||||
|
||||
_kill_routers(client)
|
||||
_kill_http_proxies(client)
|
||||
|
||||
def function(_):
|
||||
return "hello2"
|
||||
@@ -113,7 +113,7 @@ def test_http_proxy_failure(serve_instance):
|
||||
|
||||
def _get_worker_handles(client, backend):
|
||||
controller = client._controller
|
||||
backend_dict = ray.get(controller.get_all_replica_handles.remote())
|
||||
backend_dict = ray.get(controller._all_replica_handles.remote())
|
||||
|
||||
return list(backend_dict[backend].values())
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import sys
|
||||
import functools
|
||||
import time
|
||||
import asyncio
|
||||
import os
|
||||
@@ -8,12 +7,12 @@ from typing import Dict
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.serve.long_poll import (LongPollerAsyncClient, LongPollerHost,
|
||||
from ray.serve.long_poll import (LongPollAsyncClient, LongPollHost,
|
||||
UpdatedObject)
|
||||
|
||||
|
||||
def test_host_standalone(serve_instance):
|
||||
host = ray.remote(LongPollerHost).remote()
|
||||
host = ray.remote(LongPollHost).remote()
|
||||
|
||||
# Write two values
|
||||
ray.get(host.notify_changed.remote("key_1", 999))
|
||||
@@ -44,10 +43,10 @@ def test_long_poll_restarts(serve_instance):
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
)
|
||||
class RestartableLongPollerHost:
|
||||
class RestartableLongPollHost:
|
||||
def __init__(self) -> None:
|
||||
print("actor started")
|
||||
self.host = LongPollerHost()
|
||||
self.host = LongPollHost()
|
||||
self.host.notify_changed("timer", time.time())
|
||||
self.should_exit = False
|
||||
|
||||
@@ -63,7 +62,7 @@ def test_long_poll_restarts(serve_instance):
|
||||
print("actor exit")
|
||||
os._exit(1)
|
||||
|
||||
host = RestartableLongPollerHost.remote()
|
||||
host = RestartableLongPollHost.remote()
|
||||
updated_values = ray.get(host.listen_for_change.remote({"timer": -1}))
|
||||
timer: UpdatedObject = updated_values["timer"]
|
||||
|
||||
@@ -81,22 +80,31 @@ def test_long_poll_restarts(serve_instance):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client(serve_instance):
|
||||
host = ray.remote(LongPollerHost).remote()
|
||||
host = ray.remote(LongPollHost).remote()
|
||||
|
||||
# Write two values
|
||||
ray.get(host.notify_changed.remote("key_1", 100))
|
||||
ray.get(host.notify_changed.remote("key_2", 999))
|
||||
|
||||
# Check that construction fails with a sync callback.
|
||||
def callback(result, key):
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
client = LongPollAsyncClient(host, {"key": callback})
|
||||
|
||||
callback_results = dict()
|
||||
|
||||
async def callback(result, key):
|
||||
callback_results[key] = result
|
||||
async def key_1_callback(result):
|
||||
callback_results["key_1"] = result
|
||||
|
||||
client = LongPollerAsyncClient(
|
||||
host, {
|
||||
"key_1": functools.partial(callback, key="key_1"),
|
||||
"key_2": functools.partial(callback, key="key_2")
|
||||
})
|
||||
async def key_2_callback(result):
|
||||
callback_results["key_2"] = result
|
||||
|
||||
client = LongPollAsyncClient(host, {
|
||||
"key_1": key_1_callback,
|
||||
"key_2": key_2_callback,
|
||||
})
|
||||
|
||||
while len(client.object_snapshots) == 0:
|
||||
# Yield the loop for client to get the result
|
||||
|
||||
@@ -144,6 +144,7 @@ class ServeEncoder(json.JSONEncoder):
|
||||
@ray.remote(num_cpus=0)
|
||||
def block_until_http_ready(http_endpoint,
|
||||
backoff_time_s=1,
|
||||
check_ready=None,
|
||||
timeout=HTTP_PROXY_TIMEOUT):
|
||||
http_is_ready = False
|
||||
start_time = time.time()
|
||||
@@ -152,7 +153,10 @@ def block_until_http_ready(http_endpoint,
|
||||
try:
|
||||
resp = requests.get(http_endpoint)
|
||||
assert resp.status_code == 200
|
||||
http_is_ready = True
|
||||
if check_ready is None:
|
||||
http_is_ready = True
|
||||
else:
|
||||
http_is_ready = check_ready(resp)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user