[tune] added type hints (#10806)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Kai Fricke
2020-09-16 05:03:56 +01:00
committed by GitHub
parent 5e030db8a5
commit c9fafe7733
31 changed files with 709 additions and 472 deletions
+1
View File
@@ -40,6 +40,7 @@ MOCK_MODULES = [
"horovod",
"horovod.ray",
"kubernetes",
"mxnet",
"mxnet.model",
"psutil",
"ray._raylet",
+56 -29
View File
@@ -1,9 +1,9 @@
import json
import logging
import os
from typing import Dict
from numbers import Number
from typing import Any, Dict, List, Optional, Tuple
from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.utils import flatten_dict
try:
@@ -37,7 +37,10 @@ class Analysis:
in the respective functions.
"""
def __init__(self, experiment_dir, default_metric=None, default_mode=None):
def __init__(self,
experiment_dir: str,
default_metric: Optional[str] = None,
default_mode: Optional[str] = None):
experiment_dir = os.path.expanduser(experiment_dir)
if not os.path.isdir(experiment_dir):
raise ValueError(
@@ -59,14 +62,14 @@ class Analysis:
else:
self.fetch_trial_dataframes()
def _validate_metric(self, metric):
def _validate_metric(self, metric: str) -> str:
if not metric and not self.default_metric:
raise ValueError(
"No `metric` has been passed and `default_metric` has "
"not been set. Please specify the `metric` parameter.")
return metric or self.default_metric
def _validate_mode(self, mode):
def _validate_mode(self, mode: str) -> str:
if not mode and not self.default_mode:
raise ValueError(
"No `mode` has been passed and `default_mode` has "
@@ -75,7 +78,9 @@ class Analysis:
raise ValueError("If set, `mode` has to be one of [min, max]")
return mode or self.default_mode
def dataframe(self, metric=None, mode=None):
def dataframe(self,
metric: Optional[str] = None,
mode: Optional[str] = None) -> DataFrame:
"""Returns a pandas.DataFrame object constructed from the trials.
Args:
@@ -97,7 +102,9 @@ class Analysis:
rows[path].update(logdir=path)
return pd.DataFrame(list(rows.values()))
def get_best_config(self, metric=None, mode=None):
def get_best_config(self,
metric: Optional[str] = None,
mode: Optional[str] = None) -> Optional[Dict]:
"""Retrieve the best config corresponding to the trial.
Args:
@@ -122,7 +129,9 @@ class Analysis:
best_path = compare_op(rows, key=lambda k: rows[k][metric])
return all_configs[best_path]
def get_best_logdir(self, metric=None, mode=None):
def get_best_logdir(self,
metric: Optional[str] = None,
mode: Optional[str] = None) -> Optional[str]:
"""Retrieve the logdir corresponding to the best trial.
Args:
@@ -148,7 +157,7 @@ class Analysis:
self._experiment_dir))
return None
def fetch_trial_dataframes(self):
def fetch_trial_dataframes(self) -> Dict[str, DataFrame]:
fail_count = 0
for path in self._get_trial_paths():
try:
@@ -162,7 +171,7 @@ class Analysis:
"Couldn't read results from {} paths".format(fail_count))
return self.trial_dataframes
def get_all_configs(self, prefix=False):
def get_all_configs(self, prefix: bool = False) -> Dict[str, Dict]:
"""Returns a list of all configurations.
Args:
@@ -170,7 +179,8 @@ class Analysis:
and prepends `config/`.
Returns:
List[dict]: List of all configurations of trials,
Dict[str, Dict]: Dict of all configurations of trials, indexed by
their trial dir.
"""
fail_count = 0
for path in self._get_trial_paths():
@@ -189,7 +199,10 @@ class Analysis:
"Couldn't read config from {} paths".format(fail_count))
return self._configs
def get_trial_checkpoints_paths(self, trial, metric=None):
def get_trial_checkpoints_paths(self,
trial: Trial,
metric: Optional[str] = None
) -> List[Tuple[str, Number]]:
"""Gets paths and metrics of all persistent checkpoints of a trial.
Args:
@@ -215,11 +228,14 @@ class Analysis:
return path_metric_df[["chkpt_path", metric]].values.tolist()
elif isinstance(trial, Trial):
checkpoints = trial.checkpoint_manager.best_checkpoints()
return [[c.value, c.result[metric]] for c in checkpoints]
return [(c.value, c.result[metric]) for c in checkpoints]
else:
raise ValueError("trial should be a string or a Trial instance.")
def get_best_checkpoint(self, trial, metric=None, mode=None):
def get_best_checkpoint(self,
trial: Trial,
metric: Optional[str] = None,
mode: Optional[str] = None) -> Optional[str]:
"""Gets best persistent checkpoint path of provided trial.
Args:
@@ -244,7 +260,9 @@ class Analysis:
else:
return min(checkpoint_paths, key=lambda x: x[1])[0]
def _retrieve_rows(self, metric=None, mode=None):
def _retrieve_rows(self,
metric: Optional[str] = None,
mode: Optional[str] = None) -> Dict[str, Any]:
assert mode is None or mode in ["max", "min"]
rows = {}
for path, df in self.trial_dataframes.items():
@@ -264,7 +282,7 @@ class Analysis:
return rows
def _get_trial_paths(self):
def _get_trial_paths(self) -> List[str]:
_trial_paths = []
for trial_path, _, files in os.walk(self._experiment_dir):
if EXPR_PROGRESS_FILE in files:
@@ -276,7 +294,7 @@ class Analysis:
return _trial_paths
@property
def trial_dataframes(self):
def trial_dataframes(self) -> Dict[str, DataFrame]:
"""List of all dataframes of the trials."""
return self._trial_dataframes
@@ -306,10 +324,10 @@ class ExperimentAnalysis(Analysis):
"""
def __init__(self,
experiment_checkpoint_path,
trials=None,
default_metric=None,
default_mode=None):
experiment_checkpoint_path: str,
trials: Optional[List[Trial]] = None,
default_metric: Optional[str] = None,
default_mode: Optional[str] = None):
experiment_checkpoint_path = os.path.expanduser(
experiment_checkpoint_path)
if not os.path.isfile(experiment_checkpoint_path):
@@ -365,8 +383,8 @@ class ExperimentAnalysis(Analysis):
return self.get_best_config(self.default_metric, self.default_mode)
@property
def best_checkpoint(self) -> Checkpoint:
"""Get the checkpoint of the best trial of the experiment
def best_checkpoint(self) -> str:
"""Get the checkpoint path of the best trial of the experiment
The best trial is determined by comparing the last trial results
using the `metric` and `mode` parameters passed to `tune.run()`.
@@ -471,7 +489,10 @@ class ExperimentAnalysis(Analysis):
],
index="trial_id")
def get_best_trial(self, metric=None, mode=None, scope="last"):
def get_best_trial(self,
metric: Optional[str] = None,
mode: Optional[str] = None,
scope: str = "last") -> Optional[Trial]:
"""Retrieve the best trial object.
Compares all trials' scores on ``metric``.
@@ -535,7 +556,10 @@ class ExperimentAnalysis(Analysis):
"parameter?")
return best_trial
def get_best_config(self, metric=None, mode=None, scope="last"):
def get_best_config(self,
metric: Optional[str] = None,
mode: Optional[str] = None,
scope: str = "last") -> Optional[Dict]:
"""Retrieve the best config corresponding to the trial.
Compares all trials' scores on `metric`.
@@ -562,7 +586,10 @@ class ExperimentAnalysis(Analysis):
best_trial = self.get_best_trial(metric, mode, scope)
return best_trial.config if best_trial else None
def get_best_logdir(self, metric=None, mode=None, scope="last"):
def get_best_logdir(self,
metric: Optional[str] = None,
mode: Optional[str] = None,
scope: str = "last") -> Optional[str]:
"""Retrieve the logdir corresponding to the best trial.
Compares all trials' scores on `metric`.
@@ -589,15 +616,15 @@ class ExperimentAnalysis(Analysis):
best_trial = self.get_best_trial(metric, mode, scope)
return best_trial.logdir if best_trial else None
def stats(self):
def stats(self) -> Dict:
"""Returns a dictionary of the statistics of the experiment."""
return self._experiment_state.get("stats")
def runner_data(self):
def runner_data(self) -> Dict:
"""Returns a dictionary of the TrialRunner data."""
return self._experiment_state.get("runner_data")
def _get_trial_paths(self):
def _get_trial_paths(self) -> List[str]:
"""Overwrites Analysis to only have trials of one experiment."""
if self.trials:
_trial_paths = [t.logdir for t in self.trials]
+18 -15
View File
@@ -1,5 +1,7 @@
import os
import logging
from typing import Callable, Dict, Type
from filelock import FileLock
import ray
@@ -15,11 +17,11 @@ from horovod.ray import RayExecutor
logger = logging.getLogger(__name__)
def get_rank():
def get_rank() -> str:
return os.environ["HOROVOD_RANK"]
def logger_creator(log_config, logdir):
def logger_creator(log_config: Dict, logdir: str) -> NoopLogger:
"""Simple NOOP logger for worker trainables."""
index = get_rank()
worker_dir = os.path.join(logdir, "worker_{}".format(index))
@@ -51,7 +53,7 @@ class _HorovodTrainable(tune.Trainable):
def num_workers(self):
return self._num_hosts * self._num_slots
def setup(self, config):
def setup(self, config: Dict):
trainable = wrap_function(self.__class__._function)
# We use a filelock here to ensure that the file-writing
# process is safe across different trainables.
@@ -82,7 +84,7 @@ class _HorovodTrainable(tune.Trainable):
"logger_creator": lambda cfg: logger_creator(cfg, logdir_)
})
def step(self):
def step(self) -> Dict:
if self._finished:
raise RuntimeError("Training has already finished.")
result = self.executor.execute(lambda w: w.step())[0]
@@ -90,14 +92,14 @@ class _HorovodTrainable(tune.Trainable):
self._finished = True
return result
def save_checkpoint(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir: str) -> str:
# TODO: optimize if colocated
save_obj = self.executor.execute_single(lambda w: w.save_to_object())
checkpoint_path = TrainableUtil.create_from_pickle(
save_obj, checkpoint_dir)
return checkpoint_path
def load_checkpoint(self, checkpoint_dir):
def load_checkpoint(self, checkpoint_dir: str):
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
x_id = ray.put(checkpoint_obj)
return self.executor.execute(lambda w: w.restore_from_object(x_id))
@@ -107,13 +109,14 @@ class _HorovodTrainable(tune.Trainable):
self.executor.shutdown()
def DistributedTrainableCreator(func,
use_gpu=False,
num_hosts=1,
num_slots=1,
num_cpus_per_slot=1,
timeout_s=30,
replicate_pem=False):
def DistributedTrainableCreator(
func: Callable,
use_gpu: bool = False,
num_hosts: int = 1,
num_slots: int = 1,
num_cpus_per_slot: int = 1,
timeout_s: int = 30,
replicate_pem: bool = False) -> Type[_HorovodTrainable]:
"""Converts Horovod functions to be executable by Tune.
Requires horovod > 0.19 to work.
@@ -198,7 +201,7 @@ def DistributedTrainableCreator(func,
_timeout_s = timeout_s
@classmethod
def default_resource_request(cls, config):
def default_resource_request(cls, config: Dict):
extra_gpu = int(num_hosts * num_slots) * int(use_gpu)
extra_cpu = int(num_hosts * num_slots * num_cpus_per_slot)
@@ -216,7 +219,7 @@ def DistributedTrainableCreator(func,
# that force us to include mocks as part of the module.
def _train_simple(config):
def _train_simple(config: Dict):
import horovod.torch as hvd
hvd.init()
from ray import tune
+1 -1
View File
@@ -39,7 +39,7 @@ class TuneCallback(Callback):
on, self._allowed))
self._on = on
def _handle(self, logs: Dict):
def _handle(self, logs: Dict, when: str):
raise NotImplementedError
def on_batch_begin(self, batch, logs=None):
+16 -11
View File
@@ -1,3 +1,5 @@
from typing import Any, Optional, Tuple
import kubernetes
import subprocess
@@ -45,7 +47,10 @@ class KubernetesSyncer(NodeSyncer):
_namespace = "ray"
def __init__(self, local_dir, remote_dir, sync_client=None):
def __init__(self,
local_dir: str,
remote_dir: str,
sync_client: Optional[SyncClient] = None):
self.local_ip = services.get_node_ip_address()
self.local_node = self._get_kubernetes_node_by_ip(self.local_ip)
self.worker_ip = None
@@ -56,11 +61,11 @@ class KubernetesSyncer(NodeSyncer):
super(NodeSyncer, self).__init__(local_dir, remote_dir, sync_client)
def set_worker_ip(self, worker_ip):
def set_worker_ip(self, worker_ip: str):
self.worker_ip = worker_ip
self.worker_node = self._get_kubernetes_node_by_ip(worker_ip)
def _get_kubernetes_node_by_ip(self, node_ip):
def _get_kubernetes_node_by_ip(self, node_ip: str) -> Optional[str]:
"""Return node name by internal or external IP"""
kubernetes.config.load_incluster_config()
api = kubernetes.client.CoreV1Api()
@@ -75,8 +80,8 @@ class KubernetesSyncer(NodeSyncer):
return None
@property
def _remote_path(self):
return (self.worker_node, self._remote_dir)
def _remote_path(self) -> Tuple[str, str]:
return self.worker_node, self._remote_dir
class KubernetesSyncClient(SyncClient):
@@ -95,12 +100,12 @@ class KubernetesSyncClient(SyncClient):
"""
def __init__(self, namespace, process_runner=subprocess):
def __init__(self, namespace: str, process_runner: Any = subprocess):
self.namespace = namespace
self._process_runner = process_runner
self._command_runners = {}
def _create_command_runner(self, node_id):
def _create_command_runner(self, node_id: str) -> KubernetesCommandRunner:
"""Create a command runner for one Kubernetes node"""
return KubernetesCommandRunner(
log_prefix="KubernetesSyncClient: {}:".format(node_id),
@@ -109,7 +114,7 @@ class KubernetesSyncClient(SyncClient):
auth_config=None,
process_runner=self._process_runner)
def _get_command_runner(self, node_id):
def _get_command_runner(self, node_id: str) -> KubernetesCommandRunner:
"""Create command runner if it doesn't exist"""
# Todo(krfricke): These cached runners are currently
# never cleaned up. They are cheap so this shouldn't
@@ -120,7 +125,7 @@ class KubernetesSyncClient(SyncClient):
self._command_runners[node_id] = command_runner
return self._command_runners[node_id]
def sync_up(self, source, target):
def sync_up(self, source: str, target: Tuple[str, str]) -> bool:
"""Here target is a tuple (target_node, target_dir)"""
target_node, target_dir = target
@@ -132,7 +137,7 @@ class KubernetesSyncClient(SyncClient):
command_runner.run_rsync_up(source, target_dir)
return True
def sync_down(self, source, target):
def sync_down(self, source: Tuple[str, str], target: str) -> bool:
"""Here source is a tuple (source_node, source_dir)"""
source_node, source_dir = source
@@ -144,7 +149,7 @@ class KubernetesSyncClient(SyncClient):
command_runner.run_rsync_down(source_dir, target)
return True
def delete(self, target):
def delete(self, target: str) -> bool:
"""No delete function because it is only used by
the KubernetesSyncer, which doesn't call delete."""
return True
+6 -3
View File
@@ -2,7 +2,9 @@ from typing import Dict, List, Union
from ray import tune
from mxnet.model import save_checkpoint
import mxnet
from mxnet.model import save_checkpoint, BatchEndParam
import numpy as np
import os
@@ -49,7 +51,7 @@ class TuneReportCallback(TuneCallback):
metrics = [metrics]
self._metrics = metrics
def __call__(self, param):
def __call__(self, param: BatchEndParam):
if not param.eval_metric:
return
if not self._metrics:
@@ -110,7 +112,8 @@ class TuneCheckpointCallback(TuneCallback):
self._filename = filename
self._frequency = frequency
def __call__(self, epoch, sym, arg, aux):
def __call__(self, epoch: int, sym: mxnet.symbol.Symbol,
arg: Dict[str, np.ndarray], aux: Dict[str, np.ndarray]):
if epoch % self._frequency != 0:
return
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
+22 -18
View File
@@ -5,6 +5,8 @@ import os
import logging
import shutil
import tempfile
from typing import Callable, Dict, Generator, Optional, Type
import torch
from datetime import timedelta
@@ -34,7 +36,7 @@ def enable_distributed_trainable():
_distributed_enabled = True
def logger_creator(log_config, logdir, rank):
def logger_creator(log_config: Dict, logdir: str, rank: int) -> NoopLogger:
worker_dir = os.path.join(logdir, "worker_{}".format(rank))
os.makedirs(worker_dir, exist_ok=True)
return NoopLogger(log_config, worker_dir)
@@ -54,16 +56,16 @@ class _TorchTrainable(tune.Trainable):
__slots__ = ["workers", "_finished"]
@classmethod
def default_process_group_parameters(self):
def default_process_group_parameters(self) -> Dict:
return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo")
@classmethod
def get_remote_worker_options(self):
def get_remote_worker_options(self) -> Dict[str, int]:
num_gpus = 1 if self._use_gpu else 0
num_cpus = int(self._num_cpus_per_worker or 1)
return dict(num_cpus=num_cpus, num_gpus=num_gpus)
def setup(self, config):
def setup(self, config: Dict):
self._finished = False
num_workers = self._num_workers
logdir = self.logdir
@@ -103,7 +105,7 @@ class _TorchTrainable(tune.Trainable):
for rank, w in enumerate(self.workers)
])
def step(self):
def step(self) -> Dict:
if self._finished:
raise RuntimeError("Training has already finished.")
result = ray.get([w.step.remote() for w in self.workers])[0]
@@ -111,14 +113,14 @@ class _TorchTrainable(tune.Trainable):
self._finished = True
return result
def save_checkpoint(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir: str) -> str:
# TODO: optimize if colocated
save_obj = ray.get(self.workers[0].save_to_object.remote())
checkpoint_path = TrainableUtil.create_from_pickle(
save_obj, checkpoint_dir)
return checkpoint_path
def load_checkpoint(self, checkpoint_dir):
def load_checkpoint(self, checkpoint_dir: str):
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
return ray.get(
w.restore_from_object.remote(checkpoint_obj) for w in self.workers)
@@ -127,12 +129,13 @@ class _TorchTrainable(tune.Trainable):
ray.get([worker.stop.remote() for worker in self.workers])
def DistributedTrainableCreator(func,
use_gpu=False,
num_workers=1,
num_cpus_per_worker=1,
backend="gloo",
timeout_s=NCCL_TIMEOUT_S):
def DistributedTrainableCreator(
func: Callable,
use_gpu: bool = False,
num_workers: int = 1,
num_cpus_per_worker: int = 1,
backend: str = "gloo",
timeout_s: int = NCCL_TIMEOUT_S) -> Type[_TorchTrainable]:
"""Creates a class that executes distributed training.
Similar to running `torch.distributed.launch`.
@@ -179,11 +182,11 @@ def DistributedTrainableCreator(func,
_num_cpus_per_worker = num_cpus_per_worker
@classmethod
def default_process_group_parameters(self):
def default_process_group_parameters(self) -> Dict:
return dict(timeout=timedelta(timeout_s), backend=backend)
@classmethod
def default_resource_request(cls, config):
def default_resource_request(cls, config: Dict) -> Resources:
num_workers_ = int(config.get("num_workers", num_workers))
num_cpus = int(
config.get("num_cpus_per_worker", num_cpus_per_worker))
@@ -199,7 +202,8 @@ def DistributedTrainableCreator(func,
@contextmanager
def distributed_checkpoint_dir(step, disable=False):
def distributed_checkpoint_dir(
step: int, disable: bool = False) -> Generator[str, None, None]:
"""ContextManager for creating a distributed checkpoint.
Only checkpoints a file on the "main" training actor, avoiding
@@ -236,7 +240,7 @@ def distributed_checkpoint_dir(step, disable=False):
shutil.rmtree(path)
def _train_check_global(config, checkpoint_dir=None):
def _train_check_global(config: Dict, checkpoint_dir: Optional[str] = None):
"""For testing only. Putting this here because Ray has problems
serializing within the test file."""
assert is_distributed_trainable()
@@ -245,7 +249,7 @@ def _train_check_global(config, checkpoint_dir=None):
tune.report(is_distributed=True)
def _train_simple(config, checkpoint_dir=None):
def _train_simple(config: Dict, checkpoint_dir: Optional[str] = None):
"""For testing only. Putting this here because Ray has problems
serializing within the test file."""
import torch.nn as nn
+8 -6
View File
@@ -2,6 +2,7 @@ import os
import pickle
from multiprocessing import Process, Queue
from numbers import Number
from typing import Callable, Dict, List, Tuple
import numpy as np
from ray import logger
@@ -70,7 +71,7 @@ def _clean_log(obj):
return fallback
def wandb_mixin(func):
def wandb_mixin(func: Callable):
"""wandb_mixin
Weights and biases (https://www.wandb.com/) is a tool for experiment
@@ -142,7 +143,7 @@ def wandb_mixin(func):
return func
def _set_api_key(wandb_config):
def _set_api_key(wandb_config: Dict):
"""Set WandB API key from `wandb_config`. Will pop the
`api_key_file` and `api_key` keys from `wandb_config` parameter"""
api_key_file = os.path.expanduser(wandb_config.pop("api_key_file", ""))
@@ -176,7 +177,8 @@ class _WandbLoggingProcess(Process):
wandb logging instances locally.
"""
def __init__(self, queue, exclude, to_config, *args, **kwargs):
def __init__(self, queue: Queue, exclude: List[str], to_config: List[str],
*args, **kwargs):
super(_WandbLoggingProcess, self).__init__()
self.queue = queue
self._exclude = set(exclude)
@@ -195,7 +197,7 @@ class _WandbLoggingProcess(Process):
wandb.log(log)
wandb.join()
def _handle_result(self, result):
def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
config_update = result.get("config", {}).copy()
log = {}
flat_result = flatten_dict(result, delimiter="/")
@@ -377,7 +379,7 @@ class WandbLogger(Logger):
**wandb_init_kwargs)
self._wandb.start()
def on_result(self, result):
def on_result(self, result: Dict):
result = _clean_log(result)
self._queue.put(result)
@@ -389,7 +391,7 @@ class WandbLogger(Logger):
class WandbTrainableMixin:
_wandb = wandb
def __init__(self, config, *args, **kwargs):
def __init__(self, config: Dict, *args, **kwargs):
if not isinstance(self, Trainable):
raise ValueError(
"The `WandbTrainableMixin` can only be used as a mixin "
+29 -18
View File
@@ -1,7 +1,11 @@
import logging
from typing import Dict, Optional, Union
import numpy as np
from ray.tune import trial_runner
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.tune.trial import Trial
logger = logging.getLogger(__name__)
@@ -36,14 +40,14 @@ class AsyncHyperBandScheduler(FIFOScheduler):
"""
def __init__(self,
time_attr="training_iteration",
reward_attr=None,
metric=None,
mode=None,
max_t=100,
grace_period=1,
reduction_factor=4,
brackets=1):
time_attr: str = "training_iteration",
reward_attr: Optional[str] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
max_t: int = 100,
grace_period: int = 1,
reduction_factor: float = 4,
brackets: int = 1):
assert max_t > 0, "Max (time_attr) not valid!"
assert max_t >= grace_period, "grace_period must be <= max_t!"
assert grace_period > 0, "grace_period must be positive!"
@@ -82,7 +86,8 @@ class AsyncHyperBandScheduler(FIFOScheduler):
self._metric_op = -1.
self._time_attr = time_attr
def set_search_properties(self, metric, mode):
def set_search_properties(self, metric: Optional[str],
mode: Optional[str]) -> bool:
if self._metric and metric:
return False
if self._mode and mode:
@@ -100,7 +105,8 @@ class AsyncHyperBandScheduler(FIFOScheduler):
return True
def on_trial_add(self, trial_runner, trial):
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
if not self._metric or not self._metric_op:
raise ValueError(
"{} has been instantiated without a valid `metric` ({}) or "
@@ -115,7 +121,8 @@ class AsyncHyperBandScheduler(FIFOScheduler):
idx = np.random.choice(len(self._brackets), p=normalized)
self._trial_info[trial.trial_id] = self._brackets[idx]
def on_trial_result(self, trial_runner, trial, result):
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict) -> str:
action = TrialScheduler.CONTINUE
if self._time_attr not in result or self._metric not in result:
return action
@@ -129,7 +136,8 @@ class AsyncHyperBandScheduler(FIFOScheduler):
self._num_stopped += 1
return action
def on_trial_complete(self, trial_runner, trial, result):
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict):
if self._time_attr not in result or self._metric not in result:
return
bracket = self._trial_info[trial.trial_id]
@@ -137,10 +145,11 @@ class AsyncHyperBandScheduler(FIFOScheduler):
self._metric_op * result[self._metric])
del self._trial_info[trial.trial_id]
def on_trial_remove(self, trial_runner, trial):
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
del self._trial_info[trial.trial_id]
def debug_string(self):
def debug_string(self) -> str:
out = "Using AsyncHyperBand: num_stopped={}".format(self._num_stopped)
out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
return out
@@ -161,19 +170,21 @@ class _Bracket():
>>> b.cutoff(b._rungs[3][1]) == 2.0
"""
def __init__(self, min_t, max_t, reduction_factor, s):
def __init__(self, min_t: int, max_t: int, reduction_factor: float,
s: int):
self.rf = reduction_factor
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
self._rungs = [(min_t * self.rf**(k + s), {})
for k in reversed(range(MAX_RUNGS))]
def cutoff(self, recorded):
def cutoff(self, recorded) -> Union[None, int, float, complex, np.ndarray]:
if not recorded:
return None
return np.nanpercentile(
list(recorded.values()), (1 - 1 / self.rf) * 100)
def on_result(self, trial, cur_iter, cur_rew):
def on_result(self, trial: Trial, cur_iter: int,
cur_rew: Optional[float]) -> str:
action = TrialScheduler.CONTINUE
for milestone, recorded in self._rungs:
if cur_iter < milestone or trial.trial_id in recorded:
@@ -190,7 +201,7 @@ class _Bracket():
break
return action
def debug_str(self):
def debug_str(self) -> str:
# TODO: fix up the output for this
iters = " | ".join([
"Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
+11 -4
View File
@@ -1,5 +1,7 @@
import logging
from typing import Dict, Optional
from ray.tune import trial_runner
from ray.tune.schedulers.trial_scheduler import TrialScheduler
from ray.tune.schedulers.hyperband import HyperBandScheduler, Bracket
from ray.tune.trial import Trial
@@ -22,7 +24,8 @@ class HyperBandForBOHB(HyperBandScheduler):
See ray.tune.schedulers.HyperBandScheduler for parameter docstring.
"""
def on_trial_add(self, trial_runner, trial):
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
"""Adds new trial.
On a new trial add, if current bracket is not filled, add to current
@@ -67,7 +70,8 @@ class HyperBandForBOHB(HyperBandScheduler):
self._state["bracket"].add_trial(trial)
self._trial_info[trial] = cur_bracket, self._state["band_idx"]
def on_trial_result(self, trial_runner, trial, result):
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict) -> str:
"""If bracket is finished, all trials will be stopped.
If a given trial finishes and bracket iteration is not done,
@@ -96,11 +100,14 @@ class HyperBandForBOHB(HyperBandScheduler):
action = self._process_bracket(trial_runner, bracket)
return action
def _unpause_trial(self, trial_runner, trial):
def _unpause_trial(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
trial_runner.trial_executor.unpause_trial(trial)
trial_runner._search_alg.searcher.on_unpause(trial.trial_id)
def choose_trial_to_run(self, trial_runner, allow_recurse=True):
def choose_trial_to_run(self,
trial_runner: "trial_runner.TrialRunner",
allow_recurse: bool = True) -> Optional[Trial]:
"""Fair scheduling within iteration by completion percentage.
List of trials not used since all trials are tracked as state
+47 -33
View File
@@ -1,7 +1,10 @@
import collections
from typing import Dict, List, Optional, Tuple
import numpy as np
import logging
from ray.tune import trial_runner
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.tune.trial import Trial
from ray.tune.error import TuneError
@@ -74,12 +77,12 @@ class HyperBandScheduler(FIFOScheduler):
"""
def __init__(self,
time_attr="training_iteration",
reward_attr=None,
metric=None,
mode=None,
max_t=81,
reduction_factor=3):
time_attr: str = "training_iteration",
reward_attr: Optional[str] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
max_t: int = 81,
reduction_factor: float = 3):
assert max_t > 0, "Max (time_attr) not valid!"
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
@@ -118,7 +121,8 @@ class HyperBandScheduler(FIFOScheduler):
self._metric_op = -1.
self._time_attr = time_attr
def set_search_properties(self, metric, mode):
def set_search_properties(self, metric: Optional[str],
mode: Optional[str]) -> bool:
if self._metric and metric:
return False
if self._mode and mode:
@@ -136,7 +140,8 @@ class HyperBandScheduler(FIFOScheduler):
return True
def on_trial_add(self, trial_runner, trial):
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
"""Adds new trial.
On a new trial add, if current bracket is not filled,
@@ -179,7 +184,7 @@ class HyperBandScheduler(FIFOScheduler):
self._state["bracket"].add_trial(trial)
self._trial_info[trial] = cur_bracket, self._state["band_idx"]
def _cur_band_filled(self):
def _cur_band_filled(self) -> bool:
"""Checks if the current band is filled.
The size of the current band should be equal to s_max_1"""
@@ -187,7 +192,8 @@ class HyperBandScheduler(FIFOScheduler):
cur_band = self._hyperbands[self._state["band_idx"]]
return len(cur_band) == self._s_max_1
def on_trial_result(self, trial_runner, trial, result):
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict):
"""If bracket is finished, all trials will be stopped.
If a given trial finishes and bracket iteration is not done,
@@ -211,7 +217,8 @@ class HyperBandScheduler(FIFOScheduler):
metric_val=result.get(self._time_attr)))
return action
def _process_bracket(self, trial_runner, bracket):
def _process_bracket(self, trial_runner: "trial_runner.TrialRunner",
bracket: "Bracket") -> str:
"""This is called whenever a trial makes progress.
When all live trials in the bracket have no more iterations left,
@@ -250,7 +257,8 @@ class HyperBandScheduler(FIFOScheduler):
action = TrialScheduler.CONTINUE
return action
def on_trial_remove(self, trial_runner, trial):
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
"""Notification when trial terminates.
Trial info is removed from bracket. Triggers halving if bracket is
@@ -260,15 +268,18 @@ class HyperBandScheduler(FIFOScheduler):
if not bracket.finished():
self._process_bracket(trial_runner, bracket)
def on_trial_complete(self, trial_runner, trial, result):
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict):
"""Cleans up trial info from bracket if trial completed early."""
self.on_trial_remove(trial_runner, trial)
def on_trial_error(self, trial_runner, trial):
def on_trial_error(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
"""Cleans up trial info from bracket if trial errored early."""
self.on_trial_remove(trial_runner, trial)
def choose_trial_to_run(self, trial_runner):
def choose_trial_to_run(
self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]:
"""Fair scheduling within iteration by completion percentage.
List of trials not used since all trials are tracked as state
@@ -288,7 +299,7 @@ class HyperBandScheduler(FIFOScheduler):
return trial
return None
def debug_string(self):
def debug_string(self) -> str:
"""This provides a progress notification for the algorithm.
For each bracket, the algorithm will output a string as follows:
@@ -315,13 +326,14 @@ class HyperBandScheduler(FIFOScheduler):
out += "\n {}".format(bracket)
return out
def state(self):
def state(self) -> Dict[str, int]:
return {
"num_brackets": sum(len(band) for band in self._hyperbands),
"num_stopped": self._num_stopped
}
def _unpause_trial(self, trial_runner, trial):
def _unpause_trial(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
trial_runner.trial_executor.unpause_trial(trial)
@@ -332,7 +344,8 @@ class Bracket:
Also keeps track of progress to ensure good scheduling.
"""
def __init__(self, time_attr, max_trials, init_t_attr, max_t_attr, eta, s):
def __init__(self, time_attr: str, max_trials: int, init_t_attr: int,
max_t_attr: int, eta: float, s: int):
self._live_trials = {} # maps trial -> current result
self._all_trials = []
self._time_attr = time_attr # attribute to
@@ -348,7 +361,7 @@ class Bracket:
self._total_work = self._calculate_total_work(self._n0, self._r0, s)
self._completed_progress = 0
def add_trial(self, trial):
def add_trial(self, trial: Trial):
"""Add trial to bracket assuming bracket is not filled.
At a later iteration, a newly added trial will be given equal
@@ -357,7 +370,7 @@ class Bracket:
self._live_trials[trial] = None
self._all_trials.append(trial)
def cur_iter_done(self):
def cur_iter_done(self) -> bool:
"""Checks if all iterations have completed.
TODO(rliaw): also check that `t.iterations == self._r`"""
@@ -365,20 +378,20 @@ class Bracket:
self._get_result_time(result) >= self._cumul_r
for result in self._live_trials.values())
def finished(self):
def finished(self) -> bool:
return self._halves == 0 and self.cur_iter_done()
def current_trials(self):
def current_trials(self) -> List[Trial]:
return list(self._live_trials)
def continue_trial(self, trial):
def continue_trial(self, trial: Trial) -> bool:
result = self._live_trials[trial]
if self._get_result_time(result) < self._cumul_r:
return True
else:
return False
def filled(self):
def filled(self) -> bool:
"""Checks if bracket is filled.
Only let new trials be added at current level minimizing the need
@@ -386,7 +399,8 @@ class Bracket:
return len(self._live_trials) == self._n
def successive_halving(self, metric, metric_op):
def successive_halving(self, metric: str, metric_op: float
) -> Tuple[List[Trial], List[Trial]]:
assert self._halves > 0
self._halves -= 1
self._n /= self._eta
@@ -402,7 +416,7 @@ class Bracket:
good, bad = sorted_trials[-self._n:], sorted_trials[:-self._n]
return good, bad
def update_trial_stats(self, trial, result):
def update_trial_stats(self, trial: Trial, result: Dict):
"""Update result for trial. Called after trial has finished
an iteration - will decrement iteration count.
@@ -422,7 +436,7 @@ class Bracket:
self._completed_progress += delta
self._live_trials[trial] = result
def cleanup_trial(self, trial):
def cleanup_trial(self, trial: Trial):
"""Clean up statistics tracking for terminated trials (either by force
or otherwise).
@@ -432,7 +446,7 @@ class Bracket:
assert trial in self._live_trials
del self._live_trials[trial]
def cleanup_full(self, trial_runner):
def cleanup_full(self, trial_runner: "trial_runner.TrialRunner"):
"""Cleans up bracket after bracket is completely finished.
Lets the last trial continue to run until termination condition
@@ -441,7 +455,7 @@ class Bracket:
if (trial.status == Trial.PAUSED):
trial_runner.stop_trial(trial)
def completion_percentage(self):
def completion_percentage(self) -> float:
"""Returns a progress metric.
This will not be always finish with 100 since dead trials
@@ -450,12 +464,12 @@ class Bracket:
return 1.0
return self._completed_progress / self._total_work
def _get_result_time(self, result):
def _get_result_time(self, result: Dict) -> float:
if result is None:
return 0
return result[self._time_attr]
def _calculate_total_work(self, n, r, s):
def _calculate_total_work(self, n: int, r: float, s: int):
work = 0
cumulative_r = r
for _ in range(s + 1):
@@ -466,7 +480,7 @@ class Bracket:
r = int(min(r, self._max_t_attr - cumulative_r))
return work
def __repr__(self):
def __repr__(self) -> str:
status = ", ".join([
"Max Size (n)={}".format(self._n),
"Milestone (r)={}".format(self._cumul_r),
@@ -1,7 +1,10 @@
import collections
import logging
from typing import Dict, List, Optional
import numpy as np
from ray.tune import trial_runner
from ray.tune.trial import Trial
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
@@ -38,14 +41,14 @@ class MedianStoppingRule(FIFOScheduler):
"""
def __init__(self,
time_attr="time_total_s",
reward_attr=None,
metric=None,
mode=None,
grace_period=60.0,
min_samples_required=3,
min_time_slice=0,
hard_stop=True):
time_attr: str = "time_total_s",
reward_attr: Optional[str] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
grace_period: float = 60.0,
min_samples_required: int = 3,
min_time_slice: int = 0,
hard_stop: bool = True):
if reward_attr is not None:
mode = "max"
metric = reward_attr
@@ -75,7 +78,8 @@ class MedianStoppingRule(FIFOScheduler):
self._last_pause = collections.defaultdict(lambda: float("-inf"))
self._results = collections.defaultdict(list)
def set_search_properties(self, metric, mode):
def set_search_properties(self, metric: Optional[str],
mode: Optional[str]) -> bool:
if self._metric and metric:
return False
if self._mode and mode:
@@ -91,7 +95,8 @@ class MedianStoppingRule(FIFOScheduler):
return True
def on_trial_add(self, trial_runner, trial):
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
if not self._metric or not self._worst or not self._compare_op:
raise ValueError(
"{} has been instantiated without a valid `metric` ({}) or "
@@ -102,7 +107,8 @@ class MedianStoppingRule(FIFOScheduler):
super(MedianStoppingRule, self).on_trial_add(trial_runner, trial)
def on_trial_result(self, trial_runner, trial, result):
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict) -> str:
"""Callback for early stopping.
This stopping rule stops a running trial if the trial's best objective
@@ -154,14 +160,17 @@ class MedianStoppingRule(FIFOScheduler):
else:
return TrialScheduler.CONTINUE
def on_trial_complete(self, trial_runner, trial, result):
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict):
self._results[trial].append(result)
def debug_string(self):
def debug_string(self) -> str:
return "Using MedianStoppingRule: num_stopped={}.".format(
len(self._stopped_trials))
def _on_insufficient_samples(self, trial_runner, trial, time):
def _on_insufficient_samples(self,
trial_runner: "trial_runner.TrialRunner",
trial: Trial, time: float) -> str:
pause = time - self._last_pause[trial] > self._min_time_slice
pause = pause and [
t for t in trial_runner.get_trials()
@@ -169,17 +178,17 @@ class MedianStoppingRule(FIFOScheduler):
]
return TrialScheduler.PAUSE if pause else TrialScheduler.CONTINUE
def _trials_beyond_time(self, time):
def _trials_beyond_time(self, time: float) -> List[Trial]:
trials = [
trial for trial in self._results
if self._results[trial][-1][self._time_attr] >= time
]
return trials
def _median_result(self, trials, time):
def _median_result(self, trials: List[Trial], time: float):
return np.median([self._running_mean(trial, time) for trial in trials])
def _running_mean(self, trial, time):
def _running_mean(self, trial: Trial, time: float) -> np.ndarray:
results = self._results[trial]
# TODO(ekl) we could do interpolation to be more precise, but for now
# assume len(results) is large and the time diffs are roughly equal
+50 -35
View File
@@ -5,7 +5,10 @@ import math
import os
import random
import shutil
from typing import Callable, Dict, List, Optional, Tuple, Union
from ray.tune import trial_runner
from ray.tune import trial_executor
from ray.tune.error import TuneError
from ray.tune.result import TRAINING_ITERATION
from ray.tune.logger import _SafeFallbackEncoder
@@ -13,7 +16,6 @@ from ray.tune.sample import Domain, Function
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest.variant_generator import format_vars
from ray.tune.trial import Trial, Checkpoint
from ray.util.debug import log_once
logger = logging.getLogger(__name__)
@@ -22,7 +24,7 @@ logger = logging.getLogger(__name__)
class PBTTrialState:
"""Internal PBT state tracked per-trial."""
def __init__(self, trial):
def __init__(self, trial: Trial):
self.orig_tag = trial.experiment_tag
self.last_score = None
self.last_checkpoint = None
@@ -30,12 +32,13 @@ class PBTTrialState:
self.last_train_time = 0 # Used for synchronous mode.
self.last_result = None # Used for synchronous mode.
def __repr__(self):
def __repr__(self) -> str:
return str((self.last_score, self.last_checkpoint,
self.last_train_time, self.last_perturbation_time))
def explore(config, mutations, resample_probability, custom_explore_fn):
def explore(config: Dict, mutations: Dict, resample_probability: float,
custom_explore_fn: Optional[Callable]) -> Dict:
"""Return a config perturbed as specified.
Args:
@@ -83,7 +86,7 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
return new_config
def make_experiment_tag(orig_tag, config, mutations):
def make_experiment_tag(orig_tag: str, config: Dict, mutations: Dict) -> str:
"""Appends perturbed params to the trial name to show in the console."""
resolved_vars = {}
@@ -92,7 +95,8 @@ def make_experiment_tag(orig_tag, config, mutations):
return "{}@perturbed[{}]".format(orig_tag, format_vars(resolved_vars))
def fill_config(config, attr, search_space):
def fill_config(config: Dict, attr: str,
search_space: Union[Callable, Domain, list, dict]):
"""Add attr to config by sampling from search_space."""
if callable(search_space):
config[attr] = search_space()
@@ -214,18 +218,19 @@ class PopulationBasedTraining(FIFOScheduler):
"""
def __init__(self,
time_attr="time_total_s",
reward_attr=None,
metric=None,
mode=None,
perturbation_interval=60.0,
hyperparam_mutations={},
quantile_fraction=0.25,
resample_probability=0.25,
custom_explore_fn=None,
log_config=True,
require_attrs=True,
synch=False):
time_attr: str = "time_total_s",
reward_attr: Optional[str] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
perturbation_interval: float = 60.0,
hyperparam_mutations: Dict = None,
quantile_fraction: float = 0.25,
resample_probability: float = 0.25,
custom_explore_fn: Optional[Callable] = None,
log_config: bool = True,
require_attrs: bool = True,
synch: bool = False):
hyperparam_mutations = hyperparam_mutations or {}
for value in hyperparam_mutations.values():
if not (isinstance(value,
(list, dict, Domain)) or callable(value)):
@@ -288,7 +293,8 @@ class PopulationBasedTraining(FIFOScheduler):
self._num_checkpoints = 0
self._num_perturbations = 0
def set_search_properties(self, metric, mode):
def set_search_properties(self, metric: Optional[str],
mode: Optional[str]) -> bool:
if self._metric and metric:
return False
if self._mode and mode:
@@ -306,7 +312,8 @@ class PopulationBasedTraining(FIFOScheduler):
return True
def on_trial_add(self, trial_runner, trial):
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
if not self._metric or not self._metric_op:
raise ValueError(
"{} has been instantiated without a valid `metric` ({}) or "
@@ -329,7 +336,8 @@ class PopulationBasedTraining(FIFOScheduler):
# Make sure this attribute is added to CLI output.
trial.evaluated_params[attr] = trial.config[attr]
def on_trial_result(self, trial_runner, trial, result):
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict) -> str:
if self._time_attr not in result:
time_missing_msg = "Cannot find time_attr {} " \
"in trial result {}. Make sure that this " \
@@ -425,8 +433,9 @@ class PopulationBasedTraining(FIFOScheduler):
# the paused trials.
return TrialScheduler.PAUSE
def _perturb_trial(self, trial, trial_runner, upper_quantile,
lower_quantile):
def _perturb_trial(
self, trial: Trial, trial_runner: "trial_runner.TrialRunner",
upper_quantile: List[Trial], lower_quantile: List[Trial]):
"""Checkpoint if in upper quantile, exploits if in lower."""
state = self._trial_state[trial]
if trial in upper_quantile:
@@ -450,8 +459,9 @@ class PopulationBasedTraining(FIFOScheduler):
assert trial is not trial_to_clone
self._exploit(trial_runner.trial_executor, trial, trial_to_clone)
def _log_config_on_step(self, trial_state, new_state, trial,
trial_to_clone, new_config):
def _log_config_on_step(self, trial_state: PBTTrialState,
new_state: PBTTrialState, trial: Trial,
trial_to_clone: Trial, new_config: Dict):
"""Logs transition during exploit/exploit step.
For each step, logs: [target trial tag, clone trial tag, target trial
@@ -482,7 +492,8 @@ class PopulationBasedTraining(FIFOScheduler):
with open(trial_path, "a+") as f:
f.write(json.dumps(policy, cls=_SafeFallbackEncoder) + "\n")
def _exploit(self, trial_executor, trial, trial_to_clone):
def _exploit(self, trial_executor: "trial_executor.TrialExecutor",
trial: Trial, trial_to_clone: Trial):
"""Transfers perturbed state from trial_to_clone -> trial.
If specified, also logs the updated hyperparam state.
@@ -554,7 +565,7 @@ class PopulationBasedTraining(FIFOScheduler):
trial_state.last_perturbation_time = new_state.last_perturbation_time
trial_state.last_train_time = new_state.last_train_time
def _quantiles(self):
def _quantiles(self) -> Tuple[List[Trial], List[Trial]]:
"""Returns trials in the lower and upper `quantile` of the population.
If there is not enough data to compute this, returns empty lists.
@@ -578,7 +589,8 @@ class PopulationBasedTraining(FIFOScheduler):
return (trials[:num_trials_in_quantile],
trials[-num_trials_in_quantile:])
def choose_trial_to_run(self, trial_runner):
def choose_trial_to_run(
self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]:
"""Ensures all trials get fair share of time (as defined by time_attr).
This enables the PBT scheduler to support a greater number of
@@ -601,7 +613,7 @@ class PopulationBasedTraining(FIFOScheduler):
self._num_perturbations = 0
self._num_checkpoints = 0
def last_scores(self, trials):
def last_scores(self, trials: List[Trial]) -> List[float]:
scores = []
for trial in trials:
state = self._trial_state[trial]
@@ -609,7 +621,7 @@ class PopulationBasedTraining(FIFOScheduler):
scores.append(state.last_score)
return scores
def debug_string(self):
def debug_string(self) -> str:
return "PopulationBasedTraining: {} checkpoints, {} perturbs".format(
self._num_checkpoints, self._num_perturbations)
@@ -657,7 +669,7 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
"""
def __init__(self, policy_file):
def __init__(self, policy_file: str):
policy_file = os.path.expanduser(policy_file)
if not os.path.exists(policy_file):
raise ValueError("Policy file not found: {}".format(policy_file))
@@ -679,7 +691,8 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
self._policy_iter = iter(self._policy)
self._next_policy = next(self._policy_iter, None)
def _load_policy(self, policy_file):
def _load_policy(self,
policy_file: str) -> Tuple[Dict, List[Tuple[int, Dict]]]:
raw_policy = []
with open(policy_file, "rt") as fp:
for row in fp.readlines():
@@ -708,7 +721,8 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
return last_old_conf, list(reversed(policy))
def on_trial_add(self, trial_runner, trial):
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
if self._trial:
raise ValueError(
"More than one trial added to PBT replay run. This "
@@ -729,7 +743,8 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
"or consider not using PBT replay for this run.")
self._trial.config = self.config
def on_trial_result(self, trial_runner, trial, result):
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict) -> str:
if TRAINING_ITERATION not in result:
# No time reported
return TrialScheduler.CONTINUE
@@ -775,6 +790,6 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
return TrialScheduler.CONTINUE
def debug_string(self):
def debug_string(self) -> str:
return "PopulationBasedTraining replay: Step {}, perturb {}".format(
self._current_step, self._num_perturbations)
+31 -15
View File
@@ -1,3 +1,6 @@
from typing import Dict, Optional
from ray.tune import trial_runner
from ray.tune.trial import Trial
@@ -8,7 +11,8 @@ class TrialScheduler:
PAUSE = "PAUSE" #: Status for pausing trial execution
STOP = "STOP" #: Status for stopping trial execution
def set_search_properties(self, metric, mode):
def set_search_properties(self, metric: Optional[str],
mode: Optional[str]) -> bool:
"""Pass search properties to scheduler.
This method acts as an alternative to instantiating schedulers
@@ -20,19 +24,22 @@ class TrialScheduler:
"""
return True
def on_trial_add(self, trial_runner, trial):
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
"""Called when a new trial is added to the trial runner."""
raise NotImplementedError
def on_trial_error(self, trial_runner, trial):
def on_trial_error(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
"""Notification for the error of trial.
This will only be called when the trial is in the RUNNING state."""
raise NotImplementedError
def on_trial_result(self, trial_runner, trial, result):
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict) -> str:
"""Called on each intermediate result returned by a trial.
At this point, the trial scheduler can make a decision by returning
@@ -41,7 +48,8 @@ class TrialScheduler:
raise NotImplementedError
def on_trial_complete(self, trial_runner, trial, result):
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict):
"""Notification for the completion of trial.
This will only be called when the trial is in the RUNNING state and
@@ -49,7 +57,8 @@ class TrialScheduler:
raise NotImplementedError
def on_trial_remove(self, trial_runner, trial):
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
"""Called to remove trial.
This is called when the trial is in PAUSED or PENDING state. Otherwise,
@@ -57,7 +66,8 @@ class TrialScheduler:
raise NotImplementedError
def choose_trial_to_run(self, trial_runner):
def choose_trial_to_run(
self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]:
"""Called to choose a new trial to run.
This should return one of the trials in trial_runner that is in
@@ -67,7 +77,7 @@ class TrialScheduler:
raise NotImplementedError
def debug_string(self):
def debug_string(self) -> str:
"""Returns a human readable message for printing to the console."""
raise NotImplementedError
@@ -76,22 +86,28 @@ class TrialScheduler:
class FIFOScheduler(TrialScheduler):
"""Simple scheduler that just runs trials in submission order."""
def on_trial_add(self, trial_runner, trial):
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
pass
def on_trial_error(self, trial_runner, trial):
def on_trial_error(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
pass
def on_trial_result(self, trial_runner, trial, result):
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict) -> str:
return TrialScheduler.CONTINUE
def on_trial_complete(self, trial_runner, trial, result):
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict):
pass
def on_trial_remove(self, trial_runner, trial):
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial):
pass
def choose_trial_to_run(self, trial_runner):
def choose_trial_to_run(
self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]:
for trial in trial_runner.get_trials():
if (trial.status == Trial.PENDING
and trial_runner.has_resources(trial.resources)):
@@ -102,5 +118,5 @@ class FIFOScheduler(TrialScheduler):
return trial
return None
def debug_string(self):
def debug_string(self) -> str:
return "Using FIFO scheduling algorithm."
+13 -7
View File
@@ -1,5 +1,8 @@
from typing import Dict, List, Optional
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
from ray.tune.suggest.search_generator import SearchGenerator
from ray.tune.trial import Trial
class _MockSearcher(Searcher):
@@ -11,29 +14,32 @@ class _MockSearcher(Searcher):
self.results = []
super(_MockSearcher, self).__init__(**kwargs)
def suggest(self, trial_id):
def suggest(self, trial_id: str):
if not self.stall:
self.live_trials[trial_id] = 1
return {"test_variable": 2}
return None
def on_trial_result(self, trial_id, result):
def on_trial_result(self, trial_id: str, result: Dict):
self.counter["result"] += 1
self.results += [result]
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
self.counter["complete"] += 1
if result:
self._process_result(result)
if trial_id in self.live_trials:
del self.live_trials[trial_id]
def _process_result(self, result):
def _process_result(self, result: Dict):
self.final_results += [result]
class _MockSuggestionAlgorithm(SearchGenerator):
def __init__(self, max_concurrent=None, **kwargs):
def __init__(self, max_concurrent: Optional[int] = None, **kwargs):
self.searcher = _MockSearcher(**kwargs)
if max_concurrent:
self.searcher = ConcurrencyLimiter(
@@ -41,9 +47,9 @@ class _MockSuggestionAlgorithm(SearchGenerator):
super(_MockSuggestionAlgorithm, self).__init__(self.searcher)
@property
def live_trials(self):
def live_trials(self) -> List[Trial]:
return self.searcher.live_trials
@property
def results(self):
def results(self) -> List[Dict]:
return self.searcher.results
+12 -11
View File
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, List, Optional
from ax.service.ax_client import AxClient
from ray.tune.sample import Categorical, Float, Integer, LogUniform, \
@@ -103,14 +103,14 @@ class AxSearch(Searcher):
"""
def __init__(self,
space=None,
metric=None,
mode=None,
parameter_constraints=None,
outcome_constraints=None,
ax_client=None,
use_early_stopped_trials=None,
max_concurrent=None):
space: Optional[List[Dict]] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
parameter_constraints: Optional[List] = None,
outcome_constraints: Optional[List] = None,
ax_client: Optional[AxClient] = None,
use_early_stopped_trials: Optional[bool] = None,
max_concurrent: Optional[int] = None):
assert ax is not None, "Ax must be installed!"
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
@@ -177,7 +177,8 @@ class AxSearch(Searcher):
logger.warning("Detected sequential enforcement. Be sure to use "
"a ConcurrencyLimiter.")
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict):
if self._ax:
return False
space = self.convert_search_space(config)
@@ -189,7 +190,7 @@ class AxSearch(Searcher):
self.setup_experiment()
return True
def suggest(self, trial_id):
def suggest(self, trial_id: str) -> Optional[Dict]:
if not self._ax:
raise RuntimeError(
"Trying to sample a configuration from {}, but no search "
+6 -3
View File
@@ -2,9 +2,10 @@ import itertools
import os
import random
import uuid
from typing import Dict, List, Union
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list
from ray.tune.experiment import Experiment, convert_to_experiment_list
from ray.tune.config_parser import make_parser, create_trial_from_spec
from ray.tune.suggest.variant_generator import (generate_variants, format_vars,
flatten_resolved_vars)
@@ -42,7 +43,7 @@ class BasicVariantGenerator(SearchAlgorithm):
searcher.is_finished == True
"""
def __init__(self, shuffle=False):
def __init__(self, shuffle: bool = False):
"""Initializes the Variant Generator.
"""
@@ -60,7 +61,9 @@ class BasicVariantGenerator(SearchAlgorithm):
else:
self._uuid_prefix = str(uuid.uuid1().hex)[:5] + "_"
def add_configurations(self, experiments):
def add_configurations(
self,
experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]):
"""Chains generator given experiment specifications.
Arguments:
+28 -23
View File
@@ -2,9 +2,10 @@ from collections import defaultdict
import logging
import pickle
import json
from typing import Dict
from typing import Dict, Optional, Tuple
from ray.tune.sample import Float, Quantized
from ray.tune import ExperimentAnalysis
from ray.tune.sample import Domain, Float, Quantized
from ray.tune.suggest.variant_generator import parse_spec_vars
from ray.tune.utils.util import unflatten_dict
@@ -100,18 +101,18 @@ class BayesOptSearch(Searcher):
optimizer = None
def __init__(self,
space=None,
metric=None,
mode=None,
utility_kwargs=None,
random_state=42,
random_search_steps=10,
verbose=0,
patience=5,
skip_duplicate=True,
analysis=None,
max_concurrent=None,
use_early_stopped_trials=None):
space: Optional[Dict] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
utility_kwargs: Optional[Dict] = None,
random_state: int = 42,
random_search_steps: int = 10,
verbose: int = 0,
patience: int = 5,
skip_duplicate: bool = True,
analysis: Optional[ExperimentAnalysis] = None,
max_concurrent: Optional[int] = None,
use_early_stopped_trials: Optional[bool] = None):
"""Instantiate new BayesOptSearch object.
Args:
@@ -200,7 +201,8 @@ class BayesOptSearch(Searcher):
verbose=self._verbose,
random_state=self._random_state)
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
if self.optimizer:
return False
space = self.convert_search_space(config)
@@ -218,7 +220,7 @@ class BayesOptSearch(Searcher):
self.setup_optimizer()
return True
def suggest(self, trial_id):
def suggest(self, trial_id: str) -> Optional[Dict]:
"""Return new point to be explored by black box function.
Args:
@@ -278,7 +280,7 @@ class BayesOptSearch(Searcher):
# Return a deep copy of the mapping
return unflatten_dict(config)
def register_analysis(self, analysis):
def register_analysis(self, analysis: ExperimentAnalysis):
"""Integrate the given analysis into the gaussian process.
Args:
@@ -293,7 +295,10 @@ class BayesOptSearch(Searcher):
# gaussian process optimizer
self._register_result(params, report)
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
"""Notification for the completion of trial.
Args:
@@ -330,18 +335,18 @@ class BayesOptSearch(Searcher):
for params, result in self._buffered_trial_results:
self._register_result(params, result)
def _register_result(self, params, result):
def _register_result(self, params: Tuple[str], result: Dict):
"""Register given tuple of params and results."""
self.optimizer.register(params, self._metric_op * result[self.metric])
def save(self, checkpoint_path):
def save(self, checkpoint_path: str):
"""Storing current optimizer state."""
with open(checkpoint_path, "wb") as f:
pickle.dump(
(self.optimizer, self._buffered_trial_results,
self._total_random_search_trials, self._config_counter), f)
def restore(self, checkpoint_path):
def restore(self, checkpoint_path: str):
"""Restoring current optimizer state."""
with open(checkpoint_path, "rb") as f:
(self.optimizer, self._buffered_trial_results,
@@ -349,7 +354,7 @@ class BayesOptSearch(Searcher):
self._config_counter) = pickle.load(f)
@staticmethod
def convert_search_space(spec: Dict):
def convert_search_space(spec: Dict) -> Dict:
spec = flatten_dict(spec, prevent_delimiter=True)
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
@@ -358,7 +363,7 @@ class BayesOptSearch(Searcher):
"Grid search parameters cannot be automatically converted "
"to a BayesOpt search space.")
def resolve_value(domain):
def resolve_value(domain: Domain) -> Tuple[float, float]:
sampler = domain.get_sampler()
if isinstance(sampler, Quantized):
logger.warning(
+23 -17
View File
@@ -3,10 +3,11 @@
import copy
import logging
import math
from typing import Dict
from typing import Dict, Optional
import ConfigSpace
from ray.tune.sample import Categorical, Float, Integer, LogUniform, Normal, \
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
Normal, \
Quantized, \
Uniform
from ray.tune.suggest import Searcher
@@ -20,7 +21,7 @@ logger = logging.getLogger(__name__)
class _BOHBJobWrapper():
"""Mock object for HpBandSter to process."""
def __init__(self, loss, budget, config):
def __init__(self, loss: float, budget: float, config: Dict):
self.result = {"loss": loss}
self.kwargs = {"budget": budget, "config": config.copy()}
self.exception = None
@@ -92,11 +93,11 @@ class TuneBOHB(Searcher):
"""
def __init__(self,
space=None,
bohb_config=None,
max_concurrent=10,
metric=None,
mode=None):
space: Optional[ConfigSpace.ConfigurationSpace] = None,
bohb_config: Optional[Dict] = None,
max_concurrent: int = 10,
metric: Optional[str] = None,
mode: Optional[str] = None):
from hpbandster.optimizers.config_generators.bohb import BOHB
assert BOHB is not None, "HpBandSter must be installed!"
if mode:
@@ -126,7 +127,8 @@ class TuneBOHB(Searcher):
bohb_config = self._bohb_config or {}
self.bohber = BOHB(self._space, **bohb_config)
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
if self._space:
return False
space = self.convert_search_space(config)
@@ -140,7 +142,7 @@ class TuneBOHB(Searcher):
self.setup_bohb()
return True
def suggest(self, trial_id):
def suggest(self, trial_id: str) -> Optional[Dict]:
if not self._space:
raise RuntimeError(
"Trying to sample a configuration from {}, but no search "
@@ -156,7 +158,7 @@ class TuneBOHB(Searcher):
return unflatten_dict(config)
return None
def on_trial_result(self, trial_id, result):
def on_trial_result(self, trial_id: str, result: Dict):
if trial_id not in self.paused:
self.running.add(trial_id)
if "hyperband_info" not in result:
@@ -166,28 +168,31 @@ class TuneBOHB(Searcher):
hbs_wrapper = self.to_wrapper(trial_id, result)
self.bohber.new_result(hbs_wrapper)
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
del self.trial_to_params[trial_id]
if trial_id in self.paused:
self.paused.remove(trial_id)
if trial_id in self.running:
self.running.remove(trial_id)
def to_wrapper(self, trial_id, result):
def to_wrapper(self, trial_id: str, result: Dict) -> _BOHBJobWrapper:
return _BOHBJobWrapper(self._metric_op * result[self.metric],
result["hyperband_info"]["budget"],
self.trial_to_params[trial_id])
def on_pause(self, trial_id):
def on_pause(self, trial_id: str):
self.paused.add(trial_id)
self.running.remove(trial_id)
def on_unpause(self, trial_id):
def on_unpause(self, trial_id: str):
self.paused.remove(trial_id)
self.running.add(trial_id)
@staticmethod
def convert_search_space(spec: Dict):
def convert_search_space(spec: Dict) -> ConfigSpace.ConfigurationSpace:
spec = flatten_dict(spec, prevent_delimiter=True)
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
@@ -196,7 +201,8 @@ class TuneBOHB(Searcher):
"Grid search parameters cannot be automatically converted "
"to a TuneBOHB search space.")
def resolve_value(par, domain):
def resolve_value(par: str, domain: Domain
) -> ConfigSpace.hyperparameters.Hyperparameter:
quantize = None
sampler = domain.get_sampler()
+23 -19
View File
@@ -5,16 +5,18 @@ from __future__ import print_function
import inspect
import logging
import pickle
from typing import Dict
from typing import Dict, List, Optional
from ray.tune.sample import Float, Quantized
from ray.tune.sample import Domain, Float, Quantized
from ray.tune.suggest.variant_generator import parse_spec_vars
from ray.tune.utils.util import flatten_dict
try: # Python 3 only -- needed for lint test.
import dragonfly
from dragonfly.opt.blackbox_optimiser import BlackboxOptimiser
except ImportError:
dragonfly = None
BlackboxOptimiser = None
from ray.tune.suggest.suggestion import Searcher
@@ -127,13 +129,13 @@ class DragonflySearch(Searcher):
"""
def __init__(self,
optimizer=None,
domain=None,
space=None,
metric=None,
mode=None,
points_to_evaluate=None,
evaluated_rewards=None,
optimizer: Optional[BlackboxOptimiser] = None,
domain: Optional[str] = None,
space: Optional[List[Dict]] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
points_to_evaluate: Optional[List[List]] = None,
evaluated_rewards: Optional[List] = None,
**kwargs):
assert dragonfly is not None, """dragonfly must be installed!
You can install Dragonfly with the command:
@@ -144,8 +146,6 @@ class DragonflySearch(Searcher):
super(DragonflySearch, self).__init__(
metric=metric, mode=mode, **kwargs)
from dragonfly.opt.blackbox_optimiser import BlackboxOptimiser
self._opt_arg = optimizer
self._domain = domain
self._space = space
@@ -245,7 +245,8 @@ class DragonflySearch(Searcher):
elif self._mode == "max":
self._metric_op = 1.
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
if self._opt:
return False
space = self.convert_search_space(config)
@@ -258,7 +259,7 @@ class DragonflySearch(Searcher):
self.setup_dragonfly()
return True
def suggest(self, trial_id):
def suggest(self, trial_id: str) -> Optional[Dict]:
if not self._opt:
raise RuntimeError(
"Trying to sample a configuration from {}, but no search "
@@ -281,26 +282,29 @@ class DragonflySearch(Searcher):
self._live_trial_mapping[trial_id] = suggested_config
return {"point": suggested_config}
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
"""Passes result to Dragonfly unless early terminated or errored."""
trial_info = self._live_trial_mapping.pop(trial_id)
if result:
self._opt.tell([(trial_info,
self._metric_op * result[self._metric])])
def save(self, checkpoint_dir):
def save(self, checkpoint_path: str):
trials_object = (self._initial_points, self._opt)
with open(checkpoint_dir, "wb") as outputFile:
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(trials_object, outputFile)
def restore(self, checkpoint_dir):
def restore(self, checkpoint_dir: str):
with open(checkpoint_dir, "rb") as inputFile:
trials_object = pickle.load(inputFile)
self._initial_points = trials_object[0]
self._opt = trials_object[1]
@staticmethod
def convert_search_space(spec: Dict):
def convert_search_space(spec: Dict) -> List[Dict]:
spec = flatten_dict(spec, prevent_delimiter=True)
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
@@ -309,7 +313,7 @@ class DragonflySearch(Searcher):
"Grid search parameters cannot be automatically converted "
"to a Dragonfly search space.")
def resolve_value(par, domain):
def resolve_value(par: str, domain: Domain) -> Dict:
sampler = domain.get_sampler()
if isinstance(sampler, Quantized):
logger.warning(
+30 -25
View File
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Any, Dict, List, Optional
import numpy as np
import copy
@@ -6,7 +6,8 @@ import logging
from functools import partial
import pickle
from ray.tune.sample import Categorical, Float, Integer, LogUniform, Normal, \
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
Normal, \
Quantized, \
Uniform
from ray.tune.suggest.variant_generator import assign_value, parse_spec_vars
@@ -117,15 +118,15 @@ class HyperOptSearch(Searcher):
def __init__(
self,
space=None,
metric=None,
mode=None,
points_to_evaluate=None,
n_initial_points=20,
random_state_seed=None,
gamma=0.25,
max_concurrent=None,
use_early_stopped_trials=None,
space: Optional[Dict] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
points_to_evaluate: Optional[List[Dict]] = None,
n_initial_points: int = 20,
random_state_seed: Optional[int] = None,
gamma: float = 0.25,
max_concurrent: Optional[int] = None,
use_early_stopped_trials: Optional[bool] = None,
):
assert hpo is not None, (
"HyperOpt must be installed! Run `pip install hyperopt`.")
@@ -170,7 +171,8 @@ class HyperOptSearch(Searcher):
if space:
self.domain = hpo.Domain(lambda spc: spc, space)
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
if self.domain:
return False
space = self.convert_search_space(config)
@@ -188,7 +190,7 @@ class HyperOptSearch(Searcher):
return True
def suggest(self, trial_id):
def suggest(self, trial_id: str) -> Optional[Dict]:
if not self.domain:
raise RuntimeError(
"Trying to sample a configuration from {}, but no search "
@@ -226,7 +228,7 @@ class HyperOptSearch(Searcher):
print_node_on_error=self.domain.rec_eval_print_node_on_error)
return copy.deepcopy(suggested_config)
def on_trial_result(self, trial_id, result):
def on_trial_result(self, trial_id: str, result: Dict):
ho_trial = self._get_hyperopt_trial(trial_id)
if ho_trial is None:
return
@@ -234,7 +236,10 @@ class HyperOptSearch(Searcher):
ho_trial["book_time"] = now
ho_trial["refresh_time"] = now
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
"""Notification for the completion of trial.
The result is internally negated when interacting with HyperOpt
@@ -252,7 +257,7 @@ class HyperOptSearch(Searcher):
self._process_result(trial_id, result)
del self._live_trial_mapping[trial_id]
def _process_result(self, trial_id, result):
def _process_result(self, trial_id: str, result: Dict):
ho_trial = self._get_hyperopt_trial(trial_id)
if not ho_trial:
return
@@ -263,10 +268,10 @@ class HyperOptSearch(Searcher):
ho_trial["result"] = hp_result
self._hpopt_trials.refresh()
def _to_hyperopt_result(self, result):
def _to_hyperopt_result(self, result: Dict) -> Dict:
return {"loss": self.metric_op * result[self.metric], "status": "ok"}
def _get_hyperopt_trial(self, trial_id):
def _get_hyperopt_trial(self, trial_id: str) -> Optional[Dict]:
if trial_id not in self._live_trial_mapping:
return
hyperopt_tid = self._live_trial_mapping[trial_id][0]
@@ -274,21 +279,21 @@ class HyperOptSearch(Searcher):
t for t in self._hpopt_trials.trials if t["tid"] == hyperopt_tid
][0]
def get_state(self):
def get_state(self) -> Dict:
return {
"hyperopt_trials": self._hpopt_trials,
"rstate": self.rstate.get_state()
}
def set_state(self, state):
def set_state(self, state: Dict):
self._hpopt_trials = state["hyperopt_trials"]
self.rstate.set_state(state["rstate"])
def save(self, checkpoint_path):
def save(self, checkpoint_path: str):
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(self.get_state(), outputFile)
def restore(self, checkpoint_path):
def restore(self, checkpoint_path: str):
with open(checkpoint_path, "rb") as inputFile:
trials_object = pickle.load(inputFile)
@@ -299,19 +304,19 @@ class HyperOptSearch(Searcher):
self.set_state(trials_object)
@staticmethod
def convert_search_space(spec: Dict):
def convert_search_space(spec: Dict) -> Dict:
spec = copy.deepcopy(spec)
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
if not domain_vars and not grid_vars:
return []
return {}
if grid_vars:
raise ValueError(
"Grid search parameters cannot be automatically converted "
"to a HyperOpt search space.")
def resolve_value(par, domain):
def resolve_value(par: str, domain: Domain) -> Any:
quantize = None
sampler = domain.get_sampler()
+28 -17
View File
@@ -1,16 +1,23 @@
import logging
import pickle
from typing import Dict
from typing import Dict, Optional, Union
from ray.tune.sample import Categorical, Float, Integer, LogUniform, Quantized
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
Quantized
from ray.tune.suggest.variant_generator import parse_spec_vars
from ray.tune.utils import flatten_dict
from ray.tune.utils.util import unflatten_dict
try:
import nevergrad as ng
from nevergrad.optimization import Optimizer
from nevergrad.optimization.base import ConfiguredOptimizer
Parameter = ng.p.Parameter
except ImportError:
ng = None
Optimizer = None
ConfiguredOptimizer = None
Parameter = None
from ray.tune.suggest import Searcher
@@ -85,11 +92,11 @@ class NevergradSearch(Searcher):
"""
def __init__(self,
optimizer=None,
space=None,
metric=None,
mode=None,
max_concurrent=None,
optimizer: Union[None, Optimizer, ConfiguredOptimizer] = None,
space: Optional[Parameter] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
max_concurrent: Optional[int] = None,
**kwargs):
assert ng is not None, "Nevergrad must be installed!"
if mode:
@@ -102,7 +109,7 @@ class NevergradSearch(Searcher):
self._opt_factory = None
self._nevergrad_opt = None
if isinstance(optimizer, ng.optimization.Optimizer):
if isinstance(optimizer, Optimizer):
if space is not None or isinstance(space, list):
raise ValueError(
"If you pass a configured optimizer to Nevergrad, either "
@@ -110,7 +117,7 @@ class NevergradSearch(Searcher):
"parameter.")
self._parameters = space
self._nevergrad_opt = optimizer
elif isinstance(optimizer, ng.optimization.base.ConfiguredOptimizer):
elif isinstance(optimizer, ConfiguredOptimizer):
self._opt_factory = optimizer
self._parameters = None
self._space = space
@@ -155,7 +162,8 @@ class NevergradSearch(Searcher):
raise ValueError("len(parameters_names) must match optimizer "
"dimension for non-instrumented optimizers")
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
if self._nevergrad_opt or self._space:
return False
space = self.convert_search_space(config)
@@ -169,7 +177,7 @@ class NevergradSearch(Searcher):
self.setup_nevergrad()
return True
def suggest(self, trial_id):
def suggest(self, trial_id: str) -> Optional[Dict]:
if not self._nevergrad_opt:
raise RuntimeError(
"Trying to sample a configuration from {}, but no search "
@@ -192,7 +200,10 @@ class NevergradSearch(Searcher):
else:
return unflatten_dict(suggested_config.kwargs)
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
"""Notification for the completion of trial.
The result is internally negated when interacting with Nevergrad
@@ -204,24 +215,24 @@ class NevergradSearch(Searcher):
self._live_trial_mapping.pop(trial_id)
def _process_result(self, trial_id, result):
def _process_result(self, trial_id: str, result: Dict):
ng_trial_info = self._live_trial_mapping[trial_id]
self._nevergrad_opt.tell(ng_trial_info,
self._metric_op * result[self._metric])
def save(self, checkpoint_path):
def save(self, checkpoint_path: str):
trials_object = (self._nevergrad_opt, self._parameters)
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(trials_object, outputFile)
def restore(self, checkpoint_path):
def restore(self, checkpoint_path: str):
with open(checkpoint_path, "rb") as inputFile:
trials_object = pickle.load(inputFile)
self._nevergrad_opt = trials_object[0]
self._parameters = trials_object[1]
@staticmethod
def convert_search_space(spec: Dict):
def convert_search_space(spec: Dict) -> Parameter:
spec = flatten_dict(spec, prevent_delimiter=True)
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
@@ -230,7 +241,7 @@ class NevergradSearch(Searcher):
"Grid search parameters cannot be automatically converted "
"to a Nevergrad search space.")
def resolve_value(domain):
def resolve_value(domain: Domain) -> Parameter:
sampler = domain.get_sampler()
if isinstance(sampler, Quantized):
logger.warning("Nevergrad does not support quantization. "
+23 -13
View File
@@ -1,9 +1,9 @@
import logging
import pickle
from typing import Dict
from typing import Dict, List, Optional, Tuple
from ray.tune.result import TRAINING_ITERATION
from ray.tune.sample import Categorical, Float, Integer, LogUniform, \
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
Quantized, Uniform
from ray.tune.suggest.variant_generator import parse_spec_vars
from ray.tune.utils import flatten_dict
@@ -11,8 +11,10 @@ from ray.tune.utils.util import unflatten_dict
try:
import optuna as ot
from optuna.samplers import BaseSampler
except ImportError:
ot = None
BaseSampler = None
from ray.tune.suggest import Searcher
@@ -100,7 +102,11 @@ class OptunaSearch(Searcher):
"""
def __init__(self, space=None, metric=None, mode=None, sampler=None):
def __init__(self,
space: Optional[List[Tuple]] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
sampler: Optional[BaseSampler] = None):
assert ot is not None, (
"Optuna must be installed! Run `pip install optuna`.")
super(OptunaSearch, self).__init__(
@@ -113,7 +119,7 @@ class OptunaSearch(Searcher):
self._study_name = "optuna" # Fixed study name for in-memory storage
self._sampler = sampler or ot.samplers.TPESampler()
assert isinstance(self._sampler, ot.samplers.BaseSampler), \
assert isinstance(self._sampler, BaseSampler), \
"You can only pass an instance of `optuna.samplers.BaseSampler` " \
"as a sampler to `OptunaSearcher`."
@@ -125,7 +131,7 @@ class OptunaSearch(Searcher):
if self._space:
self.setup_study(mode)
def setup_study(self, mode):
def setup_study(self, mode: str):
self._ot_study = ot.study.create_study(
storage=self._storage,
sampler=self._sampler,
@@ -134,7 +140,8 @@ class OptunaSearch(Searcher):
direction="minimize" if mode == "min" else "maximize",
load_if_exists=True)
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
if self._space:
return False
space = self.convert_search_space(config)
@@ -146,7 +153,7 @@ class OptunaSearch(Searcher):
self.setup_study(mode)
return True
def suggest(self, trial_id):
def suggest(self, trial_id: str) -> Optional[Dict]:
if not self._space:
raise RuntimeError(
"Trying to sample a configuration from {}, but no search "
@@ -169,13 +176,16 @@ class OptunaSearch(Searcher):
}
return unflatten_dict(params)
def on_trial_result(self, trial_id, result):
def on_trial_result(self, trial_id: str, result: Dict):
metric = result[self.metric]
step = result[TRAINING_ITERATION]
ot_trial = self._ot_trials[trial_id]
ot_trial.report(metric, step)
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
ot_trial = self._ot_trials[trial_id]
ot_trial_id = ot_trial._trial_id
self._storage.set_trial_value(ot_trial_id, result.get(
@@ -183,20 +193,20 @@ class OptunaSearch(Searcher):
self._storage.set_trial_state(ot_trial_id,
ot.trial.TrialState.COMPLETE)
def save(self, checkpoint_path):
def save(self, checkpoint_path: str):
save_object = (self._storage, self._pruner, self._sampler,
self._ot_trials, self._ot_study)
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(save_object, outputFile)
def restore(self, checkpoint_path):
def restore(self, checkpoint_path: str):
with open(checkpoint_path, "rb") as inputFile:
save_object = pickle.load(inputFile)
self._storage, self._pruner, self._sampler, \
self._ot_trials, self._ot_study = save_object
@staticmethod
def convert_search_space(spec: Dict):
def convert_search_space(spec: Dict) -> List[Tuple]:
spec = flatten_dict(spec, prevent_delimiter=True)
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
@@ -208,7 +218,7 @@ class OptunaSearch(Searcher):
"Grid search parameters cannot be automatically converted "
"to an Optuna search space.")
def resolve_value(par, domain):
def resolve_value(par: str, domain: Domain) -> Tuple:
quantize = None
sampler = domain.get_sampler()
+26 -14
View File
@@ -1,5 +1,7 @@
import copy
import logging
from typing import Dict, List, Optional
import numpy as np
from ray.tune.suggest.suggestion import Searcher
@@ -10,7 +12,7 @@ TRIAL_INDEX = "__trial_index__"
"""str: A constant value representing the repeat index of the trial."""
def _warn_num_samples(searcher, num_samples):
def _warn_num_samples(searcher: Searcher, num_samples: int):
if isinstance(searcher, Repeater) and num_samples % searcher.repeat:
logger.warning(
"`num_samples` is now expected to be the total number of trials, "
@@ -34,7 +36,10 @@ class _TrialGroup:
"""
def __init__(self, primary_trial_id, config, max_trials=1):
def __init__(self,
primary_trial_id: str,
config: Dict,
max_trials: int = 1):
assert type(config) is dict, (
"config is not a dict, got {}".format(config))
self.primary_trial_id = primary_trial_id
@@ -42,27 +47,27 @@ class _TrialGroup:
self._trials = {primary_trial_id: None}
self.max_trials = max_trials
def add(self, trial_id):
def add(self, trial_id: str):
assert len(self._trials) < self.max_trials
self._trials.setdefault(trial_id, None)
def full(self):
def full(self) -> bool:
return len(self._trials) == self.max_trials
def report(self, trial_id, score):
def report(self, trial_id: str, score: float):
assert trial_id in self._trials
if score is None:
raise ValueError("Internal Error: Score cannot be None.")
self._trials[trial_id] = score
def finished_reporting(self):
def finished_reporting(self) -> bool:
return None not in self._trials.values() and len(
self._trials) == self.max_trials
def scores(self):
def scores(self) -> List[Optional[float]]:
return list(self._trials.values())
def count(self):
def count(self) -> int:
return len(self._trials)
@@ -103,7 +108,10 @@ class Repeater(Searcher):
"""
def __init__(self, searcher, repeat=1, set_index=True):
def __init__(self,
searcher: Searcher,
repeat: int = 1,
set_index: bool = True):
self.searcher = searcher
self.repeat = repeat
self._set_index = set_index
@@ -113,7 +121,7 @@ class Repeater(Searcher):
super(Repeater, self).__init__(
metric=self.searcher.metric, mode=self.searcher.mode)
def suggest(self, trial_id):
def suggest(self, trial_id: str) -> Optional[Dict]:
if self._current_group is None or self._current_group.full():
config = self.searcher.suggest(trial_id)
if config is None:
@@ -132,7 +140,10 @@ class Repeater(Searcher):
self._trial_id_to_group[trial_id] = self._current_group
return config
def on_trial_complete(self, trial_id, result=None, **kwargs):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
**kwargs):
"""Stores the score for and keeps track of a completed trial.
Stores the metric of a trial as nan if any of the following conditions
@@ -160,13 +171,14 @@ class Repeater(Searcher):
result={self.searcher.metric: np.nanmean(scores)},
**kwargs)
def get_state(self):
def get_state(self) -> Dict:
self_state = self.__dict__.copy()
del self_state["searcher"]
return self_state
def set_state(self, state):
def set_state(self, state: Dict):
self.__dict__.update(state)
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
return self.searcher.set_search_properties(metric, mode, config)
+22 -9
View File
@@ -1,3 +1,9 @@
from typing import Dict, List, Optional, Union
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial
class SearchAlgorithm:
"""Interface of an event handler API for hyperparameter search.
@@ -12,7 +18,8 @@ class SearchAlgorithm:
"""
_finished = False
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
"""Pass search properties to search algorithm.
This method acts as an alternative to instantiating search algorithms
@@ -29,7 +36,9 @@ class SearchAlgorithm:
"""
return True
def add_configurations(self, experiments):
def add_configurations(
self,
experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]):
"""Tracks given experiment specifications.
Arguments:
@@ -37,7 +46,7 @@ class SearchAlgorithm:
"""
raise NotImplementedError
def next_trials(self):
def next_trials(self) -> List[Trial]:
"""Provides Trial objects to be queued into the TrialRunner.
Returns:
@@ -45,17 +54,21 @@ class SearchAlgorithm:
"""
raise NotImplementedError
def on_trial_result(self, trial_id, result):
def on_trial_result(self, trial_id: str, result: Dict):
"""Called on each intermediate result returned by a trial.
This will only be called when the trial is in the RUNNING state.
Arguments:
trial_id: Identifier for the trial.
result: Result dictionary.
"""
pass
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
"""Notification for the completion of trial.
Arguments:
@@ -69,7 +82,7 @@ class SearchAlgorithm:
"""
pass
def is_finished(self):
def is_finished(self) -> bool:
"""Returns True if no trials left to be queued into TrialRunner.
Can return True before all trials have finished executing.
@@ -80,14 +93,14 @@ class SearchAlgorithm:
"""Marks the search algorithm as finished."""
self._finished = True
def has_checkpoint(self, dirpath):
def has_checkpoint(self, dirpath: str) -> bool:
"""Should return False if not restoring is not implemented."""
return False
def save_to_dir(self, dirpath, **kwargs):
def save_to_dir(self, dirpath: str, **kwargs):
"""Saves a search algorithm."""
pass
def restore_from_dir(self, dirpath):
def restore_from_dir(self, dirpath: str):
"""Restores a search algorithm along with its wrapped state."""
pass
+24 -16
View File
@@ -2,10 +2,11 @@ import os
import copy
import logging
import glob
from typing import Dict, List, Optional, Union
import ray.cloudpickle as cloudpickle
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list
from ray.tune.experiment import Experiment, convert_to_experiment_list
from ray.tune.config_parser import make_parser, create_trial_from_spec
from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.suggest.suggestion import Searcher
@@ -21,7 +22,7 @@ def _warn_on_repeater(searcher, total_samples):
_warn_num_samples(searcher, total_samples)
def _atomic_save(state, checkpoint_dir, file_name):
def _atomic_save(state: Dict, checkpoint_dir: str, file_name: str):
"""Atomically saves the object to the checkpoint directory
This is automatically used by tune.run during a Tune job.
@@ -34,7 +35,7 @@ def _atomic_save(state, checkpoint_dir, file_name):
os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
def _find_newest_ckpt(dirpath, pattern):
def _find_newest_ckpt(dirpath: str, pattern: str):
"""Returns path to most recently modified checkpoint."""
full_paths = glob.glob(os.path.join(dirpath, pattern))
if not full_paths:
@@ -58,7 +59,7 @@ class SearchGenerator(SearchAlgorithm):
"""
CKPT_FILE_TMPL = "search_gen_state-{}.json"
def __init__(self, searcher):
def __init__(self, searcher: Searcher):
assert issubclass(
type(searcher),
Searcher), ("Searcher should be subclassing Searcher.")
@@ -69,10 +70,13 @@ class SearchGenerator(SearchAlgorithm):
self._total_samples = None # int: total samples to evaluate.
self._finished = False
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
return self.searcher.set_search_properties(metric, mode, config)
def add_configurations(self, experiments):
def add_configurations(
self,
experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]):
"""Registers experiment specifications.
Arguments:
@@ -91,7 +95,7 @@ class SearchGenerator(SearchAlgorithm):
if "run" not in experiment_spec:
raise TuneError("Must specify `run` in {}".format(experiment_spec))
def next_trials(self):
def next_trials(self) -> List[Trial]:
"""Provides a batch of Trial objects to be queued into the TrialRunner.
Returns:
@@ -106,7 +110,8 @@ class SearchGenerator(SearchAlgorithm):
trials.append(trial)
return trials
def create_trial_if_possible(self, experiment_spec, output_path):
def create_trial_if_possible(self, experiment_spec: Dict,
output_path: str) -> Optional[Trial]:
logger.debug("creating trial")
trial_id = Trial.generate_id()
suggested_config = self.searcher.suggest(trial_id)
@@ -135,18 +140,21 @@ class SearchGenerator(SearchAlgorithm):
trial_id=trial_id)
return trial
def on_trial_result(self, trial_id, result):
def on_trial_result(self, trial_id: str, result: Dict):
"""Notifies the underlying searcher."""
self.searcher.on_trial_result(trial_id, result)
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
self.searcher.on_trial_complete(
trial_id=trial_id, result=result, error=error)
def is_finished(self):
def is_finished(self) -> bool:
return self._counter >= self._total_samples or self._finished
def get_state(self):
def get_state(self) -> Dict:
return {
"counter": self._counter,
"total_samples": self._total_samples,
@@ -154,17 +162,17 @@ class SearchGenerator(SearchAlgorithm):
"experiment": self._experiment
}
def set_state(self, state):
def set_state(self, state: Dict):
self._counter = state["counter"]
self._total_samples = state["total_samples"]
self._finished = state["finished"]
self._experiment = state["experiment"]
def has_checkpoint(self, dirpath):
def has_checkpoint(self, dirpath: str):
return bool(
_find_newest_ckpt(dirpath, self.CKPT_FILE_TMPL.format("*")))
def save_to_dir(self, dirpath, session_str):
def save_to_dir(self, dirpath: str, session_str: str):
"""Saves self + searcher to dir.
Separates the "searcher" from its wrappers (concurrency, repeating).
@@ -196,7 +204,7 @@ class SearchGenerator(SearchAlgorithm):
_atomic_save(search_alg_state, dirpath,
self.CKPT_FILE_TMPL.format(session_str))
def restore_from_dir(self, dirpath):
def restore_from_dir(self, dirpath: str):
"""Restores self + searcher + search wrappers from dirpath."""
searcher = self.searcher
+23 -16
View File
@@ -2,10 +2,14 @@ import copy
import os
import logging
import pickle
from typing import Dict, List, Optional, Union
try:
import sigopt as sgo
Connection = sgo.Connection
except ImportError:
sgo = None
Connection = None
from ray.tune.suggest import Searcher
@@ -122,16 +126,16 @@ class SigOptSearch(Searcher):
}
def __init__(self,
space=None,
name="Default Tune Experiment",
max_concurrent=1,
reward_attr=None,
connection=None,
experiment_id=None,
observation_budget=None,
project=None,
metric="episode_reward_mean",
mode="max",
space: List[Dict] = None,
name: str = "Default Tune Experiment",
max_concurrent: int = 1,
reward_attr: Optional[str] = None,
connection: Optional[Connection] = None,
experiment_id: Optional[str] = None,
observation_budget: Optional[int] = None,
project: Optional[str] = None,
metric: Union[None, str, List[str]] = "episode_reward_mean",
mode: Union[None, str, List[str]] = "max",
**kwargs):
assert (experiment_id is
None) ^ (space is None), "space xor experiment_id must be set"
@@ -178,7 +182,7 @@ class SigOptSearch(Searcher):
super(SigOptSearch, self).__init__(metric=metric, mode=mode, **kwargs)
def suggest(self, trial_id):
def suggest(self, trial_id: str):
if self._max_concurrent:
if len(self._live_trial_mapping) >= self._max_concurrent:
return None
@@ -190,7 +194,10 @@ class SigOptSearch(Searcher):
return copy.deepcopy(suggestion.assignments)
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
"""Notification for the completion of trial.
If a trial fails, it will be reported as a failed Observation, telling
@@ -214,7 +221,7 @@ class SigOptSearch(Searcher):
del self._live_trial_mapping[trial_id]
@staticmethod
def serialize_metric(metrics, modes):
def serialize_metric(metrics: List[str], modes: List[str]):
"""
Converts metrics to https://app.sigopt.com/docs/objects/metric
"""
@@ -224,7 +231,7 @@ class SigOptSearch(Searcher):
dict(name=metric, **SigOptSearch.OBJECTIVE_MAP[mode].copy()))
return serialized_metric
def serialize_result(self, result):
def serialize_result(self, result: Dict):
"""
Converts experiments results to
https://app.sigopt.com/docs/objects/metric_evaluation
@@ -244,12 +251,12 @@ class SigOptSearch(Searcher):
values.append(value)
return values
def save(self, checkpoint_path):
def save(self, checkpoint_path: str):
trials_object = (self.conn, self.experiment)
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(trials_object, outputFile)
def restore(self, checkpoint_path):
def restore(self, checkpoint_path: str):
with open(checkpoint_path, "rb") as inputFile:
trials_object = pickle.load(inputFile)
self.conn = trials_object[0]
+26 -21
View File
@@ -1,8 +1,8 @@
import logging
import pickle
from typing import Dict
from typing import Dict, List, Optional, Tuple, Union
from ray.tune.sample import Categorical, Float, Integer, Quantized
from ray.tune.sample import Categorical, Domain, Float, Integer, Quantized
from ray.tune.suggest.variant_generator import parse_spec_vars
from ray.tune.utils import flatten_dict
from ray.tune.utils.util import unflatten_dict
@@ -17,8 +17,9 @@ from ray.tune.suggest import Searcher
logger = logging.getLogger(__name__)
def _validate_warmstart(parameter_names, points_to_evaluate,
evaluated_rewards):
def _validate_warmstart(parameter_names: List[str],
points_to_evaluate: List[List],
evaluated_rewards: List):
if points_to_evaluate:
if not isinstance(points_to_evaluate, list):
raise TypeError(
@@ -125,14 +126,14 @@ class SkOptSearch(Searcher):
"""
def __init__(self,
optimizer=None,
space=None,
metric=None,
mode=None,
points_to_evaluate=None,
evaluated_rewards=None,
max_concurrent=None,
use_early_stopped_trials=None):
optimizer: Optional[sko.optimizer.Optimizer] = None,
space: Union[List[str], Dict[str, Union[Tuple, List]]] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
points_to_evaluate: Optional[List[List]] = None,
evaluated_rewards: Optional[List] = None,
max_concurrent: Optional[int] = None,
use_early_stopped_trials: Optional[bool] = None):
assert sko is not None, """skopt must be installed!
You can install Skopt with the command:
`pip install scikit-optimize`."""
@@ -162,7 +163,7 @@ class SkOptSearch(Searcher):
"names.")
self._parameter_names = space
else:
self._parameter_names = space.keys()
self._parameter_names = list(space.keys())
self._parameter_ranges = space.values()
self._points_to_evaluate = points_to_evaluate
@@ -199,7 +200,8 @@ class SkOptSearch(Searcher):
elif self._mode == "min":
self._metric_op = 1.
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
if self._skopt_opt:
return False
space = self.convert_search_space(config)
@@ -216,7 +218,7 @@ class SkOptSearch(Searcher):
self.setup_skopt()
return True
def suggest(self, trial_id):
def suggest(self, trial_id: str) -> Optional[Dict]:
if not self._skopt_opt:
raise RuntimeError(
"Trying to sample a configuration from {}, but no search "
@@ -235,7 +237,10 @@ class SkOptSearch(Searcher):
self._live_trial_mapping[trial_id] = suggested_config
return unflatten_dict(dict(zip(self._parameters, suggested_config)))
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
"""Notification for the completion of trial.
The result is internally negated when interacting with Skopt
@@ -247,24 +252,24 @@ class SkOptSearch(Searcher):
self._process_result(trial_id, result)
self._live_trial_mapping.pop(trial_id)
def _process_result(self, trial_id, result):
def _process_result(self, trial_id: str, result: Dict):
skopt_trial_info = self._live_trial_mapping[trial_id]
self._skopt_opt.tell(skopt_trial_info,
self._metric_op * result[self._metric])
def save(self, checkpoint_path):
def save(self, checkpoint_path: str):
trials_object = (self._initial_points, self._skopt_opt)
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(trials_object, outputFile)
def restore(self, checkpoint_path):
def restore(self, checkpoint_path: str):
with open(checkpoint_path, "rb") as inputFile:
trials_object = pickle.load(inputFile)
self._initial_points = trials_object[0]
self._skopt_opt = trials_object[1]
@staticmethod
def convert_search_space(spec: Dict):
def convert_search_space(spec: Dict) -> Dict:
spec = flatten_dict(spec, prevent_delimiter=True)
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
@@ -273,7 +278,7 @@ class SkOptSearch(Searcher):
"Grid search parameters cannot be automatically converted "
"to a SkOpt search space.")
def resolve_value(domain):
def resolve_value(domain: Domain) -> Union[Tuple, List]:
sampler = domain.get_sampler()
if isinstance(sampler, Quantized):
logger.warning("SkOpt search does not support quantization. "
+36 -24
View File
@@ -2,6 +2,7 @@ import copy
import glob
import logging
import os
from typing import Dict, Optional
from ray.util.debug import log_once
@@ -56,10 +57,10 @@ class Searcher:
CKPT_FILE_TMPL = "searcher-state-{}.pkl"
def __init__(self,
metric=None,
mode=None,
max_concurrent=None,
use_early_stopped_trials=None):
metric: Optional[str] = None,
mode: Optional[str] = None,
max_concurrent: Optional[int] = None,
use_early_stopped_trials: Optional[bool] = None):
if use_early_stopped_trials is False:
raise DeprecationWarning(
"Early stopped trials are now always used. If this is a "
@@ -90,7 +91,8 @@ class Searcher:
else:
raise ValueError("Mode most either be a list or string")
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
"""Pass search properties to searcher.
This method acts as an alternative to instantiating search algorithms
@@ -106,7 +108,7 @@ class Searcher:
"""
return False
def on_trial_result(self, trial_id, result):
def on_trial_result(self, trial_id: str, result: Dict):
"""Optional notification for result during training.
Note that by default, the result dict may include NaNs or
@@ -124,7 +126,10 @@ class Searcher:
"""
pass
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
"""Notification for the completion of trial.
Typically, this method is used for notifying the underlying
@@ -143,7 +148,7 @@ class Searcher:
"""
raise NotImplementedError
def suggest(self, trial_id):
def suggest(self, trial_id: str) -> Optional[Dict]:
"""Queries the algorithm to retrieve the next set of parameters.
Arguments:
@@ -159,7 +164,7 @@ class Searcher:
"""
raise NotImplementedError
def save(self, checkpoint_path):
def save(self, checkpoint_path: str):
"""Save state to path for this search algorithm.
Args:
@@ -190,7 +195,7 @@ class Searcher:
"""
raise NotImplementedError
def restore(self, checkpoint_path):
def restore(self, checkpoint_path: str):
"""Restore state for this search algorithm
@@ -213,13 +218,13 @@ class Searcher:
"""
raise NotImplementedError
def get_state(self):
def get_state(self) -> Dict:
raise NotImplementedError
def set_state(self, state):
def set_state(self, state: Dict):
raise NotImplementedError
def save_to_dir(self, checkpoint_dir, session_str="default"):
def save_to_dir(self, checkpoint_dir: str, session_str: str = "default"):
"""Automatically saves the given searcher to the checkpoint_dir.
This is automatically used by tune.run during a Tune job.
@@ -246,7 +251,7 @@ class Searcher:
os.path.join(checkpoint_dir,
self.CKPT_FILE_TMPL.format(session_str)))
def restore_from_dir(self, checkpoint_dir):
def restore_from_dir(self, checkpoint_dir: str):
"""Restores the state of a searcher from a given checkpoint_dir.
Typically, you should use this function to restore from an
@@ -277,12 +282,12 @@ class Searcher:
self.restore(most_recent_checkpoint)
@property
def metric(self):
def metric(self) -> str:
"""The training result objective value attribute."""
return self._metric
@property
def mode(self):
def mode(self) -> str:
"""Specifies if minimizing or maximizing the metric."""
return self._mode
@@ -308,7 +313,10 @@ class ConcurrencyLimiter(Searcher):
tune.run(trainable, search_alg=search_alg)
"""
def __init__(self, searcher, max_concurrent, batch=False):
def __init__(self,
searcher: Searcher,
max_concurrent: int,
batch: bool = False):
assert type(max_concurrent) is int and max_concurrent > 0
self.searcher = searcher
self.max_concurrent = max_concurrent
@@ -318,7 +326,7 @@ class ConcurrencyLimiter(Searcher):
super(ConcurrencyLimiter, self).__init__(
metric=self.searcher.metric, mode=self.searcher.mode)
def suggest(self, trial_id):
def suggest(self, trial_id: str) -> Optional[Dict]:
assert trial_id not in self.live_trials, (
f"Trial ID {trial_id} must be unique: already found in set.")
if len(self.live_trials) >= self.max_concurrent:
@@ -333,7 +341,10 @@ class ConcurrencyLimiter(Searcher):
self.live_trials.add(trial_id)
return suggestion
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
if trial_id not in self.live_trials:
return
elif self.batch:
@@ -353,19 +364,20 @@ class ConcurrencyLimiter(Searcher):
trial_id, result=result, error=error)
self.live_trials.remove(trial_id)
def get_state(self):
def get_state(self) -> Dict:
state = self.__dict__.copy()
del state["searcher"]
return copy.deepcopy(state)
def set_state(self, state):
def set_state(self, state: Dict):
self.__dict__.update(state)
def on_pause(self, trial_id):
def on_pause(self, trial_id: str):
self.searcher.on_pause(trial_id)
def on_unpause(self, trial_id):
def on_unpause(self, trial_id: str):
self.searcher.on_unpause(trial_id)
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
return self.searcher.set_search_properties(metric, mode, config)
+25 -18
View File
@@ -1,5 +1,7 @@
import copy
import logging
from typing import Any, Dict, Generator, List, Tuple
import numpy
import random
@@ -9,7 +11,8 @@ from ray.tune.sample import Categorical, Domain, Function
logger = logging.getLogger(__name__)
def generate_variants(unresolved_spec):
def generate_variants(
unresolved_spec: Dict) -> Generator[Tuple[Dict, Dict], None, None]:
"""Generates variants from a spec (dict) with unresolved values.
There are two types of unresolved values:
@@ -45,7 +48,7 @@ def generate_variants(unresolved_spec):
yield resolved_vars, spec
def grid_search(values):
def grid_search(values: List) -> Dict[str, List]:
"""Convenience method for specifying grid search over a value.
Arguments:
@@ -63,7 +66,7 @@ _STANDARD_IMPORTS = {
_MAX_RESOLUTION_PASSES = 20
def resolve_nested_dict(nested_dict):
def resolve_nested_dict(nested_dict: Dict) -> Dict[Tuple, Any]:
"""Flattens a nested dict by joining keys into tuple of paths.
Can then be passed into `format_vars`.
@@ -78,7 +81,7 @@ def resolve_nested_dict(nested_dict):
return res
def format_vars(resolved_vars):
def format_vars(resolved_vars: Dict) -> str:
"""Formats the resolved variable dict into a single string."""
out = []
for path, value in sorted(resolved_vars.items()):
@@ -97,7 +100,7 @@ def format_vars(resolved_vars):
return ",".join(out)
def flatten_resolved_vars(resolved_vars):
def flatten_resolved_vars(resolved_vars: Dict) -> Dict:
"""Formats the resolved variable dict into a mapping of (str -> value)."""
flattened_resolved_vars_dict = {}
for pieces, value in resolved_vars.items():
@@ -108,14 +111,15 @@ def flatten_resolved_vars(resolved_vars):
return flattened_resolved_vars_dict
def _clean_value(value):
def _clean_value(value: Any) -> str:
if isinstance(value, float):
return "{:.5}".format(value)
else:
return str(value).replace("/", "_")
def parse_spec_vars(spec):
def parse_spec_vars(spec: Dict) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[
Tuple, Any]], List[Tuple[Tuple, Any]]]:
resolved, unresolved = _split_resolved_unresolved_values(spec)
resolved_vars = list(resolved.items())
@@ -134,7 +138,7 @@ def parse_spec_vars(spec):
return resolved_vars, domain_vars, grid_vars
def _generate_variants(spec):
def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]:
spec = copy.deepcopy(spec)
_, domain_vars, grid_vars = parse_spec_vars(spec)
@@ -159,19 +163,20 @@ def _generate_variants(spec):
yield resolved_vars, spec
def assign_value(spec, path, value):
def assign_value(spec: Dict, path: Tuple, value: Any):
for k in path[:-1]:
spec = spec[k]
spec[path[-1]] = value
def _get_value(spec, path):
def _get_value(spec: Dict, path: Tuple) -> Any:
for k in path:
spec = spec[k]
return spec
def _resolve_domain_vars(spec, domain_vars):
def _resolve_domain_vars(spec: Dict,
domain_vars: List[Tuple[Tuple, Domain]]) -> Dict:
resolved = {}
error = True
num_passes = 0
@@ -197,7 +202,8 @@ def _resolve_domain_vars(spec, domain_vars):
return resolved
def _grid_search_generator(unresolved_spec, grid_vars):
def _grid_search_generator(unresolved_spec: Dict,
grid_vars: List) -> Generator[Dict, None, None]:
value_indices = [0] * len(grid_vars)
def increment(i):
@@ -225,12 +231,12 @@ def _grid_search_generator(unresolved_spec, grid_vars):
break
def _is_resolved(v):
def _is_resolved(v) -> bool:
resolved, _ = _try_resolve(v)
return resolved
def _try_resolve(v):
def _try_resolve(v) -> Tuple[bool, Any]:
if isinstance(v, Domain):
# Domain to sample from
return False, v
@@ -249,7 +255,8 @@ def _try_resolve(v):
return True, v
def _split_resolved_unresolved_values(spec):
def _split_resolved_unresolved_values(
spec: Dict) -> Tuple[Dict[Tuple, Any], Dict[Tuple, Any]]:
resolved_vars = {}
unresolved_vars = {}
for k, v in spec.items():
@@ -278,11 +285,11 @@ def _split_resolved_unresolved_values(spec):
return resolved_vars, unresolved_vars
def _unresolved_values(spec):
def _unresolved_values(spec: Dict) -> Dict[Tuple, Any]:
return _split_resolved_unresolved_values(spec)[1]
def has_unresolved_values(spec):
def has_unresolved_values(spec: Dict) -> bool:
return True if _unresolved_values(spec) else False
@@ -303,5 +310,5 @@ class _UnresolvedAccessGuard(dict):
class RecursiveDependencyError(Exception):
def __init__(self, msg):
def __init__(self, msg: str):
Exception.__init__(self, msg)
+19 -14
View File
@@ -1,9 +1,10 @@
import copy
import logging
from typing import Dict
from typing import Dict, Optional, Tuple
import ray.cloudpickle as pickle
from ray.tune.sample import Categorical, Float, Integer, Quantized, Uniform
from ray.tune.sample import Categorical, Domain, Float, Integer, Quantized, \
Uniform
from ray.tune.suggest.variant_generator import parse_spec_vars
from ray.tune.utils.util import unflatten_dict
from zoopt import ValueType
@@ -106,11 +107,11 @@ class ZOOptSearch(Searcher):
optimizer = None
def __init__(self,
algo="asracos",
budget=None,
dim_dict=None,
metric=None,
mode=None,
algo: str = "asracos",
budget: Optional[int] = None,
dim_dict: Optional[Dict] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
**kwargs):
assert zoopt is not None, "Zoopt not found - please install zoopt."
assert budget is not None, "`budget` should not be None!"
@@ -154,7 +155,8 @@ class ZOOptSearch(Searcher):
from zoopt.algos.opt_algorithms.racos.sracos import SRacosTune
self.optimizer = SRacosTune(dimension=dim, parameter=par)
def set_search_properties(self, metric, mode, config):
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
if self._dim_dict:
return False
space = self.convert_search_space(config)
@@ -173,7 +175,7 @@ class ZOOptSearch(Searcher):
self.setup_zoopt()
return True
def suggest(self, trial_id):
def suggest(self, trial_id: str) -> Optional[Dict]:
if not self._dim_dict or not self.optimizer:
raise RuntimeError(
"Trying to sample a configuration from {}, but no search "
@@ -189,7 +191,10 @@ class ZOOptSearch(Searcher):
self._live_trial_mapping[trial_id] = new_trial
return unflatten_dict(new_trial)
def on_trial_complete(self, trial_id, result=None, error=False):
def on_trial_complete(self,
trial_id: str,
result: Optional[Dict] = None,
error: bool = False):
"""Notification for the completion of trial."""
if result:
_solution = self.solution_dict[str(trial_id)]
@@ -200,18 +205,18 @@ class ZOOptSearch(Searcher):
del self._live_trial_mapping[trial_id]
def save(self, checkpoint_path):
def save(self, checkpoint_path: str):
trials_object = self.optimizer
with open(checkpoint_path, "wb") as output:
pickle.dump(trials_object, output)
def restore(self, checkpoint_path):
def restore(self, checkpoint_path: str):
with open(checkpoint_path, "rb") as input:
trials_object = pickle.load(input)
self.optimizer = trials_object
@staticmethod
def convert_search_space(spec: Dict):
def convert_search_space(spec: Dict) -> Dict[str, Tuple]:
spec = copy.deepcopy(spec)
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
@@ -223,7 +228,7 @@ class ZOOptSearch(Searcher):
"Grid search parameters cannot be automatically converted "
"to a ZOOpt search space.")
def resolve_value(domain):
def resolve_value(domain: Domain) -> Tuple:
quantize = None
sampler = domain.get_sampler()