diff --git a/python/ray/autoscaler/_private/commands.py b/python/ray/autoscaler/_private/commands.py index 9f5a2d0f1..eed234593 100644 --- a/python/ray/autoscaler/_private/commands.py +++ b/python/ray/autoscaler/_private/commands.py @@ -120,7 +120,7 @@ def create_or_update_cluster(config_file: str, no_restart: bool, restart_only: bool, yes: bool, - override_cluster_name: Optional[str], + override_cluster_name: Optional[str] = None, no_config_cache: bool = False, redirect_command_output: bool = False, use_login_shells: bool = True) -> None: @@ -1037,7 +1037,7 @@ def rsync(config_file: str, def get_head_node_ip(config_file: str, - override_cluster_name: Optional[str]) -> str: + override_cluster_name: Optional[str] = None) -> str: """Returns head node IP for given configuration file if exists.""" config = yaml.safe_load(open(config_file).read()) @@ -1058,7 +1058,8 @@ def get_head_node_ip(config_file: str, def get_worker_node_ips(config_file: str, - override_cluster_name: Optional[str]) -> List[str]: + override_cluster_name: Optional[str] = None + ) -> List[str]: """Returns worker node IPs for given configuration file.""" config = yaml.safe_load(open(config_file).read()) diff --git a/python/ray/autoscaler/sdk.py b/python/ray/autoscaler/sdk.py index 123058097..0a48d6533 100644 --- a/python/ray/autoscaler/sdk.py +++ b/python/ray/autoscaler/sdk.py @@ -1,5 +1,6 @@ """IMPORTANT: this is an experimental interface and not currently stable.""" +from contextlib import contextmanager from typing import Any, Dict, Optional, List, Union import json import os @@ -26,17 +27,18 @@ def create_or_update_cluster(cluster_config: Union[dict, str], no_config_cache (bool): Whether to disable the config cache and fully resolve all environment settings from the Cloud provider again. """ - return commands.create_or_update_cluster( - config_file=_as_config_file(cluster_config), - override_min_workers=None, - override_max_workers=None, - no_restart=no_restart, - restart_only=restart_only, - yes=True, - override_cluster_name=None, - no_config_cache=no_config_cache, - redirect_command_output=None, - use_login_shells=True) + with _as_config_file(cluster_config) as config_file: + return commands.create_or_update_cluster( + config_file=config_file, + override_min_workers=None, + override_max_workers=None, + no_restart=no_restart, + restart_only=restart_only, + yes=True, + override_cluster_name=None, + no_config_cache=no_config_cache, + redirect_command_output=None, + use_login_shells=True) def teardown_cluster(cluster_config: Union[dict, str]) -> None: @@ -46,12 +48,13 @@ def teardown_cluster(cluster_config: Union[dict, str]) -> None: cluster_config (Union[str, dict]): Either the config dict of the cluster, or a path pointing to a file containing the config. """ - return commands.teardown_cluster( - config_file=_as_config_file(cluster_config), - yes=True, - workers_only=False, - override_cluster_name=None, - keep_min_workers=False) + with _as_config_file(cluster_config) as config_file: + return commands.teardown_cluster( + config_file=config_file, + yes=True, + workers_only=False, + override_cluster_name=None, + keep_min_workers=False) def run_on_cluster(cluster_config: Union[dict, str], @@ -77,18 +80,19 @@ def run_on_cluster(cluster_config: Union[dict, str], Returns: The output of the command as a string. """ - return commands.exec_cluster( - _as_config_file(cluster_config), - cmd=cmd, - run_env=run_env, - screen=False, - tmux=False, - stop=False, - start=False, - override_cluster_name=None, - no_config_cache=no_config_cache, - port_forward=port_forward, - with_output=with_output) + with _as_config_file(cluster_config) as config_file: + return commands.exec_cluster( + config_file, + cmd=cmd, + run_env=run_env, + screen=False, + tmux=False, + stop=False, + start=False, + override_cluster_name=None, + no_config_cache=no_config_cache, + port_forward=port_forward, + with_output=with_output) def rsync(cluster_config: Union[dict, str], @@ -116,16 +120,17 @@ def rsync(cluster_config: Union[dict, str], Raises: RuntimeError if the cluster head node is not found. """ - return commands.rsync( - config_file=_as_config_file(cluster_config), - source=source, - target=target, - override_cluster_name=None, - down=down, - ip_address=ip_address, - use_internal_ip=use_internal_ip, - no_config_cache=no_config_cache, - all_nodes=False) + with _as_config_file(cluster_config) as config_file: + return commands.rsync( + config_file=config_file, + source=source, + target=target, + override_cluster_name=None, + down=down, + ip_address=ip_address, + use_internal_ip=use_internal_ip, + no_config_cache=no_config_cache, + all_nodes=False) def get_head_node_ip(cluster_config: Union[dict, str]) -> str: @@ -141,7 +146,8 @@ def get_head_node_ip(cluster_config: Union[dict, str]) -> str: Raises: RuntimeError if the cluster is not found. """ - return commands.get_head_node_ip(_as_config_file(cluster_config)) + with _as_config_file(cluster_config) as config_file: + return commands.get_head_node_ip(config_file) def get_worker_node_ips(cluster_config: Union[dict, str]) -> List[str]: @@ -157,7 +163,8 @@ def get_worker_node_ips(cluster_config: Union[dict, str]) -> List[str]: Raises: RuntimeError if the cluster is not found. """ - return commands.get_worker_node_ips(_as_config_file(cluster_config)) + with _as_config_file(cluster_config) as config_file: + return commands.get_worker_node_ips(config_file) def request_resources(num_cpus=None, bundles=None): @@ -177,6 +184,7 @@ def request_resources(num_cpus=None, bundles=None): return commands.request_resources(num_cpus, bundles) +@contextmanager def _as_config_file(cluster_config: Union[dict, str]): if isinstance(cluster_config, dict): tmp = tempfile.NamedTemporaryFile("w", prefix="autoscaler-sdk-tmp-") @@ -185,7 +193,7 @@ def _as_config_file(cluster_config: Union[dict, str]): cluster_config = tmp.name if not os.path.exists(cluster_config): raise ValueError("Cluster config not found {}".format(cluster_config)) - return cluster_config + yield cluster_config def bootstrap_config(cluster_config: Dict[str, any],