Type annotations added to node_provider.py (#11221)

This commit is contained in:
Gekho457
2020-10-06 01:03:04 -04:00
committed by GitHub
parent 8ec044f1f5
commit 66e265fdb9
+32 -25
View File
@@ -1,6 +1,8 @@
import logging
from typing import Any, Dict
from types import ModuleType
from typing import Any, Dict, List, Optional
from ray.autoscaler.command_runner import CommandRunnerInterface
from ray.autoscaler._private.command_runner import \
SSHCommandRunner, DockerCommandRunner
@@ -24,13 +26,14 @@ class NodeProvider:
immediately to terminated when `terminate_node` is called.
"""
def __init__(self, provider_config, cluster_name):
def __init__(self, provider_config: Dict[str, Any],
cluster_name: str) -> None:
self.provider_config = provider_config
self.cluster_name = cluster_name
self._internal_ip_cache = {}
self._external_ip_cache = {}
self._internal_ip_cache: Dict[str, str] = {}
self._external_ip_cache: Dict[str, str] = {}
def non_terminated_nodes(self, tag_filters):
def non_terminated_nodes(self, tag_filters: Dict[str, str]) -> List[str]:
"""Return a list of node ids filtered by the specified tags dict.
This list must not include terminated nodes. For performance reasons,
@@ -44,27 +47,28 @@ class NodeProvider:
"""
raise NotImplementedError
def is_running(self, node_id):
def is_running(self, node_id: str) -> bool:
"""Return whether the specified node is running."""
raise NotImplementedError
def is_terminated(self, node_id):
def is_terminated(self, node_id: str) -> bool:
"""Return whether the specified node is terminated."""
raise NotImplementedError
def node_tags(self, node_id):
def node_tags(self, node_id: str) -> Dict[str, str]:
"""Returns the tags of the given node (string dict)."""
raise NotImplementedError
def external_ip(self, node_id):
def external_ip(self, node_id: str) -> str:
"""Returns the external ip of the given node."""
raise NotImplementedError
def internal_ip(self, node_id):
def internal_ip(self, node_id: str) -> str:
"""Returns the internal ip (Ray ip) of the given node."""
raise NotImplementedError
def get_node_id(self, ip_address, use_internal_ip=False) -> str:
def get_node_id(self, ip_address: str,
use_internal_ip: bool = False) -> str:
"""Returns the node_id given an IP address.
Assumes ip-address is unique per node.
@@ -105,42 +109,44 @@ class NodeProvider:
return find_node_id()
def create_node(self, node_config, tags, count):
def create_node(self, node_config: Dict[str, Any], tags: Dict[str, str],
count: int) -> None:
"""Creates a number of nodes within the namespace."""
raise NotImplementedError
def set_node_tags(self, node_id, tags):
def set_node_tags(self, node_id: str, tags: Dict[str, str]) -> None:
"""Sets the tag values (string dict) for the specified node."""
raise NotImplementedError
def terminate_node(self, node_id):
def terminate_node(self, node_id: str) -> None:
"""Terminates the specified node."""
raise NotImplementedError
def terminate_nodes(self, node_ids):
def terminate_nodes(self, node_ids: List[str]) -> None:
"""Terminates a set of nodes. May be overridden with a batch method."""
for node_id in node_ids:
logger.info("NodeProvider: "
"{}: Terminating node".format(node_id))
self.terminate_node(node_id)
def cleanup(self):
def cleanup(self) -> None:
"""Clean-up when a Provider is no longer required."""
pass
@staticmethod
def bootstrap_config(cluster_config):
def bootstrap_config(cluster_config: Dict[str, Any]) -> Dict[str, Any]:
"""Bootstraps the cluster config by adding env defaults if needed."""
return cluster_config
def get_command_runner(self,
log_prefix,
node_id,
auth_config,
cluster_name,
process_runner,
use_internal_ip,
docker_config=None) -> Any:
log_prefix: str,
node_id: str,
auth_config: Dict[str, Any],
cluster_name: str,
process_runner: ModuleType,
use_internal_ip: bool,
docker_config: Optional[Dict[str, Any]] = None
) -> CommandRunnerInterface:
"""Returns the CommandRunner class used to perform SSH commands.
Args:
@@ -171,7 +177,8 @@ class NodeProvider:
else:
return SSHCommandRunner(**common_args)
def prepare_for_head_node(self, cluster_config):
def prepare_for_head_node(
self, cluster_config: Dict[str, Any]) -> Dict[str, Any]:
"""Returns a new cluster config with custom configs for head node."""
return cluster_config