mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 10:35:55 +08:00
Type annotations added to node_provider.py (#11221)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ray.autoscaler.command_runner import CommandRunnerInterface
|
||||
from ray.autoscaler._private.command_runner import \
|
||||
SSHCommandRunner, DockerCommandRunner
|
||||
|
||||
@@ -24,13 +26,14 @@ class NodeProvider:
|
||||
immediately to terminated when `terminate_node` is called.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_config, cluster_name):
|
||||
def __init__(self, provider_config: Dict[str, Any],
|
||||
cluster_name: str) -> None:
|
||||
self.provider_config = provider_config
|
||||
self.cluster_name = cluster_name
|
||||
self._internal_ip_cache = {}
|
||||
self._external_ip_cache = {}
|
||||
self._internal_ip_cache: Dict[str, str] = {}
|
||||
self._external_ip_cache: Dict[str, str] = {}
|
||||
|
||||
def non_terminated_nodes(self, tag_filters):
|
||||
def non_terminated_nodes(self, tag_filters: Dict[str, str]) -> List[str]:
|
||||
"""Return a list of node ids filtered by the specified tags dict.
|
||||
|
||||
This list must not include terminated nodes. For performance reasons,
|
||||
@@ -44,27 +47,28 @@ class NodeProvider:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_running(self, node_id):
|
||||
def is_running(self, node_id: str) -> bool:
|
||||
"""Return whether the specified node is running."""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_terminated(self, node_id):
|
||||
def is_terminated(self, node_id: str) -> bool:
|
||||
"""Return whether the specified node is terminated."""
|
||||
raise NotImplementedError
|
||||
|
||||
def node_tags(self, node_id):
|
||||
def node_tags(self, node_id: str) -> Dict[str, str]:
|
||||
"""Returns the tags of the given node (string dict)."""
|
||||
raise NotImplementedError
|
||||
|
||||
def external_ip(self, node_id):
|
||||
def external_ip(self, node_id: str) -> str:
|
||||
"""Returns the external ip of the given node."""
|
||||
raise NotImplementedError
|
||||
|
||||
def internal_ip(self, node_id):
|
||||
def internal_ip(self, node_id: str) -> str:
|
||||
"""Returns the internal ip (Ray ip) of the given node."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_node_id(self, ip_address, use_internal_ip=False) -> str:
|
||||
def get_node_id(self, ip_address: str,
|
||||
use_internal_ip: bool = False) -> str:
|
||||
"""Returns the node_id given an IP address.
|
||||
|
||||
Assumes ip-address is unique per node.
|
||||
@@ -105,42 +109,44 @@ class NodeProvider:
|
||||
|
||||
return find_node_id()
|
||||
|
||||
def create_node(self, node_config, tags, count):
|
||||
def create_node(self, node_config: Dict[str, Any], tags: Dict[str, str],
|
||||
count: int) -> None:
|
||||
"""Creates a number of nodes within the namespace."""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_node_tags(self, node_id, tags):
|
||||
def set_node_tags(self, node_id: str, tags: Dict[str, str]) -> None:
|
||||
"""Sets the tag values (string dict) for the specified node."""
|
||||
raise NotImplementedError
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
def terminate_node(self, node_id: str) -> None:
|
||||
"""Terminates the specified node."""
|
||||
raise NotImplementedError
|
||||
|
||||
def terminate_nodes(self, node_ids):
|
||||
def terminate_nodes(self, node_ids: List[str]) -> None:
|
||||
"""Terminates a set of nodes. May be overridden with a batch method."""
|
||||
for node_id in node_ids:
|
||||
logger.info("NodeProvider: "
|
||||
"{}: Terminating node".format(node_id))
|
||||
self.terminate_node(node_id)
|
||||
|
||||
def cleanup(self):
|
||||
def cleanup(self) -> None:
|
||||
"""Clean-up when a Provider is no longer required."""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def bootstrap_config(cluster_config):
|
||||
def bootstrap_config(cluster_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Bootstraps the cluster config by adding env defaults if needed."""
|
||||
return cluster_config
|
||||
|
||||
def get_command_runner(self,
|
||||
log_prefix,
|
||||
node_id,
|
||||
auth_config,
|
||||
cluster_name,
|
||||
process_runner,
|
||||
use_internal_ip,
|
||||
docker_config=None) -> Any:
|
||||
log_prefix: str,
|
||||
node_id: str,
|
||||
auth_config: Dict[str, Any],
|
||||
cluster_name: str,
|
||||
process_runner: ModuleType,
|
||||
use_internal_ip: bool,
|
||||
docker_config: Optional[Dict[str, Any]] = None
|
||||
) -> CommandRunnerInterface:
|
||||
"""Returns the CommandRunner class used to perform SSH commands.
|
||||
|
||||
Args:
|
||||
@@ -171,7 +177,8 @@ class NodeProvider:
|
||||
else:
|
||||
return SSHCommandRunner(**common_args)
|
||||
|
||||
def prepare_for_head_node(self, cluster_config):
|
||||
def prepare_for_head_node(
|
||||
self, cluster_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Returns a new cluster config with custom configs for head node."""
|
||||
return cluster_config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user