mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 21:46:57 +08:00
[serve] Centralize HTTP-related logic in HTTPState (#13020)
This commit is contained in:
@@ -17,7 +17,8 @@ from ray.serve.handle import RayServeHandle, RayServeSyncHandle
|
||||
from ray.serve.utils import (block_until_http_ready, format_actor_name,
|
||||
get_random_letters, logger, get_conda_env_dir)
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
|
||||
from ray.serve.config import (BackendConfig, ReplicaConfig, BackendMetadata,
|
||||
HTTPConfig)
|
||||
from ray.serve.env import CondaEnv
|
||||
from ray.serve.router import RequestMetadata, Router
|
||||
from ray.actor import ActorHandle
|
||||
@@ -93,8 +94,7 @@ class Client:
|
||||
self._controller_name = controller_name
|
||||
self._detached = detached
|
||||
self._shutdown = False
|
||||
self._http_host, self._http_port = ray.get(
|
||||
controller.get_http_config.remote())
|
||||
self._http_config = ray.get(controller.get_http_config.remote())
|
||||
|
||||
self._sync_proxied_router = None
|
||||
self._async_proxied_router = None
|
||||
@@ -237,8 +237,8 @@ class Client:
|
||||
num_cpus=0, resources={
|
||||
node_id: 0.01
|
||||
}).remote(
|
||||
"http://{}:{}/-/routes".format(self._http_host,
|
||||
self._http_port),
|
||||
"http://{}:{}/-/routes".format(self._http_config.host,
|
||||
self._http_config.port),
|
||||
check_ready=check_ready,
|
||||
timeout=HTTP_PROXY_TIMEOUT)
|
||||
futures.append(future)
|
||||
@@ -559,9 +559,7 @@ def start(detached: bool = False,
|
||||
max_task_retries=-1,
|
||||
).remote(
|
||||
controller_name,
|
||||
http_host,
|
||||
http_port,
|
||||
http_middlewares,
|
||||
HTTPConfig(http_host, http_port, http_middlewares),
|
||||
detached=detached)
|
||||
|
||||
if http_host is not None:
|
||||
|
||||
@@ -2,8 +2,8 @@ import inspect
|
||||
|
||||
from pydantic import BaseModel, PositiveInt, validator
|
||||
from ray.serve.constants import ASYNC_CONCURRENCY
|
||||
from typing import Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
def _callable_accepts_batch(func_or_class):
|
||||
@@ -191,3 +191,10 @@ class ReplicaConfig:
|
||||
raise TypeError(
|
||||
"resources in ray_actor_options must be a dictionary.")
|
||||
self.resource_dict.update(custom_resources)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPConfig:
|
||||
host: str = field(init=True)
|
||||
port: int = field(init=True)
|
||||
middlewares: List[Any] = field(init=True)
|
||||
|
||||
+108
-117
@@ -20,7 +20,7 @@ from ray.serve.kv_store import RayInternalKVStore
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
|
||||
try_schedule_resources_on_nodes, get_all_node_ids)
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig, HTTPConfig
|
||||
from ray.serve.long_poll import LongPollHost
|
||||
from ray.actor import ActorHandle
|
||||
|
||||
@@ -80,6 +80,77 @@ class TrafficPolicy:
|
||||
return f"<Traffic {self.traffic_dict}; Shadow {self.shadow_dict}>"
|
||||
|
||||
|
||||
class HTTPState:
|
||||
def __init__(self, controller_name: str, detached: bool,
|
||||
config: HTTPConfig):
|
||||
self._controller_name = controller_name
|
||||
self._detached = detached
|
||||
self._config = config
|
||||
self._proxy_actors: Dict[NodeId, ActorHandle] = dict()
|
||||
|
||||
# Will populate self.proxy_actors with existing actors.
|
||||
self._start_proxies_if_needed()
|
||||
|
||||
def get_config(self):
|
||||
return self._config
|
||||
|
||||
def get_http_proxy_handles(self) -> Dict[NodeId, ActorHandle]:
|
||||
return self._proxy_actors
|
||||
|
||||
def update(self):
|
||||
self._start_proxies_if_needed()
|
||||
self._stop_proxies_if_needed()
|
||||
|
||||
def _start_proxies_if_needed(self) -> None:
|
||||
"""Start a proxy on every node if it doesn't already exist."""
|
||||
if self._config.host is None:
|
||||
return
|
||||
|
||||
for node_id, node_resource in get_all_node_ids():
|
||||
if node_id in self._proxy_actors:
|
||||
continue
|
||||
|
||||
name = format_actor_name(SERVE_PROXY_NAME, self._controller_name,
|
||||
node_id)
|
||||
try:
|
||||
proxy = ray.get_actor(name)
|
||||
except ValueError:
|
||||
logger.info("Starting HTTP proxy with name '{}' on node '{}' "
|
||||
"listening on '{}:{}'".format(
|
||||
name, node_id, self._config.host,
|
||||
self._config.port))
|
||||
proxy = HTTPProxyActor.options(
|
||||
name=name,
|
||||
lifetime="detached" if self._detached else None,
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
resources={
|
||||
node_resource: 0.01
|
||||
},
|
||||
).remote(
|
||||
self._config.host,
|
||||
self._config.port,
|
||||
controller_name=self._controller_name,
|
||||
http_middlewares=self._config.middlewares)
|
||||
|
||||
self._proxy_actors[node_id] = proxy
|
||||
|
||||
def _stop_proxies_if_needed(self) -> bool:
|
||||
"""Removes proxy actors from any nodes that no longer exist."""
|
||||
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
|
||||
to_stop = []
|
||||
for node_id in self._proxy_actors:
|
||||
if node_id not in all_node_ids:
|
||||
logger.info("Removing HTTP proxy on removed node '{}'.".format(
|
||||
node_id))
|
||||
to_stop.append(node_id)
|
||||
|
||||
for node_id in to_stop:
|
||||
proxy = self._proxy_actors.pop(node_id)
|
||||
ray.kill(proxy, no_restart=True)
|
||||
|
||||
|
||||
class BackendInfo(BaseModel):
|
||||
# TODO(architkulkarni): Add type hint for worker_class after upgrading
|
||||
# cloudpickle and adding types to RayServeWrappedReplica
|
||||
@@ -93,6 +164,32 @@ class BackendInfo(BaseModel):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class BackendState:
|
||||
def __init__(self, checkpoint: bytes = None):
|
||||
self.backends: Dict[BackendTag, BackendInfo] = dict()
|
||||
|
||||
if checkpoint is not None:
|
||||
self.backends = pickle.loads(checkpoint)
|
||||
|
||||
def checkpoint(self):
|
||||
return pickle.dumps(self.backends)
|
||||
|
||||
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]:
|
||||
return {
|
||||
tag: info.backend_config
|
||||
for tag, info in self.backends.items()
|
||||
}
|
||||
|
||||
def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]:
|
||||
return self.backends.get(backend_tag)
|
||||
|
||||
def add_backend(self,
|
||||
backend_tag: BackendTag,
|
||||
backend_info: BackendInfo,
|
||||
goal_id: GoalId = 0) -> None:
|
||||
self.backends[backend_tag] = backend_info
|
||||
|
||||
|
||||
class EndpointState:
|
||||
def __init__(self, checkpoint: bytes = None):
|
||||
self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict()
|
||||
@@ -124,38 +221,11 @@ class EndpointState:
|
||||
return endpoints
|
||||
|
||||
|
||||
class BackendState:
|
||||
def __init__(self, checkpoint: bytes = None):
|
||||
self.backends: Dict[BackendTag, BackendInfo] = dict()
|
||||
|
||||
if checkpoint is not None:
|
||||
self.backends = pickle.loads(checkpoint)
|
||||
|
||||
def checkpoint(self):
|
||||
return pickle.dumps(self.backends)
|
||||
|
||||
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]:
|
||||
return {
|
||||
tag: info.backend_config
|
||||
for tag, info in self.backends.items()
|
||||
}
|
||||
|
||||
def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]:
|
||||
return self.backends.get(backend_tag)
|
||||
|
||||
def add_backend(self,
|
||||
backend_tag: BackendTag,
|
||||
backend_info: BackendInfo,
|
||||
goal_id: GoalId = 0) -> None:
|
||||
self.backends[backend_tag] = backend_info
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorStateReconciler:
|
||||
controller_name: str = field(init=True)
|
||||
detached: bool = field(init=True)
|
||||
|
||||
http_proxy_cache: Dict[NodeId, ActorHandle] = field(default_factory=dict)
|
||||
backend_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]] = field(
|
||||
default_factory=lambda: defaultdict(dict))
|
||||
backend_replicas_to_start: Dict[BackendTag, List[ReplicaTag]] = field(
|
||||
@@ -184,9 +254,6 @@ class ActorStateReconciler:
|
||||
|
||||
# TODO(edoakes): consider removing this and just using the names.
|
||||
|
||||
def http_proxy_handles(self) -> List[ActorHandle]:
|
||||
return list(self.http_proxy_cache.values())
|
||||
|
||||
def get_replica_handles(self) -> List[ActorHandle]:
|
||||
return list(
|
||||
chain.from_iterable([
|
||||
@@ -389,70 +456,7 @@ class ActorStateReconciler:
|
||||
|
||||
asyncio.sleep(1)
|
||||
|
||||
def _start_http_proxies_if_needed(self, http_host: str, http_port: str,
|
||||
http_middlewares: List[Any]) -> None:
|
||||
"""Start an HTTP proxy on every node if it doesn't already exist."""
|
||||
if http_host is None:
|
||||
return
|
||||
|
||||
for node_id, node_resource in get_all_node_ids():
|
||||
if node_id in self.http_proxy_cache:
|
||||
continue
|
||||
|
||||
name = format_actor_name(SERVE_PROXY_NAME, self.controller_name,
|
||||
node_id)
|
||||
try:
|
||||
proxy = ray.get_actor(name)
|
||||
except ValueError:
|
||||
logger.info("Starting HTTP proxy with name '{}' on node '{}' "
|
||||
"listening on '{}:{}'".format(
|
||||
name, node_id, http_host, http_port))
|
||||
proxy = HTTPProxyActor.options(
|
||||
name=name,
|
||||
lifetime="detached" if self.detached else None,
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
resources={
|
||||
node_resource: 0.01
|
||||
},
|
||||
).remote(
|
||||
http_host,
|
||||
http_port,
|
||||
controller_name=self.controller_name,
|
||||
http_middlewares=http_middlewares)
|
||||
|
||||
self.http_proxy_cache[node_id] = proxy
|
||||
|
||||
def _stop_http_proxies_if_needed(self) -> bool:
|
||||
"""Removes HTTP proxy actors from any nodes that no longer exist.
|
||||
|
||||
Returns whether or not any actors were removed (a checkpoint should
|
||||
be taken).
|
||||
"""
|
||||
actor_stopped = False
|
||||
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
|
||||
to_stop = []
|
||||
for node_id in self.http_proxy_cache:
|
||||
if node_id not in all_node_ids:
|
||||
logger.info("Removing HTTP proxy on removed node '{}'.".format(
|
||||
node_id))
|
||||
to_stop.append(node_id)
|
||||
|
||||
for node_id in to_stop:
|
||||
proxy = self.http_proxy_cache.pop(node_id)
|
||||
ray.kill(proxy, no_restart=True)
|
||||
actor_stopped = True
|
||||
|
||||
return actor_stopped
|
||||
|
||||
def _recover_actor_handles(self) -> None:
|
||||
# Refresh the RouterCache
|
||||
for node_id in self.http_proxy_cache.keys():
|
||||
name = format_actor_name(SERVE_PROXY_NAME, self.controller_name,
|
||||
node_id)
|
||||
self.http_proxy_cache[node_id] = ray.get_actor(name)
|
||||
|
||||
# Fetch actor handles for all of the backend replicas in the system.
|
||||
# All of these backend_replicas are guaranteed to already exist because
|
||||
# they would not be written to a checkpoint in self.backend_replicas
|
||||
@@ -526,9 +530,7 @@ class ServeController:
|
||||
|
||||
async def __init__(self,
|
||||
controller_name: str,
|
||||
http_host: str,
|
||||
http_port: str,
|
||||
http_middlewares: List[Any],
|
||||
http_config: HTTPConfig,
|
||||
detached: bool = False):
|
||||
# Used to read/write checkpoints.
|
||||
self.kv_store = RayInternalKVStore(namespace=controller_name)
|
||||
@@ -544,20 +546,14 @@ class ServeController:
|
||||
# at any given time.
|
||||
self.write_lock = asyncio.Lock()
|
||||
|
||||
self.http_host = http_host
|
||||
self.http_port = http_port
|
||||
self.http_middlewares = http_middlewares
|
||||
|
||||
# If starting the actor for the first time, starts up the other system
|
||||
# components. If recovering, fetches their actor handles.
|
||||
self.actor_reconciler._start_http_proxies_if_needed(
|
||||
self.http_host, self.http_port, self.http_middlewares)
|
||||
|
||||
# Map of awaiting results
|
||||
# TODO(ilr): Checkpoint this once this becomes asynchronous
|
||||
self.inflight_results: Dict[UUID, asyncio.Event] = dict()
|
||||
self._serializable_inflight_results: Dict[UUID, FutureResult] = dict()
|
||||
|
||||
# HTTP state doesn't currently require a checkpoint.
|
||||
self.http_state = HTTPState(controller_name, detached, http_config)
|
||||
|
||||
checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY)
|
||||
if checkpoint_bytes is None:
|
||||
logger.debug("No checkpoint found")
|
||||
@@ -650,9 +646,9 @@ class ServeController:
|
||||
return await (
|
||||
self.long_poll_host.listen_for_change(keys_to_snapshot_ids))
|
||||
|
||||
def get_http_proxies(self) -> Dict[str, ActorHandle]:
|
||||
def get_http_proxies(self) -> Dict[NodeId, ActorHandle]:
|
||||
"""Returns a dictionary of node ID to http_proxy actor handles."""
|
||||
return self.actor_reconciler.http_proxy_cache
|
||||
return self.http_state.get_http_proxy_handles()
|
||||
|
||||
def _checkpoint(self) -> None:
|
||||
"""Checkpoint internal state and write it to the KV store."""
|
||||
@@ -737,12 +733,7 @@ class ServeController:
|
||||
while True:
|
||||
await self.do_autoscale()
|
||||
async with self.write_lock:
|
||||
self.actor_reconciler._start_http_proxies_if_needed(
|
||||
self.http_host, self.http_port, self.http_middlewares)
|
||||
checkpoint_required = self.actor_reconciler.\
|
||||
_stop_http_proxies_if_needed()
|
||||
if checkpoint_required:
|
||||
self._checkpoint()
|
||||
self.http_state.update()
|
||||
|
||||
await asyncio.sleep(CONTROL_LOOP_PERIOD_S)
|
||||
|
||||
@@ -1057,13 +1048,13 @@ class ServeController:
|
||||
|
||||
def get_http_config(self):
|
||||
"""Return the HTTP proxy configuration."""
|
||||
return self.http_host, self.http_port
|
||||
return self.http_state.get_config()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shuts down the serve instance completely."""
|
||||
async with self.write_lock:
|
||||
for http_proxy in self.actor_reconciler.http_proxy_handles():
|
||||
ray.kill(http_proxy, no_restart=True)
|
||||
for proxy in self.http_state.get_http_proxy_handles().values():
|
||||
ray.kill(proxy, no_restart=True)
|
||||
for replica in self.actor_reconciler.get_replica_handles():
|
||||
ray.kill(replica, no_restart=True)
|
||||
self.kv_store.delete(CHECKPOINT_KEY)
|
||||
|
||||
Reference in New Issue
Block a user