diff --git a/python/ray/autoscaler/node_provider.py b/python/ray/autoscaler/node_provider.py index dd287a743..008602ef6 100644 --- a/python/ray/autoscaler/node_provider.py +++ b/python/ray/autoscaler/node_provider.py @@ -1,6 +1,8 @@ import logging -from typing import Any, Dict +from types import ModuleType +from typing import Any, Dict, List, Optional +from ray.autoscaler.command_runner import CommandRunnerInterface from ray.autoscaler._private.command_runner import \ SSHCommandRunner, DockerCommandRunner @@ -24,13 +26,14 @@ class NodeProvider: immediately to terminated when `terminate_node` is called. """ - def __init__(self, provider_config, cluster_name): + def __init__(self, provider_config: Dict[str, Any], + cluster_name: str) -> None: self.provider_config = provider_config self.cluster_name = cluster_name - self._internal_ip_cache = {} - self._external_ip_cache = {} + self._internal_ip_cache: Dict[str, str] = {} + self._external_ip_cache: Dict[str, str] = {} - def non_terminated_nodes(self, tag_filters): + def non_terminated_nodes(self, tag_filters: Dict[str, str]) -> List[str]: """Return a list of node ids filtered by the specified tags dict. This list must not include terminated nodes. For performance reasons, @@ -44,27 +47,28 @@ class NodeProvider: """ raise NotImplementedError - def is_running(self, node_id): + def is_running(self, node_id: str) -> bool: """Return whether the specified node is running.""" raise NotImplementedError - def is_terminated(self, node_id): + def is_terminated(self, node_id: str) -> bool: """Return whether the specified node is terminated.""" raise NotImplementedError - def node_tags(self, node_id): + def node_tags(self, node_id: str) -> Dict[str, str]: """Returns the tags of the given node (string dict).""" raise NotImplementedError - def external_ip(self, node_id): + def external_ip(self, node_id: str) -> str: """Returns the external ip of the given node.""" raise NotImplementedError - def internal_ip(self, node_id): + def internal_ip(self, node_id: str) -> str: """Returns the internal ip (Ray ip) of the given node.""" raise NotImplementedError - def get_node_id(self, ip_address, use_internal_ip=False) -> str: + def get_node_id(self, ip_address: str, + use_internal_ip: bool = False) -> str: """Returns the node_id given an IP address. Assumes ip-address is unique per node. @@ -105,42 +109,44 @@ class NodeProvider: return find_node_id() - def create_node(self, node_config, tags, count): + def create_node(self, node_config: Dict[str, Any], tags: Dict[str, str], + count: int) -> None: """Creates a number of nodes within the namespace.""" raise NotImplementedError - def set_node_tags(self, node_id, tags): + def set_node_tags(self, node_id: str, tags: Dict[str, str]) -> None: """Sets the tag values (string dict) for the specified node.""" raise NotImplementedError - def terminate_node(self, node_id): + def terminate_node(self, node_id: str) -> None: """Terminates the specified node.""" raise NotImplementedError - def terminate_nodes(self, node_ids): + def terminate_nodes(self, node_ids: List[str]) -> None: """Terminates a set of nodes. May be overridden with a batch method.""" for node_id in node_ids: logger.info("NodeProvider: " "{}: Terminating node".format(node_id)) self.terminate_node(node_id) - def cleanup(self): + def cleanup(self) -> None: """Clean-up when a Provider is no longer required.""" pass @staticmethod - def bootstrap_config(cluster_config): + def bootstrap_config(cluster_config: Dict[str, Any]) -> Dict[str, Any]: """Bootstraps the cluster config by adding env defaults if needed.""" return cluster_config def get_command_runner(self, - log_prefix, - node_id, - auth_config, - cluster_name, - process_runner, - use_internal_ip, - docker_config=None) -> Any: + log_prefix: str, + node_id: str, + auth_config: Dict[str, Any], + cluster_name: str, + process_runner: ModuleType, + use_internal_ip: bool, + docker_config: Optional[Dict[str, Any]] = None + ) -> CommandRunnerInterface: """Returns the CommandRunner class used to perform SSH commands. Args: @@ -171,7 +177,8 @@ class NodeProvider: else: return SSHCommandRunner(**common_args) - def prepare_for_head_node(self, cluster_config): + def prepare_for_head_node( + self, cluster_config: Dict[str, Any]) -> Dict[str, Any]: """Returns a new cluster config with custom configs for head node.""" return cluster_config