mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:32:11 +08:00
[tune] added type hints (#10806)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -40,6 +40,7 @@ MOCK_MODULES = [
|
||||
"horovod",
|
||||
"horovod.ray",
|
||||
"kubernetes",
|
||||
"mxnet",
|
||||
"mxnet.model",
|
||||
"psutil",
|
||||
"ray._raylet",
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ray.tune.checkpoint_manager import Checkpoint
|
||||
from ray.tune.utils import flatten_dict
|
||||
|
||||
try:
|
||||
@@ -37,7 +37,10 @@ class Analysis:
|
||||
in the respective functions.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_dir, default_metric=None, default_mode=None):
|
||||
def __init__(self,
|
||||
experiment_dir: str,
|
||||
default_metric: Optional[str] = None,
|
||||
default_mode: Optional[str] = None):
|
||||
experiment_dir = os.path.expanduser(experiment_dir)
|
||||
if not os.path.isdir(experiment_dir):
|
||||
raise ValueError(
|
||||
@@ -59,14 +62,14 @@ class Analysis:
|
||||
else:
|
||||
self.fetch_trial_dataframes()
|
||||
|
||||
def _validate_metric(self, metric):
|
||||
def _validate_metric(self, metric: str) -> str:
|
||||
if not metric and not self.default_metric:
|
||||
raise ValueError(
|
||||
"No `metric` has been passed and `default_metric` has "
|
||||
"not been set. Please specify the `metric` parameter.")
|
||||
return metric or self.default_metric
|
||||
|
||||
def _validate_mode(self, mode):
|
||||
def _validate_mode(self, mode: str) -> str:
|
||||
if not mode and not self.default_mode:
|
||||
raise ValueError(
|
||||
"No `mode` has been passed and `default_mode` has "
|
||||
@@ -75,7 +78,9 @@ class Analysis:
|
||||
raise ValueError("If set, `mode` has to be one of [min, max]")
|
||||
return mode or self.default_mode
|
||||
|
||||
def dataframe(self, metric=None, mode=None):
|
||||
def dataframe(self,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None) -> DataFrame:
|
||||
"""Returns a pandas.DataFrame object constructed from the trials.
|
||||
|
||||
Args:
|
||||
@@ -97,7 +102,9 @@ class Analysis:
|
||||
rows[path].update(logdir=path)
|
||||
return pd.DataFrame(list(rows.values()))
|
||||
|
||||
def get_best_config(self, metric=None, mode=None):
|
||||
def get_best_config(self,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None) -> Optional[Dict]:
|
||||
"""Retrieve the best config corresponding to the trial.
|
||||
|
||||
Args:
|
||||
@@ -122,7 +129,9 @@ class Analysis:
|
||||
best_path = compare_op(rows, key=lambda k: rows[k][metric])
|
||||
return all_configs[best_path]
|
||||
|
||||
def get_best_logdir(self, metric=None, mode=None):
|
||||
def get_best_logdir(self,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None) -> Optional[str]:
|
||||
"""Retrieve the logdir corresponding to the best trial.
|
||||
|
||||
Args:
|
||||
@@ -148,7 +157,7 @@ class Analysis:
|
||||
self._experiment_dir))
|
||||
return None
|
||||
|
||||
def fetch_trial_dataframes(self):
|
||||
def fetch_trial_dataframes(self) -> Dict[str, DataFrame]:
|
||||
fail_count = 0
|
||||
for path in self._get_trial_paths():
|
||||
try:
|
||||
@@ -162,7 +171,7 @@ class Analysis:
|
||||
"Couldn't read results from {} paths".format(fail_count))
|
||||
return self.trial_dataframes
|
||||
|
||||
def get_all_configs(self, prefix=False):
|
||||
def get_all_configs(self, prefix: bool = False) -> Dict[str, Dict]:
|
||||
"""Returns a list of all configurations.
|
||||
|
||||
Args:
|
||||
@@ -170,7 +179,8 @@ class Analysis:
|
||||
and prepends `config/`.
|
||||
|
||||
Returns:
|
||||
List[dict]: List of all configurations of trials,
|
||||
Dict[str, Dict]: Dict of all configurations of trials, indexed by
|
||||
their trial dir.
|
||||
"""
|
||||
fail_count = 0
|
||||
for path in self._get_trial_paths():
|
||||
@@ -189,7 +199,10 @@ class Analysis:
|
||||
"Couldn't read config from {} paths".format(fail_count))
|
||||
return self._configs
|
||||
|
||||
def get_trial_checkpoints_paths(self, trial, metric=None):
|
||||
def get_trial_checkpoints_paths(self,
|
||||
trial: Trial,
|
||||
metric: Optional[str] = None
|
||||
) -> List[Tuple[str, Number]]:
|
||||
"""Gets paths and metrics of all persistent checkpoints of a trial.
|
||||
|
||||
Args:
|
||||
@@ -215,11 +228,14 @@ class Analysis:
|
||||
return path_metric_df[["chkpt_path", metric]].values.tolist()
|
||||
elif isinstance(trial, Trial):
|
||||
checkpoints = trial.checkpoint_manager.best_checkpoints()
|
||||
return [[c.value, c.result[metric]] for c in checkpoints]
|
||||
return [(c.value, c.result[metric]) for c in checkpoints]
|
||||
else:
|
||||
raise ValueError("trial should be a string or a Trial instance.")
|
||||
|
||||
def get_best_checkpoint(self, trial, metric=None, mode=None):
|
||||
def get_best_checkpoint(self,
|
||||
trial: Trial,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None) -> Optional[str]:
|
||||
"""Gets best persistent checkpoint path of provided trial.
|
||||
|
||||
Args:
|
||||
@@ -244,7 +260,9 @@ class Analysis:
|
||||
else:
|
||||
return min(checkpoint_paths, key=lambda x: x[1])[0]
|
||||
|
||||
def _retrieve_rows(self, metric=None, mode=None):
|
||||
def _retrieve_rows(self,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None) -> Dict[str, Any]:
|
||||
assert mode is None or mode in ["max", "min"]
|
||||
rows = {}
|
||||
for path, df in self.trial_dataframes.items():
|
||||
@@ -264,7 +282,7 @@ class Analysis:
|
||||
|
||||
return rows
|
||||
|
||||
def _get_trial_paths(self):
|
||||
def _get_trial_paths(self) -> List[str]:
|
||||
_trial_paths = []
|
||||
for trial_path, _, files in os.walk(self._experiment_dir):
|
||||
if EXPR_PROGRESS_FILE in files:
|
||||
@@ -276,7 +294,7 @@ class Analysis:
|
||||
return _trial_paths
|
||||
|
||||
@property
|
||||
def trial_dataframes(self):
|
||||
def trial_dataframes(self) -> Dict[str, DataFrame]:
|
||||
"""List of all dataframes of the trials."""
|
||||
return self._trial_dataframes
|
||||
|
||||
@@ -306,10 +324,10 @@ class ExperimentAnalysis(Analysis):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
experiment_checkpoint_path,
|
||||
trials=None,
|
||||
default_metric=None,
|
||||
default_mode=None):
|
||||
experiment_checkpoint_path: str,
|
||||
trials: Optional[List[Trial]] = None,
|
||||
default_metric: Optional[str] = None,
|
||||
default_mode: Optional[str] = None):
|
||||
experiment_checkpoint_path = os.path.expanduser(
|
||||
experiment_checkpoint_path)
|
||||
if not os.path.isfile(experiment_checkpoint_path):
|
||||
@@ -365,8 +383,8 @@ class ExperimentAnalysis(Analysis):
|
||||
return self.get_best_config(self.default_metric, self.default_mode)
|
||||
|
||||
@property
|
||||
def best_checkpoint(self) -> Checkpoint:
|
||||
"""Get the checkpoint of the best trial of the experiment
|
||||
def best_checkpoint(self) -> str:
|
||||
"""Get the checkpoint path of the best trial of the experiment
|
||||
|
||||
The best trial is determined by comparing the last trial results
|
||||
using the `metric` and `mode` parameters passed to `tune.run()`.
|
||||
@@ -471,7 +489,10 @@ class ExperimentAnalysis(Analysis):
|
||||
],
|
||||
index="trial_id")
|
||||
|
||||
def get_best_trial(self, metric=None, mode=None, scope="last"):
|
||||
def get_best_trial(self,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
scope: str = "last") -> Optional[Trial]:
|
||||
"""Retrieve the best trial object.
|
||||
|
||||
Compares all trials' scores on ``metric``.
|
||||
@@ -535,7 +556,10 @@ class ExperimentAnalysis(Analysis):
|
||||
"parameter?")
|
||||
return best_trial
|
||||
|
||||
def get_best_config(self, metric=None, mode=None, scope="last"):
|
||||
def get_best_config(self,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
scope: str = "last") -> Optional[Dict]:
|
||||
"""Retrieve the best config corresponding to the trial.
|
||||
|
||||
Compares all trials' scores on `metric`.
|
||||
@@ -562,7 +586,10 @@ class ExperimentAnalysis(Analysis):
|
||||
best_trial = self.get_best_trial(metric, mode, scope)
|
||||
return best_trial.config if best_trial else None
|
||||
|
||||
def get_best_logdir(self, metric=None, mode=None, scope="last"):
|
||||
def get_best_logdir(self,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
scope: str = "last") -> Optional[str]:
|
||||
"""Retrieve the logdir corresponding to the best trial.
|
||||
|
||||
Compares all trials' scores on `metric`.
|
||||
@@ -589,15 +616,15 @@ class ExperimentAnalysis(Analysis):
|
||||
best_trial = self.get_best_trial(metric, mode, scope)
|
||||
return best_trial.logdir if best_trial else None
|
||||
|
||||
def stats(self):
|
||||
def stats(self) -> Dict:
|
||||
"""Returns a dictionary of the statistics of the experiment."""
|
||||
return self._experiment_state.get("stats")
|
||||
|
||||
def runner_data(self):
|
||||
def runner_data(self) -> Dict:
|
||||
"""Returns a dictionary of the TrialRunner data."""
|
||||
return self._experiment_state.get("runner_data")
|
||||
|
||||
def _get_trial_paths(self):
|
||||
def _get_trial_paths(self) -> List[str]:
|
||||
"""Overwrites Analysis to only have trials of one experiment."""
|
||||
if self.trials:
|
||||
_trial_paths = [t.logdir for t in self.trials]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Callable, Dict, Type
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
import ray
|
||||
@@ -15,11 +17,11 @@ from horovod.ray import RayExecutor
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_rank():
|
||||
def get_rank() -> str:
|
||||
return os.environ["HOROVOD_RANK"]
|
||||
|
||||
|
||||
def logger_creator(log_config, logdir):
|
||||
def logger_creator(log_config: Dict, logdir: str) -> NoopLogger:
|
||||
"""Simple NOOP logger for worker trainables."""
|
||||
index = get_rank()
|
||||
worker_dir = os.path.join(logdir, "worker_{}".format(index))
|
||||
@@ -51,7 +53,7 @@ class _HorovodTrainable(tune.Trainable):
|
||||
def num_workers(self):
|
||||
return self._num_hosts * self._num_slots
|
||||
|
||||
def setup(self, config):
|
||||
def setup(self, config: Dict):
|
||||
trainable = wrap_function(self.__class__._function)
|
||||
# We use a filelock here to ensure that the file-writing
|
||||
# process is safe across different trainables.
|
||||
@@ -82,7 +84,7 @@ class _HorovodTrainable(tune.Trainable):
|
||||
"logger_creator": lambda cfg: logger_creator(cfg, logdir_)
|
||||
})
|
||||
|
||||
def step(self):
|
||||
def step(self) -> Dict:
|
||||
if self._finished:
|
||||
raise RuntimeError("Training has already finished.")
|
||||
result = self.executor.execute(lambda w: w.step())[0]
|
||||
@@ -90,14 +92,14 @@ class _HorovodTrainable(tune.Trainable):
|
||||
self._finished = True
|
||||
return result
|
||||
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir: str) -> str:
|
||||
# TODO: optimize if colocated
|
||||
save_obj = self.executor.execute_single(lambda w: w.save_to_object())
|
||||
checkpoint_path = TrainableUtil.create_from_pickle(
|
||||
save_obj, checkpoint_dir)
|
||||
return checkpoint_path
|
||||
|
||||
def load_checkpoint(self, checkpoint_dir):
|
||||
def load_checkpoint(self, checkpoint_dir: str):
|
||||
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
|
||||
x_id = ray.put(checkpoint_obj)
|
||||
return self.executor.execute(lambda w: w.restore_from_object(x_id))
|
||||
@@ -107,13 +109,14 @@ class _HorovodTrainable(tune.Trainable):
|
||||
self.executor.shutdown()
|
||||
|
||||
|
||||
def DistributedTrainableCreator(func,
|
||||
use_gpu=False,
|
||||
num_hosts=1,
|
||||
num_slots=1,
|
||||
num_cpus_per_slot=1,
|
||||
timeout_s=30,
|
||||
replicate_pem=False):
|
||||
def DistributedTrainableCreator(
|
||||
func: Callable,
|
||||
use_gpu: bool = False,
|
||||
num_hosts: int = 1,
|
||||
num_slots: int = 1,
|
||||
num_cpus_per_slot: int = 1,
|
||||
timeout_s: int = 30,
|
||||
replicate_pem: bool = False) -> Type[_HorovodTrainable]:
|
||||
"""Converts Horovod functions to be executable by Tune.
|
||||
|
||||
Requires horovod > 0.19 to work.
|
||||
@@ -198,7 +201,7 @@ def DistributedTrainableCreator(func,
|
||||
_timeout_s = timeout_s
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
def default_resource_request(cls, config: Dict):
|
||||
extra_gpu = int(num_hosts * num_slots) * int(use_gpu)
|
||||
extra_cpu = int(num_hosts * num_slots * num_cpus_per_slot)
|
||||
|
||||
@@ -216,7 +219,7 @@ def DistributedTrainableCreator(func,
|
||||
# that force us to include mocks as part of the module.
|
||||
|
||||
|
||||
def _train_simple(config):
|
||||
def _train_simple(config: Dict):
|
||||
import horovod.torch as hvd
|
||||
hvd.init()
|
||||
from ray import tune
|
||||
|
||||
@@ -39,7 +39,7 @@ class TuneCallback(Callback):
|
||||
on, self._allowed))
|
||||
self._on = on
|
||||
|
||||
def _handle(self, logs: Dict):
|
||||
def _handle(self, logs: Dict, when: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def on_batch_begin(self, batch, logs=None):
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import kubernetes
|
||||
import subprocess
|
||||
|
||||
@@ -45,7 +47,10 @@ class KubernetesSyncer(NodeSyncer):
|
||||
|
||||
_namespace = "ray"
|
||||
|
||||
def __init__(self, local_dir, remote_dir, sync_client=None):
|
||||
def __init__(self,
|
||||
local_dir: str,
|
||||
remote_dir: str,
|
||||
sync_client: Optional[SyncClient] = None):
|
||||
self.local_ip = services.get_node_ip_address()
|
||||
self.local_node = self._get_kubernetes_node_by_ip(self.local_ip)
|
||||
self.worker_ip = None
|
||||
@@ -56,11 +61,11 @@ class KubernetesSyncer(NodeSyncer):
|
||||
|
||||
super(NodeSyncer, self).__init__(local_dir, remote_dir, sync_client)
|
||||
|
||||
def set_worker_ip(self, worker_ip):
|
||||
def set_worker_ip(self, worker_ip: str):
|
||||
self.worker_ip = worker_ip
|
||||
self.worker_node = self._get_kubernetes_node_by_ip(worker_ip)
|
||||
|
||||
def _get_kubernetes_node_by_ip(self, node_ip):
|
||||
def _get_kubernetes_node_by_ip(self, node_ip: str) -> Optional[str]:
|
||||
"""Return node name by internal or external IP"""
|
||||
kubernetes.config.load_incluster_config()
|
||||
api = kubernetes.client.CoreV1Api()
|
||||
@@ -75,8 +80,8 @@ class KubernetesSyncer(NodeSyncer):
|
||||
return None
|
||||
|
||||
@property
|
||||
def _remote_path(self):
|
||||
return (self.worker_node, self._remote_dir)
|
||||
def _remote_path(self) -> Tuple[str, str]:
|
||||
return self.worker_node, self._remote_dir
|
||||
|
||||
|
||||
class KubernetesSyncClient(SyncClient):
|
||||
@@ -95,12 +100,12 @@ class KubernetesSyncClient(SyncClient):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, namespace, process_runner=subprocess):
|
||||
def __init__(self, namespace: str, process_runner: Any = subprocess):
|
||||
self.namespace = namespace
|
||||
self._process_runner = process_runner
|
||||
self._command_runners = {}
|
||||
|
||||
def _create_command_runner(self, node_id):
|
||||
def _create_command_runner(self, node_id: str) -> KubernetesCommandRunner:
|
||||
"""Create a command runner for one Kubernetes node"""
|
||||
return KubernetesCommandRunner(
|
||||
log_prefix="KubernetesSyncClient: {}:".format(node_id),
|
||||
@@ -109,7 +114,7 @@ class KubernetesSyncClient(SyncClient):
|
||||
auth_config=None,
|
||||
process_runner=self._process_runner)
|
||||
|
||||
def _get_command_runner(self, node_id):
|
||||
def _get_command_runner(self, node_id: str) -> KubernetesCommandRunner:
|
||||
"""Create command runner if it doesn't exist"""
|
||||
# Todo(krfricke): These cached runners are currently
|
||||
# never cleaned up. They are cheap so this shouldn't
|
||||
@@ -120,7 +125,7 @@ class KubernetesSyncClient(SyncClient):
|
||||
self._command_runners[node_id] = command_runner
|
||||
return self._command_runners[node_id]
|
||||
|
||||
def sync_up(self, source, target):
|
||||
def sync_up(self, source: str, target: Tuple[str, str]) -> bool:
|
||||
"""Here target is a tuple (target_node, target_dir)"""
|
||||
target_node, target_dir = target
|
||||
|
||||
@@ -132,7 +137,7 @@ class KubernetesSyncClient(SyncClient):
|
||||
command_runner.run_rsync_up(source, target_dir)
|
||||
return True
|
||||
|
||||
def sync_down(self, source, target):
|
||||
def sync_down(self, source: Tuple[str, str], target: str) -> bool:
|
||||
"""Here source is a tuple (source_node, source_dir)"""
|
||||
source_node, source_dir = source
|
||||
|
||||
@@ -144,7 +149,7 @@ class KubernetesSyncClient(SyncClient):
|
||||
command_runner.run_rsync_down(source_dir, target)
|
||||
return True
|
||||
|
||||
def delete(self, target):
|
||||
def delete(self, target: str) -> bool:
|
||||
"""No delete function because it is only used by
|
||||
the KubernetesSyncer, which doesn't call delete."""
|
||||
return True
|
||||
|
||||
@@ -2,7 +2,9 @@ from typing import Dict, List, Union
|
||||
|
||||
from ray import tune
|
||||
|
||||
from mxnet.model import save_checkpoint
|
||||
import mxnet
|
||||
from mxnet.model import save_checkpoint, BatchEndParam
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
|
||||
@@ -49,7 +51,7 @@ class TuneReportCallback(TuneCallback):
|
||||
metrics = [metrics]
|
||||
self._metrics = metrics
|
||||
|
||||
def __call__(self, param):
|
||||
def __call__(self, param: BatchEndParam):
|
||||
if not param.eval_metric:
|
||||
return
|
||||
if not self._metrics:
|
||||
@@ -110,7 +112,8 @@ class TuneCheckpointCallback(TuneCallback):
|
||||
self._filename = filename
|
||||
self._frequency = frequency
|
||||
|
||||
def __call__(self, epoch, sym, arg, aux):
|
||||
def __call__(self, epoch: int, sym: mxnet.symbol.Symbol,
|
||||
arg: Dict[str, np.ndarray], aux: Dict[str, np.ndarray]):
|
||||
if epoch % self._frequency != 0:
|
||||
return
|
||||
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
|
||||
|
||||
@@ -5,6 +5,8 @@ import os
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Callable, Dict, Generator, Optional, Type
|
||||
|
||||
import torch
|
||||
from datetime import timedelta
|
||||
|
||||
@@ -34,7 +36,7 @@ def enable_distributed_trainable():
|
||||
_distributed_enabled = True
|
||||
|
||||
|
||||
def logger_creator(log_config, logdir, rank):
|
||||
def logger_creator(log_config: Dict, logdir: str, rank: int) -> NoopLogger:
|
||||
worker_dir = os.path.join(logdir, "worker_{}".format(rank))
|
||||
os.makedirs(worker_dir, exist_ok=True)
|
||||
return NoopLogger(log_config, worker_dir)
|
||||
@@ -54,16 +56,16 @@ class _TorchTrainable(tune.Trainable):
|
||||
__slots__ = ["workers", "_finished"]
|
||||
|
||||
@classmethod
|
||||
def default_process_group_parameters(self):
|
||||
def default_process_group_parameters(self) -> Dict:
|
||||
return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo")
|
||||
|
||||
@classmethod
|
||||
def get_remote_worker_options(self):
|
||||
def get_remote_worker_options(self) -> Dict[str, int]:
|
||||
num_gpus = 1 if self._use_gpu else 0
|
||||
num_cpus = int(self._num_cpus_per_worker or 1)
|
||||
return dict(num_cpus=num_cpus, num_gpus=num_gpus)
|
||||
|
||||
def setup(self, config):
|
||||
def setup(self, config: Dict):
|
||||
self._finished = False
|
||||
num_workers = self._num_workers
|
||||
logdir = self.logdir
|
||||
@@ -103,7 +105,7 @@ class _TorchTrainable(tune.Trainable):
|
||||
for rank, w in enumerate(self.workers)
|
||||
])
|
||||
|
||||
def step(self):
|
||||
def step(self) -> Dict:
|
||||
if self._finished:
|
||||
raise RuntimeError("Training has already finished.")
|
||||
result = ray.get([w.step.remote() for w in self.workers])[0]
|
||||
@@ -111,14 +113,14 @@ class _TorchTrainable(tune.Trainable):
|
||||
self._finished = True
|
||||
return result
|
||||
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
def save_checkpoint(self, checkpoint_dir: str) -> str:
|
||||
# TODO: optimize if colocated
|
||||
save_obj = ray.get(self.workers[0].save_to_object.remote())
|
||||
checkpoint_path = TrainableUtil.create_from_pickle(
|
||||
save_obj, checkpoint_dir)
|
||||
return checkpoint_path
|
||||
|
||||
def load_checkpoint(self, checkpoint_dir):
|
||||
def load_checkpoint(self, checkpoint_dir: str):
|
||||
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
|
||||
return ray.get(
|
||||
w.restore_from_object.remote(checkpoint_obj) for w in self.workers)
|
||||
@@ -127,12 +129,13 @@ class _TorchTrainable(tune.Trainable):
|
||||
ray.get([worker.stop.remote() for worker in self.workers])
|
||||
|
||||
|
||||
def DistributedTrainableCreator(func,
|
||||
use_gpu=False,
|
||||
num_workers=1,
|
||||
num_cpus_per_worker=1,
|
||||
backend="gloo",
|
||||
timeout_s=NCCL_TIMEOUT_S):
|
||||
def DistributedTrainableCreator(
|
||||
func: Callable,
|
||||
use_gpu: bool = False,
|
||||
num_workers: int = 1,
|
||||
num_cpus_per_worker: int = 1,
|
||||
backend: str = "gloo",
|
||||
timeout_s: int = NCCL_TIMEOUT_S) -> Type[_TorchTrainable]:
|
||||
"""Creates a class that executes distributed training.
|
||||
|
||||
Similar to running `torch.distributed.launch`.
|
||||
@@ -179,11 +182,11 @@ def DistributedTrainableCreator(func,
|
||||
_num_cpus_per_worker = num_cpus_per_worker
|
||||
|
||||
@classmethod
|
||||
def default_process_group_parameters(self):
|
||||
def default_process_group_parameters(self) -> Dict:
|
||||
return dict(timeout=timedelta(timeout_s), backend=backend)
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
def default_resource_request(cls, config: Dict) -> Resources:
|
||||
num_workers_ = int(config.get("num_workers", num_workers))
|
||||
num_cpus = int(
|
||||
config.get("num_cpus_per_worker", num_cpus_per_worker))
|
||||
@@ -199,7 +202,8 @@ def DistributedTrainableCreator(func,
|
||||
|
||||
|
||||
@contextmanager
|
||||
def distributed_checkpoint_dir(step, disable=False):
|
||||
def distributed_checkpoint_dir(
|
||||
step: int, disable: bool = False) -> Generator[str, None, None]:
|
||||
"""ContextManager for creating a distributed checkpoint.
|
||||
|
||||
Only checkpoints a file on the "main" training actor, avoiding
|
||||
@@ -236,7 +240,7 @@ def distributed_checkpoint_dir(step, disable=False):
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
||||
def _train_check_global(config, checkpoint_dir=None):
|
||||
def _train_check_global(config: Dict, checkpoint_dir: Optional[str] = None):
|
||||
"""For testing only. Putting this here because Ray has problems
|
||||
serializing within the test file."""
|
||||
assert is_distributed_trainable()
|
||||
@@ -245,7 +249,7 @@ def _train_check_global(config, checkpoint_dir=None):
|
||||
tune.report(is_distributed=True)
|
||||
|
||||
|
||||
def _train_simple(config, checkpoint_dir=None):
|
||||
def _train_simple(config: Dict, checkpoint_dir: Optional[str] = None):
|
||||
"""For testing only. Putting this here because Ray has problems
|
||||
serializing within the test file."""
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import pickle
|
||||
from multiprocessing import Process, Queue
|
||||
from numbers import Number
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
import numpy as np
|
||||
|
||||
from ray import logger
|
||||
@@ -70,7 +71,7 @@ def _clean_log(obj):
|
||||
return fallback
|
||||
|
||||
|
||||
def wandb_mixin(func):
|
||||
def wandb_mixin(func: Callable):
|
||||
"""wandb_mixin
|
||||
|
||||
Weights and biases (https://www.wandb.com/) is a tool for experiment
|
||||
@@ -142,7 +143,7 @@ def wandb_mixin(func):
|
||||
return func
|
||||
|
||||
|
||||
def _set_api_key(wandb_config):
|
||||
def _set_api_key(wandb_config: Dict):
|
||||
"""Set WandB API key from `wandb_config`. Will pop the
|
||||
`api_key_file` and `api_key` keys from `wandb_config` parameter"""
|
||||
api_key_file = os.path.expanduser(wandb_config.pop("api_key_file", ""))
|
||||
@@ -176,7 +177,8 @@ class _WandbLoggingProcess(Process):
|
||||
wandb logging instances locally.
|
||||
"""
|
||||
|
||||
def __init__(self, queue, exclude, to_config, *args, **kwargs):
|
||||
def __init__(self, queue: Queue, exclude: List[str], to_config: List[str],
|
||||
*args, **kwargs):
|
||||
super(_WandbLoggingProcess, self).__init__()
|
||||
self.queue = queue
|
||||
self._exclude = set(exclude)
|
||||
@@ -195,7 +197,7 @@ class _WandbLoggingProcess(Process):
|
||||
wandb.log(log)
|
||||
wandb.join()
|
||||
|
||||
def _handle_result(self, result):
|
||||
def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
|
||||
config_update = result.get("config", {}).copy()
|
||||
log = {}
|
||||
flat_result = flatten_dict(result, delimiter="/")
|
||||
@@ -377,7 +379,7 @@ class WandbLogger(Logger):
|
||||
**wandb_init_kwargs)
|
||||
self._wandb.start()
|
||||
|
||||
def on_result(self, result):
|
||||
def on_result(self, result: Dict):
|
||||
result = _clean_log(result)
|
||||
self._queue.put(result)
|
||||
|
||||
@@ -389,7 +391,7 @@ class WandbLogger(Logger):
|
||||
class WandbTrainableMixin:
|
||||
_wandb = wandb
|
||||
|
||||
def __init__(self, config, *args, **kwargs):
|
||||
def __init__(self, config: Dict, *args, **kwargs):
|
||||
if not isinstance(self, Trainable):
|
||||
raise ValueError(
|
||||
"The `WandbTrainableMixin` can only be used as a mixin "
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import logging
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.tune import trial_runner
|
||||
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,14 +40,14 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
time_attr="training_iteration",
|
||||
reward_attr=None,
|
||||
metric=None,
|
||||
mode=None,
|
||||
max_t=100,
|
||||
grace_period=1,
|
||||
reduction_factor=4,
|
||||
brackets=1):
|
||||
time_attr: str = "training_iteration",
|
||||
reward_attr: Optional[str] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
max_t: int = 100,
|
||||
grace_period: int = 1,
|
||||
reduction_factor: float = 4,
|
||||
brackets: int = 1):
|
||||
assert max_t > 0, "Max (time_attr) not valid!"
|
||||
assert max_t >= grace_period, "grace_period must be <= max_t!"
|
||||
assert grace_period > 0, "grace_period must be positive!"
|
||||
@@ -82,7 +86,8 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
||||
self._metric_op = -1.
|
||||
self._time_attr = time_attr
|
||||
|
||||
def set_search_properties(self, metric, mode):
|
||||
def set_search_properties(self, metric: Optional[str],
|
||||
mode: Optional[str]) -> bool:
|
||||
if self._metric and metric:
|
||||
return False
|
||||
if self._mode and mode:
|
||||
@@ -100,7 +105,8 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
||||
|
||||
return True
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
if not self._metric or not self._metric_op:
|
||||
raise ValueError(
|
||||
"{} has been instantiated without a valid `metric` ({}) or "
|
||||
@@ -115,7 +121,8 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
||||
idx = np.random.choice(len(self._brackets), p=normalized)
|
||||
self._trial_info[trial.trial_id] = self._brackets[idx]
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict) -> str:
|
||||
action = TrialScheduler.CONTINUE
|
||||
if self._time_attr not in result or self._metric not in result:
|
||||
return action
|
||||
@@ -129,7 +136,8 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
||||
self._num_stopped += 1
|
||||
return action
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict):
|
||||
if self._time_attr not in result or self._metric not in result:
|
||||
return
|
||||
bracket = self._trial_info[trial.trial_id]
|
||||
@@ -137,10 +145,11 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
||||
self._metric_op * result[self._metric])
|
||||
del self._trial_info[trial.trial_id]
|
||||
|
||||
def on_trial_remove(self, trial_runner, trial):
|
||||
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
del self._trial_info[trial.trial_id]
|
||||
|
||||
def debug_string(self):
|
||||
def debug_string(self) -> str:
|
||||
out = "Using AsyncHyperBand: num_stopped={}".format(self._num_stopped)
|
||||
out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
|
||||
return out
|
||||
@@ -161,19 +170,21 @@ class _Bracket():
|
||||
>>> b.cutoff(b._rungs[3][1]) == 2.0
|
||||
"""
|
||||
|
||||
def __init__(self, min_t, max_t, reduction_factor, s):
|
||||
def __init__(self, min_t: int, max_t: int, reduction_factor: float,
|
||||
s: int):
|
||||
self.rf = reduction_factor
|
||||
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
|
||||
self._rungs = [(min_t * self.rf**(k + s), {})
|
||||
for k in reversed(range(MAX_RUNGS))]
|
||||
|
||||
def cutoff(self, recorded):
|
||||
def cutoff(self, recorded) -> Union[None, int, float, complex, np.ndarray]:
|
||||
if not recorded:
|
||||
return None
|
||||
return np.nanpercentile(
|
||||
list(recorded.values()), (1 - 1 / self.rf) * 100)
|
||||
|
||||
def on_result(self, trial, cur_iter, cur_rew):
|
||||
def on_result(self, trial: Trial, cur_iter: int,
|
||||
cur_rew: Optional[float]) -> str:
|
||||
action = TrialScheduler.CONTINUE
|
||||
for milestone, recorded in self._rungs:
|
||||
if cur_iter < milestone or trial.trial_id in recorded:
|
||||
@@ -190,7 +201,7 @@ class _Bracket():
|
||||
break
|
||||
return action
|
||||
|
||||
def debug_str(self):
|
||||
def debug_str(self) -> str:
|
||||
# TODO: fix up the output for this
|
||||
iters = " | ".join([
|
||||
"Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ray.tune import trial_runner
|
||||
from ray.tune.schedulers.trial_scheduler import TrialScheduler
|
||||
from ray.tune.schedulers.hyperband import HyperBandScheduler, Bracket
|
||||
from ray.tune.trial import Trial
|
||||
@@ -22,7 +24,8 @@ class HyperBandForBOHB(HyperBandScheduler):
|
||||
See ray.tune.schedulers.HyperBandScheduler for parameter docstring.
|
||||
"""
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
"""Adds new trial.
|
||||
|
||||
On a new trial add, if current bracket is not filled, add to current
|
||||
@@ -67,7 +70,8 @@ class HyperBandForBOHB(HyperBandScheduler):
|
||||
self._state["bracket"].add_trial(trial)
|
||||
self._trial_info[trial] = cur_bracket, self._state["band_idx"]
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict) -> str:
|
||||
"""If bracket is finished, all trials will be stopped.
|
||||
|
||||
If a given trial finishes and bracket iteration is not done,
|
||||
@@ -96,11 +100,14 @@ class HyperBandForBOHB(HyperBandScheduler):
|
||||
action = self._process_bracket(trial_runner, bracket)
|
||||
return action
|
||||
|
||||
def _unpause_trial(self, trial_runner, trial):
|
||||
def _unpause_trial(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
trial_runner.trial_executor.unpause_trial(trial)
|
||||
trial_runner._search_alg.searcher.on_unpause(trial.trial_id)
|
||||
|
||||
def choose_trial_to_run(self, trial_runner, allow_recurse=True):
|
||||
def choose_trial_to_run(self,
|
||||
trial_runner: "trial_runner.TrialRunner",
|
||||
allow_recurse: bool = True) -> Optional[Trial]:
|
||||
"""Fair scheduling within iteration by completion percentage.
|
||||
|
||||
List of trials not used since all trials are tracked as state
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import collections
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
from ray.tune import trial_runner
|
||||
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.error import TuneError
|
||||
@@ -74,12 +77,12 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
time_attr="training_iteration",
|
||||
reward_attr=None,
|
||||
metric=None,
|
||||
mode=None,
|
||||
max_t=81,
|
||||
reduction_factor=3):
|
||||
time_attr: str = "training_iteration",
|
||||
reward_attr: Optional[str] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
max_t: int = 81,
|
||||
reduction_factor: float = 3):
|
||||
assert max_t > 0, "Max (time_attr) not valid!"
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
@@ -118,7 +121,8 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
self._metric_op = -1.
|
||||
self._time_attr = time_attr
|
||||
|
||||
def set_search_properties(self, metric, mode):
|
||||
def set_search_properties(self, metric: Optional[str],
|
||||
mode: Optional[str]) -> bool:
|
||||
if self._metric and metric:
|
||||
return False
|
||||
if self._mode and mode:
|
||||
@@ -136,7 +140,8 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
|
||||
return True
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
"""Adds new trial.
|
||||
|
||||
On a new trial add, if current bracket is not filled,
|
||||
@@ -179,7 +184,7 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
self._state["bracket"].add_trial(trial)
|
||||
self._trial_info[trial] = cur_bracket, self._state["band_idx"]
|
||||
|
||||
def _cur_band_filled(self):
|
||||
def _cur_band_filled(self) -> bool:
|
||||
"""Checks if the current band is filled.
|
||||
|
||||
The size of the current band should be equal to s_max_1"""
|
||||
@@ -187,7 +192,8 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
cur_band = self._hyperbands[self._state["band_idx"]]
|
||||
return len(cur_band) == self._s_max_1
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict):
|
||||
"""If bracket is finished, all trials will be stopped.
|
||||
|
||||
If a given trial finishes and bracket iteration is not done,
|
||||
@@ -211,7 +217,8 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
metric_val=result.get(self._time_attr)))
|
||||
return action
|
||||
|
||||
def _process_bracket(self, trial_runner, bracket):
|
||||
def _process_bracket(self, trial_runner: "trial_runner.TrialRunner",
|
||||
bracket: "Bracket") -> str:
|
||||
"""This is called whenever a trial makes progress.
|
||||
|
||||
When all live trials in the bracket have no more iterations left,
|
||||
@@ -250,7 +257,8 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
action = TrialScheduler.CONTINUE
|
||||
return action
|
||||
|
||||
def on_trial_remove(self, trial_runner, trial):
|
||||
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
"""Notification when trial terminates.
|
||||
|
||||
Trial info is removed from bracket. Triggers halving if bracket is
|
||||
@@ -260,15 +268,18 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
if not bracket.finished():
|
||||
self._process_bracket(trial_runner, bracket)
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict):
|
||||
"""Cleans up trial info from bracket if trial completed early."""
|
||||
self.on_trial_remove(trial_runner, trial)
|
||||
|
||||
def on_trial_error(self, trial_runner, trial):
|
||||
def on_trial_error(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
"""Cleans up trial info from bracket if trial errored early."""
|
||||
self.on_trial_remove(trial_runner, trial)
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
def choose_trial_to_run(
|
||||
self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]:
|
||||
"""Fair scheduling within iteration by completion percentage.
|
||||
|
||||
List of trials not used since all trials are tracked as state
|
||||
@@ -288,7 +299,7 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
return trial
|
||||
return None
|
||||
|
||||
def debug_string(self):
|
||||
def debug_string(self) -> str:
|
||||
"""This provides a progress notification for the algorithm.
|
||||
|
||||
For each bracket, the algorithm will output a string as follows:
|
||||
@@ -315,13 +326,14 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
out += "\n {}".format(bracket)
|
||||
return out
|
||||
|
||||
def state(self):
|
||||
def state(self) -> Dict[str, int]:
|
||||
return {
|
||||
"num_brackets": sum(len(band) for band in self._hyperbands),
|
||||
"num_stopped": self._num_stopped
|
||||
}
|
||||
|
||||
def _unpause_trial(self, trial_runner, trial):
|
||||
def _unpause_trial(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
trial_runner.trial_executor.unpause_trial(trial)
|
||||
|
||||
|
||||
@@ -332,7 +344,8 @@ class Bracket:
|
||||
Also keeps track of progress to ensure good scheduling.
|
||||
"""
|
||||
|
||||
def __init__(self, time_attr, max_trials, init_t_attr, max_t_attr, eta, s):
|
||||
def __init__(self, time_attr: str, max_trials: int, init_t_attr: int,
|
||||
max_t_attr: int, eta: float, s: int):
|
||||
self._live_trials = {} # maps trial -> current result
|
||||
self._all_trials = []
|
||||
self._time_attr = time_attr # attribute to
|
||||
@@ -348,7 +361,7 @@ class Bracket:
|
||||
self._total_work = self._calculate_total_work(self._n0, self._r0, s)
|
||||
self._completed_progress = 0
|
||||
|
||||
def add_trial(self, trial):
|
||||
def add_trial(self, trial: Trial):
|
||||
"""Add trial to bracket assuming bracket is not filled.
|
||||
|
||||
At a later iteration, a newly added trial will be given equal
|
||||
@@ -357,7 +370,7 @@ class Bracket:
|
||||
self._live_trials[trial] = None
|
||||
self._all_trials.append(trial)
|
||||
|
||||
def cur_iter_done(self):
|
||||
def cur_iter_done(self) -> bool:
|
||||
"""Checks if all iterations have completed.
|
||||
|
||||
TODO(rliaw): also check that `t.iterations == self._r`"""
|
||||
@@ -365,20 +378,20 @@ class Bracket:
|
||||
self._get_result_time(result) >= self._cumul_r
|
||||
for result in self._live_trials.values())
|
||||
|
||||
def finished(self):
|
||||
def finished(self) -> bool:
|
||||
return self._halves == 0 and self.cur_iter_done()
|
||||
|
||||
def current_trials(self):
|
||||
def current_trials(self) -> List[Trial]:
|
||||
return list(self._live_trials)
|
||||
|
||||
def continue_trial(self, trial):
|
||||
def continue_trial(self, trial: Trial) -> bool:
|
||||
result = self._live_trials[trial]
|
||||
if self._get_result_time(result) < self._cumul_r:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def filled(self):
|
||||
def filled(self) -> bool:
|
||||
"""Checks if bracket is filled.
|
||||
|
||||
Only let new trials be added at current level minimizing the need
|
||||
@@ -386,7 +399,8 @@ class Bracket:
|
||||
|
||||
return len(self._live_trials) == self._n
|
||||
|
||||
def successive_halving(self, metric, metric_op):
|
||||
def successive_halving(self, metric: str, metric_op: float
|
||||
) -> Tuple[List[Trial], List[Trial]]:
|
||||
assert self._halves > 0
|
||||
self._halves -= 1
|
||||
self._n /= self._eta
|
||||
@@ -402,7 +416,7 @@ class Bracket:
|
||||
good, bad = sorted_trials[-self._n:], sorted_trials[:-self._n]
|
||||
return good, bad
|
||||
|
||||
def update_trial_stats(self, trial, result):
|
||||
def update_trial_stats(self, trial: Trial, result: Dict):
|
||||
"""Update result for trial. Called after trial has finished
|
||||
an iteration - will decrement iteration count.
|
||||
|
||||
@@ -422,7 +436,7 @@ class Bracket:
|
||||
self._completed_progress += delta
|
||||
self._live_trials[trial] = result
|
||||
|
||||
def cleanup_trial(self, trial):
|
||||
def cleanup_trial(self, trial: Trial):
|
||||
"""Clean up statistics tracking for terminated trials (either by force
|
||||
or otherwise).
|
||||
|
||||
@@ -432,7 +446,7 @@ class Bracket:
|
||||
assert trial in self._live_trials
|
||||
del self._live_trials[trial]
|
||||
|
||||
def cleanup_full(self, trial_runner):
|
||||
def cleanup_full(self, trial_runner: "trial_runner.TrialRunner"):
|
||||
"""Cleans up bracket after bracket is completely finished.
|
||||
|
||||
Lets the last trial continue to run until termination condition
|
||||
@@ -441,7 +455,7 @@ class Bracket:
|
||||
if (trial.status == Trial.PAUSED):
|
||||
trial_runner.stop_trial(trial)
|
||||
|
||||
def completion_percentage(self):
|
||||
def completion_percentage(self) -> float:
|
||||
"""Returns a progress metric.
|
||||
|
||||
This will not be always finish with 100 since dead trials
|
||||
@@ -450,12 +464,12 @@ class Bracket:
|
||||
return 1.0
|
||||
return self._completed_progress / self._total_work
|
||||
|
||||
def _get_result_time(self, result):
|
||||
def _get_result_time(self, result: Dict) -> float:
|
||||
if result is None:
|
||||
return 0
|
||||
return result[self._time_attr]
|
||||
|
||||
def _calculate_total_work(self, n, r, s):
|
||||
def _calculate_total_work(self, n: int, r: float, s: int):
|
||||
work = 0
|
||||
cumulative_r = r
|
||||
for _ in range(s + 1):
|
||||
@@ -466,7 +480,7 @@ class Bracket:
|
||||
r = int(min(r, self._max_t_attr - cumulative_r))
|
||||
return work
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
status = ", ".join([
|
||||
"Max Size (n)={}".format(self._n),
|
||||
"Milestone (r)={}".format(self._cumul_r),
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import collections
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.tune import trial_runner
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
@@ -38,14 +41,14 @@ class MedianStoppingRule(FIFOScheduler):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
time_attr="time_total_s",
|
||||
reward_attr=None,
|
||||
metric=None,
|
||||
mode=None,
|
||||
grace_period=60.0,
|
||||
min_samples_required=3,
|
||||
min_time_slice=0,
|
||||
hard_stop=True):
|
||||
time_attr: str = "time_total_s",
|
||||
reward_attr: Optional[str] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
grace_period: float = 60.0,
|
||||
min_samples_required: int = 3,
|
||||
min_time_slice: int = 0,
|
||||
hard_stop: bool = True):
|
||||
if reward_attr is not None:
|
||||
mode = "max"
|
||||
metric = reward_attr
|
||||
@@ -75,7 +78,8 @@ class MedianStoppingRule(FIFOScheduler):
|
||||
self._last_pause = collections.defaultdict(lambda: float("-inf"))
|
||||
self._results = collections.defaultdict(list)
|
||||
|
||||
def set_search_properties(self, metric, mode):
|
||||
def set_search_properties(self, metric: Optional[str],
|
||||
mode: Optional[str]) -> bool:
|
||||
if self._metric and metric:
|
||||
return False
|
||||
if self._mode and mode:
|
||||
@@ -91,7 +95,8 @@ class MedianStoppingRule(FIFOScheduler):
|
||||
|
||||
return True
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
if not self._metric or not self._worst or not self._compare_op:
|
||||
raise ValueError(
|
||||
"{} has been instantiated without a valid `metric` ({}) or "
|
||||
@@ -102,7 +107,8 @@ class MedianStoppingRule(FIFOScheduler):
|
||||
|
||||
super(MedianStoppingRule, self).on_trial_add(trial_runner, trial)
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict) -> str:
|
||||
"""Callback for early stopping.
|
||||
|
||||
This stopping rule stops a running trial if the trial's best objective
|
||||
@@ -154,14 +160,17 @@ class MedianStoppingRule(FIFOScheduler):
|
||||
else:
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict):
|
||||
self._results[trial].append(result)
|
||||
|
||||
def debug_string(self):
|
||||
def debug_string(self) -> str:
|
||||
return "Using MedianStoppingRule: num_stopped={}.".format(
|
||||
len(self._stopped_trials))
|
||||
|
||||
def _on_insufficient_samples(self, trial_runner, trial, time):
|
||||
def _on_insufficient_samples(self,
|
||||
trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, time: float) -> str:
|
||||
pause = time - self._last_pause[trial] > self._min_time_slice
|
||||
pause = pause and [
|
||||
t for t in trial_runner.get_trials()
|
||||
@@ -169,17 +178,17 @@ class MedianStoppingRule(FIFOScheduler):
|
||||
]
|
||||
return TrialScheduler.PAUSE if pause else TrialScheduler.CONTINUE
|
||||
|
||||
def _trials_beyond_time(self, time):
|
||||
def _trials_beyond_time(self, time: float) -> List[Trial]:
|
||||
trials = [
|
||||
trial for trial in self._results
|
||||
if self._results[trial][-1][self._time_attr] >= time
|
||||
]
|
||||
return trials
|
||||
|
||||
def _median_result(self, trials, time):
|
||||
def _median_result(self, trials: List[Trial], time: float):
|
||||
return np.median([self._running_mean(trial, time) for trial in trials])
|
||||
|
||||
def _running_mean(self, trial, time):
|
||||
def _running_mean(self, trial: Trial, time: float) -> np.ndarray:
|
||||
results = self._results[trial]
|
||||
# TODO(ekl) we could do interpolation to be more precise, but for now
|
||||
# assume len(results) is large and the time diffs are roughly equal
|
||||
|
||||
@@ -5,7 +5,10 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ray.tune import trial_runner
|
||||
from ray.tune import trial_executor
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.logger import _SafeFallbackEncoder
|
||||
@@ -13,7 +16,6 @@ from ray.tune.sample import Domain, Function
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest.variant_generator import format_vars
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
|
||||
from ray.util.debug import log_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,7 +24,7 @@ logger = logging.getLogger(__name__)
|
||||
class PBTTrialState:
|
||||
"""Internal PBT state tracked per-trial."""
|
||||
|
||||
def __init__(self, trial):
|
||||
def __init__(self, trial: Trial):
|
||||
self.orig_tag = trial.experiment_tag
|
||||
self.last_score = None
|
||||
self.last_checkpoint = None
|
||||
@@ -30,12 +32,13 @@ class PBTTrialState:
|
||||
self.last_train_time = 0 # Used for synchronous mode.
|
||||
self.last_result = None # Used for synchronous mode.
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return str((self.last_score, self.last_checkpoint,
|
||||
self.last_train_time, self.last_perturbation_time))
|
||||
|
||||
|
||||
def explore(config, mutations, resample_probability, custom_explore_fn):
|
||||
def explore(config: Dict, mutations: Dict, resample_probability: float,
|
||||
custom_explore_fn: Optional[Callable]) -> Dict:
|
||||
"""Return a config perturbed as specified.
|
||||
|
||||
Args:
|
||||
@@ -83,7 +86,7 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
|
||||
return new_config
|
||||
|
||||
|
||||
def make_experiment_tag(orig_tag, config, mutations):
|
||||
def make_experiment_tag(orig_tag: str, config: Dict, mutations: Dict) -> str:
|
||||
"""Appends perturbed params to the trial name to show in the console."""
|
||||
|
||||
resolved_vars = {}
|
||||
@@ -92,7 +95,8 @@ def make_experiment_tag(orig_tag, config, mutations):
|
||||
return "{}@perturbed[{}]".format(orig_tag, format_vars(resolved_vars))
|
||||
|
||||
|
||||
def fill_config(config, attr, search_space):
|
||||
def fill_config(config: Dict, attr: str,
|
||||
search_space: Union[Callable, Domain, list, dict]):
|
||||
"""Add attr to config by sampling from search_space."""
|
||||
if callable(search_space):
|
||||
config[attr] = search_space()
|
||||
@@ -214,18 +218,19 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
time_attr="time_total_s",
|
||||
reward_attr=None,
|
||||
metric=None,
|
||||
mode=None,
|
||||
perturbation_interval=60.0,
|
||||
hyperparam_mutations={},
|
||||
quantile_fraction=0.25,
|
||||
resample_probability=0.25,
|
||||
custom_explore_fn=None,
|
||||
log_config=True,
|
||||
require_attrs=True,
|
||||
synch=False):
|
||||
time_attr: str = "time_total_s",
|
||||
reward_attr: Optional[str] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
perturbation_interval: float = 60.0,
|
||||
hyperparam_mutations: Dict = None,
|
||||
quantile_fraction: float = 0.25,
|
||||
resample_probability: float = 0.25,
|
||||
custom_explore_fn: Optional[Callable] = None,
|
||||
log_config: bool = True,
|
||||
require_attrs: bool = True,
|
||||
synch: bool = False):
|
||||
hyperparam_mutations = hyperparam_mutations or {}
|
||||
for value in hyperparam_mutations.values():
|
||||
if not (isinstance(value,
|
||||
(list, dict, Domain)) or callable(value)):
|
||||
@@ -288,7 +293,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
self._num_checkpoints = 0
|
||||
self._num_perturbations = 0
|
||||
|
||||
def set_search_properties(self, metric, mode):
|
||||
def set_search_properties(self, metric: Optional[str],
|
||||
mode: Optional[str]) -> bool:
|
||||
if self._metric and metric:
|
||||
return False
|
||||
if self._mode and mode:
|
||||
@@ -306,7 +312,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
|
||||
return True
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
if not self._metric or not self._metric_op:
|
||||
raise ValueError(
|
||||
"{} has been instantiated without a valid `metric` ({}) or "
|
||||
@@ -329,7 +336,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
# Make sure this attribute is added to CLI output.
|
||||
trial.evaluated_params[attr] = trial.config[attr]
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict) -> str:
|
||||
if self._time_attr not in result:
|
||||
time_missing_msg = "Cannot find time_attr {} " \
|
||||
"in trial result {}. Make sure that this " \
|
||||
@@ -425,8 +433,9 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
# the paused trials.
|
||||
return TrialScheduler.PAUSE
|
||||
|
||||
def _perturb_trial(self, trial, trial_runner, upper_quantile,
|
||||
lower_quantile):
|
||||
def _perturb_trial(
|
||||
self, trial: Trial, trial_runner: "trial_runner.TrialRunner",
|
||||
upper_quantile: List[Trial], lower_quantile: List[Trial]):
|
||||
"""Checkpoint if in upper quantile, exploits if in lower."""
|
||||
state = self._trial_state[trial]
|
||||
if trial in upper_quantile:
|
||||
@@ -450,8 +459,9 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
assert trial is not trial_to_clone
|
||||
self._exploit(trial_runner.trial_executor, trial, trial_to_clone)
|
||||
|
||||
def _log_config_on_step(self, trial_state, new_state, trial,
|
||||
trial_to_clone, new_config):
|
||||
def _log_config_on_step(self, trial_state: PBTTrialState,
|
||||
new_state: PBTTrialState, trial: Trial,
|
||||
trial_to_clone: Trial, new_config: Dict):
|
||||
"""Logs transition during exploit/exploit step.
|
||||
|
||||
For each step, logs: [target trial tag, clone trial tag, target trial
|
||||
@@ -482,7 +492,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
with open(trial_path, "a+") as f:
|
||||
f.write(json.dumps(policy, cls=_SafeFallbackEncoder) + "\n")
|
||||
|
||||
def _exploit(self, trial_executor, trial, trial_to_clone):
|
||||
def _exploit(self, trial_executor: "trial_executor.TrialExecutor",
|
||||
trial: Trial, trial_to_clone: Trial):
|
||||
"""Transfers perturbed state from trial_to_clone -> trial.
|
||||
|
||||
If specified, also logs the updated hyperparam state.
|
||||
@@ -554,7 +565,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
trial_state.last_perturbation_time = new_state.last_perturbation_time
|
||||
trial_state.last_train_time = new_state.last_train_time
|
||||
|
||||
def _quantiles(self):
|
||||
def _quantiles(self) -> Tuple[List[Trial], List[Trial]]:
|
||||
"""Returns trials in the lower and upper `quantile` of the population.
|
||||
|
||||
If there is not enough data to compute this, returns empty lists.
|
||||
@@ -578,7 +589,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
return (trials[:num_trials_in_quantile],
|
||||
trials[-num_trials_in_quantile:])
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
def choose_trial_to_run(
|
||||
self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]:
|
||||
"""Ensures all trials get fair share of time (as defined by time_attr).
|
||||
|
||||
This enables the PBT scheduler to support a greater number of
|
||||
@@ -601,7 +613,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
self._num_perturbations = 0
|
||||
self._num_checkpoints = 0
|
||||
|
||||
def last_scores(self, trials):
|
||||
def last_scores(self, trials: List[Trial]) -> List[float]:
|
||||
scores = []
|
||||
for trial in trials:
|
||||
state = self._trial_state[trial]
|
||||
@@ -609,7 +621,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
scores.append(state.last_score)
|
||||
return scores
|
||||
|
||||
def debug_string(self):
|
||||
def debug_string(self) -> str:
|
||||
return "PopulationBasedTraining: {} checkpoints, {} perturbs".format(
|
||||
self._num_checkpoints, self._num_perturbations)
|
||||
|
||||
@@ -657,7 +669,7 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, policy_file):
|
||||
def __init__(self, policy_file: str):
|
||||
policy_file = os.path.expanduser(policy_file)
|
||||
if not os.path.exists(policy_file):
|
||||
raise ValueError("Policy file not found: {}".format(policy_file))
|
||||
@@ -679,7 +691,8 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
|
||||
self._policy_iter = iter(self._policy)
|
||||
self._next_policy = next(self._policy_iter, None)
|
||||
|
||||
def _load_policy(self, policy_file):
|
||||
def _load_policy(self,
|
||||
policy_file: str) -> Tuple[Dict, List[Tuple[int, Dict]]]:
|
||||
raw_policy = []
|
||||
with open(policy_file, "rt") as fp:
|
||||
for row in fp.readlines():
|
||||
@@ -708,7 +721,8 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
|
||||
|
||||
return last_old_conf, list(reversed(policy))
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
if self._trial:
|
||||
raise ValueError(
|
||||
"More than one trial added to PBT replay run. This "
|
||||
@@ -729,7 +743,8 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
|
||||
"or consider not using PBT replay for this run.")
|
||||
self._trial.config = self.config
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict) -> str:
|
||||
if TRAINING_ITERATION not in result:
|
||||
# No time reported
|
||||
return TrialScheduler.CONTINUE
|
||||
@@ -775,6 +790,6 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
|
||||
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def debug_string(self):
|
||||
def debug_string(self) -> str:
|
||||
return "PopulationBasedTraining replay: Step {}, perturb {}".format(
|
||||
self._current_step, self._num_perturbations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ray.tune import trial_runner
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
|
||||
@@ -8,7 +11,8 @@ class TrialScheduler:
|
||||
PAUSE = "PAUSE" #: Status for pausing trial execution
|
||||
STOP = "STOP" #: Status for stopping trial execution
|
||||
|
||||
def set_search_properties(self, metric, mode):
|
||||
def set_search_properties(self, metric: Optional[str],
|
||||
mode: Optional[str]) -> bool:
|
||||
"""Pass search properties to scheduler.
|
||||
|
||||
This method acts as an alternative to instantiating schedulers
|
||||
@@ -20,19 +24,22 @@ class TrialScheduler:
|
||||
"""
|
||||
return True
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
"""Called when a new trial is added to the trial runner."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_error(self, trial_runner, trial):
|
||||
def on_trial_error(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
"""Notification for the error of trial.
|
||||
|
||||
This will only be called when the trial is in the RUNNING state."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict) -> str:
|
||||
"""Called on each intermediate result returned by a trial.
|
||||
|
||||
At this point, the trial scheduler can make a decision by returning
|
||||
@@ -41,7 +48,8 @@ class TrialScheduler:
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict):
|
||||
"""Notification for the completion of trial.
|
||||
|
||||
This will only be called when the trial is in the RUNNING state and
|
||||
@@ -49,7 +57,8 @@ class TrialScheduler:
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_remove(self, trial_runner, trial):
|
||||
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
"""Called to remove trial.
|
||||
|
||||
This is called when the trial is in PAUSED or PENDING state. Otherwise,
|
||||
@@ -57,7 +66,8 @@ class TrialScheduler:
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
def choose_trial_to_run(
|
||||
self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]:
|
||||
"""Called to choose a new trial to run.
|
||||
|
||||
This should return one of the trials in trial_runner that is in
|
||||
@@ -67,7 +77,7 @@ class TrialScheduler:
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def debug_string(self):
|
||||
def debug_string(self) -> str:
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
|
||||
raise NotImplementedError
|
||||
@@ -76,22 +86,28 @@ class TrialScheduler:
|
||||
class FIFOScheduler(TrialScheduler):
|
||||
"""Simple scheduler that just runs trials in submission order."""
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
pass
|
||||
|
||||
def on_trial_error(self, trial_runner, trial):
|
||||
def on_trial_error(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
pass
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict) -> str:
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict):
|
||||
pass
|
||||
|
||||
def on_trial_remove(self, trial_runner, trial):
|
||||
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
pass
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
def choose_trial_to_run(
|
||||
self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]:
|
||||
for trial in trial_runner.get_trials():
|
||||
if (trial.status == Trial.PENDING
|
||||
and trial_runner.has_resources(trial.resources)):
|
||||
@@ -102,5 +118,5 @@ class FIFOScheduler(TrialScheduler):
|
||||
return trial
|
||||
return None
|
||||
|
||||
def debug_string(self):
|
||||
def debug_string(self) -> str:
|
||||
return "Using FIFO scheduling algorithm."
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
|
||||
from ray.tune.suggest.search_generator import SearchGenerator
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
|
||||
class _MockSearcher(Searcher):
|
||||
@@ -11,29 +14,32 @@ class _MockSearcher(Searcher):
|
||||
self.results = []
|
||||
super(_MockSearcher, self).__init__(**kwargs)
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str):
|
||||
if not self.stall:
|
||||
self.live_trials[trial_id] = 1
|
||||
return {"test_variable": 2}
|
||||
return None
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
def on_trial_result(self, trial_id: str, result: Dict):
|
||||
self.counter["result"] += 1
|
||||
self.results += [result]
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
self.counter["complete"] += 1
|
||||
if result:
|
||||
self._process_result(result)
|
||||
if trial_id in self.live_trials:
|
||||
del self.live_trials[trial_id]
|
||||
|
||||
def _process_result(self, result):
|
||||
def _process_result(self, result: Dict):
|
||||
self.final_results += [result]
|
||||
|
||||
|
||||
class _MockSuggestionAlgorithm(SearchGenerator):
|
||||
def __init__(self, max_concurrent=None, **kwargs):
|
||||
def __init__(self, max_concurrent: Optional[int] = None, **kwargs):
|
||||
self.searcher = _MockSearcher(**kwargs)
|
||||
if max_concurrent:
|
||||
self.searcher = ConcurrencyLimiter(
|
||||
@@ -41,9 +47,9 @@ class _MockSuggestionAlgorithm(SearchGenerator):
|
||||
super(_MockSuggestionAlgorithm, self).__init__(self.searcher)
|
||||
|
||||
@property
|
||||
def live_trials(self):
|
||||
def live_trials(self) -> List[Trial]:
|
||||
return self.searcher.live_trials
|
||||
|
||||
@property
|
||||
def results(self):
|
||||
def results(self) -> List[Dict]:
|
||||
return self.searcher.results
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ax.service.ax_client import AxClient
|
||||
from ray.tune.sample import Categorical, Float, Integer, LogUniform, \
|
||||
@@ -103,14 +103,14 @@ class AxSearch(Searcher):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
space=None,
|
||||
metric=None,
|
||||
mode=None,
|
||||
parameter_constraints=None,
|
||||
outcome_constraints=None,
|
||||
ax_client=None,
|
||||
use_early_stopped_trials=None,
|
||||
max_concurrent=None):
|
||||
space: Optional[List[Dict]] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
parameter_constraints: Optional[List] = None,
|
||||
outcome_constraints: Optional[List] = None,
|
||||
ax_client: Optional[AxClient] = None,
|
||||
use_early_stopped_trials: Optional[bool] = None,
|
||||
max_concurrent: Optional[int] = None):
|
||||
assert ax is not None, "Ax must be installed!"
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
@@ -177,7 +177,8 @@ class AxSearch(Searcher):
|
||||
logger.warning("Detected sequential enforcement. Be sure to use "
|
||||
"a ConcurrencyLimiter.")
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict):
|
||||
if self._ax:
|
||||
return False
|
||||
space = self.convert_search_space(config)
|
||||
@@ -189,7 +190,7 @@ class AxSearch(Searcher):
|
||||
self.setup_experiment()
|
||||
return True
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str) -> Optional[Dict]:
|
||||
if not self._ax:
|
||||
raise RuntimeError(
|
||||
"Trying to sample a configuration from {}, but no search "
|
||||
|
||||
@@ -2,9 +2,10 @@ import itertools
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list
|
||||
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||
from ray.tune.suggest.variant_generator import (generate_variants, format_vars,
|
||||
flatten_resolved_vars)
|
||||
@@ -42,7 +43,7 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
searcher.is_finished == True
|
||||
"""
|
||||
|
||||
def __init__(self, shuffle=False):
|
||||
def __init__(self, shuffle: bool = False):
|
||||
"""Initializes the Variant Generator.
|
||||
|
||||
"""
|
||||
@@ -60,7 +61,9 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
else:
|
||||
self._uuid_prefix = str(uuid.uuid1().hex)[:5] + "_"
|
||||
|
||||
def add_configurations(self, experiments):
|
||||
def add_configurations(
|
||||
self,
|
||||
experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]):
|
||||
"""Chains generator given experiment specifications.
|
||||
|
||||
Arguments:
|
||||
|
||||
@@ -2,9 +2,10 @@ from collections import defaultdict
|
||||
import logging
|
||||
import pickle
|
||||
import json
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from ray.tune.sample import Float, Quantized
|
||||
from ray.tune import ExperimentAnalysis
|
||||
from ray.tune.sample import Domain, Float, Quantized
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils.util import unflatten_dict
|
||||
|
||||
@@ -100,18 +101,18 @@ class BayesOptSearch(Searcher):
|
||||
optimizer = None
|
||||
|
||||
def __init__(self,
|
||||
space=None,
|
||||
metric=None,
|
||||
mode=None,
|
||||
utility_kwargs=None,
|
||||
random_state=42,
|
||||
random_search_steps=10,
|
||||
verbose=0,
|
||||
patience=5,
|
||||
skip_duplicate=True,
|
||||
analysis=None,
|
||||
max_concurrent=None,
|
||||
use_early_stopped_trials=None):
|
||||
space: Optional[Dict] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
utility_kwargs: Optional[Dict] = None,
|
||||
random_state: int = 42,
|
||||
random_search_steps: int = 10,
|
||||
verbose: int = 0,
|
||||
patience: int = 5,
|
||||
skip_duplicate: bool = True,
|
||||
analysis: Optional[ExperimentAnalysis] = None,
|
||||
max_concurrent: Optional[int] = None,
|
||||
use_early_stopped_trials: Optional[bool] = None):
|
||||
"""Instantiate new BayesOptSearch object.
|
||||
|
||||
Args:
|
||||
@@ -200,7 +201,8 @@ class BayesOptSearch(Searcher):
|
||||
verbose=self._verbose,
|
||||
random_state=self._random_state)
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
if self.optimizer:
|
||||
return False
|
||||
space = self.convert_search_space(config)
|
||||
@@ -218,7 +220,7 @@ class BayesOptSearch(Searcher):
|
||||
self.setup_optimizer()
|
||||
return True
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str) -> Optional[Dict]:
|
||||
"""Return new point to be explored by black box function.
|
||||
|
||||
Args:
|
||||
@@ -278,7 +280,7 @@ class BayesOptSearch(Searcher):
|
||||
# Return a deep copy of the mapping
|
||||
return unflatten_dict(config)
|
||||
|
||||
def register_analysis(self, analysis):
|
||||
def register_analysis(self, analysis: ExperimentAnalysis):
|
||||
"""Integrate the given analysis into the gaussian process.
|
||||
|
||||
Args:
|
||||
@@ -293,7 +295,10 @@ class BayesOptSearch(Searcher):
|
||||
# gaussian process optimizer
|
||||
self._register_result(params, report)
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
"""Notification for the completion of trial.
|
||||
|
||||
Args:
|
||||
@@ -330,18 +335,18 @@ class BayesOptSearch(Searcher):
|
||||
for params, result in self._buffered_trial_results:
|
||||
self._register_result(params, result)
|
||||
|
||||
def _register_result(self, params, result):
|
||||
def _register_result(self, params: Tuple[str], result: Dict):
|
||||
"""Register given tuple of params and results."""
|
||||
self.optimizer.register(params, self._metric_op * result[self.metric])
|
||||
|
||||
def save(self, checkpoint_path):
|
||||
def save(self, checkpoint_path: str):
|
||||
"""Storing current optimizer state."""
|
||||
with open(checkpoint_path, "wb") as f:
|
||||
pickle.dump(
|
||||
(self.optimizer, self._buffered_trial_results,
|
||||
self._total_random_search_trials, self._config_counter), f)
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
def restore(self, checkpoint_path: str):
|
||||
"""Restoring current optimizer state."""
|
||||
with open(checkpoint_path, "rb") as f:
|
||||
(self.optimizer, self._buffered_trial_results,
|
||||
@@ -349,7 +354,7 @@ class BayesOptSearch(Searcher):
|
||||
self._config_counter) = pickle.load(f)
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict):
|
||||
def convert_search_space(spec: Dict) -> Dict:
|
||||
spec = flatten_dict(spec, prevent_delimiter=True)
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
@@ -358,7 +363,7 @@ class BayesOptSearch(Searcher):
|
||||
"Grid search parameters cannot be automatically converted "
|
||||
"to a BayesOpt search space.")
|
||||
|
||||
def resolve_value(domain):
|
||||
def resolve_value(domain: Domain) -> Tuple[float, float]:
|
||||
sampler = domain.get_sampler()
|
||||
if isinstance(sampler, Quantized):
|
||||
logger.warning(
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import ConfigSpace
|
||||
from ray.tune.sample import Categorical, Float, Integer, LogUniform, Normal, \
|
||||
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
|
||||
Normal, \
|
||||
Quantized, \
|
||||
Uniform
|
||||
from ray.tune.suggest import Searcher
|
||||
@@ -20,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
class _BOHBJobWrapper():
|
||||
"""Mock object for HpBandSter to process."""
|
||||
|
||||
def __init__(self, loss, budget, config):
|
||||
def __init__(self, loss: float, budget: float, config: Dict):
|
||||
self.result = {"loss": loss}
|
||||
self.kwargs = {"budget": budget, "config": config.copy()}
|
||||
self.exception = None
|
||||
@@ -92,11 +93,11 @@ class TuneBOHB(Searcher):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
space=None,
|
||||
bohb_config=None,
|
||||
max_concurrent=10,
|
||||
metric=None,
|
||||
mode=None):
|
||||
space: Optional[ConfigSpace.ConfigurationSpace] = None,
|
||||
bohb_config: Optional[Dict] = None,
|
||||
max_concurrent: int = 10,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None):
|
||||
from hpbandster.optimizers.config_generators.bohb import BOHB
|
||||
assert BOHB is not None, "HpBandSter must be installed!"
|
||||
if mode:
|
||||
@@ -126,7 +127,8 @@ class TuneBOHB(Searcher):
|
||||
bohb_config = self._bohb_config or {}
|
||||
self.bohber = BOHB(self._space, **bohb_config)
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
if self._space:
|
||||
return False
|
||||
space = self.convert_search_space(config)
|
||||
@@ -140,7 +142,7 @@ class TuneBOHB(Searcher):
|
||||
self.setup_bohb()
|
||||
return True
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str) -> Optional[Dict]:
|
||||
if not self._space:
|
||||
raise RuntimeError(
|
||||
"Trying to sample a configuration from {}, but no search "
|
||||
@@ -156,7 +158,7 @@ class TuneBOHB(Searcher):
|
||||
return unflatten_dict(config)
|
||||
return None
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
def on_trial_result(self, trial_id: str, result: Dict):
|
||||
if trial_id not in self.paused:
|
||||
self.running.add(trial_id)
|
||||
if "hyperband_info" not in result:
|
||||
@@ -166,28 +168,31 @@ class TuneBOHB(Searcher):
|
||||
hbs_wrapper = self.to_wrapper(trial_id, result)
|
||||
self.bohber.new_result(hbs_wrapper)
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
del self.trial_to_params[trial_id]
|
||||
if trial_id in self.paused:
|
||||
self.paused.remove(trial_id)
|
||||
if trial_id in self.running:
|
||||
self.running.remove(trial_id)
|
||||
|
||||
def to_wrapper(self, trial_id, result):
|
||||
def to_wrapper(self, trial_id: str, result: Dict) -> _BOHBJobWrapper:
|
||||
return _BOHBJobWrapper(self._metric_op * result[self.metric],
|
||||
result["hyperband_info"]["budget"],
|
||||
self.trial_to_params[trial_id])
|
||||
|
||||
def on_pause(self, trial_id):
|
||||
def on_pause(self, trial_id: str):
|
||||
self.paused.add(trial_id)
|
||||
self.running.remove(trial_id)
|
||||
|
||||
def on_unpause(self, trial_id):
|
||||
def on_unpause(self, trial_id: str):
|
||||
self.paused.remove(trial_id)
|
||||
self.running.add(trial_id)
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict):
|
||||
def convert_search_space(spec: Dict) -> ConfigSpace.ConfigurationSpace:
|
||||
spec = flatten_dict(spec, prevent_delimiter=True)
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
@@ -196,7 +201,8 @@ class TuneBOHB(Searcher):
|
||||
"Grid search parameters cannot be automatically converted "
|
||||
"to a TuneBOHB search space.")
|
||||
|
||||
def resolve_value(par, domain):
|
||||
def resolve_value(par: str, domain: Domain
|
||||
) -> ConfigSpace.hyperparameters.Hyperparameter:
|
||||
quantize = None
|
||||
|
||||
sampler = domain.get_sampler()
|
||||
|
||||
@@ -5,16 +5,18 @@ from __future__ import print_function
|
||||
import inspect
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ray.tune.sample import Float, Quantized
|
||||
from ray.tune.sample import Domain, Float, Quantized
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils.util import flatten_dict
|
||||
|
||||
try: # Python 3 only -- needed for lint test.
|
||||
import dragonfly
|
||||
from dragonfly.opt.blackbox_optimiser import BlackboxOptimiser
|
||||
except ImportError:
|
||||
dragonfly = None
|
||||
BlackboxOptimiser = None
|
||||
|
||||
from ray.tune.suggest.suggestion import Searcher
|
||||
|
||||
@@ -127,13 +129,13 @@ class DragonflySearch(Searcher):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer=None,
|
||||
domain=None,
|
||||
space=None,
|
||||
metric=None,
|
||||
mode=None,
|
||||
points_to_evaluate=None,
|
||||
evaluated_rewards=None,
|
||||
optimizer: Optional[BlackboxOptimiser] = None,
|
||||
domain: Optional[str] = None,
|
||||
space: Optional[List[Dict]] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
points_to_evaluate: Optional[List[List]] = None,
|
||||
evaluated_rewards: Optional[List] = None,
|
||||
**kwargs):
|
||||
assert dragonfly is not None, """dragonfly must be installed!
|
||||
You can install Dragonfly with the command:
|
||||
@@ -144,8 +146,6 @@ class DragonflySearch(Searcher):
|
||||
super(DragonflySearch, self).__init__(
|
||||
metric=metric, mode=mode, **kwargs)
|
||||
|
||||
from dragonfly.opt.blackbox_optimiser import BlackboxOptimiser
|
||||
|
||||
self._opt_arg = optimizer
|
||||
self._domain = domain
|
||||
self._space = space
|
||||
@@ -245,7 +245,8 @@ class DragonflySearch(Searcher):
|
||||
elif self._mode == "max":
|
||||
self._metric_op = 1.
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
if self._opt:
|
||||
return False
|
||||
space = self.convert_search_space(config)
|
||||
@@ -258,7 +259,7 @@ class DragonflySearch(Searcher):
|
||||
self.setup_dragonfly()
|
||||
return True
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str) -> Optional[Dict]:
|
||||
if not self._opt:
|
||||
raise RuntimeError(
|
||||
"Trying to sample a configuration from {}, but no search "
|
||||
@@ -281,26 +282,29 @@ class DragonflySearch(Searcher):
|
||||
self._live_trial_mapping[trial_id] = suggested_config
|
||||
return {"point": suggested_config}
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
"""Passes result to Dragonfly unless early terminated or errored."""
|
||||
trial_info = self._live_trial_mapping.pop(trial_id)
|
||||
if result:
|
||||
self._opt.tell([(trial_info,
|
||||
self._metric_op * result[self._metric])])
|
||||
|
||||
def save(self, checkpoint_dir):
|
||||
def save(self, checkpoint_path: str):
|
||||
trials_object = (self._initial_points, self._opt)
|
||||
with open(checkpoint_dir, "wb") as outputFile:
|
||||
with open(checkpoint_path, "wb") as outputFile:
|
||||
pickle.dump(trials_object, outputFile)
|
||||
|
||||
def restore(self, checkpoint_dir):
|
||||
def restore(self, checkpoint_dir: str):
|
||||
with open(checkpoint_dir, "rb") as inputFile:
|
||||
trials_object = pickle.load(inputFile)
|
||||
self._initial_points = trials_object[0]
|
||||
self._opt = trials_object[1]
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict):
|
||||
def convert_search_space(spec: Dict) -> List[Dict]:
|
||||
spec = flatten_dict(spec, prevent_delimiter=True)
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
@@ -309,7 +313,7 @@ class DragonflySearch(Searcher):
|
||||
"Grid search parameters cannot be automatically converted "
|
||||
"to a Dragonfly search space.")
|
||||
|
||||
def resolve_value(par, domain):
|
||||
def resolve_value(par: str, domain: Domain) -> Dict:
|
||||
sampler = domain.get_sampler()
|
||||
if isinstance(sampler, Quantized):
|
||||
logger.warning(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import copy
|
||||
@@ -6,7 +6,8 @@ import logging
|
||||
from functools import partial
|
||||
import pickle
|
||||
|
||||
from ray.tune.sample import Categorical, Float, Integer, LogUniform, Normal, \
|
||||
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
|
||||
Normal, \
|
||||
Quantized, \
|
||||
Uniform
|
||||
from ray.tune.suggest.variant_generator import assign_value, parse_spec_vars
|
||||
@@ -117,15 +118,15 @@ class HyperOptSearch(Searcher):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
space=None,
|
||||
metric=None,
|
||||
mode=None,
|
||||
points_to_evaluate=None,
|
||||
n_initial_points=20,
|
||||
random_state_seed=None,
|
||||
gamma=0.25,
|
||||
max_concurrent=None,
|
||||
use_early_stopped_trials=None,
|
||||
space: Optional[Dict] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
points_to_evaluate: Optional[List[Dict]] = None,
|
||||
n_initial_points: int = 20,
|
||||
random_state_seed: Optional[int] = None,
|
||||
gamma: float = 0.25,
|
||||
max_concurrent: Optional[int] = None,
|
||||
use_early_stopped_trials: Optional[bool] = None,
|
||||
):
|
||||
assert hpo is not None, (
|
||||
"HyperOpt must be installed! Run `pip install hyperopt`.")
|
||||
@@ -170,7 +171,8 @@ class HyperOptSearch(Searcher):
|
||||
if space:
|
||||
self.domain = hpo.Domain(lambda spc: spc, space)
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
if self.domain:
|
||||
return False
|
||||
space = self.convert_search_space(config)
|
||||
@@ -188,7 +190,7 @@ class HyperOptSearch(Searcher):
|
||||
|
||||
return True
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str) -> Optional[Dict]:
|
||||
if not self.domain:
|
||||
raise RuntimeError(
|
||||
"Trying to sample a configuration from {}, but no search "
|
||||
@@ -226,7 +228,7 @@ class HyperOptSearch(Searcher):
|
||||
print_node_on_error=self.domain.rec_eval_print_node_on_error)
|
||||
return copy.deepcopy(suggested_config)
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
def on_trial_result(self, trial_id: str, result: Dict):
|
||||
ho_trial = self._get_hyperopt_trial(trial_id)
|
||||
if ho_trial is None:
|
||||
return
|
||||
@@ -234,7 +236,10 @@ class HyperOptSearch(Searcher):
|
||||
ho_trial["book_time"] = now
|
||||
ho_trial["refresh_time"] = now
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
"""Notification for the completion of trial.
|
||||
|
||||
The result is internally negated when interacting with HyperOpt
|
||||
@@ -252,7 +257,7 @@ class HyperOptSearch(Searcher):
|
||||
self._process_result(trial_id, result)
|
||||
del self._live_trial_mapping[trial_id]
|
||||
|
||||
def _process_result(self, trial_id, result):
|
||||
def _process_result(self, trial_id: str, result: Dict):
|
||||
ho_trial = self._get_hyperopt_trial(trial_id)
|
||||
if not ho_trial:
|
||||
return
|
||||
@@ -263,10 +268,10 @@ class HyperOptSearch(Searcher):
|
||||
ho_trial["result"] = hp_result
|
||||
self._hpopt_trials.refresh()
|
||||
|
||||
def _to_hyperopt_result(self, result):
|
||||
def _to_hyperopt_result(self, result: Dict) -> Dict:
|
||||
return {"loss": self.metric_op * result[self.metric], "status": "ok"}
|
||||
|
||||
def _get_hyperopt_trial(self, trial_id):
|
||||
def _get_hyperopt_trial(self, trial_id: str) -> Optional[Dict]:
|
||||
if trial_id not in self._live_trial_mapping:
|
||||
return
|
||||
hyperopt_tid = self._live_trial_mapping[trial_id][0]
|
||||
@@ -274,21 +279,21 @@ class HyperOptSearch(Searcher):
|
||||
t for t in self._hpopt_trials.trials if t["tid"] == hyperopt_tid
|
||||
][0]
|
||||
|
||||
def get_state(self):
|
||||
def get_state(self) -> Dict:
|
||||
return {
|
||||
"hyperopt_trials": self._hpopt_trials,
|
||||
"rstate": self.rstate.get_state()
|
||||
}
|
||||
|
||||
def set_state(self, state):
|
||||
def set_state(self, state: Dict):
|
||||
self._hpopt_trials = state["hyperopt_trials"]
|
||||
self.rstate.set_state(state["rstate"])
|
||||
|
||||
def save(self, checkpoint_path):
|
||||
def save(self, checkpoint_path: str):
|
||||
with open(checkpoint_path, "wb") as outputFile:
|
||||
pickle.dump(self.get_state(), outputFile)
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
def restore(self, checkpoint_path: str):
|
||||
with open(checkpoint_path, "rb") as inputFile:
|
||||
trials_object = pickle.load(inputFile)
|
||||
|
||||
@@ -299,19 +304,19 @@ class HyperOptSearch(Searcher):
|
||||
self.set_state(trials_object)
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict):
|
||||
def convert_search_space(spec: Dict) -> Dict:
|
||||
spec = copy.deepcopy(spec)
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
if not domain_vars and not grid_vars:
|
||||
return []
|
||||
return {}
|
||||
|
||||
if grid_vars:
|
||||
raise ValueError(
|
||||
"Grid search parameters cannot be automatically converted "
|
||||
"to a HyperOpt search space.")
|
||||
|
||||
def resolve_value(par, domain):
|
||||
def resolve_value(par: str, domain: Domain) -> Any:
|
||||
quantize = None
|
||||
|
||||
sampler = domain.get_sampler()
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from ray.tune.sample import Categorical, Float, Integer, LogUniform, Quantized
|
||||
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
|
||||
Quantized
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.utils.util import unflatten_dict
|
||||
|
||||
try:
|
||||
import nevergrad as ng
|
||||
from nevergrad.optimization import Optimizer
|
||||
from nevergrad.optimization.base import ConfiguredOptimizer
|
||||
Parameter = ng.p.Parameter
|
||||
except ImportError:
|
||||
ng = None
|
||||
Optimizer = None
|
||||
ConfiguredOptimizer = None
|
||||
Parameter = None
|
||||
|
||||
from ray.tune.suggest import Searcher
|
||||
|
||||
@@ -85,11 +92,11 @@ class NevergradSearch(Searcher):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer=None,
|
||||
space=None,
|
||||
metric=None,
|
||||
mode=None,
|
||||
max_concurrent=None,
|
||||
optimizer: Union[None, Optimizer, ConfiguredOptimizer] = None,
|
||||
space: Optional[Parameter] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
max_concurrent: Optional[int] = None,
|
||||
**kwargs):
|
||||
assert ng is not None, "Nevergrad must be installed!"
|
||||
if mode:
|
||||
@@ -102,7 +109,7 @@ class NevergradSearch(Searcher):
|
||||
self._opt_factory = None
|
||||
self._nevergrad_opt = None
|
||||
|
||||
if isinstance(optimizer, ng.optimization.Optimizer):
|
||||
if isinstance(optimizer, Optimizer):
|
||||
if space is not None or isinstance(space, list):
|
||||
raise ValueError(
|
||||
"If you pass a configured optimizer to Nevergrad, either "
|
||||
@@ -110,7 +117,7 @@ class NevergradSearch(Searcher):
|
||||
"parameter.")
|
||||
self._parameters = space
|
||||
self._nevergrad_opt = optimizer
|
||||
elif isinstance(optimizer, ng.optimization.base.ConfiguredOptimizer):
|
||||
elif isinstance(optimizer, ConfiguredOptimizer):
|
||||
self._opt_factory = optimizer
|
||||
self._parameters = None
|
||||
self._space = space
|
||||
@@ -155,7 +162,8 @@ class NevergradSearch(Searcher):
|
||||
raise ValueError("len(parameters_names) must match optimizer "
|
||||
"dimension for non-instrumented optimizers")
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
if self._nevergrad_opt or self._space:
|
||||
return False
|
||||
space = self.convert_search_space(config)
|
||||
@@ -169,7 +177,7 @@ class NevergradSearch(Searcher):
|
||||
self.setup_nevergrad()
|
||||
return True
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str) -> Optional[Dict]:
|
||||
if not self._nevergrad_opt:
|
||||
raise RuntimeError(
|
||||
"Trying to sample a configuration from {}, but no search "
|
||||
@@ -192,7 +200,10 @@ class NevergradSearch(Searcher):
|
||||
else:
|
||||
return unflatten_dict(suggested_config.kwargs)
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
"""Notification for the completion of trial.
|
||||
|
||||
The result is internally negated when interacting with Nevergrad
|
||||
@@ -204,24 +215,24 @@ class NevergradSearch(Searcher):
|
||||
|
||||
self._live_trial_mapping.pop(trial_id)
|
||||
|
||||
def _process_result(self, trial_id, result):
|
||||
def _process_result(self, trial_id: str, result: Dict):
|
||||
ng_trial_info = self._live_trial_mapping[trial_id]
|
||||
self._nevergrad_opt.tell(ng_trial_info,
|
||||
self._metric_op * result[self._metric])
|
||||
|
||||
def save(self, checkpoint_path):
|
||||
def save(self, checkpoint_path: str):
|
||||
trials_object = (self._nevergrad_opt, self._parameters)
|
||||
with open(checkpoint_path, "wb") as outputFile:
|
||||
pickle.dump(trials_object, outputFile)
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
def restore(self, checkpoint_path: str):
|
||||
with open(checkpoint_path, "rb") as inputFile:
|
||||
trials_object = pickle.load(inputFile)
|
||||
self._nevergrad_opt = trials_object[0]
|
||||
self._parameters = trials_object[1]
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict):
|
||||
def convert_search_space(spec: Dict) -> Parameter:
|
||||
spec = flatten_dict(spec, prevent_delimiter=True)
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
@@ -230,7 +241,7 @@ class NevergradSearch(Searcher):
|
||||
"Grid search parameters cannot be automatically converted "
|
||||
"to a Nevergrad search space.")
|
||||
|
||||
def resolve_value(domain):
|
||||
def resolve_value(domain: Domain) -> Parameter:
|
||||
sampler = domain.get_sampler()
|
||||
if isinstance(sampler, Quantized):
|
||||
logger.warning("Nevergrad does not support quantization. "
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.sample import Categorical, Float, Integer, LogUniform, \
|
||||
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
|
||||
Quantized, Uniform
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils import flatten_dict
|
||||
@@ -11,8 +11,10 @@ from ray.tune.utils.util import unflatten_dict
|
||||
|
||||
try:
|
||||
import optuna as ot
|
||||
from optuna.samplers import BaseSampler
|
||||
except ImportError:
|
||||
ot = None
|
||||
BaseSampler = None
|
||||
|
||||
from ray.tune.suggest import Searcher
|
||||
|
||||
@@ -100,7 +102,11 @@ class OptunaSearch(Searcher):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, space=None, metric=None, mode=None, sampler=None):
|
||||
def __init__(self,
|
||||
space: Optional[List[Tuple]] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
sampler: Optional[BaseSampler] = None):
|
||||
assert ot is not None, (
|
||||
"Optuna must be installed! Run `pip install optuna`.")
|
||||
super(OptunaSearch, self).__init__(
|
||||
@@ -113,7 +119,7 @@ class OptunaSearch(Searcher):
|
||||
|
||||
self._study_name = "optuna" # Fixed study name for in-memory storage
|
||||
self._sampler = sampler or ot.samplers.TPESampler()
|
||||
assert isinstance(self._sampler, ot.samplers.BaseSampler), \
|
||||
assert isinstance(self._sampler, BaseSampler), \
|
||||
"You can only pass an instance of `optuna.samplers.BaseSampler` " \
|
||||
"as a sampler to `OptunaSearcher`."
|
||||
|
||||
@@ -125,7 +131,7 @@ class OptunaSearch(Searcher):
|
||||
if self._space:
|
||||
self.setup_study(mode)
|
||||
|
||||
def setup_study(self, mode):
|
||||
def setup_study(self, mode: str):
|
||||
self._ot_study = ot.study.create_study(
|
||||
storage=self._storage,
|
||||
sampler=self._sampler,
|
||||
@@ -134,7 +140,8 @@ class OptunaSearch(Searcher):
|
||||
direction="minimize" if mode == "min" else "maximize",
|
||||
load_if_exists=True)
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
if self._space:
|
||||
return False
|
||||
space = self.convert_search_space(config)
|
||||
@@ -146,7 +153,7 @@ class OptunaSearch(Searcher):
|
||||
self.setup_study(mode)
|
||||
return True
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str) -> Optional[Dict]:
|
||||
if not self._space:
|
||||
raise RuntimeError(
|
||||
"Trying to sample a configuration from {}, but no search "
|
||||
@@ -169,13 +176,16 @@ class OptunaSearch(Searcher):
|
||||
}
|
||||
return unflatten_dict(params)
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
def on_trial_result(self, trial_id: str, result: Dict):
|
||||
metric = result[self.metric]
|
||||
step = result[TRAINING_ITERATION]
|
||||
ot_trial = self._ot_trials[trial_id]
|
||||
ot_trial.report(metric, step)
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
ot_trial = self._ot_trials[trial_id]
|
||||
ot_trial_id = ot_trial._trial_id
|
||||
self._storage.set_trial_value(ot_trial_id, result.get(
|
||||
@@ -183,20 +193,20 @@ class OptunaSearch(Searcher):
|
||||
self._storage.set_trial_state(ot_trial_id,
|
||||
ot.trial.TrialState.COMPLETE)
|
||||
|
||||
def save(self, checkpoint_path):
|
||||
def save(self, checkpoint_path: str):
|
||||
save_object = (self._storage, self._pruner, self._sampler,
|
||||
self._ot_trials, self._ot_study)
|
||||
with open(checkpoint_path, "wb") as outputFile:
|
||||
pickle.dump(save_object, outputFile)
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
def restore(self, checkpoint_path: str):
|
||||
with open(checkpoint_path, "rb") as inputFile:
|
||||
save_object = pickle.load(inputFile)
|
||||
self._storage, self._pruner, self._sampler, \
|
||||
self._ot_trials, self._ot_study = save_object
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict):
|
||||
def convert_search_space(spec: Dict) -> List[Tuple]:
|
||||
spec = flatten_dict(spec, prevent_delimiter=True)
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
@@ -208,7 +218,7 @@ class OptunaSearch(Searcher):
|
||||
"Grid search parameters cannot be automatically converted "
|
||||
"to an Optuna search space.")
|
||||
|
||||
def resolve_value(par, domain):
|
||||
def resolve_value(par: str, domain: Domain) -> Tuple:
|
||||
quantize = None
|
||||
|
||||
sampler = domain.get_sampler()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import copy
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.suggest.suggestion import Searcher
|
||||
@@ -10,7 +12,7 @@ TRIAL_INDEX = "__trial_index__"
|
||||
"""str: A constant value representing the repeat index of the trial."""
|
||||
|
||||
|
||||
def _warn_num_samples(searcher, num_samples):
|
||||
def _warn_num_samples(searcher: Searcher, num_samples: int):
|
||||
if isinstance(searcher, Repeater) and num_samples % searcher.repeat:
|
||||
logger.warning(
|
||||
"`num_samples` is now expected to be the total number of trials, "
|
||||
@@ -34,7 +36,10 @@ class _TrialGroup:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, primary_trial_id, config, max_trials=1):
|
||||
def __init__(self,
|
||||
primary_trial_id: str,
|
||||
config: Dict,
|
||||
max_trials: int = 1):
|
||||
assert type(config) is dict, (
|
||||
"config is not a dict, got {}".format(config))
|
||||
self.primary_trial_id = primary_trial_id
|
||||
@@ -42,27 +47,27 @@ class _TrialGroup:
|
||||
self._trials = {primary_trial_id: None}
|
||||
self.max_trials = max_trials
|
||||
|
||||
def add(self, trial_id):
|
||||
def add(self, trial_id: str):
|
||||
assert len(self._trials) < self.max_trials
|
||||
self._trials.setdefault(trial_id, None)
|
||||
|
||||
def full(self):
|
||||
def full(self) -> bool:
|
||||
return len(self._trials) == self.max_trials
|
||||
|
||||
def report(self, trial_id, score):
|
||||
def report(self, trial_id: str, score: float):
|
||||
assert trial_id in self._trials
|
||||
if score is None:
|
||||
raise ValueError("Internal Error: Score cannot be None.")
|
||||
self._trials[trial_id] = score
|
||||
|
||||
def finished_reporting(self):
|
||||
def finished_reporting(self) -> bool:
|
||||
return None not in self._trials.values() and len(
|
||||
self._trials) == self.max_trials
|
||||
|
||||
def scores(self):
|
||||
def scores(self) -> List[Optional[float]]:
|
||||
return list(self._trials.values())
|
||||
|
||||
def count(self):
|
||||
def count(self) -> int:
|
||||
return len(self._trials)
|
||||
|
||||
|
||||
@@ -103,7 +108,10 @@ class Repeater(Searcher):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, searcher, repeat=1, set_index=True):
|
||||
def __init__(self,
|
||||
searcher: Searcher,
|
||||
repeat: int = 1,
|
||||
set_index: bool = True):
|
||||
self.searcher = searcher
|
||||
self.repeat = repeat
|
||||
self._set_index = set_index
|
||||
@@ -113,7 +121,7 @@ class Repeater(Searcher):
|
||||
super(Repeater, self).__init__(
|
||||
metric=self.searcher.metric, mode=self.searcher.mode)
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str) -> Optional[Dict]:
|
||||
if self._current_group is None or self._current_group.full():
|
||||
config = self.searcher.suggest(trial_id)
|
||||
if config is None:
|
||||
@@ -132,7 +140,10 @@ class Repeater(Searcher):
|
||||
self._trial_id_to_group[trial_id] = self._current_group
|
||||
return config
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, **kwargs):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
**kwargs):
|
||||
"""Stores the score for and keeps track of a completed trial.
|
||||
|
||||
Stores the metric of a trial as nan if any of the following conditions
|
||||
@@ -160,13 +171,14 @@ class Repeater(Searcher):
|
||||
result={self.searcher.metric: np.nanmean(scores)},
|
||||
**kwargs)
|
||||
|
||||
def get_state(self):
|
||||
def get_state(self) -> Dict:
|
||||
self_state = self.__dict__.copy()
|
||||
del self_state["searcher"]
|
||||
return self_state
|
||||
|
||||
def set_state(self, state):
|
||||
def set_state(self, state: Dict):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
return self.searcher.set_search_properties(metric, mode, config)
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
|
||||
class SearchAlgorithm:
|
||||
"""Interface of an event handler API for hyperparameter search.
|
||||
|
||||
@@ -12,7 +18,8 @@ class SearchAlgorithm:
|
||||
"""
|
||||
_finished = False
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
"""Pass search properties to search algorithm.
|
||||
|
||||
This method acts as an alternative to instantiating search algorithms
|
||||
@@ -29,7 +36,9 @@ class SearchAlgorithm:
|
||||
"""
|
||||
return True
|
||||
|
||||
def add_configurations(self, experiments):
|
||||
def add_configurations(
|
||||
self,
|
||||
experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]):
|
||||
"""Tracks given experiment specifications.
|
||||
|
||||
Arguments:
|
||||
@@ -37,7 +46,7 @@ class SearchAlgorithm:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def next_trials(self):
|
||||
def next_trials(self) -> List[Trial]:
|
||||
"""Provides Trial objects to be queued into the TrialRunner.
|
||||
|
||||
Returns:
|
||||
@@ -45,17 +54,21 @@ class SearchAlgorithm:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
def on_trial_result(self, trial_id: str, result: Dict):
|
||||
"""Called on each intermediate result returned by a trial.
|
||||
|
||||
This will only be called when the trial is in the RUNNING state.
|
||||
|
||||
Arguments:
|
||||
trial_id: Identifier for the trial.
|
||||
result: Result dictionary.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
"""Notification for the completion of trial.
|
||||
|
||||
Arguments:
|
||||
@@ -69,7 +82,7 @@ class SearchAlgorithm:
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_finished(self):
|
||||
def is_finished(self) -> bool:
|
||||
"""Returns True if no trials left to be queued into TrialRunner.
|
||||
|
||||
Can return True before all trials have finished executing.
|
||||
@@ -80,14 +93,14 @@ class SearchAlgorithm:
|
||||
"""Marks the search algorithm as finished."""
|
||||
self._finished = True
|
||||
|
||||
def has_checkpoint(self, dirpath):
|
||||
def has_checkpoint(self, dirpath: str) -> bool:
|
||||
"""Should return False if not restoring is not implemented."""
|
||||
return False
|
||||
|
||||
def save_to_dir(self, dirpath, **kwargs):
|
||||
def save_to_dir(self, dirpath: str, **kwargs):
|
||||
"""Saves a search algorithm."""
|
||||
pass
|
||||
|
||||
def restore_from_dir(self, dirpath):
|
||||
def restore_from_dir(self, dirpath: str):
|
||||
"""Restores a search algorithm along with its wrapped state."""
|
||||
pass
|
||||
|
||||
@@ -2,10 +2,11 @@ import os
|
||||
import copy
|
||||
import logging
|
||||
import glob
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list
|
||||
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.suggest.suggestion import Searcher
|
||||
@@ -21,7 +22,7 @@ def _warn_on_repeater(searcher, total_samples):
|
||||
_warn_num_samples(searcher, total_samples)
|
||||
|
||||
|
||||
def _atomic_save(state, checkpoint_dir, file_name):
|
||||
def _atomic_save(state: Dict, checkpoint_dir: str, file_name: str):
|
||||
"""Atomically saves the object to the checkpoint directory
|
||||
|
||||
This is automatically used by tune.run during a Tune job.
|
||||
@@ -34,7 +35,7 @@ def _atomic_save(state, checkpoint_dir, file_name):
|
||||
os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
|
||||
|
||||
|
||||
def _find_newest_ckpt(dirpath, pattern):
|
||||
def _find_newest_ckpt(dirpath: str, pattern: str):
|
||||
"""Returns path to most recently modified checkpoint."""
|
||||
full_paths = glob.glob(os.path.join(dirpath, pattern))
|
||||
if not full_paths:
|
||||
@@ -58,7 +59,7 @@ class SearchGenerator(SearchAlgorithm):
|
||||
"""
|
||||
CKPT_FILE_TMPL = "search_gen_state-{}.json"
|
||||
|
||||
def __init__(self, searcher):
|
||||
def __init__(self, searcher: Searcher):
|
||||
assert issubclass(
|
||||
type(searcher),
|
||||
Searcher), ("Searcher should be subclassing Searcher.")
|
||||
@@ -69,10 +70,13 @@ class SearchGenerator(SearchAlgorithm):
|
||||
self._total_samples = None # int: total samples to evaluate.
|
||||
self._finished = False
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
return self.searcher.set_search_properties(metric, mode, config)
|
||||
|
||||
def add_configurations(self, experiments):
|
||||
def add_configurations(
|
||||
self,
|
||||
experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]):
|
||||
"""Registers experiment specifications.
|
||||
|
||||
Arguments:
|
||||
@@ -91,7 +95,7 @@ class SearchGenerator(SearchAlgorithm):
|
||||
if "run" not in experiment_spec:
|
||||
raise TuneError("Must specify `run` in {}".format(experiment_spec))
|
||||
|
||||
def next_trials(self):
|
||||
def next_trials(self) -> List[Trial]:
|
||||
"""Provides a batch of Trial objects to be queued into the TrialRunner.
|
||||
|
||||
Returns:
|
||||
@@ -106,7 +110,8 @@ class SearchGenerator(SearchAlgorithm):
|
||||
trials.append(trial)
|
||||
return trials
|
||||
|
||||
def create_trial_if_possible(self, experiment_spec, output_path):
|
||||
def create_trial_if_possible(self, experiment_spec: Dict,
|
||||
output_path: str) -> Optional[Trial]:
|
||||
logger.debug("creating trial")
|
||||
trial_id = Trial.generate_id()
|
||||
suggested_config = self.searcher.suggest(trial_id)
|
||||
@@ -135,18 +140,21 @@ class SearchGenerator(SearchAlgorithm):
|
||||
trial_id=trial_id)
|
||||
return trial
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
def on_trial_result(self, trial_id: str, result: Dict):
|
||||
"""Notifies the underlying searcher."""
|
||||
self.searcher.on_trial_result(trial_id, result)
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
self.searcher.on_trial_complete(
|
||||
trial_id=trial_id, result=result, error=error)
|
||||
|
||||
def is_finished(self):
|
||||
def is_finished(self) -> bool:
|
||||
return self._counter >= self._total_samples or self._finished
|
||||
|
||||
def get_state(self):
|
||||
def get_state(self) -> Dict:
|
||||
return {
|
||||
"counter": self._counter,
|
||||
"total_samples": self._total_samples,
|
||||
@@ -154,17 +162,17 @@ class SearchGenerator(SearchAlgorithm):
|
||||
"experiment": self._experiment
|
||||
}
|
||||
|
||||
def set_state(self, state):
|
||||
def set_state(self, state: Dict):
|
||||
self._counter = state["counter"]
|
||||
self._total_samples = state["total_samples"]
|
||||
self._finished = state["finished"]
|
||||
self._experiment = state["experiment"]
|
||||
|
||||
def has_checkpoint(self, dirpath):
|
||||
def has_checkpoint(self, dirpath: str):
|
||||
return bool(
|
||||
_find_newest_ckpt(dirpath, self.CKPT_FILE_TMPL.format("*")))
|
||||
|
||||
def save_to_dir(self, dirpath, session_str):
|
||||
def save_to_dir(self, dirpath: str, session_str: str):
|
||||
"""Saves self + searcher to dir.
|
||||
|
||||
Separates the "searcher" from its wrappers (concurrency, repeating).
|
||||
@@ -196,7 +204,7 @@ class SearchGenerator(SearchAlgorithm):
|
||||
_atomic_save(search_alg_state, dirpath,
|
||||
self.CKPT_FILE_TMPL.format(session_str))
|
||||
|
||||
def restore_from_dir(self, dirpath):
|
||||
def restore_from_dir(self, dirpath: str):
|
||||
"""Restores self + searcher + search wrappers from dirpath."""
|
||||
|
||||
searcher = self.searcher
|
||||
|
||||
@@ -2,10 +2,14 @@ import copy
|
||||
import os
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
try:
|
||||
import sigopt as sgo
|
||||
Connection = sgo.Connection
|
||||
except ImportError:
|
||||
sgo = None
|
||||
Connection = None
|
||||
|
||||
from ray.tune.suggest import Searcher
|
||||
|
||||
@@ -122,16 +126,16 @@ class SigOptSearch(Searcher):
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
space=None,
|
||||
name="Default Tune Experiment",
|
||||
max_concurrent=1,
|
||||
reward_attr=None,
|
||||
connection=None,
|
||||
experiment_id=None,
|
||||
observation_budget=None,
|
||||
project=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
space: List[Dict] = None,
|
||||
name: str = "Default Tune Experiment",
|
||||
max_concurrent: int = 1,
|
||||
reward_attr: Optional[str] = None,
|
||||
connection: Optional[Connection] = None,
|
||||
experiment_id: Optional[str] = None,
|
||||
observation_budget: Optional[int] = None,
|
||||
project: Optional[str] = None,
|
||||
metric: Union[None, str, List[str]] = "episode_reward_mean",
|
||||
mode: Union[None, str, List[str]] = "max",
|
||||
**kwargs):
|
||||
assert (experiment_id is
|
||||
None) ^ (space is None), "space xor experiment_id must be set"
|
||||
@@ -178,7 +182,7 @@ class SigOptSearch(Searcher):
|
||||
|
||||
super(SigOptSearch, self).__init__(metric=metric, mode=mode, **kwargs)
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str):
|
||||
if self._max_concurrent:
|
||||
if len(self._live_trial_mapping) >= self._max_concurrent:
|
||||
return None
|
||||
@@ -190,7 +194,10 @@ class SigOptSearch(Searcher):
|
||||
|
||||
return copy.deepcopy(suggestion.assignments)
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
"""Notification for the completion of trial.
|
||||
|
||||
If a trial fails, it will be reported as a failed Observation, telling
|
||||
@@ -214,7 +221,7 @@ class SigOptSearch(Searcher):
|
||||
del self._live_trial_mapping[trial_id]
|
||||
|
||||
@staticmethod
|
||||
def serialize_metric(metrics, modes):
|
||||
def serialize_metric(metrics: List[str], modes: List[str]):
|
||||
"""
|
||||
Converts metrics to https://app.sigopt.com/docs/objects/metric
|
||||
"""
|
||||
@@ -224,7 +231,7 @@ class SigOptSearch(Searcher):
|
||||
dict(name=metric, **SigOptSearch.OBJECTIVE_MAP[mode].copy()))
|
||||
return serialized_metric
|
||||
|
||||
def serialize_result(self, result):
|
||||
def serialize_result(self, result: Dict):
|
||||
"""
|
||||
Converts experiments results to
|
||||
https://app.sigopt.com/docs/objects/metric_evaluation
|
||||
@@ -244,12 +251,12 @@ class SigOptSearch(Searcher):
|
||||
values.append(value)
|
||||
return values
|
||||
|
||||
def save(self, checkpoint_path):
|
||||
def save(self, checkpoint_path: str):
|
||||
trials_object = (self.conn, self.experiment)
|
||||
with open(checkpoint_path, "wb") as outputFile:
|
||||
pickle.dump(trials_object, outputFile)
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
def restore(self, checkpoint_path: str):
|
||||
with open(checkpoint_path, "rb") as inputFile:
|
||||
trials_object = pickle.load(inputFile)
|
||||
self.conn = trials_object[0]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ray.tune.sample import Categorical, Float, Integer, Quantized
|
||||
from ray.tune.sample import Categorical, Domain, Float, Integer, Quantized
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.utils.util import unflatten_dict
|
||||
@@ -17,8 +17,9 @@ from ray.tune.suggest import Searcher
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_warmstart(parameter_names, points_to_evaluate,
|
||||
evaluated_rewards):
|
||||
def _validate_warmstart(parameter_names: List[str],
|
||||
points_to_evaluate: List[List],
|
||||
evaluated_rewards: List):
|
||||
if points_to_evaluate:
|
||||
if not isinstance(points_to_evaluate, list):
|
||||
raise TypeError(
|
||||
@@ -125,14 +126,14 @@ class SkOptSearch(Searcher):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer=None,
|
||||
space=None,
|
||||
metric=None,
|
||||
mode=None,
|
||||
points_to_evaluate=None,
|
||||
evaluated_rewards=None,
|
||||
max_concurrent=None,
|
||||
use_early_stopped_trials=None):
|
||||
optimizer: Optional[sko.optimizer.Optimizer] = None,
|
||||
space: Union[List[str], Dict[str, Union[Tuple, List]]] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
points_to_evaluate: Optional[List[List]] = None,
|
||||
evaluated_rewards: Optional[List] = None,
|
||||
max_concurrent: Optional[int] = None,
|
||||
use_early_stopped_trials: Optional[bool] = None):
|
||||
assert sko is not None, """skopt must be installed!
|
||||
You can install Skopt with the command:
|
||||
`pip install scikit-optimize`."""
|
||||
@@ -162,7 +163,7 @@ class SkOptSearch(Searcher):
|
||||
"names.")
|
||||
self._parameter_names = space
|
||||
else:
|
||||
self._parameter_names = space.keys()
|
||||
self._parameter_names = list(space.keys())
|
||||
self._parameter_ranges = space.values()
|
||||
|
||||
self._points_to_evaluate = points_to_evaluate
|
||||
@@ -199,7 +200,8 @@ class SkOptSearch(Searcher):
|
||||
elif self._mode == "min":
|
||||
self._metric_op = 1.
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
if self._skopt_opt:
|
||||
return False
|
||||
space = self.convert_search_space(config)
|
||||
@@ -216,7 +218,7 @@ class SkOptSearch(Searcher):
|
||||
self.setup_skopt()
|
||||
return True
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str) -> Optional[Dict]:
|
||||
if not self._skopt_opt:
|
||||
raise RuntimeError(
|
||||
"Trying to sample a configuration from {}, but no search "
|
||||
@@ -235,7 +237,10 @@ class SkOptSearch(Searcher):
|
||||
self._live_trial_mapping[trial_id] = suggested_config
|
||||
return unflatten_dict(dict(zip(self._parameters, suggested_config)))
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
"""Notification for the completion of trial.
|
||||
|
||||
The result is internally negated when interacting with Skopt
|
||||
@@ -247,24 +252,24 @@ class SkOptSearch(Searcher):
|
||||
self._process_result(trial_id, result)
|
||||
self._live_trial_mapping.pop(trial_id)
|
||||
|
||||
def _process_result(self, trial_id, result):
|
||||
def _process_result(self, trial_id: str, result: Dict):
|
||||
skopt_trial_info = self._live_trial_mapping[trial_id]
|
||||
self._skopt_opt.tell(skopt_trial_info,
|
||||
self._metric_op * result[self._metric])
|
||||
|
||||
def save(self, checkpoint_path):
|
||||
def save(self, checkpoint_path: str):
|
||||
trials_object = (self._initial_points, self._skopt_opt)
|
||||
with open(checkpoint_path, "wb") as outputFile:
|
||||
pickle.dump(trials_object, outputFile)
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
def restore(self, checkpoint_path: str):
|
||||
with open(checkpoint_path, "rb") as inputFile:
|
||||
trials_object = pickle.load(inputFile)
|
||||
self._initial_points = trials_object[0]
|
||||
self._skopt_opt = trials_object[1]
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict):
|
||||
def convert_search_space(spec: Dict) -> Dict:
|
||||
spec = flatten_dict(spec, prevent_delimiter=True)
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
@@ -273,7 +278,7 @@ class SkOptSearch(Searcher):
|
||||
"Grid search parameters cannot be automatically converted "
|
||||
"to a SkOpt search space.")
|
||||
|
||||
def resolve_value(domain):
|
||||
def resolve_value(domain: Domain) -> Union[Tuple, List]:
|
||||
sampler = domain.get_sampler()
|
||||
if isinstance(sampler, Quantized):
|
||||
logger.warning("SkOpt search does not support quantization. "
|
||||
|
||||
@@ -2,6 +2,7 @@ import copy
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ray.util.debug import log_once
|
||||
|
||||
@@ -56,10 +57,10 @@ class Searcher:
|
||||
CKPT_FILE_TMPL = "searcher-state-{}.pkl"
|
||||
|
||||
def __init__(self,
|
||||
metric=None,
|
||||
mode=None,
|
||||
max_concurrent=None,
|
||||
use_early_stopped_trials=None):
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
max_concurrent: Optional[int] = None,
|
||||
use_early_stopped_trials: Optional[bool] = None):
|
||||
if use_early_stopped_trials is False:
|
||||
raise DeprecationWarning(
|
||||
"Early stopped trials are now always used. If this is a "
|
||||
@@ -90,7 +91,8 @@ class Searcher:
|
||||
else:
|
||||
raise ValueError("Mode most either be a list or string")
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
"""Pass search properties to searcher.
|
||||
|
||||
This method acts as an alternative to instantiating search algorithms
|
||||
@@ -106,7 +108,7 @@ class Searcher:
|
||||
"""
|
||||
return False
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
def on_trial_result(self, trial_id: str, result: Dict):
|
||||
"""Optional notification for result during training.
|
||||
|
||||
Note that by default, the result dict may include NaNs or
|
||||
@@ -124,7 +126,10 @@ class Searcher:
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
"""Notification for the completion of trial.
|
||||
|
||||
Typically, this method is used for notifying the underlying
|
||||
@@ -143,7 +148,7 @@ class Searcher:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str) -> Optional[Dict]:
|
||||
"""Queries the algorithm to retrieve the next set of parameters.
|
||||
|
||||
Arguments:
|
||||
@@ -159,7 +164,7 @@ class Searcher:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, checkpoint_path):
|
||||
def save(self, checkpoint_path: str):
|
||||
"""Save state to path for this search algorithm.
|
||||
|
||||
Args:
|
||||
@@ -190,7 +195,7 @@ class Searcher:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
def restore(self, checkpoint_path: str):
|
||||
"""Restore state for this search algorithm
|
||||
|
||||
|
||||
@@ -213,13 +218,13 @@ class Searcher:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_state(self):
|
||||
def get_state(self) -> Dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_state(self, state):
|
||||
def set_state(self, state: Dict):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_to_dir(self, checkpoint_dir, session_str="default"):
|
||||
def save_to_dir(self, checkpoint_dir: str, session_str: str = "default"):
|
||||
"""Automatically saves the given searcher to the checkpoint_dir.
|
||||
|
||||
This is automatically used by tune.run during a Tune job.
|
||||
@@ -246,7 +251,7 @@ class Searcher:
|
||||
os.path.join(checkpoint_dir,
|
||||
self.CKPT_FILE_TMPL.format(session_str)))
|
||||
|
||||
def restore_from_dir(self, checkpoint_dir):
|
||||
def restore_from_dir(self, checkpoint_dir: str):
|
||||
"""Restores the state of a searcher from a given checkpoint_dir.
|
||||
|
||||
Typically, you should use this function to restore from an
|
||||
@@ -277,12 +282,12 @@ class Searcher:
|
||||
self.restore(most_recent_checkpoint)
|
||||
|
||||
@property
|
||||
def metric(self):
|
||||
def metric(self) -> str:
|
||||
"""The training result objective value attribute."""
|
||||
return self._metric
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
def mode(self) -> str:
|
||||
"""Specifies if minimizing or maximizing the metric."""
|
||||
return self._mode
|
||||
|
||||
@@ -308,7 +313,10 @@ class ConcurrencyLimiter(Searcher):
|
||||
tune.run(trainable, search_alg=search_alg)
|
||||
"""
|
||||
|
||||
def __init__(self, searcher, max_concurrent, batch=False):
|
||||
def __init__(self,
|
||||
searcher: Searcher,
|
||||
max_concurrent: int,
|
||||
batch: bool = False):
|
||||
assert type(max_concurrent) is int and max_concurrent > 0
|
||||
self.searcher = searcher
|
||||
self.max_concurrent = max_concurrent
|
||||
@@ -318,7 +326,7 @@ class ConcurrencyLimiter(Searcher):
|
||||
super(ConcurrencyLimiter, self).__init__(
|
||||
metric=self.searcher.metric, mode=self.searcher.mode)
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str) -> Optional[Dict]:
|
||||
assert trial_id not in self.live_trials, (
|
||||
f"Trial ID {trial_id} must be unique: already found in set.")
|
||||
if len(self.live_trials) >= self.max_concurrent:
|
||||
@@ -333,7 +341,10 @@ class ConcurrencyLimiter(Searcher):
|
||||
self.live_trials.add(trial_id)
|
||||
return suggestion
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
if trial_id not in self.live_trials:
|
||||
return
|
||||
elif self.batch:
|
||||
@@ -353,19 +364,20 @@ class ConcurrencyLimiter(Searcher):
|
||||
trial_id, result=result, error=error)
|
||||
self.live_trials.remove(trial_id)
|
||||
|
||||
def get_state(self):
|
||||
def get_state(self) -> Dict:
|
||||
state = self.__dict__.copy()
|
||||
del state["searcher"]
|
||||
return copy.deepcopy(state)
|
||||
|
||||
def set_state(self, state):
|
||||
def set_state(self, state: Dict):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def on_pause(self, trial_id):
|
||||
def on_pause(self, trial_id: str):
|
||||
self.searcher.on_pause(trial_id)
|
||||
|
||||
def on_unpause(self, trial_id):
|
||||
def on_unpause(self, trial_id: str):
|
||||
self.searcher.on_unpause(trial_id)
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
return self.searcher.set_search_properties(metric, mode, config)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import copy
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
|
||||
import numpy
|
||||
import random
|
||||
|
||||
@@ -9,7 +11,8 @@ from ray.tune.sample import Categorical, Domain, Function
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_variants(unresolved_spec):
|
||||
def generate_variants(
|
||||
unresolved_spec: Dict) -> Generator[Tuple[Dict, Dict], None, None]:
|
||||
"""Generates variants from a spec (dict) with unresolved values.
|
||||
|
||||
There are two types of unresolved values:
|
||||
@@ -45,7 +48,7 @@ def generate_variants(unresolved_spec):
|
||||
yield resolved_vars, spec
|
||||
|
||||
|
||||
def grid_search(values):
|
||||
def grid_search(values: List) -> Dict[str, List]:
|
||||
"""Convenience method for specifying grid search over a value.
|
||||
|
||||
Arguments:
|
||||
@@ -63,7 +66,7 @@ _STANDARD_IMPORTS = {
|
||||
_MAX_RESOLUTION_PASSES = 20
|
||||
|
||||
|
||||
def resolve_nested_dict(nested_dict):
|
||||
def resolve_nested_dict(nested_dict: Dict) -> Dict[Tuple, Any]:
|
||||
"""Flattens a nested dict by joining keys into tuple of paths.
|
||||
|
||||
Can then be passed into `format_vars`.
|
||||
@@ -78,7 +81,7 @@ def resolve_nested_dict(nested_dict):
|
||||
return res
|
||||
|
||||
|
||||
def format_vars(resolved_vars):
|
||||
def format_vars(resolved_vars: Dict) -> str:
|
||||
"""Formats the resolved variable dict into a single string."""
|
||||
out = []
|
||||
for path, value in sorted(resolved_vars.items()):
|
||||
@@ -97,7 +100,7 @@ def format_vars(resolved_vars):
|
||||
return ",".join(out)
|
||||
|
||||
|
||||
def flatten_resolved_vars(resolved_vars):
|
||||
def flatten_resolved_vars(resolved_vars: Dict) -> Dict:
|
||||
"""Formats the resolved variable dict into a mapping of (str -> value)."""
|
||||
flattened_resolved_vars_dict = {}
|
||||
for pieces, value in resolved_vars.items():
|
||||
@@ -108,14 +111,15 @@ def flatten_resolved_vars(resolved_vars):
|
||||
return flattened_resolved_vars_dict
|
||||
|
||||
|
||||
def _clean_value(value):
|
||||
def _clean_value(value: Any) -> str:
|
||||
if isinstance(value, float):
|
||||
return "{:.5}".format(value)
|
||||
else:
|
||||
return str(value).replace("/", "_")
|
||||
|
||||
|
||||
def parse_spec_vars(spec):
|
||||
def parse_spec_vars(spec: Dict) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[
|
||||
Tuple, Any]], List[Tuple[Tuple, Any]]]:
|
||||
resolved, unresolved = _split_resolved_unresolved_values(spec)
|
||||
resolved_vars = list(resolved.items())
|
||||
|
||||
@@ -134,7 +138,7 @@ def parse_spec_vars(spec):
|
||||
return resolved_vars, domain_vars, grid_vars
|
||||
|
||||
|
||||
def _generate_variants(spec):
|
||||
def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]:
|
||||
spec = copy.deepcopy(spec)
|
||||
_, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
@@ -159,19 +163,20 @@ def _generate_variants(spec):
|
||||
yield resolved_vars, spec
|
||||
|
||||
|
||||
def assign_value(spec, path, value):
|
||||
def assign_value(spec: Dict, path: Tuple, value: Any):
|
||||
for k in path[:-1]:
|
||||
spec = spec[k]
|
||||
spec[path[-1]] = value
|
||||
|
||||
|
||||
def _get_value(spec, path):
|
||||
def _get_value(spec: Dict, path: Tuple) -> Any:
|
||||
for k in path:
|
||||
spec = spec[k]
|
||||
return spec
|
||||
|
||||
|
||||
def _resolve_domain_vars(spec, domain_vars):
|
||||
def _resolve_domain_vars(spec: Dict,
|
||||
domain_vars: List[Tuple[Tuple, Domain]]) -> Dict:
|
||||
resolved = {}
|
||||
error = True
|
||||
num_passes = 0
|
||||
@@ -197,7 +202,8 @@ def _resolve_domain_vars(spec, domain_vars):
|
||||
return resolved
|
||||
|
||||
|
||||
def _grid_search_generator(unresolved_spec, grid_vars):
|
||||
def _grid_search_generator(unresolved_spec: Dict,
|
||||
grid_vars: List) -> Generator[Dict, None, None]:
|
||||
value_indices = [0] * len(grid_vars)
|
||||
|
||||
def increment(i):
|
||||
@@ -225,12 +231,12 @@ def _grid_search_generator(unresolved_spec, grid_vars):
|
||||
break
|
||||
|
||||
|
||||
def _is_resolved(v):
|
||||
def _is_resolved(v) -> bool:
|
||||
resolved, _ = _try_resolve(v)
|
||||
return resolved
|
||||
|
||||
|
||||
def _try_resolve(v):
|
||||
def _try_resolve(v) -> Tuple[bool, Any]:
|
||||
if isinstance(v, Domain):
|
||||
# Domain to sample from
|
||||
return False, v
|
||||
@@ -249,7 +255,8 @@ def _try_resolve(v):
|
||||
return True, v
|
||||
|
||||
|
||||
def _split_resolved_unresolved_values(spec):
|
||||
def _split_resolved_unresolved_values(
|
||||
spec: Dict) -> Tuple[Dict[Tuple, Any], Dict[Tuple, Any]]:
|
||||
resolved_vars = {}
|
||||
unresolved_vars = {}
|
||||
for k, v in spec.items():
|
||||
@@ -278,11 +285,11 @@ def _split_resolved_unresolved_values(spec):
|
||||
return resolved_vars, unresolved_vars
|
||||
|
||||
|
||||
def _unresolved_values(spec):
|
||||
def _unresolved_values(spec: Dict) -> Dict[Tuple, Any]:
|
||||
return _split_resolved_unresolved_values(spec)[1]
|
||||
|
||||
|
||||
def has_unresolved_values(spec):
|
||||
def has_unresolved_values(spec: Dict) -> bool:
|
||||
return True if _unresolved_values(spec) else False
|
||||
|
||||
|
||||
@@ -303,5 +310,5 @@ class _UnresolvedAccessGuard(dict):
|
||||
|
||||
|
||||
class RecursiveDependencyError(Exception):
|
||||
def __init__(self, msg):
|
||||
def __init__(self, msg: str):
|
||||
Exception.__init__(self, msg)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import copy
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.tune.sample import Categorical, Float, Integer, Quantized, Uniform
|
||||
from ray.tune.sample import Categorical, Domain, Float, Integer, Quantized, \
|
||||
Uniform
|
||||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils.util import unflatten_dict
|
||||
from zoopt import ValueType
|
||||
@@ -106,11 +107,11 @@ class ZOOptSearch(Searcher):
|
||||
optimizer = None
|
||||
|
||||
def __init__(self,
|
||||
algo="asracos",
|
||||
budget=None,
|
||||
dim_dict=None,
|
||||
metric=None,
|
||||
mode=None,
|
||||
algo: str = "asracos",
|
||||
budget: Optional[int] = None,
|
||||
dim_dict: Optional[Dict] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
**kwargs):
|
||||
assert zoopt is not None, "Zoopt not found - please install zoopt."
|
||||
assert budget is not None, "`budget` should not be None!"
|
||||
@@ -154,7 +155,8 @@ class ZOOptSearch(Searcher):
|
||||
from zoopt.algos.opt_algorithms.racos.sracos import SRacosTune
|
||||
self.optimizer = SRacosTune(dimension=dim, parameter=par)
|
||||
|
||||
def set_search_properties(self, metric, mode, config):
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
if self._dim_dict:
|
||||
return False
|
||||
space = self.convert_search_space(config)
|
||||
@@ -173,7 +175,7 @@ class ZOOptSearch(Searcher):
|
||||
self.setup_zoopt()
|
||||
return True
|
||||
|
||||
def suggest(self, trial_id):
|
||||
def suggest(self, trial_id: str) -> Optional[Dict]:
|
||||
if not self._dim_dict or not self.optimizer:
|
||||
raise RuntimeError(
|
||||
"Trying to sample a configuration from {}, but no search "
|
||||
@@ -189,7 +191,10 @@ class ZOOptSearch(Searcher):
|
||||
self._live_trial_mapping[trial_id] = new_trial
|
||||
return unflatten_dict(new_trial)
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
def on_trial_complete(self,
|
||||
trial_id: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: bool = False):
|
||||
"""Notification for the completion of trial."""
|
||||
if result:
|
||||
_solution = self.solution_dict[str(trial_id)]
|
||||
@@ -200,18 +205,18 @@ class ZOOptSearch(Searcher):
|
||||
|
||||
del self._live_trial_mapping[trial_id]
|
||||
|
||||
def save(self, checkpoint_path):
|
||||
def save(self, checkpoint_path: str):
|
||||
trials_object = self.optimizer
|
||||
with open(checkpoint_path, "wb") as output:
|
||||
pickle.dump(trials_object, output)
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
def restore(self, checkpoint_path: str):
|
||||
with open(checkpoint_path, "rb") as input:
|
||||
trials_object = pickle.load(input)
|
||||
self.optimizer = trials_object
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict):
|
||||
def convert_search_space(spec: Dict) -> Dict[str, Tuple]:
|
||||
spec = copy.deepcopy(spec)
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
@@ -223,7 +228,7 @@ class ZOOptSearch(Searcher):
|
||||
"Grid search parameters cannot be automatically converted "
|
||||
"to a ZOOpt search space.")
|
||||
|
||||
def resolve_value(domain):
|
||||
def resolve_value(domain: Domain) -> Tuple:
|
||||
quantize = None
|
||||
|
||||
sampler = domain.get_sampler()
|
||||
|
||||
Reference in New Issue
Block a user