From 50be2970dcf329100e9519fd35d17726524f3ef3 Mon Sep 17 00:00:00 2001 From: Gekho457 <62982571+Gekho457@users.noreply.github.com> Date: Fri, 16 Oct 2020 16:45:36 -0400 Subject: [PATCH] [autoscaler]Type hints for commands.py and sdk.py. (#11354) --- ci/travis/format.sh | 10 ++- python/ray/autoscaler/_private/commands.py | 93 ++++++++++++---------- python/ray/autoscaler/sdk.py | 17 ++-- 3 files changed, 68 insertions(+), 52 deletions(-) diff --git a/ci/travis/format.sh b/ci/travis/format.sh index 35add5cbd..c79ff5a69 100755 --- a/ci/travis/format.sh +++ b/ci/travis/format.sh @@ -95,10 +95,14 @@ YAPF_FLAGS=( # should be set to do a more stringent check. MYPY_FLAGS=( '--follow-imports=skip' + '--ignore-missing-imports' ) MYPY_FILES=( - 'python/ray/autoscaler/node_provider.py' + # Relative to python/ray + 'autoscaler/node_provider.py' + 'autoscaler/sdk.py' + 'autoscaler/_private/commands.py' ) YAPF_EXCLUDES=( @@ -126,10 +130,12 @@ shellcheck_scripts() { # Runs mypy on each argument in sequence. This is different than running mypy # once on the list of arguments. mypy_on_each() { + pushd python/ray for file in "$@"; do echo "Running mypy on $file" mypy ${MYPY_FLAGS[@]+"${MYPY_FLAGS[@]}"} "$file" done + popd } @@ -166,8 +172,6 @@ format_files() { if [ 0 -lt "${#python_files[@]}" ]; then yapf --in-place "${YAPF_FLAGS[@]}" -- "${python_files[@]}" - echo "Running mypy on provided python files:" - mypy_on_each "${python_files[@]}" fi if shellcheck --shell=sh --format=diff - < /dev/null; then diff --git a/python/ray/autoscaler/_private/commands.py b/python/ray/autoscaler/_private/commands.py index eed234593..5e5488ff4 100644 --- a/python/ray/autoscaler/_private/commands.py +++ b/python/ray/autoscaler/_private/commands.py @@ -8,9 +8,11 @@ import sys import subprocess import tempfile import time -from typing import Any, Dict, Optional, List +from types import ModuleType +from typing import Any, Dict, List, Optional, Tuple, Union import click +import redis import yaml try: # py3 from shlex import quote @@ -19,6 +21,7 @@ except ImportError: # py2 from ray.experimental.internal_kv import _internal_kv_get import ray._private.services as services +from ray.autoscaler.node_provider import NodeProvider from ray.autoscaler._private.constants import \ AUTOSCALER_RESOURCE_REQUEST_CHANNEL from ray.autoscaler._private.util import validate_config, hash_runtime_conf, \ @@ -33,7 +36,7 @@ from ray.autoscaler._private.updater import NodeUpdaterThread from ray.autoscaler._private.command_runner import set_using_login_shells, \ set_rsync_silent from ray.autoscaler._private.log_timer import LogTimer -from ray.worker import global_worker +from ray.worker import global_worker # type: ignore from ray.util.debug import log_once import ray.autoscaler._private.subprocess_output_util as cmd_output_util @@ -46,8 +49,10 @@ RUN_ENV_TYPES = ["auto", "host", "docker"] POLL_INTERVAL = 5 +Port_forward = Union[Tuple[int, int], List[Tuple[int, int]]] -def _redis(): + +def _redis() -> redis.StrictRedis: global redis_client if redis_client is None: redis_client = services.create_redis_client( @@ -56,19 +61,21 @@ def _redis(): return redis_client -def try_logging_config(config): +def try_logging_config(config: Dict[str, Any]) -> None: if config["provider"]["type"] == "aws": from ray.autoscaler._private.aws.config import log_to_cli log_to_cli(config) -def try_get_log_state(provider_config): +def try_get_log_state(provider_config: Dict[str, Any]) -> Optional[dict]: if provider_config["type"] == "aws": from ray.autoscaler._private.aws.config import get_log_state return get_log_state() + return None -def try_reload_log_state(provider_config, log_state): +def try_reload_log_state(provider_config: Dict[str, Any], + log_state: dict) -> None: if not log_state: return if provider_config["type"] == "aws": @@ -76,7 +83,7 @@ def try_reload_log_state(provider_config, log_state): return reload_log_state(log_state) -def debug_status(): +def debug_status() -> str: """Return a debug string for the autoscaler.""" status = _internal_kv_get(DEBUG_AUTOSCALING_STATUS) error = _internal_kv_get(DEBUG_AUTOSCALING_ERROR) @@ -90,7 +97,8 @@ def debug_status(): return status -def request_resources(num_cpus=None, bundles=None): +def request_resources(num_cpus: Optional[int] = None, + bundles: Optional[List[dict]] = None) -> None: """Remotely request some CPU or GPU resources from the autoscaler. This function is to be called e.g. on a node before submitting a bunch of @@ -122,7 +130,7 @@ def create_or_update_cluster(config_file: str, yes: bool, override_cluster_name: Optional[str] = None, no_config_cache: bool = False, - redirect_command_output: bool = False, + redirect_command_output: Optional[bool] = False, use_login_shells: bool = True) -> None: """Create or updates an autoscaling Ray cluster from a config json.""" set_using_login_shells(use_login_shells) @@ -277,7 +285,7 @@ def _bootstrap_config(config: Dict[str, Any], def teardown_cluster(config_file: str, yes: bool, workers_only: bool, override_cluster_name: Optional[str], - keep_min_workers: bool): + keep_min_workers: bool) -> None: """Destroys all nodes of a Ray cluster described by a config json.""" config = yaml.safe_load(open(config_file).read()) if override_cluster_name is not None: @@ -407,7 +415,8 @@ def teardown_cluster(config_file: str, yes: bool, workers_only: bool, provider.cleanup() -def kill_node(config_file, yes, hard, override_cluster_name): +def kill_node(config_file: str, yes: bool, hard: bool, + override_cluster_name: Optional[str]) -> str: """Kills a random Raylet worker.""" config = yaml.safe_load(open(config_file).read()) @@ -458,7 +467,8 @@ def kill_node(config_file, yes, hard, override_cluster_name): return node_ip -def monitor_cluster(cluster_config_file, num_lines, override_cluster_name): +def monitor_cluster(cluster_config_file: str, num_lines: int, + override_cluster_name: Optional[str]) -> None: """Tails the autoscaler logs of a Ray cluster.""" cmd = f"tail -n {num_lines} -f /tmp/ray/session_latest/logs/monitor*" exec_cluster( @@ -473,7 +483,7 @@ def monitor_cluster(cluster_config_file, num_lines, override_cluster_name): port_forward=None) -def warn_about_bad_start_command(start_commands): +def warn_about_bad_start_command(start_commands: List[str]) -> None: ray_start_cmd = list(filter(lambda x: "ray start" in x, start_commands)) if len(ray_start_cmd) == 0: cli_logger.warning( @@ -498,14 +508,14 @@ def warn_about_bad_start_command(start_commands): "to ray start in the head_start_ray_commands section.") -def get_or_create_head_node(config, - config_file, - no_restart, - restart_only, - yes, - override_cluster_name, - _provider=None, - _runner=subprocess): +def get_or_create_head_node(config: Dict[str, Any], + config_file: str, + no_restart: bool, + restart_only: bool, + yes: bool, + override_cluster_name: Optional[str], + _provider: Optional[NodeProvider] = None, + _runner: ModuleType = subprocess) -> None: """Create the cluster head node, which in turn creates the workers.""" provider = (_provider or _get_node_provider(config["provider"], config["cluster_name"])) @@ -766,7 +776,7 @@ def attach_cluster(config_file: str, override_cluster_name: Optional[str], no_config_cache: bool = False, new: bool = False, - port_forward: Any = None): + port_forward: Optional[Port_forward] = None) -> None: """Attaches to a screen for the specified cluster. Arguments: @@ -776,7 +786,7 @@ def attach_cluster(config_file: str, use_tmux: whether to use tmux as multiplexer override_cluster_name: set the name of the cluster new: whether to force a new screen - port_forward (int or list[int]): port(s) to forward + port_forward ( (int,int) or list[(int,int)] ): port(s) to forward """ if use_tmux: @@ -819,8 +829,8 @@ def exec_cluster(config_file: str, start: bool = False, override_cluster_name: Optional[str] = None, no_config_cache: bool = False, - port_forward: Any = None, - with_output: bool = False): + port_forward: Optional[Port_forward] = None, + with_output: bool = False) -> str: """Runs a command on the specified cluster. Arguments: @@ -833,7 +843,7 @@ def exec_cluster(config_file: str, stop: whether to stop the cluster after command run start: whether to start the cluster if it isn't up override_cluster_name: set the name of the cluster - port_forward (int or list[int]): port(s) to forward + port_forward ( (int, int) or list[(int, int)] ): port(s) to forward """ assert not (screen and tmux), "Can specify only one of `screen` or `tmux`." assert run_env in RUN_ENV_TYPES, "--run_env must be in {}".format( @@ -906,28 +916,28 @@ def exec_cluster(config_file: str, provider.cleanup() -def _exec(updater, - cmd, - screen, - tmux, - port_forward=None, - with_output=False, - run_env="auto", - shutdown_after_run=False): +def _exec(updater: NodeUpdaterThread, + cmd: Optional[str] = None, + screen: bool = False, + tmux: bool = False, + port_forward: Optional[Port_forward] = None, + with_output: bool = False, + run_env: str = "auto", + shutdown_after_run: bool = False) -> str: if cmd: if screen: - cmd = [ + wrapped_cmd = [ "screen", "-L", "-dm", "bash", "-c", quote(cmd + "; exec bash") ] - cmd = " ".join(cmd) + cmd = " ".join(wrapped_cmd) elif tmux: # TODO: Consider providing named session functionality - cmd = [ + wrapped_cmd = [ "tmux", "new", "-d", "bash", "-c", quote(cmd + "; exec bash") ] - cmd = " ".join(cmd) + cmd = " ".join(wrapped_cmd) return updater.cmd_runner.run( cmd, exit_on_fail=True, @@ -946,7 +956,7 @@ def rsync(config_file: str, use_internal_ip: bool = False, no_config_cache: bool = False, all_nodes: bool = False, - _runner=subprocess): + _runner: ModuleType = subprocess) -> None: """Rsyncs files. Arguments: @@ -1080,7 +1090,8 @@ def get_worker_node_ips(config_file: str, provider.cleanup() -def _get_worker_nodes(config, override_cluster_name): +def _get_worker_nodes(config: Dict[str, Any], + override_cluster_name: Optional[str]) -> List[str]: """Returns worker node ids for given configuration.""" # todo: technically could be reused in get_worker_node_ips if override_cluster_name is not None: @@ -1126,5 +1137,5 @@ def _get_head_node(config: Dict[str, Any], config["cluster_name"])) -def confirm(msg, yes): +def confirm(msg: str, yes: bool) -> Optional[bool]: return None if yes else click.confirm(msg, abort=True) diff --git a/python/ray/autoscaler/sdk.py b/python/ray/autoscaler/sdk.py index 0a48d6533..334f45b3c 100644 --- a/python/ray/autoscaler/sdk.py +++ b/python/ray/autoscaler/sdk.py @@ -1,7 +1,7 @@ """IMPORTANT: this is an experimental interface and not currently stable.""" from contextlib import contextmanager -from typing import Any, Dict, Optional, List, Union +from typing import Any, Dict, Iterator, List, Optional, Union import json import os import tempfile @@ -62,8 +62,8 @@ def run_on_cluster(cluster_config: Union[dict, str], cmd: Optional[str] = None, run_env: str = "auto", no_config_cache: bool = False, - port_forward: Union[int, List[int]] = None, - with_output: bool = False) -> str: + port_forward: Optional[commands.Port_forward] = None, + with_output: bool = False) -> Optional[str]: """Runs a command on the specified cluster. Args: @@ -74,7 +74,7 @@ def run_on_cluster(cluster_config: Union[dict, str], container. Select between "auto", "host" and "docker". no_config_cache (bool): Whether to disable the config cache and fully resolve all environment settings from the Cloud provider again. - port_forward (int or list[int]): port(s) to forward. + port_forward ( (int,int) or list[(int,int)]): port(s) to forward. with_output (bool): Whether to capture command output. Returns: @@ -167,7 +167,8 @@ def get_worker_node_ips(cluster_config: Union[dict, str]) -> List[str]: return commands.get_worker_node_ips(config_file) -def request_resources(num_cpus=None, bundles=None): +def request_resources(num_cpus: Optional[int] = None, + bundles: Optional[List[dict]] = None) -> None: """Remotely request some CPU or GPU resources from the autoscaler. This function is to be called e.g. on a node before submitting a bunch of @@ -185,7 +186,7 @@ def request_resources(num_cpus=None, bundles=None): @contextmanager -def _as_config_file(cluster_config: Union[dict, str]): +def _as_config_file(cluster_config: Union[dict, str]) -> Iterator[str]: if isinstance(cluster_config, dict): tmp = tempfile.NamedTemporaryFile("w", prefix="autoscaler-sdk-tmp-") tmp.write(json.dumps(cluster_config)) @@ -196,8 +197,8 @@ def _as_config_file(cluster_config: Union[dict, str]): yield cluster_config -def bootstrap_config(cluster_config: Dict[str, any], - no_config_cache: bool = False) -> bool: +def bootstrap_config(cluster_config: Dict[str, Any], + no_config_cache: bool = False) -> Dict[str, Any]: """Validate and add provider-specific fields to the config. For example, IAM/authentication may be added here.""" return commands._bootstrap_config(cluster_config, no_config_cache)