mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 01:27:43 +08:00
[autoscaler]Type hints for commands.py and sdk.py. (#11354)
This commit is contained in:
@@ -8,9 +8,11 @@ import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Dict, Optional, List
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import click
|
||||
import redis
|
||||
import yaml
|
||||
try: # py3
|
||||
from shlex import quote
|
||||
@@ -19,6 +21,7 @@ except ImportError: # py2
|
||||
|
||||
from ray.experimental.internal_kv import _internal_kv_get
|
||||
import ray._private.services as services
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler._private.constants import \
|
||||
AUTOSCALER_RESOURCE_REQUEST_CHANNEL
|
||||
from ray.autoscaler._private.util import validate_config, hash_runtime_conf, \
|
||||
@@ -33,7 +36,7 @@ from ray.autoscaler._private.updater import NodeUpdaterThread
|
||||
from ray.autoscaler._private.command_runner import set_using_login_shells, \
|
||||
set_rsync_silent
|
||||
from ray.autoscaler._private.log_timer import LogTimer
|
||||
from ray.worker import global_worker
|
||||
from ray.worker import global_worker # type: ignore
|
||||
from ray.util.debug import log_once
|
||||
|
||||
import ray.autoscaler._private.subprocess_output_util as cmd_output_util
|
||||
@@ -46,8 +49,10 @@ RUN_ENV_TYPES = ["auto", "host", "docker"]
|
||||
|
||||
POLL_INTERVAL = 5
|
||||
|
||||
Port_forward = Union[Tuple[int, int], List[Tuple[int, int]]]
|
||||
|
||||
def _redis():
|
||||
|
||||
def _redis() -> redis.StrictRedis:
|
||||
global redis_client
|
||||
if redis_client is None:
|
||||
redis_client = services.create_redis_client(
|
||||
@@ -56,19 +61,21 @@ def _redis():
|
||||
return redis_client
|
||||
|
||||
|
||||
def try_logging_config(config):
|
||||
def try_logging_config(config: Dict[str, Any]) -> None:
|
||||
if config["provider"]["type"] == "aws":
|
||||
from ray.autoscaler._private.aws.config import log_to_cli
|
||||
log_to_cli(config)
|
||||
|
||||
|
||||
def try_get_log_state(provider_config):
|
||||
def try_get_log_state(provider_config: Dict[str, Any]) -> Optional[dict]:
|
||||
if provider_config["type"] == "aws":
|
||||
from ray.autoscaler._private.aws.config import get_log_state
|
||||
return get_log_state()
|
||||
return None
|
||||
|
||||
|
||||
def try_reload_log_state(provider_config, log_state):
|
||||
def try_reload_log_state(provider_config: Dict[str, Any],
|
||||
log_state: dict) -> None:
|
||||
if not log_state:
|
||||
return
|
||||
if provider_config["type"] == "aws":
|
||||
@@ -76,7 +83,7 @@ def try_reload_log_state(provider_config, log_state):
|
||||
return reload_log_state(log_state)
|
||||
|
||||
|
||||
def debug_status():
|
||||
def debug_status() -> str:
|
||||
"""Return a debug string for the autoscaler."""
|
||||
status = _internal_kv_get(DEBUG_AUTOSCALING_STATUS)
|
||||
error = _internal_kv_get(DEBUG_AUTOSCALING_ERROR)
|
||||
@@ -90,7 +97,8 @@ def debug_status():
|
||||
return status
|
||||
|
||||
|
||||
def request_resources(num_cpus=None, bundles=None):
|
||||
def request_resources(num_cpus: Optional[int] = None,
|
||||
bundles: Optional[List[dict]] = None) -> None:
|
||||
"""Remotely request some CPU or GPU resources from the autoscaler.
|
||||
|
||||
This function is to be called e.g. on a node before submitting a bunch of
|
||||
@@ -122,7 +130,7 @@ def create_or_update_cluster(config_file: str,
|
||||
yes: bool,
|
||||
override_cluster_name: Optional[str] = None,
|
||||
no_config_cache: bool = False,
|
||||
redirect_command_output: bool = False,
|
||||
redirect_command_output: Optional[bool] = False,
|
||||
use_login_shells: bool = True) -> None:
|
||||
"""Create or updates an autoscaling Ray cluster from a config json."""
|
||||
set_using_login_shells(use_login_shells)
|
||||
@@ -277,7 +285,7 @@ def _bootstrap_config(config: Dict[str, Any],
|
||||
|
||||
def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
|
||||
override_cluster_name: Optional[str],
|
||||
keep_min_workers: bool):
|
||||
keep_min_workers: bool) -> None:
|
||||
"""Destroys all nodes of a Ray cluster described by a config json."""
|
||||
config = yaml.safe_load(open(config_file).read())
|
||||
if override_cluster_name is not None:
|
||||
@@ -407,7 +415,8 @@ def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
|
||||
provider.cleanup()
|
||||
|
||||
|
||||
def kill_node(config_file, yes, hard, override_cluster_name):
|
||||
def kill_node(config_file: str, yes: bool, hard: bool,
|
||||
override_cluster_name: Optional[str]) -> str:
|
||||
"""Kills a random Raylet worker."""
|
||||
|
||||
config = yaml.safe_load(open(config_file).read())
|
||||
@@ -458,7 +467,8 @@ def kill_node(config_file, yes, hard, override_cluster_name):
|
||||
return node_ip
|
||||
|
||||
|
||||
def monitor_cluster(cluster_config_file, num_lines, override_cluster_name):
|
||||
def monitor_cluster(cluster_config_file: str, num_lines: int,
|
||||
override_cluster_name: Optional[str]) -> None:
|
||||
"""Tails the autoscaler logs of a Ray cluster."""
|
||||
cmd = f"tail -n {num_lines} -f /tmp/ray/session_latest/logs/monitor*"
|
||||
exec_cluster(
|
||||
@@ -473,7 +483,7 @@ def monitor_cluster(cluster_config_file, num_lines, override_cluster_name):
|
||||
port_forward=None)
|
||||
|
||||
|
||||
def warn_about_bad_start_command(start_commands):
|
||||
def warn_about_bad_start_command(start_commands: List[str]) -> None:
|
||||
ray_start_cmd = list(filter(lambda x: "ray start" in x, start_commands))
|
||||
if len(ray_start_cmd) == 0:
|
||||
cli_logger.warning(
|
||||
@@ -498,14 +508,14 @@ def warn_about_bad_start_command(start_commands):
|
||||
"to ray start in the head_start_ray_commands section.")
|
||||
|
||||
|
||||
def get_or_create_head_node(config,
|
||||
config_file,
|
||||
no_restart,
|
||||
restart_only,
|
||||
yes,
|
||||
override_cluster_name,
|
||||
_provider=None,
|
||||
_runner=subprocess):
|
||||
def get_or_create_head_node(config: Dict[str, Any],
|
||||
config_file: str,
|
||||
no_restart: bool,
|
||||
restart_only: bool,
|
||||
yes: bool,
|
||||
override_cluster_name: Optional[str],
|
||||
_provider: Optional[NodeProvider] = None,
|
||||
_runner: ModuleType = subprocess) -> None:
|
||||
"""Create the cluster head node, which in turn creates the workers."""
|
||||
provider = (_provider or _get_node_provider(config["provider"],
|
||||
config["cluster_name"]))
|
||||
@@ -766,7 +776,7 @@ def attach_cluster(config_file: str,
|
||||
override_cluster_name: Optional[str],
|
||||
no_config_cache: bool = False,
|
||||
new: bool = False,
|
||||
port_forward: Any = None):
|
||||
port_forward: Optional[Port_forward] = None) -> None:
|
||||
"""Attaches to a screen for the specified cluster.
|
||||
|
||||
Arguments:
|
||||
@@ -776,7 +786,7 @@ def attach_cluster(config_file: str,
|
||||
use_tmux: whether to use tmux as multiplexer
|
||||
override_cluster_name: set the name of the cluster
|
||||
new: whether to force a new screen
|
||||
port_forward (int or list[int]): port(s) to forward
|
||||
port_forward ( (int,int) or list[(int,int)] ): port(s) to forward
|
||||
"""
|
||||
|
||||
if use_tmux:
|
||||
@@ -819,8 +829,8 @@ def exec_cluster(config_file: str,
|
||||
start: bool = False,
|
||||
override_cluster_name: Optional[str] = None,
|
||||
no_config_cache: bool = False,
|
||||
port_forward: Any = None,
|
||||
with_output: bool = False):
|
||||
port_forward: Optional[Port_forward] = None,
|
||||
with_output: bool = False) -> str:
|
||||
"""Runs a command on the specified cluster.
|
||||
|
||||
Arguments:
|
||||
@@ -833,7 +843,7 @@ def exec_cluster(config_file: str,
|
||||
stop: whether to stop the cluster after command run
|
||||
start: whether to start the cluster if it isn't up
|
||||
override_cluster_name: set the name of the cluster
|
||||
port_forward (int or list[int]): port(s) to forward
|
||||
port_forward ( (int, int) or list[(int, int)] ): port(s) to forward
|
||||
"""
|
||||
assert not (screen and tmux), "Can specify only one of `screen` or `tmux`."
|
||||
assert run_env in RUN_ENV_TYPES, "--run_env must be in {}".format(
|
||||
@@ -906,28 +916,28 @@ def exec_cluster(config_file: str,
|
||||
provider.cleanup()
|
||||
|
||||
|
||||
def _exec(updater,
|
||||
cmd,
|
||||
screen,
|
||||
tmux,
|
||||
port_forward=None,
|
||||
with_output=False,
|
||||
run_env="auto",
|
||||
shutdown_after_run=False):
|
||||
def _exec(updater: NodeUpdaterThread,
|
||||
cmd: Optional[str] = None,
|
||||
screen: bool = False,
|
||||
tmux: bool = False,
|
||||
port_forward: Optional[Port_forward] = None,
|
||||
with_output: bool = False,
|
||||
run_env: str = "auto",
|
||||
shutdown_after_run: bool = False) -> str:
|
||||
if cmd:
|
||||
if screen:
|
||||
cmd = [
|
||||
wrapped_cmd = [
|
||||
"screen", "-L", "-dm", "bash", "-c",
|
||||
quote(cmd + "; exec bash")
|
||||
]
|
||||
cmd = " ".join(cmd)
|
||||
cmd = " ".join(wrapped_cmd)
|
||||
elif tmux:
|
||||
# TODO: Consider providing named session functionality
|
||||
cmd = [
|
||||
wrapped_cmd = [
|
||||
"tmux", "new", "-d", "bash", "-c",
|
||||
quote(cmd + "; exec bash")
|
||||
]
|
||||
cmd = " ".join(cmd)
|
||||
cmd = " ".join(wrapped_cmd)
|
||||
return updater.cmd_runner.run(
|
||||
cmd,
|
||||
exit_on_fail=True,
|
||||
@@ -946,7 +956,7 @@ def rsync(config_file: str,
|
||||
use_internal_ip: bool = False,
|
||||
no_config_cache: bool = False,
|
||||
all_nodes: bool = False,
|
||||
_runner=subprocess):
|
||||
_runner: ModuleType = subprocess) -> None:
|
||||
"""Rsyncs files.
|
||||
|
||||
Arguments:
|
||||
@@ -1080,7 +1090,8 @@ def get_worker_node_ips(config_file: str,
|
||||
provider.cleanup()
|
||||
|
||||
|
||||
def _get_worker_nodes(config, override_cluster_name):
|
||||
def _get_worker_nodes(config: Dict[str, Any],
|
||||
override_cluster_name: Optional[str]) -> List[str]:
|
||||
"""Returns worker node ids for given configuration."""
|
||||
# todo: technically could be reused in get_worker_node_ips
|
||||
if override_cluster_name is not None:
|
||||
@@ -1126,5 +1137,5 @@ def _get_head_node(config: Dict[str, Any],
|
||||
config["cluster_name"]))
|
||||
|
||||
|
||||
def confirm(msg, yes):
|
||||
def confirm(msg: str, yes: bool) -> Optional[bool]:
|
||||
return None if yes else click.confirm(msg, abort=True)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""IMPORTANT: this is an experimental interface and not currently stable."""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Optional, List, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
@@ -62,8 +62,8 @@ def run_on_cluster(cluster_config: Union[dict, str],
|
||||
cmd: Optional[str] = None,
|
||||
run_env: str = "auto",
|
||||
no_config_cache: bool = False,
|
||||
port_forward: Union[int, List[int]] = None,
|
||||
with_output: bool = False) -> str:
|
||||
port_forward: Optional[commands.Port_forward] = None,
|
||||
with_output: bool = False) -> Optional[str]:
|
||||
"""Runs a command on the specified cluster.
|
||||
|
||||
Args:
|
||||
@@ -74,7 +74,7 @@ def run_on_cluster(cluster_config: Union[dict, str],
|
||||
container. Select between "auto", "host" and "docker".
|
||||
no_config_cache (bool): Whether to disable the config cache and fully
|
||||
resolve all environment settings from the Cloud provider again.
|
||||
port_forward (int or list[int]): port(s) to forward.
|
||||
port_forward ( (int,int) or list[(int,int)]): port(s) to forward.
|
||||
with_output (bool): Whether to capture command output.
|
||||
|
||||
Returns:
|
||||
@@ -167,7 +167,8 @@ def get_worker_node_ips(cluster_config: Union[dict, str]) -> List[str]:
|
||||
return commands.get_worker_node_ips(config_file)
|
||||
|
||||
|
||||
def request_resources(num_cpus=None, bundles=None):
|
||||
def request_resources(num_cpus: Optional[int] = None,
|
||||
bundles: Optional[List[dict]] = None) -> None:
|
||||
"""Remotely request some CPU or GPU resources from the autoscaler.
|
||||
|
||||
This function is to be called e.g. on a node before submitting a bunch of
|
||||
@@ -185,7 +186,7 @@ def request_resources(num_cpus=None, bundles=None):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _as_config_file(cluster_config: Union[dict, str]):
|
||||
def _as_config_file(cluster_config: Union[dict, str]) -> Iterator[str]:
|
||||
if isinstance(cluster_config, dict):
|
||||
tmp = tempfile.NamedTemporaryFile("w", prefix="autoscaler-sdk-tmp-")
|
||||
tmp.write(json.dumps(cluster_config))
|
||||
@@ -196,8 +197,8 @@ def _as_config_file(cluster_config: Union[dict, str]):
|
||||
yield cluster_config
|
||||
|
||||
|
||||
def bootstrap_config(cluster_config: Dict[str, any],
|
||||
no_config_cache: bool = False) -> bool:
|
||||
def bootstrap_config(cluster_config: Dict[str, Any],
|
||||
no_config_cache: bool = False) -> Dict[str, Any]:
|
||||
"""Validate and add provider-specific fields to the config. For example,
|
||||
IAM/authentication may be added here."""
|
||||
return commands._bootstrap_config(cluster_config, no_config_cache)
|
||||
|
||||
Reference in New Issue
Block a user