[autoscaler]Type hints for commands.py and sdk.py. (#11354)

This commit is contained in:
Gekho457
2020-10-16 16:45:36 -04:00
committed by GitHub
parent 2aec77e305
commit 50be2970dc
3 changed files with 68 additions and 52 deletions
+52 -41
View File
@@ -8,9 +8,11 @@ import sys
import subprocess
import tempfile
import time
from typing import Any, Dict, Optional, List
from types import ModuleType
from typing import Any, Dict, List, Optional, Tuple, Union
import click
import redis
import yaml
try: # py3
from shlex import quote
@@ -19,6 +21,7 @@ except ImportError: # py2
from ray.experimental.internal_kv import _internal_kv_get
import ray._private.services as services
from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler._private.constants import \
AUTOSCALER_RESOURCE_REQUEST_CHANNEL
from ray.autoscaler._private.util import validate_config, hash_runtime_conf, \
@@ -33,7 +36,7 @@ from ray.autoscaler._private.updater import NodeUpdaterThread
from ray.autoscaler._private.command_runner import set_using_login_shells, \
set_rsync_silent
from ray.autoscaler._private.log_timer import LogTimer
from ray.worker import global_worker
from ray.worker import global_worker # type: ignore
from ray.util.debug import log_once
import ray.autoscaler._private.subprocess_output_util as cmd_output_util
@@ -46,8 +49,10 @@ RUN_ENV_TYPES = ["auto", "host", "docker"]
POLL_INTERVAL = 5
Port_forward = Union[Tuple[int, int], List[Tuple[int, int]]]
def _redis():
def _redis() -> redis.StrictRedis:
global redis_client
if redis_client is None:
redis_client = services.create_redis_client(
@@ -56,19 +61,21 @@ def _redis():
return redis_client
def try_logging_config(config):
def try_logging_config(config: Dict[str, Any]) -> None:
if config["provider"]["type"] == "aws":
from ray.autoscaler._private.aws.config import log_to_cli
log_to_cli(config)
def try_get_log_state(provider_config):
def try_get_log_state(provider_config: Dict[str, Any]) -> Optional[dict]:
if provider_config["type"] == "aws":
from ray.autoscaler._private.aws.config import get_log_state
return get_log_state()
return None
def try_reload_log_state(provider_config, log_state):
def try_reload_log_state(provider_config: Dict[str, Any],
log_state: dict) -> None:
if not log_state:
return
if provider_config["type"] == "aws":
@@ -76,7 +83,7 @@ def try_reload_log_state(provider_config, log_state):
return reload_log_state(log_state)
def debug_status():
def debug_status() -> str:
"""Return a debug string for the autoscaler."""
status = _internal_kv_get(DEBUG_AUTOSCALING_STATUS)
error = _internal_kv_get(DEBUG_AUTOSCALING_ERROR)
@@ -90,7 +97,8 @@ def debug_status():
return status
def request_resources(num_cpus=None, bundles=None):
def request_resources(num_cpus: Optional[int] = None,
bundles: Optional[List[dict]] = None) -> None:
"""Remotely request some CPU or GPU resources from the autoscaler.
This function is to be called e.g. on a node before submitting a bunch of
@@ -122,7 +130,7 @@ def create_or_update_cluster(config_file: str,
yes: bool,
override_cluster_name: Optional[str] = None,
no_config_cache: bool = False,
redirect_command_output: bool = False,
redirect_command_output: Optional[bool] = False,
use_login_shells: bool = True) -> None:
"""Create or updates an autoscaling Ray cluster from a config json."""
set_using_login_shells(use_login_shells)
@@ -277,7 +285,7 @@ def _bootstrap_config(config: Dict[str, Any],
def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
override_cluster_name: Optional[str],
keep_min_workers: bool):
keep_min_workers: bool) -> None:
"""Destroys all nodes of a Ray cluster described by a config json."""
config = yaml.safe_load(open(config_file).read())
if override_cluster_name is not None:
@@ -407,7 +415,8 @@ def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
provider.cleanup()
def kill_node(config_file, yes, hard, override_cluster_name):
def kill_node(config_file: str, yes: bool, hard: bool,
override_cluster_name: Optional[str]) -> str:
"""Kills a random Raylet worker."""
config = yaml.safe_load(open(config_file).read())
@@ -458,7 +467,8 @@ def kill_node(config_file, yes, hard, override_cluster_name):
return node_ip
def monitor_cluster(cluster_config_file, num_lines, override_cluster_name):
def monitor_cluster(cluster_config_file: str, num_lines: int,
override_cluster_name: Optional[str]) -> None:
"""Tails the autoscaler logs of a Ray cluster."""
cmd = f"tail -n {num_lines} -f /tmp/ray/session_latest/logs/monitor*"
exec_cluster(
@@ -473,7 +483,7 @@ def monitor_cluster(cluster_config_file, num_lines, override_cluster_name):
port_forward=None)
def warn_about_bad_start_command(start_commands):
def warn_about_bad_start_command(start_commands: List[str]) -> None:
ray_start_cmd = list(filter(lambda x: "ray start" in x, start_commands))
if len(ray_start_cmd) == 0:
cli_logger.warning(
@@ -498,14 +508,14 @@ def warn_about_bad_start_command(start_commands):
"to ray start in the head_start_ray_commands section.")
def get_or_create_head_node(config,
config_file,
no_restart,
restart_only,
yes,
override_cluster_name,
_provider=None,
_runner=subprocess):
def get_or_create_head_node(config: Dict[str, Any],
config_file: str,
no_restart: bool,
restart_only: bool,
yes: bool,
override_cluster_name: Optional[str],
_provider: Optional[NodeProvider] = None,
_runner: ModuleType = subprocess) -> None:
"""Create the cluster head node, which in turn creates the workers."""
provider = (_provider or _get_node_provider(config["provider"],
config["cluster_name"]))
@@ -766,7 +776,7 @@ def attach_cluster(config_file: str,
override_cluster_name: Optional[str],
no_config_cache: bool = False,
new: bool = False,
port_forward: Any = None):
port_forward: Optional[Port_forward] = None) -> None:
"""Attaches to a screen for the specified cluster.
Arguments:
@@ -776,7 +786,7 @@ def attach_cluster(config_file: str,
use_tmux: whether to use tmux as multiplexer
override_cluster_name: set the name of the cluster
new: whether to force a new screen
port_forward (int or list[int]): port(s) to forward
port_forward ( (int,int) or list[(int,int)] ): port(s) to forward
"""
if use_tmux:
@@ -819,8 +829,8 @@ def exec_cluster(config_file: str,
start: bool = False,
override_cluster_name: Optional[str] = None,
no_config_cache: bool = False,
port_forward: Any = None,
with_output: bool = False):
port_forward: Optional[Port_forward] = None,
with_output: bool = False) -> str:
"""Runs a command on the specified cluster.
Arguments:
@@ -833,7 +843,7 @@ def exec_cluster(config_file: str,
stop: whether to stop the cluster after command run
start: whether to start the cluster if it isn't up
override_cluster_name: set the name of the cluster
port_forward (int or list[int]): port(s) to forward
port_forward ( (int, int) or list[(int, int)] ): port(s) to forward
"""
assert not (screen and tmux), "Can specify only one of `screen` or `tmux`."
assert run_env in RUN_ENV_TYPES, "--run_env must be in {}".format(
@@ -906,28 +916,28 @@ def exec_cluster(config_file: str,
provider.cleanup()
def _exec(updater,
cmd,
screen,
tmux,
port_forward=None,
with_output=False,
run_env="auto",
shutdown_after_run=False):
def _exec(updater: NodeUpdaterThread,
cmd: Optional[str] = None,
screen: bool = False,
tmux: bool = False,
port_forward: Optional[Port_forward] = None,
with_output: bool = False,
run_env: str = "auto",
shutdown_after_run: bool = False) -> str:
if cmd:
if screen:
cmd = [
wrapped_cmd = [
"screen", "-L", "-dm", "bash", "-c",
quote(cmd + "; exec bash")
]
cmd = " ".join(cmd)
cmd = " ".join(wrapped_cmd)
elif tmux:
# TODO: Consider providing named session functionality
cmd = [
wrapped_cmd = [
"tmux", "new", "-d", "bash", "-c",
quote(cmd + "; exec bash")
]
cmd = " ".join(cmd)
cmd = " ".join(wrapped_cmd)
return updater.cmd_runner.run(
cmd,
exit_on_fail=True,
@@ -946,7 +956,7 @@ def rsync(config_file: str,
use_internal_ip: bool = False,
no_config_cache: bool = False,
all_nodes: bool = False,
_runner=subprocess):
_runner: ModuleType = subprocess) -> None:
"""Rsyncs files.
Arguments:
@@ -1080,7 +1090,8 @@ def get_worker_node_ips(config_file: str,
provider.cleanup()
def _get_worker_nodes(config, override_cluster_name):
def _get_worker_nodes(config: Dict[str, Any],
override_cluster_name: Optional[str]) -> List[str]:
"""Returns worker node ids for given configuration."""
# todo: technically could be reused in get_worker_node_ips
if override_cluster_name is not None:
@@ -1126,5 +1137,5 @@ def _get_head_node(config: Dict[str, Any],
config["cluster_name"]))
def confirm(msg, yes):
def confirm(msg: str, yes: bool) -> Optional[bool]:
return None if yes else click.confirm(msg, abort=True)
+9 -8
View File
@@ -1,7 +1,7 @@
"""IMPORTANT: this is an experimental interface and not currently stable."""
from contextlib import contextmanager
from typing import Any, Dict, Optional, List, Union
from typing import Any, Dict, Iterator, List, Optional, Union
import json
import os
import tempfile
@@ -62,8 +62,8 @@ def run_on_cluster(cluster_config: Union[dict, str],
cmd: Optional[str] = None,
run_env: str = "auto",
no_config_cache: bool = False,
port_forward: Union[int, List[int]] = None,
with_output: bool = False) -> str:
port_forward: Optional[commands.Port_forward] = None,
with_output: bool = False) -> Optional[str]:
"""Runs a command on the specified cluster.
Args:
@@ -74,7 +74,7 @@ def run_on_cluster(cluster_config: Union[dict, str],
container. Select between "auto", "host" and "docker".
no_config_cache (bool): Whether to disable the config cache and fully
resolve all environment settings from the Cloud provider again.
port_forward (int or list[int]): port(s) to forward.
port_forward ( (int,int) or list[(int,int)]): port(s) to forward.
with_output (bool): Whether to capture command output.
Returns:
@@ -167,7 +167,8 @@ def get_worker_node_ips(cluster_config: Union[dict, str]) -> List[str]:
return commands.get_worker_node_ips(config_file)
def request_resources(num_cpus=None, bundles=None):
def request_resources(num_cpus: Optional[int] = None,
bundles: Optional[List[dict]] = None) -> None:
"""Remotely request some CPU or GPU resources from the autoscaler.
This function is to be called e.g. on a node before submitting a bunch of
@@ -185,7 +186,7 @@ def request_resources(num_cpus=None, bundles=None):
@contextmanager
def _as_config_file(cluster_config: Union[dict, str]):
def _as_config_file(cluster_config: Union[dict, str]) -> Iterator[str]:
if isinstance(cluster_config, dict):
tmp = tempfile.NamedTemporaryFile("w", prefix="autoscaler-sdk-tmp-")
tmp.write(json.dumps(cluster_config))
@@ -196,8 +197,8 @@ def _as_config_file(cluster_config: Union[dict, str]):
yield cluster_config
def bootstrap_config(cluster_config: Dict[str, any],
no_config_cache: bool = False) -> bool:
def bootstrap_config(cluster_config: Dict[str, Any],
no_config_cache: bool = False) -> Dict[str, Any]:
"""Validate and add provider-specific fields to the config. For example,
IAM/authentication may be added here."""
return commands._bootstrap_config(cluster_config, no_config_cache)