[serve] Centralize HTTP-related logic in HTTPState (#13020)

This commit is contained in:
Edward Oakes
2020-12-23 18:00:02 -06:00
committed by GitHub
parent 668ea0bc26
commit 3cc213ddf6
3 changed files with 123 additions and 127 deletions
+6 -8
View File
@@ -17,7 +17,8 @@ from ray.serve.handle import RayServeHandle, RayServeSyncHandle
from ray.serve.utils import (block_until_http_ready, format_actor_name,
get_random_letters, logger, get_conda_env_dir)
from ray.serve.exceptions import RayServeException
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
from ray.serve.config import (BackendConfig, ReplicaConfig, BackendMetadata,
HTTPConfig)
from ray.serve.env import CondaEnv
from ray.serve.router import RequestMetadata, Router
from ray.actor import ActorHandle
@@ -93,8 +94,7 @@ class Client:
self._controller_name = controller_name
self._detached = detached
self._shutdown = False
self._http_host, self._http_port = ray.get(
controller.get_http_config.remote())
self._http_config = ray.get(controller.get_http_config.remote())
self._sync_proxied_router = None
self._async_proxied_router = None
@@ -237,8 +237,8 @@ class Client:
num_cpus=0, resources={
node_id: 0.01
}).remote(
"http://{}:{}/-/routes".format(self._http_host,
self._http_port),
"http://{}:{}/-/routes".format(self._http_config.host,
self._http_config.port),
check_ready=check_ready,
timeout=HTTP_PROXY_TIMEOUT)
futures.append(future)
@@ -559,9 +559,7 @@ def start(detached: bool = False,
max_task_retries=-1,
).remote(
controller_name,
http_host,
http_port,
http_middlewares,
HTTPConfig(http_host, http_port, http_middlewares),
detached=detached)
if http_host is not None:
+9 -2
View File
@@ -2,8 +2,8 @@ import inspect
from pydantic import BaseModel, PositiveInt, validator
from ray.serve.constants import ASYNC_CONCURRENCY
from typing import Optional, Dict, Any
from dataclasses import dataclass
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
def _callable_accepts_batch(func_or_class):
@@ -191,3 +191,10 @@ class ReplicaConfig:
raise TypeError(
"resources in ray_actor_options must be a dictionary.")
self.resource_dict.update(custom_resources)
@dataclass
class HTTPConfig:
host: str = field(init=True)
port: int = field(init=True)
middlewares: List[Any] = field(init=True)
+108 -117
View File
@@ -20,7 +20,7 @@ from ray.serve.kv_store import RayInternalKVStore
from ray.serve.exceptions import RayServeException
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
try_schedule_resources_on_nodes, get_all_node_ids)
from ray.serve.config import BackendConfig, ReplicaConfig
from ray.serve.config import BackendConfig, ReplicaConfig, HTTPConfig
from ray.serve.long_poll import LongPollHost
from ray.actor import ActorHandle
@@ -80,6 +80,77 @@ class TrafficPolicy:
return f"<Traffic {self.traffic_dict}; Shadow {self.shadow_dict}>"
class HTTPState:
def __init__(self, controller_name: str, detached: bool,
config: HTTPConfig):
self._controller_name = controller_name
self._detached = detached
self._config = config
self._proxy_actors: Dict[NodeId, ActorHandle] = dict()
# Will populate self.proxy_actors with existing actors.
self._start_proxies_if_needed()
def get_config(self):
return self._config
def get_http_proxy_handles(self) -> Dict[NodeId, ActorHandle]:
return self._proxy_actors
def update(self):
self._start_proxies_if_needed()
self._stop_proxies_if_needed()
def _start_proxies_if_needed(self) -> None:
"""Start a proxy on every node if it doesn't already exist."""
if self._config.host is None:
return
for node_id, node_resource in get_all_node_ids():
if node_id in self._proxy_actors:
continue
name = format_actor_name(SERVE_PROXY_NAME, self._controller_name,
node_id)
try:
proxy = ray.get_actor(name)
except ValueError:
logger.info("Starting HTTP proxy with name '{}' on node '{}' "
"listening on '{}:{}'".format(
name, node_id, self._config.host,
self._config.port))
proxy = HTTPProxyActor.options(
name=name,
lifetime="detached" if self._detached else None,
max_concurrency=ASYNC_CONCURRENCY,
max_restarts=-1,
max_task_retries=-1,
resources={
node_resource: 0.01
},
).remote(
self._config.host,
self._config.port,
controller_name=self._controller_name,
http_middlewares=self._config.middlewares)
self._proxy_actors[node_id] = proxy
def _stop_proxies_if_needed(self) -> bool:
"""Removes proxy actors from any nodes that no longer exist."""
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
to_stop = []
for node_id in self._proxy_actors:
if node_id not in all_node_ids:
logger.info("Removing HTTP proxy on removed node '{}'.".format(
node_id))
to_stop.append(node_id)
for node_id in to_stop:
proxy = self._proxy_actors.pop(node_id)
ray.kill(proxy, no_restart=True)
class BackendInfo(BaseModel):
# TODO(architkulkarni): Add type hint for worker_class after upgrading
# cloudpickle and adding types to RayServeWrappedReplica
@@ -93,6 +164,32 @@ class BackendInfo(BaseModel):
arbitrary_types_allowed = True
class BackendState:
def __init__(self, checkpoint: bytes = None):
self.backends: Dict[BackendTag, BackendInfo] = dict()
if checkpoint is not None:
self.backends = pickle.loads(checkpoint)
def checkpoint(self):
return pickle.dumps(self.backends)
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]:
return {
tag: info.backend_config
for tag, info in self.backends.items()
}
def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]:
return self.backends.get(backend_tag)
def add_backend(self,
backend_tag: BackendTag,
backend_info: BackendInfo,
goal_id: GoalId = 0) -> None:
self.backends[backend_tag] = backend_info
class EndpointState:
def __init__(self, checkpoint: bytes = None):
self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict()
@@ -124,38 +221,11 @@ class EndpointState:
return endpoints
class BackendState:
def __init__(self, checkpoint: bytes = None):
self.backends: Dict[BackendTag, BackendInfo] = dict()
if checkpoint is not None:
self.backends = pickle.loads(checkpoint)
def checkpoint(self):
return pickle.dumps(self.backends)
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]:
return {
tag: info.backend_config
for tag, info in self.backends.items()
}
def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]:
return self.backends.get(backend_tag)
def add_backend(self,
backend_tag: BackendTag,
backend_info: BackendInfo,
goal_id: GoalId = 0) -> None:
self.backends[backend_tag] = backend_info
@dataclass
class ActorStateReconciler:
controller_name: str = field(init=True)
detached: bool = field(init=True)
http_proxy_cache: Dict[NodeId, ActorHandle] = field(default_factory=dict)
backend_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]] = field(
default_factory=lambda: defaultdict(dict))
backend_replicas_to_start: Dict[BackendTag, List[ReplicaTag]] = field(
@@ -184,9 +254,6 @@ class ActorStateReconciler:
# TODO(edoakes): consider removing this and just using the names.
def http_proxy_handles(self) -> List[ActorHandle]:
return list(self.http_proxy_cache.values())
def get_replica_handles(self) -> List[ActorHandle]:
return list(
chain.from_iterable([
@@ -389,70 +456,7 @@ class ActorStateReconciler:
asyncio.sleep(1)
def _start_http_proxies_if_needed(self, http_host: str, http_port: str,
http_middlewares: List[Any]) -> None:
"""Start an HTTP proxy on every node if it doesn't already exist."""
if http_host is None:
return
for node_id, node_resource in get_all_node_ids():
if node_id in self.http_proxy_cache:
continue
name = format_actor_name(SERVE_PROXY_NAME, self.controller_name,
node_id)
try:
proxy = ray.get_actor(name)
except ValueError:
logger.info("Starting HTTP proxy with name '{}' on node '{}' "
"listening on '{}:{}'".format(
name, node_id, http_host, http_port))
proxy = HTTPProxyActor.options(
name=name,
lifetime="detached" if self.detached else None,
max_concurrency=ASYNC_CONCURRENCY,
max_restarts=-1,
max_task_retries=-1,
resources={
node_resource: 0.01
},
).remote(
http_host,
http_port,
controller_name=self.controller_name,
http_middlewares=http_middlewares)
self.http_proxy_cache[node_id] = proxy
def _stop_http_proxies_if_needed(self) -> bool:
"""Removes HTTP proxy actors from any nodes that no longer exist.
Returns whether or not any actors were removed (a checkpoint should
be taken).
"""
actor_stopped = False
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
to_stop = []
for node_id in self.http_proxy_cache:
if node_id not in all_node_ids:
logger.info("Removing HTTP proxy on removed node '{}'.".format(
node_id))
to_stop.append(node_id)
for node_id in to_stop:
proxy = self.http_proxy_cache.pop(node_id)
ray.kill(proxy, no_restart=True)
actor_stopped = True
return actor_stopped
def _recover_actor_handles(self) -> None:
# Refresh the RouterCache
for node_id in self.http_proxy_cache.keys():
name = format_actor_name(SERVE_PROXY_NAME, self.controller_name,
node_id)
self.http_proxy_cache[node_id] = ray.get_actor(name)
# Fetch actor handles for all of the backend replicas in the system.
# All of these backend_replicas are guaranteed to already exist because
# they would not be written to a checkpoint in self.backend_replicas
@@ -526,9 +530,7 @@ class ServeController:
async def __init__(self,
controller_name: str,
http_host: str,
http_port: str,
http_middlewares: List[Any],
http_config: HTTPConfig,
detached: bool = False):
# Used to read/write checkpoints.
self.kv_store = RayInternalKVStore(namespace=controller_name)
@@ -544,20 +546,14 @@ class ServeController:
# at any given time.
self.write_lock = asyncio.Lock()
self.http_host = http_host
self.http_port = http_port
self.http_middlewares = http_middlewares
# If starting the actor for the first time, starts up the other system
# components. If recovering, fetches their actor handles.
self.actor_reconciler._start_http_proxies_if_needed(
self.http_host, self.http_port, self.http_middlewares)
# Map of awaiting results
# TODO(ilr): Checkpoint this once this becomes asynchronous
self.inflight_results: Dict[UUID, asyncio.Event] = dict()
self._serializable_inflight_results: Dict[UUID, FutureResult] = dict()
# HTTP state doesn't currently require a checkpoint.
self.http_state = HTTPState(controller_name, detached, http_config)
checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY)
if checkpoint_bytes is None:
logger.debug("No checkpoint found")
@@ -650,9 +646,9 @@ class ServeController:
return await (
self.long_poll_host.listen_for_change(keys_to_snapshot_ids))
def get_http_proxies(self) -> Dict[str, ActorHandle]:
def get_http_proxies(self) -> Dict[NodeId, ActorHandle]:
"""Returns a dictionary of node ID to http_proxy actor handles."""
return self.actor_reconciler.http_proxy_cache
return self.http_state.get_http_proxy_handles()
def _checkpoint(self) -> None:
"""Checkpoint internal state and write it to the KV store."""
@@ -737,12 +733,7 @@ class ServeController:
while True:
await self.do_autoscale()
async with self.write_lock:
self.actor_reconciler._start_http_proxies_if_needed(
self.http_host, self.http_port, self.http_middlewares)
checkpoint_required = self.actor_reconciler.\
_stop_http_proxies_if_needed()
if checkpoint_required:
self._checkpoint()
self.http_state.update()
await asyncio.sleep(CONTROL_LOOP_PERIOD_S)
@@ -1057,13 +1048,13 @@ class ServeController:
def get_http_config(self):
"""Return the HTTP proxy configuration."""
return self.http_host, self.http_port
return self.http_state.get_config()
async def shutdown(self) -> None:
"""Shuts down the serve instance completely."""
async with self.write_lock:
for http_proxy in self.actor_reconciler.http_proxy_handles():
ray.kill(http_proxy, no_restart=True)
for proxy in self.http_state.get_http_proxy_handles().values():
ray.kill(proxy, no_restart=True)
for replica in self.actor_reconciler.get_replica_handles():
ray.kill(replica, no_restart=True)
self.kv_store.delete(CHECKPOINT_KEY)