diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 3e4b53b28..5e8de83e7 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -17,7 +17,8 @@ from ray.serve.handle import RayServeHandle, RayServeSyncHandle from ray.serve.utils import (block_until_http_ready, format_actor_name, get_random_letters, logger, get_conda_env_dir) from ray.serve.exceptions import RayServeException -from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata +from ray.serve.config import (BackendConfig, ReplicaConfig, BackendMetadata, + HTTPConfig) from ray.serve.env import CondaEnv from ray.serve.router import RequestMetadata, Router from ray.actor import ActorHandle @@ -93,8 +94,7 @@ class Client: self._controller_name = controller_name self._detached = detached self._shutdown = False - self._http_host, self._http_port = ray.get( - controller.get_http_config.remote()) + self._http_config = ray.get(controller.get_http_config.remote()) self._sync_proxied_router = None self._async_proxied_router = None @@ -237,8 +237,8 @@ class Client: num_cpus=0, resources={ node_id: 0.01 }).remote( - "http://{}:{}/-/routes".format(self._http_host, - self._http_port), + "http://{}:{}/-/routes".format(self._http_config.host, + self._http_config.port), check_ready=check_ready, timeout=HTTP_PROXY_TIMEOUT) futures.append(future) @@ -559,9 +559,7 @@ def start(detached: bool = False, max_task_retries=-1, ).remote( controller_name, - http_host, - http_port, - http_middlewares, + HTTPConfig(http_host, http_port, http_middlewares), detached=detached) if http_host is not None: diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 0a8070d9e..104d6da7c 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -2,8 +2,8 @@ import inspect from pydantic import BaseModel, PositiveInt, validator from ray.serve.constants import ASYNC_CONCURRENCY -from typing import Optional, Dict, Any -from dataclasses import dataclass +from typing import Optional, Dict, Any, List +from dataclasses import dataclass, field def _callable_accepts_batch(func_or_class): @@ -191,3 +191,10 @@ class ReplicaConfig: raise TypeError( "resources in ray_actor_options must be a dictionary.") self.resource_dict.update(custom_resources) + + +@dataclass +class HTTPConfig: + host: str = field(init=True) + port: int = field(init=True) + middlewares: List[Any] = field(init=True) diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 4a4b754ff..62ccb87b2 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -20,7 +20,7 @@ from ray.serve.kv_store import RayInternalKVStore from ray.serve.exceptions import RayServeException from ray.serve.utils import (format_actor_name, get_random_letters, logger, try_schedule_resources_on_nodes, get_all_node_ids) -from ray.serve.config import BackendConfig, ReplicaConfig +from ray.serve.config import BackendConfig, ReplicaConfig, HTTPConfig from ray.serve.long_poll import LongPollHost from ray.actor import ActorHandle @@ -80,6 +80,77 @@ class TrafficPolicy: return f"" +class HTTPState: + def __init__(self, controller_name: str, detached: bool, + config: HTTPConfig): + self._controller_name = controller_name + self._detached = detached + self._config = config + self._proxy_actors: Dict[NodeId, ActorHandle] = dict() + + # Will populate self.proxy_actors with existing actors. + self._start_proxies_if_needed() + + def get_config(self): + return self._config + + def get_http_proxy_handles(self) -> Dict[NodeId, ActorHandle]: + return self._proxy_actors + + def update(self): + self._start_proxies_if_needed() + self._stop_proxies_if_needed() + + def _start_proxies_if_needed(self) -> None: + """Start a proxy on every node if it doesn't already exist.""" + if self._config.host is None: + return + + for node_id, node_resource in get_all_node_ids(): + if node_id in self._proxy_actors: + continue + + name = format_actor_name(SERVE_PROXY_NAME, self._controller_name, + node_id) + try: + proxy = ray.get_actor(name) + except ValueError: + logger.info("Starting HTTP proxy with name '{}' on node '{}' " + "listening on '{}:{}'".format( + name, node_id, self._config.host, + self._config.port)) + proxy = HTTPProxyActor.options( + name=name, + lifetime="detached" if self._detached else None, + max_concurrency=ASYNC_CONCURRENCY, + max_restarts=-1, + max_task_retries=-1, + resources={ + node_resource: 0.01 + }, + ).remote( + self._config.host, + self._config.port, + controller_name=self._controller_name, + http_middlewares=self._config.middlewares) + + self._proxy_actors[node_id] = proxy + + def _stop_proxies_if_needed(self) -> bool: + """Removes proxy actors from any nodes that no longer exist.""" + all_node_ids = {node_id for node_id, _ in get_all_node_ids()} + to_stop = [] + for node_id in self._proxy_actors: + if node_id not in all_node_ids: + logger.info("Removing HTTP proxy on removed node '{}'.".format( + node_id)) + to_stop.append(node_id) + + for node_id in to_stop: + proxy = self._proxy_actors.pop(node_id) + ray.kill(proxy, no_restart=True) + + class BackendInfo(BaseModel): # TODO(architkulkarni): Add type hint for worker_class after upgrading # cloudpickle and adding types to RayServeWrappedReplica @@ -93,6 +164,32 @@ class BackendInfo(BaseModel): arbitrary_types_allowed = True +class BackendState: + def __init__(self, checkpoint: bytes = None): + self.backends: Dict[BackendTag, BackendInfo] = dict() + + if checkpoint is not None: + self.backends = pickle.loads(checkpoint) + + def checkpoint(self): + return pickle.dumps(self.backends) + + def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]: + return { + tag: info.backend_config + for tag, info in self.backends.items() + } + + def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]: + return self.backends.get(backend_tag) + + def add_backend(self, + backend_tag: BackendTag, + backend_info: BackendInfo, + goal_id: GoalId = 0) -> None: + self.backends[backend_tag] = backend_info + + class EndpointState: def __init__(self, checkpoint: bytes = None): self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict() @@ -124,38 +221,11 @@ class EndpointState: return endpoints -class BackendState: - def __init__(self, checkpoint: bytes = None): - self.backends: Dict[BackendTag, BackendInfo] = dict() - - if checkpoint is not None: - self.backends = pickle.loads(checkpoint) - - def checkpoint(self): - return pickle.dumps(self.backends) - - def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]: - return { - tag: info.backend_config - for tag, info in self.backends.items() - } - - def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]: - return self.backends.get(backend_tag) - - def add_backend(self, - backend_tag: BackendTag, - backend_info: BackendInfo, - goal_id: GoalId = 0) -> None: - self.backends[backend_tag] = backend_info - - @dataclass class ActorStateReconciler: controller_name: str = field(init=True) detached: bool = field(init=True) - http_proxy_cache: Dict[NodeId, ActorHandle] = field(default_factory=dict) backend_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]] = field( default_factory=lambda: defaultdict(dict)) backend_replicas_to_start: Dict[BackendTag, List[ReplicaTag]] = field( @@ -184,9 +254,6 @@ class ActorStateReconciler: # TODO(edoakes): consider removing this and just using the names. - def http_proxy_handles(self) -> List[ActorHandle]: - return list(self.http_proxy_cache.values()) - def get_replica_handles(self) -> List[ActorHandle]: return list( chain.from_iterable([ @@ -389,70 +456,7 @@ class ActorStateReconciler: asyncio.sleep(1) - def _start_http_proxies_if_needed(self, http_host: str, http_port: str, - http_middlewares: List[Any]) -> None: - """Start an HTTP proxy on every node if it doesn't already exist.""" - if http_host is None: - return - - for node_id, node_resource in get_all_node_ids(): - if node_id in self.http_proxy_cache: - continue - - name = format_actor_name(SERVE_PROXY_NAME, self.controller_name, - node_id) - try: - proxy = ray.get_actor(name) - except ValueError: - logger.info("Starting HTTP proxy with name '{}' on node '{}' " - "listening on '{}:{}'".format( - name, node_id, http_host, http_port)) - proxy = HTTPProxyActor.options( - name=name, - lifetime="detached" if self.detached else None, - max_concurrency=ASYNC_CONCURRENCY, - max_restarts=-1, - max_task_retries=-1, - resources={ - node_resource: 0.01 - }, - ).remote( - http_host, - http_port, - controller_name=self.controller_name, - http_middlewares=http_middlewares) - - self.http_proxy_cache[node_id] = proxy - - def _stop_http_proxies_if_needed(self) -> bool: - """Removes HTTP proxy actors from any nodes that no longer exist. - - Returns whether or not any actors were removed (a checkpoint should - be taken). - """ - actor_stopped = False - all_node_ids = {node_id for node_id, _ in get_all_node_ids()} - to_stop = [] - for node_id in self.http_proxy_cache: - if node_id not in all_node_ids: - logger.info("Removing HTTP proxy on removed node '{}'.".format( - node_id)) - to_stop.append(node_id) - - for node_id in to_stop: - proxy = self.http_proxy_cache.pop(node_id) - ray.kill(proxy, no_restart=True) - actor_stopped = True - - return actor_stopped - def _recover_actor_handles(self) -> None: - # Refresh the RouterCache - for node_id in self.http_proxy_cache.keys(): - name = format_actor_name(SERVE_PROXY_NAME, self.controller_name, - node_id) - self.http_proxy_cache[node_id] = ray.get_actor(name) - # Fetch actor handles for all of the backend replicas in the system. # All of these backend_replicas are guaranteed to already exist because # they would not be written to a checkpoint in self.backend_replicas @@ -526,9 +530,7 @@ class ServeController: async def __init__(self, controller_name: str, - http_host: str, - http_port: str, - http_middlewares: List[Any], + http_config: HTTPConfig, detached: bool = False): # Used to read/write checkpoints. self.kv_store = RayInternalKVStore(namespace=controller_name) @@ -544,20 +546,14 @@ class ServeController: # at any given time. self.write_lock = asyncio.Lock() - self.http_host = http_host - self.http_port = http_port - self.http_middlewares = http_middlewares - - # If starting the actor for the first time, starts up the other system - # components. If recovering, fetches their actor handles. - self.actor_reconciler._start_http_proxies_if_needed( - self.http_host, self.http_port, self.http_middlewares) - # Map of awaiting results # TODO(ilr): Checkpoint this once this becomes asynchronous self.inflight_results: Dict[UUID, asyncio.Event] = dict() self._serializable_inflight_results: Dict[UUID, FutureResult] = dict() + # HTTP state doesn't currently require a checkpoint. + self.http_state = HTTPState(controller_name, detached, http_config) + checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY) if checkpoint_bytes is None: logger.debug("No checkpoint found") @@ -650,9 +646,9 @@ class ServeController: return await ( self.long_poll_host.listen_for_change(keys_to_snapshot_ids)) - def get_http_proxies(self) -> Dict[str, ActorHandle]: + def get_http_proxies(self) -> Dict[NodeId, ActorHandle]: """Returns a dictionary of node ID to http_proxy actor handles.""" - return self.actor_reconciler.http_proxy_cache + return self.http_state.get_http_proxy_handles() def _checkpoint(self) -> None: """Checkpoint internal state and write it to the KV store.""" @@ -737,12 +733,7 @@ class ServeController: while True: await self.do_autoscale() async with self.write_lock: - self.actor_reconciler._start_http_proxies_if_needed( - self.http_host, self.http_port, self.http_middlewares) - checkpoint_required = self.actor_reconciler.\ - _stop_http_proxies_if_needed() - if checkpoint_required: - self._checkpoint() + self.http_state.update() await asyncio.sleep(CONTROL_LOOP_PERIOD_S) @@ -1057,13 +1048,13 @@ class ServeController: def get_http_config(self): """Return the HTTP proxy configuration.""" - return self.http_host, self.http_port + return self.http_state.get_config() async def shutdown(self) -> None: """Shuts down the serve instance completely.""" async with self.write_lock: - for http_proxy in self.actor_reconciler.http_proxy_handles(): - ray.kill(http_proxy, no_restart=True) + for proxy in self.http_state.get_http_proxy_handles().values(): + ray.kill(proxy, no_restart=True) for replica in self.actor_reconciler.get_replica_handles(): ray.kill(replica, no_restart=True) self.kv_store.delete(CHECKPOINT_KEY)