diff --git a/doc/source/conf.py b/doc/source/conf.py index dca1d11ef..b1bafc66c 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -40,6 +40,7 @@ MOCK_MODULES = [ "horovod", "horovod.ray", "kubernetes", + "mxnet", "mxnet.model", "psutil", "ray._raylet", diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index afa5de622..3c80a9cd6 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -1,9 +1,9 @@ import json import logging import os -from typing import Dict +from numbers import Number +from typing import Any, Dict, List, Optional, Tuple -from ray.tune.checkpoint_manager import Checkpoint from ray.tune.utils import flatten_dict try: @@ -37,7 +37,10 @@ class Analysis: in the respective functions. """ - def __init__(self, experiment_dir, default_metric=None, default_mode=None): + def __init__(self, + experiment_dir: str, + default_metric: Optional[str] = None, + default_mode: Optional[str] = None): experiment_dir = os.path.expanduser(experiment_dir) if not os.path.isdir(experiment_dir): raise ValueError( @@ -59,14 +62,14 @@ class Analysis: else: self.fetch_trial_dataframes() - def _validate_metric(self, metric): + def _validate_metric(self, metric: str) -> str: if not metric and not self.default_metric: raise ValueError( "No `metric` has been passed and `default_metric` has " "not been set. Please specify the `metric` parameter.") return metric or self.default_metric - def _validate_mode(self, mode): + def _validate_mode(self, mode: str) -> str: if not mode and not self.default_mode: raise ValueError( "No `mode` has been passed and `default_mode` has " @@ -75,7 +78,9 @@ class Analysis: raise ValueError("If set, `mode` has to be one of [min, max]") return mode or self.default_mode - def dataframe(self, metric=None, mode=None): + def dataframe(self, + metric: Optional[str] = None, + mode: Optional[str] = None) -> DataFrame: """Returns a pandas.DataFrame object constructed from the trials. Args: @@ -97,7 +102,9 @@ class Analysis: rows[path].update(logdir=path) return pd.DataFrame(list(rows.values())) - def get_best_config(self, metric=None, mode=None): + def get_best_config(self, + metric: Optional[str] = None, + mode: Optional[str] = None) -> Optional[Dict]: """Retrieve the best config corresponding to the trial. Args: @@ -122,7 +129,9 @@ class Analysis: best_path = compare_op(rows, key=lambda k: rows[k][metric]) return all_configs[best_path] - def get_best_logdir(self, metric=None, mode=None): + def get_best_logdir(self, + metric: Optional[str] = None, + mode: Optional[str] = None) -> Optional[str]: """Retrieve the logdir corresponding to the best trial. Args: @@ -148,7 +157,7 @@ class Analysis: self._experiment_dir)) return None - def fetch_trial_dataframes(self): + def fetch_trial_dataframes(self) -> Dict[str, DataFrame]: fail_count = 0 for path in self._get_trial_paths(): try: @@ -162,7 +171,7 @@ class Analysis: "Couldn't read results from {} paths".format(fail_count)) return self.trial_dataframes - def get_all_configs(self, prefix=False): + def get_all_configs(self, prefix: bool = False) -> Dict[str, Dict]: """Returns a list of all configurations. Args: @@ -170,7 +179,8 @@ class Analysis: and prepends `config/`. Returns: - List[dict]: List of all configurations of trials, + Dict[str, Dict]: Dict of all configurations of trials, indexed by + their trial dir. """ fail_count = 0 for path in self._get_trial_paths(): @@ -189,7 +199,10 @@ class Analysis: "Couldn't read config from {} paths".format(fail_count)) return self._configs - def get_trial_checkpoints_paths(self, trial, metric=None): + def get_trial_checkpoints_paths(self, + trial: Trial, + metric: Optional[str] = None + ) -> List[Tuple[str, Number]]: """Gets paths and metrics of all persistent checkpoints of a trial. Args: @@ -215,11 +228,14 @@ class Analysis: return path_metric_df[["chkpt_path", metric]].values.tolist() elif isinstance(trial, Trial): checkpoints = trial.checkpoint_manager.best_checkpoints() - return [[c.value, c.result[metric]] for c in checkpoints] + return [(c.value, c.result[metric]) for c in checkpoints] else: raise ValueError("trial should be a string or a Trial instance.") - def get_best_checkpoint(self, trial, metric=None, mode=None): + def get_best_checkpoint(self, + trial: Trial, + metric: Optional[str] = None, + mode: Optional[str] = None) -> Optional[str]: """Gets best persistent checkpoint path of provided trial. Args: @@ -244,7 +260,9 @@ class Analysis: else: return min(checkpoint_paths, key=lambda x: x[1])[0] - def _retrieve_rows(self, metric=None, mode=None): + def _retrieve_rows(self, + metric: Optional[str] = None, + mode: Optional[str] = None) -> Dict[str, Any]: assert mode is None or mode in ["max", "min"] rows = {} for path, df in self.trial_dataframes.items(): @@ -264,7 +282,7 @@ class Analysis: return rows - def _get_trial_paths(self): + def _get_trial_paths(self) -> List[str]: _trial_paths = [] for trial_path, _, files in os.walk(self._experiment_dir): if EXPR_PROGRESS_FILE in files: @@ -276,7 +294,7 @@ class Analysis: return _trial_paths @property - def trial_dataframes(self): + def trial_dataframes(self) -> Dict[str, DataFrame]: """List of all dataframes of the trials.""" return self._trial_dataframes @@ -306,10 +324,10 @@ class ExperimentAnalysis(Analysis): """ def __init__(self, - experiment_checkpoint_path, - trials=None, - default_metric=None, - default_mode=None): + experiment_checkpoint_path: str, + trials: Optional[List[Trial]] = None, + default_metric: Optional[str] = None, + default_mode: Optional[str] = None): experiment_checkpoint_path = os.path.expanduser( experiment_checkpoint_path) if not os.path.isfile(experiment_checkpoint_path): @@ -365,8 +383,8 @@ class ExperimentAnalysis(Analysis): return self.get_best_config(self.default_metric, self.default_mode) @property - def best_checkpoint(self) -> Checkpoint: - """Get the checkpoint of the best trial of the experiment + def best_checkpoint(self) -> str: + """Get the checkpoint path of the best trial of the experiment The best trial is determined by comparing the last trial results using the `metric` and `mode` parameters passed to `tune.run()`. @@ -471,7 +489,10 @@ class ExperimentAnalysis(Analysis): ], index="trial_id") - def get_best_trial(self, metric=None, mode=None, scope="last"): + def get_best_trial(self, + metric: Optional[str] = None, + mode: Optional[str] = None, + scope: str = "last") -> Optional[Trial]: """Retrieve the best trial object. Compares all trials' scores on ``metric``. @@ -535,7 +556,10 @@ class ExperimentAnalysis(Analysis): "parameter?") return best_trial - def get_best_config(self, metric=None, mode=None, scope="last"): + def get_best_config(self, + metric: Optional[str] = None, + mode: Optional[str] = None, + scope: str = "last") -> Optional[Dict]: """Retrieve the best config corresponding to the trial. Compares all trials' scores on `metric`. @@ -562,7 +586,10 @@ class ExperimentAnalysis(Analysis): best_trial = self.get_best_trial(metric, mode, scope) return best_trial.config if best_trial else None - def get_best_logdir(self, metric=None, mode=None, scope="last"): + def get_best_logdir(self, + metric: Optional[str] = None, + mode: Optional[str] = None, + scope: str = "last") -> Optional[str]: """Retrieve the logdir corresponding to the best trial. Compares all trials' scores on `metric`. @@ -589,15 +616,15 @@ class ExperimentAnalysis(Analysis): best_trial = self.get_best_trial(metric, mode, scope) return best_trial.logdir if best_trial else None - def stats(self): + def stats(self) -> Dict: """Returns a dictionary of the statistics of the experiment.""" return self._experiment_state.get("stats") - def runner_data(self): + def runner_data(self) -> Dict: """Returns a dictionary of the TrialRunner data.""" return self._experiment_state.get("runner_data") - def _get_trial_paths(self): + def _get_trial_paths(self) -> List[str]: """Overwrites Analysis to only have trials of one experiment.""" if self.trials: _trial_paths = [t.logdir for t in self.trials] diff --git a/python/ray/tune/integration/horovod.py b/python/ray/tune/integration/horovod.py index b718e4ff7..e3603de2b 100644 --- a/python/ray/tune/integration/horovod.py +++ b/python/ray/tune/integration/horovod.py @@ -1,5 +1,7 @@ import os import logging +from typing import Callable, Dict, Type + from filelock import FileLock import ray @@ -15,11 +17,11 @@ from horovod.ray import RayExecutor logger = logging.getLogger(__name__) -def get_rank(): +def get_rank() -> str: return os.environ["HOROVOD_RANK"] -def logger_creator(log_config, logdir): +def logger_creator(log_config: Dict, logdir: str) -> NoopLogger: """Simple NOOP logger for worker trainables.""" index = get_rank() worker_dir = os.path.join(logdir, "worker_{}".format(index)) @@ -51,7 +53,7 @@ class _HorovodTrainable(tune.Trainable): def num_workers(self): return self._num_hosts * self._num_slots - def setup(self, config): + def setup(self, config: Dict): trainable = wrap_function(self.__class__._function) # We use a filelock here to ensure that the file-writing # process is safe across different trainables. @@ -82,7 +84,7 @@ class _HorovodTrainable(tune.Trainable): "logger_creator": lambda cfg: logger_creator(cfg, logdir_) }) - def step(self): + def step(self) -> Dict: if self._finished: raise RuntimeError("Training has already finished.") result = self.executor.execute(lambda w: w.step())[0] @@ -90,14 +92,14 @@ class _HorovodTrainable(tune.Trainable): self._finished = True return result - def save_checkpoint(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir: str) -> str: # TODO: optimize if colocated save_obj = self.executor.execute_single(lambda w: w.save_to_object()) checkpoint_path = TrainableUtil.create_from_pickle( save_obj, checkpoint_dir) return checkpoint_path - def load_checkpoint(self, checkpoint_dir): + def load_checkpoint(self, checkpoint_dir: str): checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir) x_id = ray.put(checkpoint_obj) return self.executor.execute(lambda w: w.restore_from_object(x_id)) @@ -107,13 +109,14 @@ class _HorovodTrainable(tune.Trainable): self.executor.shutdown() -def DistributedTrainableCreator(func, - use_gpu=False, - num_hosts=1, - num_slots=1, - num_cpus_per_slot=1, - timeout_s=30, - replicate_pem=False): +def DistributedTrainableCreator( + func: Callable, + use_gpu: bool = False, + num_hosts: int = 1, + num_slots: int = 1, + num_cpus_per_slot: int = 1, + timeout_s: int = 30, + replicate_pem: bool = False) -> Type[_HorovodTrainable]: """Converts Horovod functions to be executable by Tune. Requires horovod > 0.19 to work. @@ -198,7 +201,7 @@ def DistributedTrainableCreator(func, _timeout_s = timeout_s @classmethod - def default_resource_request(cls, config): + def default_resource_request(cls, config: Dict): extra_gpu = int(num_hosts * num_slots) * int(use_gpu) extra_cpu = int(num_hosts * num_slots * num_cpus_per_slot) @@ -216,7 +219,7 @@ def DistributedTrainableCreator(func, # that force us to include mocks as part of the module. -def _train_simple(config): +def _train_simple(config: Dict): import horovod.torch as hvd hvd.init() from ray import tune diff --git a/python/ray/tune/integration/keras.py b/python/ray/tune/integration/keras.py index 1a0a85e04..6badf232d 100644 --- a/python/ray/tune/integration/keras.py +++ b/python/ray/tune/integration/keras.py @@ -39,7 +39,7 @@ class TuneCallback(Callback): on, self._allowed)) self._on = on - def _handle(self, logs: Dict): + def _handle(self, logs: Dict, when: str): raise NotImplementedError def on_batch_begin(self, batch, logs=None): diff --git a/python/ray/tune/integration/kubernetes.py b/python/ray/tune/integration/kubernetes.py index 0aa0c07ab..1926e9861 100644 --- a/python/ray/tune/integration/kubernetes.py +++ b/python/ray/tune/integration/kubernetes.py @@ -1,3 +1,5 @@ +from typing import Any, Optional, Tuple + import kubernetes import subprocess @@ -45,7 +47,10 @@ class KubernetesSyncer(NodeSyncer): _namespace = "ray" - def __init__(self, local_dir, remote_dir, sync_client=None): + def __init__(self, + local_dir: str, + remote_dir: str, + sync_client: Optional[SyncClient] = None): self.local_ip = services.get_node_ip_address() self.local_node = self._get_kubernetes_node_by_ip(self.local_ip) self.worker_ip = None @@ -56,11 +61,11 @@ class KubernetesSyncer(NodeSyncer): super(NodeSyncer, self).__init__(local_dir, remote_dir, sync_client) - def set_worker_ip(self, worker_ip): + def set_worker_ip(self, worker_ip: str): self.worker_ip = worker_ip self.worker_node = self._get_kubernetes_node_by_ip(worker_ip) - def _get_kubernetes_node_by_ip(self, node_ip): + def _get_kubernetes_node_by_ip(self, node_ip: str) -> Optional[str]: """Return node name by internal or external IP""" kubernetes.config.load_incluster_config() api = kubernetes.client.CoreV1Api() @@ -75,8 +80,8 @@ class KubernetesSyncer(NodeSyncer): return None @property - def _remote_path(self): - return (self.worker_node, self._remote_dir) + def _remote_path(self) -> Tuple[str, str]: + return self.worker_node, self._remote_dir class KubernetesSyncClient(SyncClient): @@ -95,12 +100,12 @@ class KubernetesSyncClient(SyncClient): """ - def __init__(self, namespace, process_runner=subprocess): + def __init__(self, namespace: str, process_runner: Any = subprocess): self.namespace = namespace self._process_runner = process_runner self._command_runners = {} - def _create_command_runner(self, node_id): + def _create_command_runner(self, node_id: str) -> KubernetesCommandRunner: """Create a command runner for one Kubernetes node""" return KubernetesCommandRunner( log_prefix="KubernetesSyncClient: {}:".format(node_id), @@ -109,7 +114,7 @@ class KubernetesSyncClient(SyncClient): auth_config=None, process_runner=self._process_runner) - def _get_command_runner(self, node_id): + def _get_command_runner(self, node_id: str) -> KubernetesCommandRunner: """Create command runner if it doesn't exist""" # Todo(krfricke): These cached runners are currently # never cleaned up. They are cheap so this shouldn't @@ -120,7 +125,7 @@ class KubernetesSyncClient(SyncClient): self._command_runners[node_id] = command_runner return self._command_runners[node_id] - def sync_up(self, source, target): + def sync_up(self, source: str, target: Tuple[str, str]) -> bool: """Here target is a tuple (target_node, target_dir)""" target_node, target_dir = target @@ -132,7 +137,7 @@ class KubernetesSyncClient(SyncClient): command_runner.run_rsync_up(source, target_dir) return True - def sync_down(self, source, target): + def sync_down(self, source: Tuple[str, str], target: str) -> bool: """Here source is a tuple (source_node, source_dir)""" source_node, source_dir = source @@ -144,7 +149,7 @@ class KubernetesSyncClient(SyncClient): command_runner.run_rsync_down(source_dir, target) return True - def delete(self, target): + def delete(self, target: str) -> bool: """No delete function because it is only used by the KubernetesSyncer, which doesn't call delete.""" return True diff --git a/python/ray/tune/integration/mxnet.py b/python/ray/tune/integration/mxnet.py index 435f2c34a..a46593c87 100644 --- a/python/ray/tune/integration/mxnet.py +++ b/python/ray/tune/integration/mxnet.py @@ -2,7 +2,9 @@ from typing import Dict, List, Union from ray import tune -from mxnet.model import save_checkpoint +import mxnet +from mxnet.model import save_checkpoint, BatchEndParam +import numpy as np import os @@ -49,7 +51,7 @@ class TuneReportCallback(TuneCallback): metrics = [metrics] self._metrics = metrics - def __call__(self, param): + def __call__(self, param: BatchEndParam): if not param.eval_metric: return if not self._metrics: @@ -110,7 +112,8 @@ class TuneCheckpointCallback(TuneCallback): self._filename = filename self._frequency = frequency - def __call__(self, epoch, sym, arg, aux): + def __call__(self, epoch: int, sym: mxnet.symbol.Symbol, + arg: Dict[str, np.ndarray], aux: Dict[str, np.ndarray]): if epoch % self._frequency != 0: return with tune.checkpoint_dir(step=epoch) as checkpoint_dir: diff --git a/python/ray/tune/integration/torch.py b/python/ray/tune/integration/torch.py index ba582d233..147734c7e 100644 --- a/python/ray/tune/integration/torch.py +++ b/python/ray/tune/integration/torch.py @@ -5,6 +5,8 @@ import os import logging import shutil import tempfile +from typing import Callable, Dict, Generator, Optional, Type + import torch from datetime import timedelta @@ -34,7 +36,7 @@ def enable_distributed_trainable(): _distributed_enabled = True -def logger_creator(log_config, logdir, rank): +def logger_creator(log_config: Dict, logdir: str, rank: int) -> NoopLogger: worker_dir = os.path.join(logdir, "worker_{}".format(rank)) os.makedirs(worker_dir, exist_ok=True) return NoopLogger(log_config, worker_dir) @@ -54,16 +56,16 @@ class _TorchTrainable(tune.Trainable): __slots__ = ["workers", "_finished"] @classmethod - def default_process_group_parameters(self): + def default_process_group_parameters(self) -> Dict: return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo") @classmethod - def get_remote_worker_options(self): + def get_remote_worker_options(self) -> Dict[str, int]: num_gpus = 1 if self._use_gpu else 0 num_cpus = int(self._num_cpus_per_worker or 1) return dict(num_cpus=num_cpus, num_gpus=num_gpus) - def setup(self, config): + def setup(self, config: Dict): self._finished = False num_workers = self._num_workers logdir = self.logdir @@ -103,7 +105,7 @@ class _TorchTrainable(tune.Trainable): for rank, w in enumerate(self.workers) ]) - def step(self): + def step(self) -> Dict: if self._finished: raise RuntimeError("Training has already finished.") result = ray.get([w.step.remote() for w in self.workers])[0] @@ -111,14 +113,14 @@ class _TorchTrainable(tune.Trainable): self._finished = True return result - def save_checkpoint(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir: str) -> str: # TODO: optimize if colocated save_obj = ray.get(self.workers[0].save_to_object.remote()) checkpoint_path = TrainableUtil.create_from_pickle( save_obj, checkpoint_dir) return checkpoint_path - def load_checkpoint(self, checkpoint_dir): + def load_checkpoint(self, checkpoint_dir: str): checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir) return ray.get( w.restore_from_object.remote(checkpoint_obj) for w in self.workers) @@ -127,12 +129,13 @@ class _TorchTrainable(tune.Trainable): ray.get([worker.stop.remote() for worker in self.workers]) -def DistributedTrainableCreator(func, - use_gpu=False, - num_workers=1, - num_cpus_per_worker=1, - backend="gloo", - timeout_s=NCCL_TIMEOUT_S): +def DistributedTrainableCreator( + func: Callable, + use_gpu: bool = False, + num_workers: int = 1, + num_cpus_per_worker: int = 1, + backend: str = "gloo", + timeout_s: int = NCCL_TIMEOUT_S) -> Type[_TorchTrainable]: """Creates a class that executes distributed training. Similar to running `torch.distributed.launch`. @@ -179,11 +182,11 @@ def DistributedTrainableCreator(func, _num_cpus_per_worker = num_cpus_per_worker @classmethod - def default_process_group_parameters(self): + def default_process_group_parameters(self) -> Dict: return dict(timeout=timedelta(timeout_s), backend=backend) @classmethod - def default_resource_request(cls, config): + def default_resource_request(cls, config: Dict) -> Resources: num_workers_ = int(config.get("num_workers", num_workers)) num_cpus = int( config.get("num_cpus_per_worker", num_cpus_per_worker)) @@ -199,7 +202,8 @@ def DistributedTrainableCreator(func, @contextmanager -def distributed_checkpoint_dir(step, disable=False): +def distributed_checkpoint_dir( + step: int, disable: bool = False) -> Generator[str, None, None]: """ContextManager for creating a distributed checkpoint. Only checkpoints a file on the "main" training actor, avoiding @@ -236,7 +240,7 @@ def distributed_checkpoint_dir(step, disable=False): shutil.rmtree(path) -def _train_check_global(config, checkpoint_dir=None): +def _train_check_global(config: Dict, checkpoint_dir: Optional[str] = None): """For testing only. Putting this here because Ray has problems serializing within the test file.""" assert is_distributed_trainable() @@ -245,7 +249,7 @@ def _train_check_global(config, checkpoint_dir=None): tune.report(is_distributed=True) -def _train_simple(config, checkpoint_dir=None): +def _train_simple(config: Dict, checkpoint_dir: Optional[str] = None): """For testing only. Putting this here because Ray has problems serializing within the test file.""" import torch.nn as nn diff --git a/python/ray/tune/integration/wandb.py b/python/ray/tune/integration/wandb.py index 7f8b35c14..a5f259cbc 100644 --- a/python/ray/tune/integration/wandb.py +++ b/python/ray/tune/integration/wandb.py @@ -2,6 +2,7 @@ import os import pickle from multiprocessing import Process, Queue from numbers import Number +from typing import Callable, Dict, List, Tuple import numpy as np from ray import logger @@ -70,7 +71,7 @@ def _clean_log(obj): return fallback -def wandb_mixin(func): +def wandb_mixin(func: Callable): """wandb_mixin Weights and biases (https://www.wandb.com/) is a tool for experiment @@ -142,7 +143,7 @@ def wandb_mixin(func): return func -def _set_api_key(wandb_config): +def _set_api_key(wandb_config: Dict): """Set WandB API key from `wandb_config`. Will pop the `api_key_file` and `api_key` keys from `wandb_config` parameter""" api_key_file = os.path.expanduser(wandb_config.pop("api_key_file", "")) @@ -176,7 +177,8 @@ class _WandbLoggingProcess(Process): wandb logging instances locally. """ - def __init__(self, queue, exclude, to_config, *args, **kwargs): + def __init__(self, queue: Queue, exclude: List[str], to_config: List[str], + *args, **kwargs): super(_WandbLoggingProcess, self).__init__() self.queue = queue self._exclude = set(exclude) @@ -195,7 +197,7 @@ class _WandbLoggingProcess(Process): wandb.log(log) wandb.join() - def _handle_result(self, result): + def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]: config_update = result.get("config", {}).copy() log = {} flat_result = flatten_dict(result, delimiter="/") @@ -377,7 +379,7 @@ class WandbLogger(Logger): **wandb_init_kwargs) self._wandb.start() - def on_result(self, result): + def on_result(self, result: Dict): result = _clean_log(result) self._queue.put(result) @@ -389,7 +391,7 @@ class WandbLogger(Logger): class WandbTrainableMixin: _wandb = wandb - def __init__(self, config, *args, **kwargs): + def __init__(self, config: Dict, *args, **kwargs): if not isinstance(self, Trainable): raise ValueError( "The `WandbTrainableMixin` can only be used as a mixin " diff --git a/python/ray/tune/schedulers/async_hyperband.py b/python/ray/tune/schedulers/async_hyperband.py index 02e59453d..375245fb9 100644 --- a/python/ray/tune/schedulers/async_hyperband.py +++ b/python/ray/tune/schedulers/async_hyperband.py @@ -1,7 +1,11 @@ import logging +from typing import Dict, Optional, Union + import numpy as np +from ray.tune import trial_runner from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler +from ray.tune.trial import Trial logger = logging.getLogger(__name__) @@ -36,14 +40,14 @@ class AsyncHyperBandScheduler(FIFOScheduler): """ def __init__(self, - time_attr="training_iteration", - reward_attr=None, - metric=None, - mode=None, - max_t=100, - grace_period=1, - reduction_factor=4, - brackets=1): + time_attr: str = "training_iteration", + reward_attr: Optional[str] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + max_t: int = 100, + grace_period: int = 1, + reduction_factor: float = 4, + brackets: int = 1): assert max_t > 0, "Max (time_attr) not valid!" assert max_t >= grace_period, "grace_period must be <= max_t!" assert grace_period > 0, "grace_period must be positive!" @@ -82,7 +86,8 @@ class AsyncHyperBandScheduler(FIFOScheduler): self._metric_op = -1. self._time_attr = time_attr - def set_search_properties(self, metric, mode): + def set_search_properties(self, metric: Optional[str], + mode: Optional[str]) -> bool: if self._metric and metric: return False if self._mode and mode: @@ -100,7 +105,8 @@ class AsyncHyperBandScheduler(FIFOScheduler): return True - def on_trial_add(self, trial_runner, trial): + def on_trial_add(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): if not self._metric or not self._metric_op: raise ValueError( "{} has been instantiated without a valid `metric` ({}) or " @@ -115,7 +121,8 @@ class AsyncHyperBandScheduler(FIFOScheduler): idx = np.random.choice(len(self._brackets), p=normalized) self._trial_info[trial.trial_id] = self._brackets[idx] - def on_trial_result(self, trial_runner, trial, result): + def on_trial_result(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial, result: Dict) -> str: action = TrialScheduler.CONTINUE if self._time_attr not in result or self._metric not in result: return action @@ -129,7 +136,8 @@ class AsyncHyperBandScheduler(FIFOScheduler): self._num_stopped += 1 return action - def on_trial_complete(self, trial_runner, trial, result): + def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial, result: Dict): if self._time_attr not in result or self._metric not in result: return bracket = self._trial_info[trial.trial_id] @@ -137,10 +145,11 @@ class AsyncHyperBandScheduler(FIFOScheduler): self._metric_op * result[self._metric]) del self._trial_info[trial.trial_id] - def on_trial_remove(self, trial_runner, trial): + def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): del self._trial_info[trial.trial_id] - def debug_string(self): + def debug_string(self) -> str: out = "Using AsyncHyperBand: num_stopped={}".format(self._num_stopped) out += "\n" + "\n".join([b.debug_str() for b in self._brackets]) return out @@ -161,19 +170,21 @@ class _Bracket(): >>> b.cutoff(b._rungs[3][1]) == 2.0 """ - def __init__(self, min_t, max_t, reduction_factor, s): + def __init__(self, min_t: int, max_t: int, reduction_factor: float, + s: int): self.rf = reduction_factor MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1) self._rungs = [(min_t * self.rf**(k + s), {}) for k in reversed(range(MAX_RUNGS))] - def cutoff(self, recorded): + def cutoff(self, recorded) -> Union[None, int, float, complex, np.ndarray]: if not recorded: return None return np.nanpercentile( list(recorded.values()), (1 - 1 / self.rf) * 100) - def on_result(self, trial, cur_iter, cur_rew): + def on_result(self, trial: Trial, cur_iter: int, + cur_rew: Optional[float]) -> str: action = TrialScheduler.CONTINUE for milestone, recorded in self._rungs: if cur_iter < milestone or trial.trial_id in recorded: @@ -190,7 +201,7 @@ class _Bracket(): break return action - def debug_str(self): + def debug_str(self) -> str: # TODO: fix up the output for this iters = " | ".join([ "Iter {:.3f}: {}".format(milestone, self.cutoff(recorded)) diff --git a/python/ray/tune/schedulers/hb_bohb.py b/python/ray/tune/schedulers/hb_bohb.py index c8c061034..83d63af65 100644 --- a/python/ray/tune/schedulers/hb_bohb.py +++ b/python/ray/tune/schedulers/hb_bohb.py @@ -1,5 +1,7 @@ import logging +from typing import Dict, Optional +from ray.tune import trial_runner from ray.tune.schedulers.trial_scheduler import TrialScheduler from ray.tune.schedulers.hyperband import HyperBandScheduler, Bracket from ray.tune.trial import Trial @@ -22,7 +24,8 @@ class HyperBandForBOHB(HyperBandScheduler): See ray.tune.schedulers.HyperBandScheduler for parameter docstring. """ - def on_trial_add(self, trial_runner, trial): + def on_trial_add(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): """Adds new trial. On a new trial add, if current bracket is not filled, add to current @@ -67,7 +70,8 @@ class HyperBandForBOHB(HyperBandScheduler): self._state["bracket"].add_trial(trial) self._trial_info[trial] = cur_bracket, self._state["band_idx"] - def on_trial_result(self, trial_runner, trial, result): + def on_trial_result(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial, result: Dict) -> str: """If bracket is finished, all trials will be stopped. If a given trial finishes and bracket iteration is not done, @@ -96,11 +100,14 @@ class HyperBandForBOHB(HyperBandScheduler): action = self._process_bracket(trial_runner, bracket) return action - def _unpause_trial(self, trial_runner, trial): + def _unpause_trial(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): trial_runner.trial_executor.unpause_trial(trial) trial_runner._search_alg.searcher.on_unpause(trial.trial_id) - def choose_trial_to_run(self, trial_runner, allow_recurse=True): + def choose_trial_to_run(self, + trial_runner: "trial_runner.TrialRunner", + allow_recurse: bool = True) -> Optional[Trial]: """Fair scheduling within iteration by completion percentage. List of trials not used since all trials are tracked as state diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index 3066cf80b..b69400e7a 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -1,7 +1,10 @@ import collections +from typing import Dict, List, Optional, Tuple + import numpy as np import logging +from ray.tune import trial_runner from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler from ray.tune.trial import Trial from ray.tune.error import TuneError @@ -74,12 +77,12 @@ class HyperBandScheduler(FIFOScheduler): """ def __init__(self, - time_attr="training_iteration", - reward_attr=None, - metric=None, - mode=None, - max_t=81, - reduction_factor=3): + time_attr: str = "training_iteration", + reward_attr: Optional[str] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + max_t: int = 81, + reduction_factor: float = 3): assert max_t > 0, "Max (time_attr) not valid!" if mode: assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" @@ -118,7 +121,8 @@ class HyperBandScheduler(FIFOScheduler): self._metric_op = -1. self._time_attr = time_attr - def set_search_properties(self, metric, mode): + def set_search_properties(self, metric: Optional[str], + mode: Optional[str]) -> bool: if self._metric and metric: return False if self._mode and mode: @@ -136,7 +140,8 @@ class HyperBandScheduler(FIFOScheduler): return True - def on_trial_add(self, trial_runner, trial): + def on_trial_add(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): """Adds new trial. On a new trial add, if current bracket is not filled, @@ -179,7 +184,7 @@ class HyperBandScheduler(FIFOScheduler): self._state["bracket"].add_trial(trial) self._trial_info[trial] = cur_bracket, self._state["band_idx"] - def _cur_band_filled(self): + def _cur_band_filled(self) -> bool: """Checks if the current band is filled. The size of the current band should be equal to s_max_1""" @@ -187,7 +192,8 @@ class HyperBandScheduler(FIFOScheduler): cur_band = self._hyperbands[self._state["band_idx"]] return len(cur_band) == self._s_max_1 - def on_trial_result(self, trial_runner, trial, result): + def on_trial_result(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial, result: Dict): """If bracket is finished, all trials will be stopped. If a given trial finishes and bracket iteration is not done, @@ -211,7 +217,8 @@ class HyperBandScheduler(FIFOScheduler): metric_val=result.get(self._time_attr))) return action - def _process_bracket(self, trial_runner, bracket): + def _process_bracket(self, trial_runner: "trial_runner.TrialRunner", + bracket: "Bracket") -> str: """This is called whenever a trial makes progress. When all live trials in the bracket have no more iterations left, @@ -250,7 +257,8 @@ class HyperBandScheduler(FIFOScheduler): action = TrialScheduler.CONTINUE return action - def on_trial_remove(self, trial_runner, trial): + def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): """Notification when trial terminates. Trial info is removed from bracket. Triggers halving if bracket is @@ -260,15 +268,18 @@ class HyperBandScheduler(FIFOScheduler): if not bracket.finished(): self._process_bracket(trial_runner, bracket) - def on_trial_complete(self, trial_runner, trial, result): + def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial, result: Dict): """Cleans up trial info from bracket if trial completed early.""" self.on_trial_remove(trial_runner, trial) - def on_trial_error(self, trial_runner, trial): + def on_trial_error(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): """Cleans up trial info from bracket if trial errored early.""" self.on_trial_remove(trial_runner, trial) - def choose_trial_to_run(self, trial_runner): + def choose_trial_to_run( + self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]: """Fair scheduling within iteration by completion percentage. List of trials not used since all trials are tracked as state @@ -288,7 +299,7 @@ class HyperBandScheduler(FIFOScheduler): return trial return None - def debug_string(self): + def debug_string(self) -> str: """This provides a progress notification for the algorithm. For each bracket, the algorithm will output a string as follows: @@ -315,13 +326,14 @@ class HyperBandScheduler(FIFOScheduler): out += "\n {}".format(bracket) return out - def state(self): + def state(self) -> Dict[str, int]: return { "num_brackets": sum(len(band) for band in self._hyperbands), "num_stopped": self._num_stopped } - def _unpause_trial(self, trial_runner, trial): + def _unpause_trial(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): trial_runner.trial_executor.unpause_trial(trial) @@ -332,7 +344,8 @@ class Bracket: Also keeps track of progress to ensure good scheduling. """ - def __init__(self, time_attr, max_trials, init_t_attr, max_t_attr, eta, s): + def __init__(self, time_attr: str, max_trials: int, init_t_attr: int, + max_t_attr: int, eta: float, s: int): self._live_trials = {} # maps trial -> current result self._all_trials = [] self._time_attr = time_attr # attribute to @@ -348,7 +361,7 @@ class Bracket: self._total_work = self._calculate_total_work(self._n0, self._r0, s) self._completed_progress = 0 - def add_trial(self, trial): + def add_trial(self, trial: Trial): """Add trial to bracket assuming bracket is not filled. At a later iteration, a newly added trial will be given equal @@ -357,7 +370,7 @@ class Bracket: self._live_trials[trial] = None self._all_trials.append(trial) - def cur_iter_done(self): + def cur_iter_done(self) -> bool: """Checks if all iterations have completed. TODO(rliaw): also check that `t.iterations == self._r`""" @@ -365,20 +378,20 @@ class Bracket: self._get_result_time(result) >= self._cumul_r for result in self._live_trials.values()) - def finished(self): + def finished(self) -> bool: return self._halves == 0 and self.cur_iter_done() - def current_trials(self): + def current_trials(self) -> List[Trial]: return list(self._live_trials) - def continue_trial(self, trial): + def continue_trial(self, trial: Trial) -> bool: result = self._live_trials[trial] if self._get_result_time(result) < self._cumul_r: return True else: return False - def filled(self): + def filled(self) -> bool: """Checks if bracket is filled. Only let new trials be added at current level minimizing the need @@ -386,7 +399,8 @@ class Bracket: return len(self._live_trials) == self._n - def successive_halving(self, metric, metric_op): + def successive_halving(self, metric: str, metric_op: float + ) -> Tuple[List[Trial], List[Trial]]: assert self._halves > 0 self._halves -= 1 self._n /= self._eta @@ -402,7 +416,7 @@ class Bracket: good, bad = sorted_trials[-self._n:], sorted_trials[:-self._n] return good, bad - def update_trial_stats(self, trial, result): + def update_trial_stats(self, trial: Trial, result: Dict): """Update result for trial. Called after trial has finished an iteration - will decrement iteration count. @@ -422,7 +436,7 @@ class Bracket: self._completed_progress += delta self._live_trials[trial] = result - def cleanup_trial(self, trial): + def cleanup_trial(self, trial: Trial): """Clean up statistics tracking for terminated trials (either by force or otherwise). @@ -432,7 +446,7 @@ class Bracket: assert trial in self._live_trials del self._live_trials[trial] - def cleanup_full(self, trial_runner): + def cleanup_full(self, trial_runner: "trial_runner.TrialRunner"): """Cleans up bracket after bracket is completely finished. Lets the last trial continue to run until termination condition @@ -441,7 +455,7 @@ class Bracket: if (trial.status == Trial.PAUSED): trial_runner.stop_trial(trial) - def completion_percentage(self): + def completion_percentage(self) -> float: """Returns a progress metric. This will not be always finish with 100 since dead trials @@ -450,12 +464,12 @@ class Bracket: return 1.0 return self._completed_progress / self._total_work - def _get_result_time(self, result): + def _get_result_time(self, result: Dict) -> float: if result is None: return 0 return result[self._time_attr] - def _calculate_total_work(self, n, r, s): + def _calculate_total_work(self, n: int, r: float, s: int): work = 0 cumulative_r = r for _ in range(s + 1): @@ -466,7 +480,7 @@ class Bracket: r = int(min(r, self._max_t_attr - cumulative_r)) return work - def __repr__(self): + def __repr__(self) -> str: status = ", ".join([ "Max Size (n)={}".format(self._n), "Milestone (r)={}".format(self._cumul_r), diff --git a/python/ray/tune/schedulers/median_stopping_rule.py b/python/ray/tune/schedulers/median_stopping_rule.py index 497c62915..b446547ec 100644 --- a/python/ray/tune/schedulers/median_stopping_rule.py +++ b/python/ray/tune/schedulers/median_stopping_rule.py @@ -1,7 +1,10 @@ import collections import logging +from typing import Dict, List, Optional + import numpy as np +from ray.tune import trial_runner from ray.tune.trial import Trial from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler @@ -38,14 +41,14 @@ class MedianStoppingRule(FIFOScheduler): """ def __init__(self, - time_attr="time_total_s", - reward_attr=None, - metric=None, - mode=None, - grace_period=60.0, - min_samples_required=3, - min_time_slice=0, - hard_stop=True): + time_attr: str = "time_total_s", + reward_attr: Optional[str] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + grace_period: float = 60.0, + min_samples_required: int = 3, + min_time_slice: int = 0, + hard_stop: bool = True): if reward_attr is not None: mode = "max" metric = reward_attr @@ -75,7 +78,8 @@ class MedianStoppingRule(FIFOScheduler): self._last_pause = collections.defaultdict(lambda: float("-inf")) self._results = collections.defaultdict(list) - def set_search_properties(self, metric, mode): + def set_search_properties(self, metric: Optional[str], + mode: Optional[str]) -> bool: if self._metric and metric: return False if self._mode and mode: @@ -91,7 +95,8 @@ class MedianStoppingRule(FIFOScheduler): return True - def on_trial_add(self, trial_runner, trial): + def on_trial_add(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): if not self._metric or not self._worst or not self._compare_op: raise ValueError( "{} has been instantiated without a valid `metric` ({}) or " @@ -102,7 +107,8 @@ class MedianStoppingRule(FIFOScheduler): super(MedianStoppingRule, self).on_trial_add(trial_runner, trial) - def on_trial_result(self, trial_runner, trial, result): + def on_trial_result(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial, result: Dict) -> str: """Callback for early stopping. This stopping rule stops a running trial if the trial's best objective @@ -154,14 +160,17 @@ class MedianStoppingRule(FIFOScheduler): else: return TrialScheduler.CONTINUE - def on_trial_complete(self, trial_runner, trial, result): + def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial, result: Dict): self._results[trial].append(result) - def debug_string(self): + def debug_string(self) -> str: return "Using MedianStoppingRule: num_stopped={}.".format( len(self._stopped_trials)) - def _on_insufficient_samples(self, trial_runner, trial, time): + def _on_insufficient_samples(self, + trial_runner: "trial_runner.TrialRunner", + trial: Trial, time: float) -> str: pause = time - self._last_pause[trial] > self._min_time_slice pause = pause and [ t for t in trial_runner.get_trials() @@ -169,17 +178,17 @@ class MedianStoppingRule(FIFOScheduler): ] return TrialScheduler.PAUSE if pause else TrialScheduler.CONTINUE - def _trials_beyond_time(self, time): + def _trials_beyond_time(self, time: float) -> List[Trial]: trials = [ trial for trial in self._results if self._results[trial][-1][self._time_attr] >= time ] return trials - def _median_result(self, trials, time): + def _median_result(self, trials: List[Trial], time: float): return np.median([self._running_mean(trial, time) for trial in trials]) - def _running_mean(self, trial, time): + def _running_mean(self, trial: Trial, time: float) -> np.ndarray: results = self._results[trial] # TODO(ekl) we could do interpolation to be more precise, but for now # assume len(results) is large and the time diffs are roughly equal diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index 6e6396097..200785bac 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -5,7 +5,10 @@ import math import os import random import shutil +from typing import Callable, Dict, List, Optional, Tuple, Union +from ray.tune import trial_runner +from ray.tune import trial_executor from ray.tune.error import TuneError from ray.tune.result import TRAINING_ITERATION from ray.tune.logger import _SafeFallbackEncoder @@ -13,7 +16,6 @@ from ray.tune.sample import Domain, Function from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.suggest.variant_generator import format_vars from ray.tune.trial import Trial, Checkpoint - from ray.util.debug import log_once logger = logging.getLogger(__name__) @@ -22,7 +24,7 @@ logger = logging.getLogger(__name__) class PBTTrialState: """Internal PBT state tracked per-trial.""" - def __init__(self, trial): + def __init__(self, trial: Trial): self.orig_tag = trial.experiment_tag self.last_score = None self.last_checkpoint = None @@ -30,12 +32,13 @@ class PBTTrialState: self.last_train_time = 0 # Used for synchronous mode. self.last_result = None # Used for synchronous mode. - def __repr__(self): + def __repr__(self) -> str: return str((self.last_score, self.last_checkpoint, self.last_train_time, self.last_perturbation_time)) -def explore(config, mutations, resample_probability, custom_explore_fn): +def explore(config: Dict, mutations: Dict, resample_probability: float, + custom_explore_fn: Optional[Callable]) -> Dict: """Return a config perturbed as specified. Args: @@ -83,7 +86,7 @@ def explore(config, mutations, resample_probability, custom_explore_fn): return new_config -def make_experiment_tag(orig_tag, config, mutations): +def make_experiment_tag(orig_tag: str, config: Dict, mutations: Dict) -> str: """Appends perturbed params to the trial name to show in the console.""" resolved_vars = {} @@ -92,7 +95,8 @@ def make_experiment_tag(orig_tag, config, mutations): return "{}@perturbed[{}]".format(orig_tag, format_vars(resolved_vars)) -def fill_config(config, attr, search_space): +def fill_config(config: Dict, attr: str, + search_space: Union[Callable, Domain, list, dict]): """Add attr to config by sampling from search_space.""" if callable(search_space): config[attr] = search_space() @@ -214,18 +218,19 @@ class PopulationBasedTraining(FIFOScheduler): """ def __init__(self, - time_attr="time_total_s", - reward_attr=None, - metric=None, - mode=None, - perturbation_interval=60.0, - hyperparam_mutations={}, - quantile_fraction=0.25, - resample_probability=0.25, - custom_explore_fn=None, - log_config=True, - require_attrs=True, - synch=False): + time_attr: str = "time_total_s", + reward_attr: Optional[str] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + perturbation_interval: float = 60.0, + hyperparam_mutations: Dict = None, + quantile_fraction: float = 0.25, + resample_probability: float = 0.25, + custom_explore_fn: Optional[Callable] = None, + log_config: bool = True, + require_attrs: bool = True, + synch: bool = False): + hyperparam_mutations = hyperparam_mutations or {} for value in hyperparam_mutations.values(): if not (isinstance(value, (list, dict, Domain)) or callable(value)): @@ -288,7 +293,8 @@ class PopulationBasedTraining(FIFOScheduler): self._num_checkpoints = 0 self._num_perturbations = 0 - def set_search_properties(self, metric, mode): + def set_search_properties(self, metric: Optional[str], + mode: Optional[str]) -> bool: if self._metric and metric: return False if self._mode and mode: @@ -306,7 +312,8 @@ class PopulationBasedTraining(FIFOScheduler): return True - def on_trial_add(self, trial_runner, trial): + def on_trial_add(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): if not self._metric or not self._metric_op: raise ValueError( "{} has been instantiated without a valid `metric` ({}) or " @@ -329,7 +336,8 @@ class PopulationBasedTraining(FIFOScheduler): # Make sure this attribute is added to CLI output. trial.evaluated_params[attr] = trial.config[attr] - def on_trial_result(self, trial_runner, trial, result): + def on_trial_result(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial, result: Dict) -> str: if self._time_attr not in result: time_missing_msg = "Cannot find time_attr {} " \ "in trial result {}. Make sure that this " \ @@ -425,8 +433,9 @@ class PopulationBasedTraining(FIFOScheduler): # the paused trials. return TrialScheduler.PAUSE - def _perturb_trial(self, trial, trial_runner, upper_quantile, - lower_quantile): + def _perturb_trial( + self, trial: Trial, trial_runner: "trial_runner.TrialRunner", + upper_quantile: List[Trial], lower_quantile: List[Trial]): """Checkpoint if in upper quantile, exploits if in lower.""" state = self._trial_state[trial] if trial in upper_quantile: @@ -450,8 +459,9 @@ class PopulationBasedTraining(FIFOScheduler): assert trial is not trial_to_clone self._exploit(trial_runner.trial_executor, trial, trial_to_clone) - def _log_config_on_step(self, trial_state, new_state, trial, - trial_to_clone, new_config): + def _log_config_on_step(self, trial_state: PBTTrialState, + new_state: PBTTrialState, trial: Trial, + trial_to_clone: Trial, new_config: Dict): """Logs transition during exploit/exploit step. For each step, logs: [target trial tag, clone trial tag, target trial @@ -482,7 +492,8 @@ class PopulationBasedTraining(FIFOScheduler): with open(trial_path, "a+") as f: f.write(json.dumps(policy, cls=_SafeFallbackEncoder) + "\n") - def _exploit(self, trial_executor, trial, trial_to_clone): + def _exploit(self, trial_executor: "trial_executor.TrialExecutor", + trial: Trial, trial_to_clone: Trial): """Transfers perturbed state from trial_to_clone -> trial. If specified, also logs the updated hyperparam state. @@ -554,7 +565,7 @@ class PopulationBasedTraining(FIFOScheduler): trial_state.last_perturbation_time = new_state.last_perturbation_time trial_state.last_train_time = new_state.last_train_time - def _quantiles(self): + def _quantiles(self) -> Tuple[List[Trial], List[Trial]]: """Returns trials in the lower and upper `quantile` of the population. If there is not enough data to compute this, returns empty lists. @@ -578,7 +589,8 @@ class PopulationBasedTraining(FIFOScheduler): return (trials[:num_trials_in_quantile], trials[-num_trials_in_quantile:]) - def choose_trial_to_run(self, trial_runner): + def choose_trial_to_run( + self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]: """Ensures all trials get fair share of time (as defined by time_attr). This enables the PBT scheduler to support a greater number of @@ -601,7 +613,7 @@ class PopulationBasedTraining(FIFOScheduler): self._num_perturbations = 0 self._num_checkpoints = 0 - def last_scores(self, trials): + def last_scores(self, trials: List[Trial]) -> List[float]: scores = [] for trial in trials: state = self._trial_state[trial] @@ -609,7 +621,7 @@ class PopulationBasedTraining(FIFOScheduler): scores.append(state.last_score) return scores - def debug_string(self): + def debug_string(self) -> str: return "PopulationBasedTraining: {} checkpoints, {} perturbs".format( self._num_checkpoints, self._num_perturbations) @@ -657,7 +669,7 @@ class PopulationBasedTrainingReplay(FIFOScheduler): """ - def __init__(self, policy_file): + def __init__(self, policy_file: str): policy_file = os.path.expanduser(policy_file) if not os.path.exists(policy_file): raise ValueError("Policy file not found: {}".format(policy_file)) @@ -679,7 +691,8 @@ class PopulationBasedTrainingReplay(FIFOScheduler): self._policy_iter = iter(self._policy) self._next_policy = next(self._policy_iter, None) - def _load_policy(self, policy_file): + def _load_policy(self, + policy_file: str) -> Tuple[Dict, List[Tuple[int, Dict]]]: raw_policy = [] with open(policy_file, "rt") as fp: for row in fp.readlines(): @@ -708,7 +721,8 @@ class PopulationBasedTrainingReplay(FIFOScheduler): return last_old_conf, list(reversed(policy)) - def on_trial_add(self, trial_runner, trial): + def on_trial_add(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): if self._trial: raise ValueError( "More than one trial added to PBT replay run. This " @@ -729,7 +743,8 @@ class PopulationBasedTrainingReplay(FIFOScheduler): "or consider not using PBT replay for this run.") self._trial.config = self.config - def on_trial_result(self, trial_runner, trial, result): + def on_trial_result(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial, result: Dict) -> str: if TRAINING_ITERATION not in result: # No time reported return TrialScheduler.CONTINUE @@ -775,6 +790,6 @@ class PopulationBasedTrainingReplay(FIFOScheduler): return TrialScheduler.CONTINUE - def debug_string(self): + def debug_string(self) -> str: return "PopulationBasedTraining replay: Step {}, perturb {}".format( self._current_step, self._num_perturbations) diff --git a/python/ray/tune/schedulers/trial_scheduler.py b/python/ray/tune/schedulers/trial_scheduler.py index 66ba25904..9b61287b7 100644 --- a/python/ray/tune/schedulers/trial_scheduler.py +++ b/python/ray/tune/schedulers/trial_scheduler.py @@ -1,3 +1,6 @@ +from typing import Dict, Optional + +from ray.tune import trial_runner from ray.tune.trial import Trial @@ -8,7 +11,8 @@ class TrialScheduler: PAUSE = "PAUSE" #: Status for pausing trial execution STOP = "STOP" #: Status for stopping trial execution - def set_search_properties(self, metric, mode): + def set_search_properties(self, metric: Optional[str], + mode: Optional[str]) -> bool: """Pass search properties to scheduler. This method acts as an alternative to instantiating schedulers @@ -20,19 +24,22 @@ class TrialScheduler: """ return True - def on_trial_add(self, trial_runner, trial): + def on_trial_add(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): """Called when a new trial is added to the trial runner.""" raise NotImplementedError - def on_trial_error(self, trial_runner, trial): + def on_trial_error(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): """Notification for the error of trial. This will only be called when the trial is in the RUNNING state.""" raise NotImplementedError - def on_trial_result(self, trial_runner, trial, result): + def on_trial_result(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial, result: Dict) -> str: """Called on each intermediate result returned by a trial. At this point, the trial scheduler can make a decision by returning @@ -41,7 +48,8 @@ class TrialScheduler: raise NotImplementedError - def on_trial_complete(self, trial_runner, trial, result): + def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial, result: Dict): """Notification for the completion of trial. This will only be called when the trial is in the RUNNING state and @@ -49,7 +57,8 @@ class TrialScheduler: raise NotImplementedError - def on_trial_remove(self, trial_runner, trial): + def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): """Called to remove trial. This is called when the trial is in PAUSED or PENDING state. Otherwise, @@ -57,7 +66,8 @@ class TrialScheduler: raise NotImplementedError - def choose_trial_to_run(self, trial_runner): + def choose_trial_to_run( + self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]: """Called to choose a new trial to run. This should return one of the trials in trial_runner that is in @@ -67,7 +77,7 @@ class TrialScheduler: raise NotImplementedError - def debug_string(self): + def debug_string(self) -> str: """Returns a human readable message for printing to the console.""" raise NotImplementedError @@ -76,22 +86,28 @@ class TrialScheduler: class FIFOScheduler(TrialScheduler): """Simple scheduler that just runs trials in submission order.""" - def on_trial_add(self, trial_runner, trial): + def on_trial_add(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): pass - def on_trial_error(self, trial_runner, trial): + def on_trial_error(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): pass - def on_trial_result(self, trial_runner, trial, result): + def on_trial_result(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial, result: Dict) -> str: return TrialScheduler.CONTINUE - def on_trial_complete(self, trial_runner, trial, result): + def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial, result: Dict): pass - def on_trial_remove(self, trial_runner, trial): + def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner", + trial: Trial): pass - def choose_trial_to_run(self, trial_runner): + def choose_trial_to_run( + self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]: for trial in trial_runner.get_trials(): if (trial.status == Trial.PENDING and trial_runner.has_resources(trial.resources)): @@ -102,5 +118,5 @@ class FIFOScheduler(TrialScheduler): return trial return None - def debug_string(self): + def debug_string(self) -> str: return "Using FIFO scheduling algorithm." diff --git a/python/ray/tune/suggest/_mock.py b/python/ray/tune/suggest/_mock.py index 5522d2f71..ce2d182cf 100644 --- a/python/ray/tune/suggest/_mock.py +++ b/python/ray/tune/suggest/_mock.py @@ -1,5 +1,8 @@ +from typing import Dict, List, Optional + from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter from ray.tune.suggest.search_generator import SearchGenerator +from ray.tune.trial import Trial class _MockSearcher(Searcher): @@ -11,29 +14,32 @@ class _MockSearcher(Searcher): self.results = [] super(_MockSearcher, self).__init__(**kwargs) - def suggest(self, trial_id): + def suggest(self, trial_id: str): if not self.stall: self.live_trials[trial_id] = 1 return {"test_variable": 2} return None - def on_trial_result(self, trial_id, result): + def on_trial_result(self, trial_id: str, result: Dict): self.counter["result"] += 1 self.results += [result] - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): self.counter["complete"] += 1 if result: self._process_result(result) if trial_id in self.live_trials: del self.live_trials[trial_id] - def _process_result(self, result): + def _process_result(self, result: Dict): self.final_results += [result] class _MockSuggestionAlgorithm(SearchGenerator): - def __init__(self, max_concurrent=None, **kwargs): + def __init__(self, max_concurrent: Optional[int] = None, **kwargs): self.searcher = _MockSearcher(**kwargs) if max_concurrent: self.searcher = ConcurrencyLimiter( @@ -41,9 +47,9 @@ class _MockSuggestionAlgorithm(SearchGenerator): super(_MockSuggestionAlgorithm, self).__init__(self.searcher) @property - def live_trials(self): + def live_trials(self) -> List[Trial]: return self.searcher.live_trials @property - def results(self): + def results(self) -> List[Dict]: return self.searcher.results diff --git a/python/ray/tune/suggest/ax.py b/python/ray/tune/suggest/ax.py index 28b52a9c6..b164fc311 100644 --- a/python/ray/tune/suggest/ax.py +++ b/python/ray/tune/suggest/ax.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, List, Optional from ax.service.ax_client import AxClient from ray.tune.sample import Categorical, Float, Integer, LogUniform, \ @@ -103,14 +103,14 @@ class AxSearch(Searcher): """ def __init__(self, - space=None, - metric=None, - mode=None, - parameter_constraints=None, - outcome_constraints=None, - ax_client=None, - use_early_stopped_trials=None, - max_concurrent=None): + space: Optional[List[Dict]] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + parameter_constraints: Optional[List] = None, + outcome_constraints: Optional[List] = None, + ax_client: Optional[AxClient] = None, + use_early_stopped_trials: Optional[bool] = None, + max_concurrent: Optional[int] = None): assert ax is not None, "Ax must be installed!" if mode: assert mode in ["min", "max"], "`mode` must be 'min' or 'max'." @@ -177,7 +177,8 @@ class AxSearch(Searcher): logger.warning("Detected sequential enforcement. Be sure to use " "a ConcurrencyLimiter.") - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict): if self._ax: return False space = self.convert_search_space(config) @@ -189,7 +190,7 @@ class AxSearch(Searcher): self.setup_experiment() return True - def suggest(self, trial_id): + def suggest(self, trial_id: str) -> Optional[Dict]: if not self._ax: raise RuntimeError( "Trying to sample a configuration from {}, but no search " diff --git a/python/ray/tune/suggest/basic_variant.py b/python/ray/tune/suggest/basic_variant.py index 2e8d12c8b..68cf7dcfe 100644 --- a/python/ray/tune/suggest/basic_variant.py +++ b/python/ray/tune/suggest/basic_variant.py @@ -2,9 +2,10 @@ import itertools import os import random import uuid +from typing import Dict, List, Union from ray.tune.error import TuneError -from ray.tune.experiment import convert_to_experiment_list +from ray.tune.experiment import Experiment, convert_to_experiment_list from ray.tune.config_parser import make_parser, create_trial_from_spec from ray.tune.suggest.variant_generator import (generate_variants, format_vars, flatten_resolved_vars) @@ -42,7 +43,7 @@ class BasicVariantGenerator(SearchAlgorithm): searcher.is_finished == True """ - def __init__(self, shuffle=False): + def __init__(self, shuffle: bool = False): """Initializes the Variant Generator. """ @@ -60,7 +61,9 @@ class BasicVariantGenerator(SearchAlgorithm): else: self._uuid_prefix = str(uuid.uuid1().hex)[:5] + "_" - def add_configurations(self, experiments): + def add_configurations( + self, + experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]): """Chains generator given experiment specifications. Arguments: diff --git a/python/ray/tune/suggest/bayesopt.py b/python/ray/tune/suggest/bayesopt.py index d5c7684c1..6a91ba60b 100644 --- a/python/ray/tune/suggest/bayesopt.py +++ b/python/ray/tune/suggest/bayesopt.py @@ -2,9 +2,10 @@ from collections import defaultdict import logging import pickle import json -from typing import Dict +from typing import Dict, Optional, Tuple -from ray.tune.sample import Float, Quantized +from ray.tune import ExperimentAnalysis +from ray.tune.sample import Domain, Float, Quantized from ray.tune.suggest.variant_generator import parse_spec_vars from ray.tune.utils.util import unflatten_dict @@ -100,18 +101,18 @@ class BayesOptSearch(Searcher): optimizer = None def __init__(self, - space=None, - metric=None, - mode=None, - utility_kwargs=None, - random_state=42, - random_search_steps=10, - verbose=0, - patience=5, - skip_duplicate=True, - analysis=None, - max_concurrent=None, - use_early_stopped_trials=None): + space: Optional[Dict] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + utility_kwargs: Optional[Dict] = None, + random_state: int = 42, + random_search_steps: int = 10, + verbose: int = 0, + patience: int = 5, + skip_duplicate: bool = True, + analysis: Optional[ExperimentAnalysis] = None, + max_concurrent: Optional[int] = None, + use_early_stopped_trials: Optional[bool] = None): """Instantiate new BayesOptSearch object. Args: @@ -200,7 +201,8 @@ class BayesOptSearch(Searcher): verbose=self._verbose, random_state=self._random_state) - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict) -> bool: if self.optimizer: return False space = self.convert_search_space(config) @@ -218,7 +220,7 @@ class BayesOptSearch(Searcher): self.setup_optimizer() return True - def suggest(self, trial_id): + def suggest(self, trial_id: str) -> Optional[Dict]: """Return new point to be explored by black box function. Args: @@ -278,7 +280,7 @@ class BayesOptSearch(Searcher): # Return a deep copy of the mapping return unflatten_dict(config) - def register_analysis(self, analysis): + def register_analysis(self, analysis: ExperimentAnalysis): """Integrate the given analysis into the gaussian process. Args: @@ -293,7 +295,10 @@ class BayesOptSearch(Searcher): # gaussian process optimizer self._register_result(params, report) - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): """Notification for the completion of trial. Args: @@ -330,18 +335,18 @@ class BayesOptSearch(Searcher): for params, result in self._buffered_trial_results: self._register_result(params, result) - def _register_result(self, params, result): + def _register_result(self, params: Tuple[str], result: Dict): """Register given tuple of params and results.""" self.optimizer.register(params, self._metric_op * result[self.metric]) - def save(self, checkpoint_path): + def save(self, checkpoint_path: str): """Storing current optimizer state.""" with open(checkpoint_path, "wb") as f: pickle.dump( (self.optimizer, self._buffered_trial_results, self._total_random_search_trials, self._config_counter), f) - def restore(self, checkpoint_path): + def restore(self, checkpoint_path: str): """Restoring current optimizer state.""" with open(checkpoint_path, "rb") as f: (self.optimizer, self._buffered_trial_results, @@ -349,7 +354,7 @@ class BayesOptSearch(Searcher): self._config_counter) = pickle.load(f) @staticmethod - def convert_search_space(spec: Dict): + def convert_search_space(spec: Dict) -> Dict: spec = flatten_dict(spec, prevent_delimiter=True) resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec) @@ -358,7 +363,7 @@ class BayesOptSearch(Searcher): "Grid search parameters cannot be automatically converted " "to a BayesOpt search space.") - def resolve_value(domain): + def resolve_value(domain: Domain) -> Tuple[float, float]: sampler = domain.get_sampler() if isinstance(sampler, Quantized): logger.warning( diff --git a/python/ray/tune/suggest/bohb.py b/python/ray/tune/suggest/bohb.py index 318e582e0..bea916f1d 100644 --- a/python/ray/tune/suggest/bohb.py +++ b/python/ray/tune/suggest/bohb.py @@ -3,10 +3,11 @@ import copy import logging import math -from typing import Dict +from typing import Dict, Optional import ConfigSpace -from ray.tune.sample import Categorical, Float, Integer, LogUniform, Normal, \ +from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \ + Normal, \ Quantized, \ Uniform from ray.tune.suggest import Searcher @@ -20,7 +21,7 @@ logger = logging.getLogger(__name__) class _BOHBJobWrapper(): """Mock object for HpBandSter to process.""" - def __init__(self, loss, budget, config): + def __init__(self, loss: float, budget: float, config: Dict): self.result = {"loss": loss} self.kwargs = {"budget": budget, "config": config.copy()} self.exception = None @@ -92,11 +93,11 @@ class TuneBOHB(Searcher): """ def __init__(self, - space=None, - bohb_config=None, - max_concurrent=10, - metric=None, - mode=None): + space: Optional[ConfigSpace.ConfigurationSpace] = None, + bohb_config: Optional[Dict] = None, + max_concurrent: int = 10, + metric: Optional[str] = None, + mode: Optional[str] = None): from hpbandster.optimizers.config_generators.bohb import BOHB assert BOHB is not None, "HpBandSter must be installed!" if mode: @@ -126,7 +127,8 @@ class TuneBOHB(Searcher): bohb_config = self._bohb_config or {} self.bohber = BOHB(self._space, **bohb_config) - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict) -> bool: if self._space: return False space = self.convert_search_space(config) @@ -140,7 +142,7 @@ class TuneBOHB(Searcher): self.setup_bohb() return True - def suggest(self, trial_id): + def suggest(self, trial_id: str) -> Optional[Dict]: if not self._space: raise RuntimeError( "Trying to sample a configuration from {}, but no search " @@ -156,7 +158,7 @@ class TuneBOHB(Searcher): return unflatten_dict(config) return None - def on_trial_result(self, trial_id, result): + def on_trial_result(self, trial_id: str, result: Dict): if trial_id not in self.paused: self.running.add(trial_id) if "hyperband_info" not in result: @@ -166,28 +168,31 @@ class TuneBOHB(Searcher): hbs_wrapper = self.to_wrapper(trial_id, result) self.bohber.new_result(hbs_wrapper) - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): del self.trial_to_params[trial_id] if trial_id in self.paused: self.paused.remove(trial_id) if trial_id in self.running: self.running.remove(trial_id) - def to_wrapper(self, trial_id, result): + def to_wrapper(self, trial_id: str, result: Dict) -> _BOHBJobWrapper: return _BOHBJobWrapper(self._metric_op * result[self.metric], result["hyperband_info"]["budget"], self.trial_to_params[trial_id]) - def on_pause(self, trial_id): + def on_pause(self, trial_id: str): self.paused.add(trial_id) self.running.remove(trial_id) - def on_unpause(self, trial_id): + def on_unpause(self, trial_id: str): self.paused.remove(trial_id) self.running.add(trial_id) @staticmethod - def convert_search_space(spec: Dict): + def convert_search_space(spec: Dict) -> ConfigSpace.ConfigurationSpace: spec = flatten_dict(spec, prevent_delimiter=True) resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec) @@ -196,7 +201,8 @@ class TuneBOHB(Searcher): "Grid search parameters cannot be automatically converted " "to a TuneBOHB search space.") - def resolve_value(par, domain): + def resolve_value(par: str, domain: Domain + ) -> ConfigSpace.hyperparameters.Hyperparameter: quantize = None sampler = domain.get_sampler() diff --git a/python/ray/tune/suggest/dragonfly.py b/python/ray/tune/suggest/dragonfly.py index b2da186b0..5626affed 100644 --- a/python/ray/tune/suggest/dragonfly.py +++ b/python/ray/tune/suggest/dragonfly.py @@ -5,16 +5,18 @@ from __future__ import print_function import inspect import logging import pickle -from typing import Dict +from typing import Dict, List, Optional -from ray.tune.sample import Float, Quantized +from ray.tune.sample import Domain, Float, Quantized from ray.tune.suggest.variant_generator import parse_spec_vars from ray.tune.utils.util import flatten_dict try: # Python 3 only -- needed for lint test. import dragonfly + from dragonfly.opt.blackbox_optimiser import BlackboxOptimiser except ImportError: dragonfly = None + BlackboxOptimiser = None from ray.tune.suggest.suggestion import Searcher @@ -127,13 +129,13 @@ class DragonflySearch(Searcher): """ def __init__(self, - optimizer=None, - domain=None, - space=None, - metric=None, - mode=None, - points_to_evaluate=None, - evaluated_rewards=None, + optimizer: Optional[BlackboxOptimiser] = None, + domain: Optional[str] = None, + space: Optional[List[Dict]] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + points_to_evaluate: Optional[List[List]] = None, + evaluated_rewards: Optional[List] = None, **kwargs): assert dragonfly is not None, """dragonfly must be installed! You can install Dragonfly with the command: @@ -144,8 +146,6 @@ class DragonflySearch(Searcher): super(DragonflySearch, self).__init__( metric=metric, mode=mode, **kwargs) - from dragonfly.opt.blackbox_optimiser import BlackboxOptimiser - self._opt_arg = optimizer self._domain = domain self._space = space @@ -245,7 +245,8 @@ class DragonflySearch(Searcher): elif self._mode == "max": self._metric_op = 1. - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict) -> bool: if self._opt: return False space = self.convert_search_space(config) @@ -258,7 +259,7 @@ class DragonflySearch(Searcher): self.setup_dragonfly() return True - def suggest(self, trial_id): + def suggest(self, trial_id: str) -> Optional[Dict]: if not self._opt: raise RuntimeError( "Trying to sample a configuration from {}, but no search " @@ -281,26 +282,29 @@ class DragonflySearch(Searcher): self._live_trial_mapping[trial_id] = suggested_config return {"point": suggested_config} - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): """Passes result to Dragonfly unless early terminated or errored.""" trial_info = self._live_trial_mapping.pop(trial_id) if result: self._opt.tell([(trial_info, self._metric_op * result[self._metric])]) - def save(self, checkpoint_dir): + def save(self, checkpoint_path: str): trials_object = (self._initial_points, self._opt) - with open(checkpoint_dir, "wb") as outputFile: + with open(checkpoint_path, "wb") as outputFile: pickle.dump(trials_object, outputFile) - def restore(self, checkpoint_dir): + def restore(self, checkpoint_dir: str): with open(checkpoint_dir, "rb") as inputFile: trials_object = pickle.load(inputFile) self._initial_points = trials_object[0] self._opt = trials_object[1] @staticmethod - def convert_search_space(spec: Dict): + def convert_search_space(spec: Dict) -> List[Dict]: spec = flatten_dict(spec, prevent_delimiter=True) resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec) @@ -309,7 +313,7 @@ class DragonflySearch(Searcher): "Grid search parameters cannot be automatically converted " "to a Dragonfly search space.") - def resolve_value(par, domain): + def resolve_value(par: str, domain: Domain) -> Dict: sampler = domain.get_sampler() if isinstance(sampler, Quantized): logger.warning( diff --git a/python/ray/tune/suggest/hyperopt.py b/python/ray/tune/suggest/hyperopt.py index b097cc29f..f129ee129 100644 --- a/python/ray/tune/suggest/hyperopt.py +++ b/python/ray/tune/suggest/hyperopt.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Any, Dict, List, Optional import numpy as np import copy @@ -6,7 +6,8 @@ import logging from functools import partial import pickle -from ray.tune.sample import Categorical, Float, Integer, LogUniform, Normal, \ +from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \ + Normal, \ Quantized, \ Uniform from ray.tune.suggest.variant_generator import assign_value, parse_spec_vars @@ -117,15 +118,15 @@ class HyperOptSearch(Searcher): def __init__( self, - space=None, - metric=None, - mode=None, - points_to_evaluate=None, - n_initial_points=20, - random_state_seed=None, - gamma=0.25, - max_concurrent=None, - use_early_stopped_trials=None, + space: Optional[Dict] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + points_to_evaluate: Optional[List[Dict]] = None, + n_initial_points: int = 20, + random_state_seed: Optional[int] = None, + gamma: float = 0.25, + max_concurrent: Optional[int] = None, + use_early_stopped_trials: Optional[bool] = None, ): assert hpo is not None, ( "HyperOpt must be installed! Run `pip install hyperopt`.") @@ -170,7 +171,8 @@ class HyperOptSearch(Searcher): if space: self.domain = hpo.Domain(lambda spc: spc, space) - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict) -> bool: if self.domain: return False space = self.convert_search_space(config) @@ -188,7 +190,7 @@ class HyperOptSearch(Searcher): return True - def suggest(self, trial_id): + def suggest(self, trial_id: str) -> Optional[Dict]: if not self.domain: raise RuntimeError( "Trying to sample a configuration from {}, but no search " @@ -226,7 +228,7 @@ class HyperOptSearch(Searcher): print_node_on_error=self.domain.rec_eval_print_node_on_error) return copy.deepcopy(suggested_config) - def on_trial_result(self, trial_id, result): + def on_trial_result(self, trial_id: str, result: Dict): ho_trial = self._get_hyperopt_trial(trial_id) if ho_trial is None: return @@ -234,7 +236,10 @@ class HyperOptSearch(Searcher): ho_trial["book_time"] = now ho_trial["refresh_time"] = now - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): """Notification for the completion of trial. The result is internally negated when interacting with HyperOpt @@ -252,7 +257,7 @@ class HyperOptSearch(Searcher): self._process_result(trial_id, result) del self._live_trial_mapping[trial_id] - def _process_result(self, trial_id, result): + def _process_result(self, trial_id: str, result: Dict): ho_trial = self._get_hyperopt_trial(trial_id) if not ho_trial: return @@ -263,10 +268,10 @@ class HyperOptSearch(Searcher): ho_trial["result"] = hp_result self._hpopt_trials.refresh() - def _to_hyperopt_result(self, result): + def _to_hyperopt_result(self, result: Dict) -> Dict: return {"loss": self.metric_op * result[self.metric], "status": "ok"} - def _get_hyperopt_trial(self, trial_id): + def _get_hyperopt_trial(self, trial_id: str) -> Optional[Dict]: if trial_id not in self._live_trial_mapping: return hyperopt_tid = self._live_trial_mapping[trial_id][0] @@ -274,21 +279,21 @@ class HyperOptSearch(Searcher): t for t in self._hpopt_trials.trials if t["tid"] == hyperopt_tid ][0] - def get_state(self): + def get_state(self) -> Dict: return { "hyperopt_trials": self._hpopt_trials, "rstate": self.rstate.get_state() } - def set_state(self, state): + def set_state(self, state: Dict): self._hpopt_trials = state["hyperopt_trials"] self.rstate.set_state(state["rstate"]) - def save(self, checkpoint_path): + def save(self, checkpoint_path: str): with open(checkpoint_path, "wb") as outputFile: pickle.dump(self.get_state(), outputFile) - def restore(self, checkpoint_path): + def restore(self, checkpoint_path: str): with open(checkpoint_path, "rb") as inputFile: trials_object = pickle.load(inputFile) @@ -299,19 +304,19 @@ class HyperOptSearch(Searcher): self.set_state(trials_object) @staticmethod - def convert_search_space(spec: Dict): + def convert_search_space(spec: Dict) -> Dict: spec = copy.deepcopy(spec) resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec) if not domain_vars and not grid_vars: - return [] + return {} if grid_vars: raise ValueError( "Grid search parameters cannot be automatically converted " "to a HyperOpt search space.") - def resolve_value(par, domain): + def resolve_value(par: str, domain: Domain) -> Any: quantize = None sampler = domain.get_sampler() diff --git a/python/ray/tune/suggest/nevergrad.py b/python/ray/tune/suggest/nevergrad.py index bee20c814..5746882da 100644 --- a/python/ray/tune/suggest/nevergrad.py +++ b/python/ray/tune/suggest/nevergrad.py @@ -1,16 +1,23 @@ import logging import pickle -from typing import Dict +from typing import Dict, Optional, Union -from ray.tune.sample import Categorical, Float, Integer, LogUniform, Quantized +from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \ + Quantized from ray.tune.suggest.variant_generator import parse_spec_vars from ray.tune.utils import flatten_dict from ray.tune.utils.util import unflatten_dict try: import nevergrad as ng + from nevergrad.optimization import Optimizer + from nevergrad.optimization.base import ConfiguredOptimizer + Parameter = ng.p.Parameter except ImportError: ng = None + Optimizer = None + ConfiguredOptimizer = None + Parameter = None from ray.tune.suggest import Searcher @@ -85,11 +92,11 @@ class NevergradSearch(Searcher): """ def __init__(self, - optimizer=None, - space=None, - metric=None, - mode=None, - max_concurrent=None, + optimizer: Union[None, Optimizer, ConfiguredOptimizer] = None, + space: Optional[Parameter] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + max_concurrent: Optional[int] = None, **kwargs): assert ng is not None, "Nevergrad must be installed!" if mode: @@ -102,7 +109,7 @@ class NevergradSearch(Searcher): self._opt_factory = None self._nevergrad_opt = None - if isinstance(optimizer, ng.optimization.Optimizer): + if isinstance(optimizer, Optimizer): if space is not None or isinstance(space, list): raise ValueError( "If you pass a configured optimizer to Nevergrad, either " @@ -110,7 +117,7 @@ class NevergradSearch(Searcher): "parameter.") self._parameters = space self._nevergrad_opt = optimizer - elif isinstance(optimizer, ng.optimization.base.ConfiguredOptimizer): + elif isinstance(optimizer, ConfiguredOptimizer): self._opt_factory = optimizer self._parameters = None self._space = space @@ -155,7 +162,8 @@ class NevergradSearch(Searcher): raise ValueError("len(parameters_names) must match optimizer " "dimension for non-instrumented optimizers") - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict) -> bool: if self._nevergrad_opt or self._space: return False space = self.convert_search_space(config) @@ -169,7 +177,7 @@ class NevergradSearch(Searcher): self.setup_nevergrad() return True - def suggest(self, trial_id): + def suggest(self, trial_id: str) -> Optional[Dict]: if not self._nevergrad_opt: raise RuntimeError( "Trying to sample a configuration from {}, but no search " @@ -192,7 +200,10 @@ class NevergradSearch(Searcher): else: return unflatten_dict(suggested_config.kwargs) - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): """Notification for the completion of trial. The result is internally negated when interacting with Nevergrad @@ -204,24 +215,24 @@ class NevergradSearch(Searcher): self._live_trial_mapping.pop(trial_id) - def _process_result(self, trial_id, result): + def _process_result(self, trial_id: str, result: Dict): ng_trial_info = self._live_trial_mapping[trial_id] self._nevergrad_opt.tell(ng_trial_info, self._metric_op * result[self._metric]) - def save(self, checkpoint_path): + def save(self, checkpoint_path: str): trials_object = (self._nevergrad_opt, self._parameters) with open(checkpoint_path, "wb") as outputFile: pickle.dump(trials_object, outputFile) - def restore(self, checkpoint_path): + def restore(self, checkpoint_path: str): with open(checkpoint_path, "rb") as inputFile: trials_object = pickle.load(inputFile) self._nevergrad_opt = trials_object[0] self._parameters = trials_object[1] @staticmethod - def convert_search_space(spec: Dict): + def convert_search_space(spec: Dict) -> Parameter: spec = flatten_dict(spec, prevent_delimiter=True) resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec) @@ -230,7 +241,7 @@ class NevergradSearch(Searcher): "Grid search parameters cannot be automatically converted " "to a Nevergrad search space.") - def resolve_value(domain): + def resolve_value(domain: Domain) -> Parameter: sampler = domain.get_sampler() if isinstance(sampler, Quantized): logger.warning("Nevergrad does not support quantization. " diff --git a/python/ray/tune/suggest/optuna.py b/python/ray/tune/suggest/optuna.py index ae3f1aadb..981ce1741 100644 --- a/python/ray/tune/suggest/optuna.py +++ b/python/ray/tune/suggest/optuna.py @@ -1,9 +1,9 @@ import logging import pickle -from typing import Dict +from typing import Dict, List, Optional, Tuple from ray.tune.result import TRAINING_ITERATION -from ray.tune.sample import Categorical, Float, Integer, LogUniform, \ +from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \ Quantized, Uniform from ray.tune.suggest.variant_generator import parse_spec_vars from ray.tune.utils import flatten_dict @@ -11,8 +11,10 @@ from ray.tune.utils.util import unflatten_dict try: import optuna as ot + from optuna.samplers import BaseSampler except ImportError: ot = None + BaseSampler = None from ray.tune.suggest import Searcher @@ -100,7 +102,11 @@ class OptunaSearch(Searcher): """ - def __init__(self, space=None, metric=None, mode=None, sampler=None): + def __init__(self, + space: Optional[List[Tuple]] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + sampler: Optional[BaseSampler] = None): assert ot is not None, ( "Optuna must be installed! Run `pip install optuna`.") super(OptunaSearch, self).__init__( @@ -113,7 +119,7 @@ class OptunaSearch(Searcher): self._study_name = "optuna" # Fixed study name for in-memory storage self._sampler = sampler or ot.samplers.TPESampler() - assert isinstance(self._sampler, ot.samplers.BaseSampler), \ + assert isinstance(self._sampler, BaseSampler), \ "You can only pass an instance of `optuna.samplers.BaseSampler` " \ "as a sampler to `OptunaSearcher`." @@ -125,7 +131,7 @@ class OptunaSearch(Searcher): if self._space: self.setup_study(mode) - def setup_study(self, mode): + def setup_study(self, mode: str): self._ot_study = ot.study.create_study( storage=self._storage, sampler=self._sampler, @@ -134,7 +140,8 @@ class OptunaSearch(Searcher): direction="minimize" if mode == "min" else "maximize", load_if_exists=True) - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict) -> bool: if self._space: return False space = self.convert_search_space(config) @@ -146,7 +153,7 @@ class OptunaSearch(Searcher): self.setup_study(mode) return True - def suggest(self, trial_id): + def suggest(self, trial_id: str) -> Optional[Dict]: if not self._space: raise RuntimeError( "Trying to sample a configuration from {}, but no search " @@ -169,13 +176,16 @@ class OptunaSearch(Searcher): } return unflatten_dict(params) - def on_trial_result(self, trial_id, result): + def on_trial_result(self, trial_id: str, result: Dict): metric = result[self.metric] step = result[TRAINING_ITERATION] ot_trial = self._ot_trials[trial_id] ot_trial.report(metric, step) - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): ot_trial = self._ot_trials[trial_id] ot_trial_id = ot_trial._trial_id self._storage.set_trial_value(ot_trial_id, result.get( @@ -183,20 +193,20 @@ class OptunaSearch(Searcher): self._storage.set_trial_state(ot_trial_id, ot.trial.TrialState.COMPLETE) - def save(self, checkpoint_path): + def save(self, checkpoint_path: str): save_object = (self._storage, self._pruner, self._sampler, self._ot_trials, self._ot_study) with open(checkpoint_path, "wb") as outputFile: pickle.dump(save_object, outputFile) - def restore(self, checkpoint_path): + def restore(self, checkpoint_path: str): with open(checkpoint_path, "rb") as inputFile: save_object = pickle.load(inputFile) self._storage, self._pruner, self._sampler, \ self._ot_trials, self._ot_study = save_object @staticmethod - def convert_search_space(spec: Dict): + def convert_search_space(spec: Dict) -> List[Tuple]: spec = flatten_dict(spec, prevent_delimiter=True) resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec) @@ -208,7 +218,7 @@ class OptunaSearch(Searcher): "Grid search parameters cannot be automatically converted " "to an Optuna search space.") - def resolve_value(par, domain): + def resolve_value(par: str, domain: Domain) -> Tuple: quantize = None sampler = domain.get_sampler() diff --git a/python/ray/tune/suggest/repeater.py b/python/ray/tune/suggest/repeater.py index 647b6cdaf..734713fe2 100644 --- a/python/ray/tune/suggest/repeater.py +++ b/python/ray/tune/suggest/repeater.py @@ -1,5 +1,7 @@ import copy import logging +from typing import Dict, List, Optional + import numpy as np from ray.tune.suggest.suggestion import Searcher @@ -10,7 +12,7 @@ TRIAL_INDEX = "__trial_index__" """str: A constant value representing the repeat index of the trial.""" -def _warn_num_samples(searcher, num_samples): +def _warn_num_samples(searcher: Searcher, num_samples: int): if isinstance(searcher, Repeater) and num_samples % searcher.repeat: logger.warning( "`num_samples` is now expected to be the total number of trials, " @@ -34,7 +36,10 @@ class _TrialGroup: """ - def __init__(self, primary_trial_id, config, max_trials=1): + def __init__(self, + primary_trial_id: str, + config: Dict, + max_trials: int = 1): assert type(config) is dict, ( "config is not a dict, got {}".format(config)) self.primary_trial_id = primary_trial_id @@ -42,27 +47,27 @@ class _TrialGroup: self._trials = {primary_trial_id: None} self.max_trials = max_trials - def add(self, trial_id): + def add(self, trial_id: str): assert len(self._trials) < self.max_trials self._trials.setdefault(trial_id, None) - def full(self): + def full(self) -> bool: return len(self._trials) == self.max_trials - def report(self, trial_id, score): + def report(self, trial_id: str, score: float): assert trial_id in self._trials if score is None: raise ValueError("Internal Error: Score cannot be None.") self._trials[trial_id] = score - def finished_reporting(self): + def finished_reporting(self) -> bool: return None not in self._trials.values() and len( self._trials) == self.max_trials - def scores(self): + def scores(self) -> List[Optional[float]]: return list(self._trials.values()) - def count(self): + def count(self) -> int: return len(self._trials) @@ -103,7 +108,10 @@ class Repeater(Searcher): """ - def __init__(self, searcher, repeat=1, set_index=True): + def __init__(self, + searcher: Searcher, + repeat: int = 1, + set_index: bool = True): self.searcher = searcher self.repeat = repeat self._set_index = set_index @@ -113,7 +121,7 @@ class Repeater(Searcher): super(Repeater, self).__init__( metric=self.searcher.metric, mode=self.searcher.mode) - def suggest(self, trial_id): + def suggest(self, trial_id: str) -> Optional[Dict]: if self._current_group is None or self._current_group.full(): config = self.searcher.suggest(trial_id) if config is None: @@ -132,7 +140,10 @@ class Repeater(Searcher): self._trial_id_to_group[trial_id] = self._current_group return config - def on_trial_complete(self, trial_id, result=None, **kwargs): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + **kwargs): """Stores the score for and keeps track of a completed trial. Stores the metric of a trial as nan if any of the following conditions @@ -160,13 +171,14 @@ class Repeater(Searcher): result={self.searcher.metric: np.nanmean(scores)}, **kwargs) - def get_state(self): + def get_state(self) -> Dict: self_state = self.__dict__.copy() del self_state["searcher"] return self_state - def set_state(self, state): + def set_state(self, state: Dict): self.__dict__.update(state) - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict) -> bool: return self.searcher.set_search_properties(metric, mode, config) diff --git a/python/ray/tune/suggest/search.py b/python/ray/tune/suggest/search.py index a878410a9..87d7da3e7 100644 --- a/python/ray/tune/suggest/search.py +++ b/python/ray/tune/suggest/search.py @@ -1,3 +1,9 @@ +from typing import Dict, List, Optional, Union + +from ray.tune.experiment import Experiment +from ray.tune.trial import Trial + + class SearchAlgorithm: """Interface of an event handler API for hyperparameter search. @@ -12,7 +18,8 @@ class SearchAlgorithm: """ _finished = False - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict) -> bool: """Pass search properties to search algorithm. This method acts as an alternative to instantiating search algorithms @@ -29,7 +36,9 @@ class SearchAlgorithm: """ return True - def add_configurations(self, experiments): + def add_configurations( + self, + experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]): """Tracks given experiment specifications. Arguments: @@ -37,7 +46,7 @@ class SearchAlgorithm: """ raise NotImplementedError - def next_trials(self): + def next_trials(self) -> List[Trial]: """Provides Trial objects to be queued into the TrialRunner. Returns: @@ -45,17 +54,21 @@ class SearchAlgorithm: """ raise NotImplementedError - def on_trial_result(self, trial_id, result): + def on_trial_result(self, trial_id: str, result: Dict): """Called on each intermediate result returned by a trial. This will only be called when the trial is in the RUNNING state. Arguments: trial_id: Identifier for the trial. + result: Result dictionary. """ pass - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): """Notification for the completion of trial. Arguments: @@ -69,7 +82,7 @@ class SearchAlgorithm: """ pass - def is_finished(self): + def is_finished(self) -> bool: """Returns True if no trials left to be queued into TrialRunner. Can return True before all trials have finished executing. @@ -80,14 +93,14 @@ class SearchAlgorithm: """Marks the search algorithm as finished.""" self._finished = True - def has_checkpoint(self, dirpath): + def has_checkpoint(self, dirpath: str) -> bool: """Should return False if not restoring is not implemented.""" return False - def save_to_dir(self, dirpath, **kwargs): + def save_to_dir(self, dirpath: str, **kwargs): """Saves a search algorithm.""" pass - def restore_from_dir(self, dirpath): + def restore_from_dir(self, dirpath: str): """Restores a search algorithm along with its wrapped state.""" pass diff --git a/python/ray/tune/suggest/search_generator.py b/python/ray/tune/suggest/search_generator.py index c5990c63c..7edd26452 100644 --- a/python/ray/tune/suggest/search_generator.py +++ b/python/ray/tune/suggest/search_generator.py @@ -2,10 +2,11 @@ import os import copy import logging import glob +from typing import Dict, List, Optional, Union import ray.cloudpickle as cloudpickle from ray.tune.error import TuneError -from ray.tune.experiment import convert_to_experiment_list +from ray.tune.experiment import Experiment, convert_to_experiment_list from ray.tune.config_parser import make_parser, create_trial_from_spec from ray.tune.suggest.search import SearchAlgorithm from ray.tune.suggest.suggestion import Searcher @@ -21,7 +22,7 @@ def _warn_on_repeater(searcher, total_samples): _warn_num_samples(searcher, total_samples) -def _atomic_save(state, checkpoint_dir, file_name): +def _atomic_save(state: Dict, checkpoint_dir: str, file_name: str): """Atomically saves the object to the checkpoint directory This is automatically used by tune.run during a Tune job. @@ -34,7 +35,7 @@ def _atomic_save(state, checkpoint_dir, file_name): os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name)) -def _find_newest_ckpt(dirpath, pattern): +def _find_newest_ckpt(dirpath: str, pattern: str): """Returns path to most recently modified checkpoint.""" full_paths = glob.glob(os.path.join(dirpath, pattern)) if not full_paths: @@ -58,7 +59,7 @@ class SearchGenerator(SearchAlgorithm): """ CKPT_FILE_TMPL = "search_gen_state-{}.json" - def __init__(self, searcher): + def __init__(self, searcher: Searcher): assert issubclass( type(searcher), Searcher), ("Searcher should be subclassing Searcher.") @@ -69,10 +70,13 @@ class SearchGenerator(SearchAlgorithm): self._total_samples = None # int: total samples to evaluate. self._finished = False - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict) -> bool: return self.searcher.set_search_properties(metric, mode, config) - def add_configurations(self, experiments): + def add_configurations( + self, + experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]): """Registers experiment specifications. Arguments: @@ -91,7 +95,7 @@ class SearchGenerator(SearchAlgorithm): if "run" not in experiment_spec: raise TuneError("Must specify `run` in {}".format(experiment_spec)) - def next_trials(self): + def next_trials(self) -> List[Trial]: """Provides a batch of Trial objects to be queued into the TrialRunner. Returns: @@ -106,7 +110,8 @@ class SearchGenerator(SearchAlgorithm): trials.append(trial) return trials - def create_trial_if_possible(self, experiment_spec, output_path): + def create_trial_if_possible(self, experiment_spec: Dict, + output_path: str) -> Optional[Trial]: logger.debug("creating trial") trial_id = Trial.generate_id() suggested_config = self.searcher.suggest(trial_id) @@ -135,18 +140,21 @@ class SearchGenerator(SearchAlgorithm): trial_id=trial_id) return trial - def on_trial_result(self, trial_id, result): + def on_trial_result(self, trial_id: str, result: Dict): """Notifies the underlying searcher.""" self.searcher.on_trial_result(trial_id, result) - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): self.searcher.on_trial_complete( trial_id=trial_id, result=result, error=error) - def is_finished(self): + def is_finished(self) -> bool: return self._counter >= self._total_samples or self._finished - def get_state(self): + def get_state(self) -> Dict: return { "counter": self._counter, "total_samples": self._total_samples, @@ -154,17 +162,17 @@ class SearchGenerator(SearchAlgorithm): "experiment": self._experiment } - def set_state(self, state): + def set_state(self, state: Dict): self._counter = state["counter"] self._total_samples = state["total_samples"] self._finished = state["finished"] self._experiment = state["experiment"] - def has_checkpoint(self, dirpath): + def has_checkpoint(self, dirpath: str): return bool( _find_newest_ckpt(dirpath, self.CKPT_FILE_TMPL.format("*"))) - def save_to_dir(self, dirpath, session_str): + def save_to_dir(self, dirpath: str, session_str: str): """Saves self + searcher to dir. Separates the "searcher" from its wrappers (concurrency, repeating). @@ -196,7 +204,7 @@ class SearchGenerator(SearchAlgorithm): _atomic_save(search_alg_state, dirpath, self.CKPT_FILE_TMPL.format(session_str)) - def restore_from_dir(self, dirpath): + def restore_from_dir(self, dirpath: str): """Restores self + searcher + search wrappers from dirpath.""" searcher = self.searcher diff --git a/python/ray/tune/suggest/sigopt.py b/python/ray/tune/suggest/sigopt.py index 9e226dce8..b3d0d2122 100644 --- a/python/ray/tune/suggest/sigopt.py +++ b/python/ray/tune/suggest/sigopt.py @@ -2,10 +2,14 @@ import copy import os import logging import pickle +from typing import Dict, List, Optional, Union + try: import sigopt as sgo + Connection = sgo.Connection except ImportError: sgo = None + Connection = None from ray.tune.suggest import Searcher @@ -122,16 +126,16 @@ class SigOptSearch(Searcher): } def __init__(self, - space=None, - name="Default Tune Experiment", - max_concurrent=1, - reward_attr=None, - connection=None, - experiment_id=None, - observation_budget=None, - project=None, - metric="episode_reward_mean", - mode="max", + space: List[Dict] = None, + name: str = "Default Tune Experiment", + max_concurrent: int = 1, + reward_attr: Optional[str] = None, + connection: Optional[Connection] = None, + experiment_id: Optional[str] = None, + observation_budget: Optional[int] = None, + project: Optional[str] = None, + metric: Union[None, str, List[str]] = "episode_reward_mean", + mode: Union[None, str, List[str]] = "max", **kwargs): assert (experiment_id is None) ^ (space is None), "space xor experiment_id must be set" @@ -178,7 +182,7 @@ class SigOptSearch(Searcher): super(SigOptSearch, self).__init__(metric=metric, mode=mode, **kwargs) - def suggest(self, trial_id): + def suggest(self, trial_id: str): if self._max_concurrent: if len(self._live_trial_mapping) >= self._max_concurrent: return None @@ -190,7 +194,10 @@ class SigOptSearch(Searcher): return copy.deepcopy(suggestion.assignments) - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): """Notification for the completion of trial. If a trial fails, it will be reported as a failed Observation, telling @@ -214,7 +221,7 @@ class SigOptSearch(Searcher): del self._live_trial_mapping[trial_id] @staticmethod - def serialize_metric(metrics, modes): + def serialize_metric(metrics: List[str], modes: List[str]): """ Converts metrics to https://app.sigopt.com/docs/objects/metric """ @@ -224,7 +231,7 @@ class SigOptSearch(Searcher): dict(name=metric, **SigOptSearch.OBJECTIVE_MAP[mode].copy())) return serialized_metric - def serialize_result(self, result): + def serialize_result(self, result: Dict): """ Converts experiments results to https://app.sigopt.com/docs/objects/metric_evaluation @@ -244,12 +251,12 @@ class SigOptSearch(Searcher): values.append(value) return values - def save(self, checkpoint_path): + def save(self, checkpoint_path: str): trials_object = (self.conn, self.experiment) with open(checkpoint_path, "wb") as outputFile: pickle.dump(trials_object, outputFile) - def restore(self, checkpoint_path): + def restore(self, checkpoint_path: str): with open(checkpoint_path, "rb") as inputFile: trials_object = pickle.load(inputFile) self.conn = trials_object[0] diff --git a/python/ray/tune/suggest/skopt.py b/python/ray/tune/suggest/skopt.py index 67dec2bde..015974837 100644 --- a/python/ray/tune/suggest/skopt.py +++ b/python/ray/tune/suggest/skopt.py @@ -1,8 +1,8 @@ import logging import pickle -from typing import Dict +from typing import Dict, List, Optional, Tuple, Union -from ray.tune.sample import Categorical, Float, Integer, Quantized +from ray.tune.sample import Categorical, Domain, Float, Integer, Quantized from ray.tune.suggest.variant_generator import parse_spec_vars from ray.tune.utils import flatten_dict from ray.tune.utils.util import unflatten_dict @@ -17,8 +17,9 @@ from ray.tune.suggest import Searcher logger = logging.getLogger(__name__) -def _validate_warmstart(parameter_names, points_to_evaluate, - evaluated_rewards): +def _validate_warmstart(parameter_names: List[str], + points_to_evaluate: List[List], + evaluated_rewards: List): if points_to_evaluate: if not isinstance(points_to_evaluate, list): raise TypeError( @@ -125,14 +126,14 @@ class SkOptSearch(Searcher): """ def __init__(self, - optimizer=None, - space=None, - metric=None, - mode=None, - points_to_evaluate=None, - evaluated_rewards=None, - max_concurrent=None, - use_early_stopped_trials=None): + optimizer: Optional[sko.optimizer.Optimizer] = None, + space: Union[List[str], Dict[str, Union[Tuple, List]]] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + points_to_evaluate: Optional[List[List]] = None, + evaluated_rewards: Optional[List] = None, + max_concurrent: Optional[int] = None, + use_early_stopped_trials: Optional[bool] = None): assert sko is not None, """skopt must be installed! You can install Skopt with the command: `pip install scikit-optimize`.""" @@ -162,7 +163,7 @@ class SkOptSearch(Searcher): "names.") self._parameter_names = space else: - self._parameter_names = space.keys() + self._parameter_names = list(space.keys()) self._parameter_ranges = space.values() self._points_to_evaluate = points_to_evaluate @@ -199,7 +200,8 @@ class SkOptSearch(Searcher): elif self._mode == "min": self._metric_op = 1. - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict) -> bool: if self._skopt_opt: return False space = self.convert_search_space(config) @@ -216,7 +218,7 @@ class SkOptSearch(Searcher): self.setup_skopt() return True - def suggest(self, trial_id): + def suggest(self, trial_id: str) -> Optional[Dict]: if not self._skopt_opt: raise RuntimeError( "Trying to sample a configuration from {}, but no search " @@ -235,7 +237,10 @@ class SkOptSearch(Searcher): self._live_trial_mapping[trial_id] = suggested_config return unflatten_dict(dict(zip(self._parameters, suggested_config))) - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): """Notification for the completion of trial. The result is internally negated when interacting with Skopt @@ -247,24 +252,24 @@ class SkOptSearch(Searcher): self._process_result(trial_id, result) self._live_trial_mapping.pop(trial_id) - def _process_result(self, trial_id, result): + def _process_result(self, trial_id: str, result: Dict): skopt_trial_info = self._live_trial_mapping[trial_id] self._skopt_opt.tell(skopt_trial_info, self._metric_op * result[self._metric]) - def save(self, checkpoint_path): + def save(self, checkpoint_path: str): trials_object = (self._initial_points, self._skopt_opt) with open(checkpoint_path, "wb") as outputFile: pickle.dump(trials_object, outputFile) - def restore(self, checkpoint_path): + def restore(self, checkpoint_path: str): with open(checkpoint_path, "rb") as inputFile: trials_object = pickle.load(inputFile) self._initial_points = trials_object[0] self._skopt_opt = trials_object[1] @staticmethod - def convert_search_space(spec: Dict): + def convert_search_space(spec: Dict) -> Dict: spec = flatten_dict(spec, prevent_delimiter=True) resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec) @@ -273,7 +278,7 @@ class SkOptSearch(Searcher): "Grid search parameters cannot be automatically converted " "to a SkOpt search space.") - def resolve_value(domain): + def resolve_value(domain: Domain) -> Union[Tuple, List]: sampler = domain.get_sampler() if isinstance(sampler, Quantized): logger.warning("SkOpt search does not support quantization. " diff --git a/python/ray/tune/suggest/suggestion.py b/python/ray/tune/suggest/suggestion.py index 2a9793cee..363c3f915 100644 --- a/python/ray/tune/suggest/suggestion.py +++ b/python/ray/tune/suggest/suggestion.py @@ -2,6 +2,7 @@ import copy import glob import logging import os +from typing import Dict, Optional from ray.util.debug import log_once @@ -56,10 +57,10 @@ class Searcher: CKPT_FILE_TMPL = "searcher-state-{}.pkl" def __init__(self, - metric=None, - mode=None, - max_concurrent=None, - use_early_stopped_trials=None): + metric: Optional[str] = None, + mode: Optional[str] = None, + max_concurrent: Optional[int] = None, + use_early_stopped_trials: Optional[bool] = None): if use_early_stopped_trials is False: raise DeprecationWarning( "Early stopped trials are now always used. If this is a " @@ -90,7 +91,8 @@ class Searcher: else: raise ValueError("Mode most either be a list or string") - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict) -> bool: """Pass search properties to searcher. This method acts as an alternative to instantiating search algorithms @@ -106,7 +108,7 @@ class Searcher: """ return False - def on_trial_result(self, trial_id, result): + def on_trial_result(self, trial_id: str, result: Dict): """Optional notification for result during training. Note that by default, the result dict may include NaNs or @@ -124,7 +126,10 @@ class Searcher: """ pass - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): """Notification for the completion of trial. Typically, this method is used for notifying the underlying @@ -143,7 +148,7 @@ class Searcher: """ raise NotImplementedError - def suggest(self, trial_id): + def suggest(self, trial_id: str) -> Optional[Dict]: """Queries the algorithm to retrieve the next set of parameters. Arguments: @@ -159,7 +164,7 @@ class Searcher: """ raise NotImplementedError - def save(self, checkpoint_path): + def save(self, checkpoint_path: str): """Save state to path for this search algorithm. Args: @@ -190,7 +195,7 @@ class Searcher: """ raise NotImplementedError - def restore(self, checkpoint_path): + def restore(self, checkpoint_path: str): """Restore state for this search algorithm @@ -213,13 +218,13 @@ class Searcher: """ raise NotImplementedError - def get_state(self): + def get_state(self) -> Dict: raise NotImplementedError - def set_state(self, state): + def set_state(self, state: Dict): raise NotImplementedError - def save_to_dir(self, checkpoint_dir, session_str="default"): + def save_to_dir(self, checkpoint_dir: str, session_str: str = "default"): """Automatically saves the given searcher to the checkpoint_dir. This is automatically used by tune.run during a Tune job. @@ -246,7 +251,7 @@ class Searcher: os.path.join(checkpoint_dir, self.CKPT_FILE_TMPL.format(session_str))) - def restore_from_dir(self, checkpoint_dir): + def restore_from_dir(self, checkpoint_dir: str): """Restores the state of a searcher from a given checkpoint_dir. Typically, you should use this function to restore from an @@ -277,12 +282,12 @@ class Searcher: self.restore(most_recent_checkpoint) @property - def metric(self): + def metric(self) -> str: """The training result objective value attribute.""" return self._metric @property - def mode(self): + def mode(self) -> str: """Specifies if minimizing or maximizing the metric.""" return self._mode @@ -308,7 +313,10 @@ class ConcurrencyLimiter(Searcher): tune.run(trainable, search_alg=search_alg) """ - def __init__(self, searcher, max_concurrent, batch=False): + def __init__(self, + searcher: Searcher, + max_concurrent: int, + batch: bool = False): assert type(max_concurrent) is int and max_concurrent > 0 self.searcher = searcher self.max_concurrent = max_concurrent @@ -318,7 +326,7 @@ class ConcurrencyLimiter(Searcher): super(ConcurrencyLimiter, self).__init__( metric=self.searcher.metric, mode=self.searcher.mode) - def suggest(self, trial_id): + def suggest(self, trial_id: str) -> Optional[Dict]: assert trial_id not in self.live_trials, ( f"Trial ID {trial_id} must be unique: already found in set.") if len(self.live_trials) >= self.max_concurrent: @@ -333,7 +341,10 @@ class ConcurrencyLimiter(Searcher): self.live_trials.add(trial_id) return suggestion - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): if trial_id not in self.live_trials: return elif self.batch: @@ -353,19 +364,20 @@ class ConcurrencyLimiter(Searcher): trial_id, result=result, error=error) self.live_trials.remove(trial_id) - def get_state(self): + def get_state(self) -> Dict: state = self.__dict__.copy() del state["searcher"] return copy.deepcopy(state) - def set_state(self, state): + def set_state(self, state: Dict): self.__dict__.update(state) - def on_pause(self, trial_id): + def on_pause(self, trial_id: str): self.searcher.on_pause(trial_id) - def on_unpause(self, trial_id): + def on_unpause(self, trial_id: str): self.searcher.on_unpause(trial_id) - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict) -> bool: return self.searcher.set_search_properties(metric, mode, config) diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index a539c7b2c..294dd3c74 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -1,5 +1,7 @@ import copy import logging +from typing import Any, Dict, Generator, List, Tuple + import numpy import random @@ -9,7 +11,8 @@ from ray.tune.sample import Categorical, Domain, Function logger = logging.getLogger(__name__) -def generate_variants(unresolved_spec): +def generate_variants( + unresolved_spec: Dict) -> Generator[Tuple[Dict, Dict], None, None]: """Generates variants from a spec (dict) with unresolved values. There are two types of unresolved values: @@ -45,7 +48,7 @@ def generate_variants(unresolved_spec): yield resolved_vars, spec -def grid_search(values): +def grid_search(values: List) -> Dict[str, List]: """Convenience method for specifying grid search over a value. Arguments: @@ -63,7 +66,7 @@ _STANDARD_IMPORTS = { _MAX_RESOLUTION_PASSES = 20 -def resolve_nested_dict(nested_dict): +def resolve_nested_dict(nested_dict: Dict) -> Dict[Tuple, Any]: """Flattens a nested dict by joining keys into tuple of paths. Can then be passed into `format_vars`. @@ -78,7 +81,7 @@ def resolve_nested_dict(nested_dict): return res -def format_vars(resolved_vars): +def format_vars(resolved_vars: Dict) -> str: """Formats the resolved variable dict into a single string.""" out = [] for path, value in sorted(resolved_vars.items()): @@ -97,7 +100,7 @@ def format_vars(resolved_vars): return ",".join(out) -def flatten_resolved_vars(resolved_vars): +def flatten_resolved_vars(resolved_vars: Dict) -> Dict: """Formats the resolved variable dict into a mapping of (str -> value).""" flattened_resolved_vars_dict = {} for pieces, value in resolved_vars.items(): @@ -108,14 +111,15 @@ def flatten_resolved_vars(resolved_vars): return flattened_resolved_vars_dict -def _clean_value(value): +def _clean_value(value: Any) -> str: if isinstance(value, float): return "{:.5}".format(value) else: return str(value).replace("/", "_") -def parse_spec_vars(spec): +def parse_spec_vars(spec: Dict) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[ + Tuple, Any]], List[Tuple[Tuple, Any]]]: resolved, unresolved = _split_resolved_unresolved_values(spec) resolved_vars = list(resolved.items()) @@ -134,7 +138,7 @@ def parse_spec_vars(spec): return resolved_vars, domain_vars, grid_vars -def _generate_variants(spec): +def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]: spec = copy.deepcopy(spec) _, domain_vars, grid_vars = parse_spec_vars(spec) @@ -159,19 +163,20 @@ def _generate_variants(spec): yield resolved_vars, spec -def assign_value(spec, path, value): +def assign_value(spec: Dict, path: Tuple, value: Any): for k in path[:-1]: spec = spec[k] spec[path[-1]] = value -def _get_value(spec, path): +def _get_value(spec: Dict, path: Tuple) -> Any: for k in path: spec = spec[k] return spec -def _resolve_domain_vars(spec, domain_vars): +def _resolve_domain_vars(spec: Dict, + domain_vars: List[Tuple[Tuple, Domain]]) -> Dict: resolved = {} error = True num_passes = 0 @@ -197,7 +202,8 @@ def _resolve_domain_vars(spec, domain_vars): return resolved -def _grid_search_generator(unresolved_spec, grid_vars): +def _grid_search_generator(unresolved_spec: Dict, + grid_vars: List) -> Generator[Dict, None, None]: value_indices = [0] * len(grid_vars) def increment(i): @@ -225,12 +231,12 @@ def _grid_search_generator(unresolved_spec, grid_vars): break -def _is_resolved(v): +def _is_resolved(v) -> bool: resolved, _ = _try_resolve(v) return resolved -def _try_resolve(v): +def _try_resolve(v) -> Tuple[bool, Any]: if isinstance(v, Domain): # Domain to sample from return False, v @@ -249,7 +255,8 @@ def _try_resolve(v): return True, v -def _split_resolved_unresolved_values(spec): +def _split_resolved_unresolved_values( + spec: Dict) -> Tuple[Dict[Tuple, Any], Dict[Tuple, Any]]: resolved_vars = {} unresolved_vars = {} for k, v in spec.items(): @@ -278,11 +285,11 @@ def _split_resolved_unresolved_values(spec): return resolved_vars, unresolved_vars -def _unresolved_values(spec): +def _unresolved_values(spec: Dict) -> Dict[Tuple, Any]: return _split_resolved_unresolved_values(spec)[1] -def has_unresolved_values(spec): +def has_unresolved_values(spec: Dict) -> bool: return True if _unresolved_values(spec) else False @@ -303,5 +310,5 @@ class _UnresolvedAccessGuard(dict): class RecursiveDependencyError(Exception): - def __init__(self, msg): + def __init__(self, msg: str): Exception.__init__(self, msg) diff --git a/python/ray/tune/suggest/zoopt.py b/python/ray/tune/suggest/zoopt.py index 8f3b24531..38351d6cd 100644 --- a/python/ray/tune/suggest/zoopt.py +++ b/python/ray/tune/suggest/zoopt.py @@ -1,9 +1,10 @@ import copy import logging -from typing import Dict +from typing import Dict, Optional, Tuple import ray.cloudpickle as pickle -from ray.tune.sample import Categorical, Float, Integer, Quantized, Uniform +from ray.tune.sample import Categorical, Domain, Float, Integer, Quantized, \ + Uniform from ray.tune.suggest.variant_generator import parse_spec_vars from ray.tune.utils.util import unflatten_dict from zoopt import ValueType @@ -106,11 +107,11 @@ class ZOOptSearch(Searcher): optimizer = None def __init__(self, - algo="asracos", - budget=None, - dim_dict=None, - metric=None, - mode=None, + algo: str = "asracos", + budget: Optional[int] = None, + dim_dict: Optional[Dict] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, **kwargs): assert zoopt is not None, "Zoopt not found - please install zoopt." assert budget is not None, "`budget` should not be None!" @@ -154,7 +155,8 @@ class ZOOptSearch(Searcher): from zoopt.algos.opt_algorithms.racos.sracos import SRacosTune self.optimizer = SRacosTune(dimension=dim, parameter=par) - def set_search_properties(self, metric, mode, config): + def set_search_properties(self, metric: Optional[str], mode: Optional[str], + config: Dict) -> bool: if self._dim_dict: return False space = self.convert_search_space(config) @@ -173,7 +175,7 @@ class ZOOptSearch(Searcher): self.setup_zoopt() return True - def suggest(self, trial_id): + def suggest(self, trial_id: str) -> Optional[Dict]: if not self._dim_dict or not self.optimizer: raise RuntimeError( "Trying to sample a configuration from {}, but no search " @@ -189,7 +191,10 @@ class ZOOptSearch(Searcher): self._live_trial_mapping[trial_id] = new_trial return unflatten_dict(new_trial) - def on_trial_complete(self, trial_id, result=None, error=False): + def on_trial_complete(self, + trial_id: str, + result: Optional[Dict] = None, + error: bool = False): """Notification for the completion of trial.""" if result: _solution = self.solution_dict[str(trial_id)] @@ -200,18 +205,18 @@ class ZOOptSearch(Searcher): del self._live_trial_mapping[trial_id] - def save(self, checkpoint_path): + def save(self, checkpoint_path: str): trials_object = self.optimizer with open(checkpoint_path, "wb") as output: pickle.dump(trials_object, output) - def restore(self, checkpoint_path): + def restore(self, checkpoint_path: str): with open(checkpoint_path, "rb") as input: trials_object = pickle.load(input) self.optimizer = trials_object @staticmethod - def convert_search_space(spec: Dict): + def convert_search_space(spec: Dict) -> Dict[str, Tuple]: spec = copy.deepcopy(spec) resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec) @@ -223,7 +228,7 @@ class ZOOptSearch(Searcher): "Grid search parameters cannot be automatically converted " "to a ZOOpt search space.") - def resolve_value(domain): + def resolve_value(domain: Domain) -> Tuple: quantize = None sampler = domain.get_sampler()