mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 04:23:03 +08:00
[tune] clean up population based training prototype (#1478)
* patch up pbt * Sat Jan 27 01:00:03 PST 2018 * Sat Jan 27 01:04:14 PST 2018 * Sat Jan 27 01:04:21 PST 2018 * Sat Jan 27 01:15:15 PST 2018 * Sat Jan 27 01:15:42 PST 2018 * Sat Jan 27 01:16:14 PST 2018 * Sat Jan 27 01:38:42 PST 2018 * Sat Jan 27 01:39:21 PST 2018 * add pbt * Sat Jan 27 01:41:19 PST 2018 * Sat Jan 27 01:44:21 PST 2018 * Sat Jan 27 01:45:46 PST 2018 * Sat Jan 27 16:54:42 PST 2018 * Sat Jan 27 16:57:53 PST 2018 * clean up test * Sat Jan 27 18:01:15 PST 2018 * Sat Jan 27 18:02:54 PST 2018 * Sat Jan 27 18:11:18 PST 2018 * Sat Jan 27 18:11:55 PST 2018 * Sat Jan 27 18:14:09 PST 2018 * review * try out a ppo example * some tweaks to ppo example * add postprocess hook * Sun Jan 28 15:00:40 PST 2018 * clean up custom explore fn * Sun Jan 28 15:10:21 PST 2018 * Sun Jan 28 15:14:53 PST 2018 * Sun Jan 28 15:17:04 PST 2018 * Sun Jan 28 15:33:13 PST 2018 * Sun Jan 28 15:56:40 PST 2018 * Sun Jan 28 15:57:36 PST 2018 * Sun Jan 28 16:00:35 PST 2018 * Sun Jan 28 16:02:58 PST 2018 * Sun Jan 28 16:29:50 PST 2018 * Sun Jan 28 16:30:36 PST 2018 * Sun Jan 28 16:31:44 PST 2018 * improve tune doc * concepts * update humanoid * Fri Feb 2 18:03:33 PST 2018 * fix example * show error file
This commit is contained in:
@@ -76,13 +76,13 @@ class PPOEvaluator(Evaluator):
|
||||
# Value function predictions before the policy update.
|
||||
self.prev_vf_preds = tf.placeholder(tf.float32, shape=(None,))
|
||||
|
||||
assert config["sgd_batchsize"] % len(devices) == 0, \
|
||||
"Batch size must be evenly divisible by devices"
|
||||
if is_remote:
|
||||
self.batch_size = config["rollout_batchsize"]
|
||||
self.per_device_batch_size = config["rollout_batchsize"]
|
||||
else:
|
||||
self.batch_size = config["sgd_batchsize"]
|
||||
self.batch_size = int(
|
||||
config["sgd_batchsize"] / len(devices)) * len(devices)
|
||||
assert self.batch_size % len(devices) == 0
|
||||
self.per_device_batch_size = int(self.batch_size / len(devices))
|
||||
|
||||
def build_loss(obs, vtargets, advs, acts, plog, pvf_preds):
|
||||
|
||||
Executable → Regular
@@ -4,6 +4,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
@@ -49,6 +50,10 @@ class MyTrainableClass(Trainable):
|
||||
register_trainable("my_class", MyTrainableClass)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
|
||||
# Hyperband early stopping, configured with `episode_reward_mean` as the
|
||||
@@ -60,7 +65,8 @@ if __name__ == "__main__":
|
||||
run_experiments({
|
||||
"hyperband_test": {
|
||||
"run": "my_class",
|
||||
"repeat": 100,
|
||||
"stop": {"training_iteration": 1 if args.smoke_test else 99999},
|
||||
"repeat": 20,
|
||||
"resources": {"cpu": 1, "gpu": 0},
|
||||
"config": {
|
||||
"width": lambda spec: 10 + int(90 * random.random()),
|
||||
|
||||
Executable
+88
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.tune import Trainable, TrainingResult, register_trainable, \
|
||||
run_experiments
|
||||
from ray.tune.pbt import PopulationBasedTraining
|
||||
|
||||
|
||||
class MyTrainableClass(Trainable):
|
||||
"""Fake agent whose learning rate is determined by dummy factors."""
|
||||
|
||||
def _setup(self):
|
||||
self.timestep = 0
|
||||
self.current_value = 0.0
|
||||
|
||||
def _train(self):
|
||||
time.sleep(0.1)
|
||||
|
||||
# Reward increase is parabolic as a function of factor_2, with a
|
||||
# maxima around factor_1=10.0.
|
||||
self.current_value += max(
|
||||
0.0, random.gauss(5.0 - (self.config["factor_1"] - 10.0)**2, 2.0))
|
||||
|
||||
# Flat increase by factor_2
|
||||
self.current_value += random.gauss(self.config["factor_2"], 1.0)
|
||||
|
||||
# Here we use `episode_reward_mean`, but you can also report other
|
||||
# objectives such as loss or accuracy (see tune/result.py).
|
||||
return TrainingResult(
|
||||
episode_reward_mean=self.current_value, timesteps_this_iter=1)
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps(
|
||||
{"timestep": self.timestep, "value": self.current_value}))
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
data = json.loads(f.read())
|
||||
self.timestep = data["timestep"]
|
||||
self.current_value = data["value"]
|
||||
|
||||
|
||||
register_trainable("my_class", MyTrainableClass)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration", reward_attr="episode_reward_mean",
|
||||
perturbation_interval=10,
|
||||
hyperparam_mutations={
|
||||
# Allow for scaling-based perturbations, with a uniform backing
|
||||
# distribution for resampling.
|
||||
"factor_1": lambda config: random.uniform(0.0, 20.0),
|
||||
# Only allows resampling from this list as a perturbation.
|
||||
"factor_2": [1, 2],
|
||||
})
|
||||
|
||||
# Try to find the best factor 1 and factor 2
|
||||
run_experiments({
|
||||
"pbt_test": {
|
||||
"run": "my_class",
|
||||
"stop": {"training_iteration": 2 if args.smoke_test else 99999},
|
||||
"repeat": 10,
|
||||
"resources": {"cpu": 1, "gpu": 0},
|
||||
"config": {
|
||||
"factor_1": 4.0,
|
||||
"factor_2": 1.0,
|
||||
},
|
||||
}
|
||||
}, scheduler=pbt, verbose=False)
|
||||
Executable
+71
@@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Example of using PBT with RLlib.
|
||||
|
||||
Note that this requires a cluster with at least 8 GPUs in order for all trials
|
||||
to run concurrently, otherwise PBT will round-robin train the trials which
|
||||
is less efficient (or you can set {"gpu": 0} to use CPUs for SGD instead).
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments
|
||||
from ray.tune.pbt import PopulationBasedTraining
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Postprocess the perturbed config to ensure it's still valid
|
||||
def explore(config):
|
||||
# ensure we collect enough timesteps to do sgd
|
||||
if config["timesteps_per_batch"] < config["sgd_batchsize"] * 2:
|
||||
config["timesteps_per_batch"] = config["sgd_batchsize"] * 2
|
||||
# ensure we run at least one sgd iter
|
||||
if config["num_sgd_iter"] < 1:
|
||||
config["num_sgd_iter"] = 1
|
||||
return config
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="time_total_s", reward_attr="episode_reward_mean",
|
||||
perturbation_interval=120,
|
||||
resample_probability=0.25,
|
||||
# Specifies the resampling distributions of these hyperparams
|
||||
hyperparam_mutations={
|
||||
"lambda": lambda config: random.uniform(0.9, 1.0),
|
||||
"clip_param": lambda config: random.uniform(0.01, 0.5),
|
||||
"sgd_stepsize": lambda config: random.uniform(.00001, .001),
|
||||
"num_sgd_iter": lambda config: random.randint(1, 30),
|
||||
"sgd_batchsize": lambda config: random.randint(128, 16384),
|
||||
"timesteps_per_batch":
|
||||
lambda config: random.randint(2000, 160000),
|
||||
},
|
||||
custom_explore_fn=explore)
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"pbt_humanoid_test": {
|
||||
"run": "PPO",
|
||||
"env": "Humanoid-v1",
|
||||
"repeat": 8,
|
||||
"resources": {"cpu": 4, "gpu": 1},
|
||||
"config": {
|
||||
"kl_coeff": 1.0,
|
||||
"num_workers": 8,
|
||||
"devices": ["/gpu:0"],
|
||||
"model": {"free_log_std": True},
|
||||
# These params are tuned from their starting value
|
||||
"lambda": 0.95,
|
||||
"clip_param": 0.2,
|
||||
# Start off with several random variations
|
||||
"sgd_stepsize": lambda spec: random.uniform(.00001, .001),
|
||||
"num_sgd_iter": lambda spec: random.choice([10, 20, 30]),
|
||||
"sgd_batchsize": lambda spec: random.choice([128, 512, 2048]),
|
||||
"timesteps_per_batch":
|
||||
lambda spec: random.choice([10000, 20000, 40000])
|
||||
},
|
||||
},
|
||||
}, scheduler=pbt)
|
||||
@@ -205,7 +205,7 @@ def train(config={'activation': 'relu'}, reporter=None):
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--fast', action='store_true', help='Finish quickly for testing')
|
||||
'--smoke-test', action='store_true', help='Finish quickly for testing')
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
register_trainable('train_mnist', train)
|
||||
@@ -220,7 +220,7 @@ if __name__ == '__main__':
|
||||
},
|
||||
}
|
||||
|
||||
if args.fast:
|
||||
if args.smoke_test:
|
||||
mnist_spec['stop']['training_iteration'] = 2
|
||||
|
||||
ray.init()
|
||||
|
||||
@@ -207,7 +207,7 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
"""Cleans up trial info from bracket if trial errored early."""
|
||||
self.on_trial_remove(trial_runner, trial)
|
||||
|
||||
def choose_trial_to_run(self, trial_runner, *args):
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
"""Fair scheduling within iteration by completion percentage.
|
||||
|
||||
List of trials not used since all trials are tracked as state
|
||||
|
||||
@@ -63,7 +63,6 @@ class UnifiedLogger(Logger):
|
||||
print("TF not installed - cannot log with {}...".format(cls))
|
||||
continue
|
||||
self._loggers.append(cls(self.config, self.logdir, self.uri))
|
||||
print("Unified logger created with logdir '{}'".format(self.logdir))
|
||||
|
||||
def on_result(self, result):
|
||||
for logger in self._loggers:
|
||||
|
||||
@@ -31,7 +31,7 @@ class MedianStoppingRule(FIFOScheduler):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr='time_total_s', reward_attr='episode_reward_mean',
|
||||
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
|
||||
grace_period=60.0, min_samples_required=3, hard_stop=True):
|
||||
FIFOScheduler.__init__(self)
|
||||
self._stopped_trials = set()
|
||||
|
||||
+234
-154
@@ -2,189 +2,269 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import random
|
||||
import math
|
||||
import copy
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.variant_generator import _format_vars
|
||||
|
||||
|
||||
# Parameters are transferred from the top PBT_QUANTILE fraction of trials to
|
||||
# the bottom PBT_QUANTILE fraction.
|
||||
PBT_QUANTILE = 0.25
|
||||
|
||||
|
||||
class PBTTrialState(object):
|
||||
"""Internal PBT state tracked per-trial."""
|
||||
|
||||
def __init__(self, trial):
|
||||
self.orig_tag = trial.experiment_tag
|
||||
self.last_score = None
|
||||
self.last_checkpoint = None
|
||||
self.last_perturbation_time = 0
|
||||
|
||||
def __repr__(self):
|
||||
return str((
|
||||
self.last_score, self.last_checkpoint,
|
||||
self.last_perturbation_time))
|
||||
|
||||
|
||||
def explore(config, mutations, resample_probability, custom_explore_fn):
|
||||
"""Return a config perturbed as specified.
|
||||
|
||||
Args:
|
||||
config (dict): Original hyperparameter configuration.
|
||||
mutations (dict): Specification of mutations to perform as documented
|
||||
in the PopulationBasedTraining scheduler.
|
||||
resample_probability (float): Probability of allowing resampling of a
|
||||
particular variable.
|
||||
custom_explore_fn (func): Custom explore fn applied after built-in
|
||||
config perturbations are.
|
||||
"""
|
||||
new_config = copy.deepcopy(config)
|
||||
for key, distribution in mutations.items():
|
||||
if isinstance(distribution, list):
|
||||
if random.random() < resample_probability:
|
||||
new_config[key] = random.choice(distribution)
|
||||
else:
|
||||
if random.random() < resample_probability:
|
||||
new_config[key] = distribution(config)
|
||||
elif random.random() > 0.5:
|
||||
new_config[key] = config[key] * 1.2
|
||||
else:
|
||||
new_config[key] = config[key] * 0.8
|
||||
if type(config[key]) is int:
|
||||
new_config[key] = int(new_config[key])
|
||||
if custom_explore_fn:
|
||||
new_config = custom_explore_fn(new_config)
|
||||
assert new_config is not None, \
|
||||
"Custom explore fn failed to return new config"
|
||||
print(
|
||||
"[explore] perturbed config from {} -> {}".format(config, new_config))
|
||||
return new_config
|
||||
|
||||
|
||||
def make_experiment_tag(orig_tag, config, mutations):
|
||||
"""Appends perturbed params to the trial name to show in the console."""
|
||||
|
||||
resolved_vars = {}
|
||||
for k in mutations.keys():
|
||||
resolved_vars[("config", k)] = config[k]
|
||||
return "{}@perturbed[{}]".format(orig_tag, _format_vars(resolved_vars))
|
||||
|
||||
|
||||
class PopulationBasedTraining(FIFOScheduler):
|
||||
"""Implements the Population Based Training algorithm as described in the
|
||||
PBT paper (https://arxiv.org/abs/1711.09846)(Experimental):
|
||||
"""Implements the Population Based Training (PBT) algorithm.
|
||||
|
||||
https://deepmind.com/blog/population-based-training-neural-networks
|
||||
|
||||
PBT trains a group of models (or agents) in parallel. Periodically, poorly
|
||||
performing models clone the state of the top performers, and a random
|
||||
mutation is applied to their hyperparameters in the hopes of
|
||||
outperforming the current top models.
|
||||
|
||||
Unlike other hyperparameter search algorithms, PBT mutates hyperparameters
|
||||
during training time. This enables very fast hyperparameter discovery and
|
||||
also automatically discovers good annealing schedules.
|
||||
|
||||
This Ray Tune PBT implementation considers all trials added as part of the
|
||||
PBT population. If the number of trials exceeds the cluster capacity,
|
||||
they will be time-multiplexed as to balance training progress across the
|
||||
population.
|
||||
|
||||
Args:
|
||||
time_attr (str): The TrainingResult attr to use for documenting length
|
||||
of time since last ready() call. Attribute only has to increase
|
||||
monotonically.
|
||||
time_attr (str): The TrainingResult attr to use for comparing time.
|
||||
Note that you can pass in something non-temporal such as
|
||||
`training_iteration` as a measure of progress, the only requirement
|
||||
is that the attribute should increase monotonically.
|
||||
reward_attr (str): The TrainingResult objective value attribute. As
|
||||
with 'time_attr'. this may refer to any objective value that
|
||||
is supposed to increase with time.
|
||||
grace_period (float): Period of time, in which algorithm will not
|
||||
compare model to other models.
|
||||
perturbation_interval (float): Used in the truncation ready function to
|
||||
determine if enough time has passed so that a agent can be tested
|
||||
for readiness.
|
||||
hyperparameter_mutations (dict); Possible values that each
|
||||
hyperparameter can mutate to, as certain hyperparameters
|
||||
only work with certain values.
|
||||
with `time_attr`, this may refer to any objective value. Stopping
|
||||
procedures will use this attribute.
|
||||
perturbation_interval (float): Models will be considered for
|
||||
perturbation at this interval of `time_attr`. Note that
|
||||
perturbation incurs checkpoint overhead, so you shouldn't set this
|
||||
to be too frequent.
|
||||
hyperparam_mutations (dict): Hyperparams to mutate. The format is
|
||||
as follows: for each key, either a list or function can be
|
||||
provided. A list specifies values for a discrete parameter.
|
||||
A function specifies the distribution of a continuous parameter.
|
||||
You must specify at least one of `hyperparam_mutations` or
|
||||
`custom_explore_fn`.
|
||||
resample_probability (float): The probability of resampling from the
|
||||
original distribution when applying `hyperparam_mutations`. If not
|
||||
resampled, the value will be perturbed by a factor of 1.2 or 0.8
|
||||
if continuous, or left unchanged if discrete.
|
||||
custom_explore_fn (func): You can also specify a custom exploration
|
||||
function. This function is invoked as `f(config)` after built-in
|
||||
perturbations from `hyperparam_mutations` are applied, and should
|
||||
return `config` updated as needed. You must specify at least one of
|
||||
`hyperparam_mutations` or `custom_explore_fn`.
|
||||
|
||||
Example:
|
||||
>>> pbt = PopulationBasedTraining(
|
||||
>>> time_attr="training_iteration",
|
||||
>>> reward_attr="episode_reward_mean",
|
||||
>>> perturbation_interval=10, # every 10 `time_attr` units
|
||||
>>> # (training_iterations in this case)
|
||||
>>> hyperparam_mutations={
|
||||
>>> # Allow for scaling-based perturbations, with a uniform
|
||||
>>> # backing distribution for resampling.
|
||||
>>> "factor_1": lambda config: random.uniform(0.0, 20.0),
|
||||
>>> # Only allows resampling from this list as a perturbation.
|
||||
>>> "factor_2": [1, 2],
|
||||
>>> })
|
||||
>>> run_experiments({...}, scheduler=pbt)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr='training_iteration',
|
||||
reward_attr='episode_reward_mean',
|
||||
grace_period=10.0, perturbation_interval=6.0,
|
||||
hyperparameter_mutations=None):
|
||||
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
|
||||
perturbation_interval=60.0, hyperparam_mutations={},
|
||||
resample_probability=0.25, custom_explore_fn=None):
|
||||
if not hyperparam_mutations and not custom_explore_fn:
|
||||
raise TuneError(
|
||||
"You must specify at least one of `hyperparam_mutations` or "
|
||||
"`custom_explore_fn` to use PBT.")
|
||||
FIFOScheduler.__init__(self)
|
||||
self._completed_trials = set()
|
||||
self._results = collections.defaultdict(list)
|
||||
self._last_perturbation_time = {}
|
||||
self._grace_period = grace_period
|
||||
self._reward_attr = reward_attr
|
||||
self._time_attr = time_attr
|
||||
|
||||
self._hyperparameter_mutations = hyperparameter_mutations
|
||||
self._perturbation_interval = perturbation_interval
|
||||
self._checkpoint_paths = {}
|
||||
self._hyperparam_mutations = hyperparam_mutations
|
||||
self._resample_probability = resample_probability
|
||||
self._trial_state = {}
|
||||
self._custom_explore_fn = custom_explore_fn
|
||||
|
||||
# Metrics
|
||||
self._num_checkpoints = 0
|
||||
self._num_perturbations = 0
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
self._trial_state[trial] = PBTTrialState(trial)
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
|
||||
self._results[trial].append(result)
|
||||
time = getattr(result, self._time_attr)
|
||||
# check model is ready to undergo mutation, based on user
|
||||
# function or default function
|
||||
self._checkpoint_paths[trial] = trial.checkpoint()
|
||||
if time > self._grace_period:
|
||||
ready = self._truncation_ready(result, trial, time)
|
||||
else:
|
||||
ready = False
|
||||
if ready:
|
||||
print("ready to undergo mutation")
|
||||
print("----")
|
||||
print("Current Trial is: {0}".format(trial))
|
||||
# get best trial for current time
|
||||
best_trial = self._get_best_trial(result, time)
|
||||
print("Best Trial is: {0}".format(best_trial))
|
||||
print(best_trial.config)
|
||||
state = self._trial_state[trial]
|
||||
|
||||
if time - state.last_perturbation_time < self._perturbation_interval:
|
||||
return TrialScheduler.CONTINUE # avoid checkpoint overhead
|
||||
|
||||
score = getattr(result, self._reward_attr)
|
||||
state.last_score = score
|
||||
state.last_perturbation_time = time
|
||||
lower_quantile, upper_quantile = self._quantiles()
|
||||
|
||||
if trial in upper_quantile:
|
||||
state.last_checkpoint = trial.checkpoint(to_object_store=True)
|
||||
self._num_checkpoints += 1
|
||||
else:
|
||||
state.last_checkpoint = None # not a top trial
|
||||
|
||||
if trial in lower_quantile:
|
||||
trial_to_clone = random.choice(upper_quantile)
|
||||
assert trial is not trial_to_clone
|
||||
self._exploit(trial, trial_to_clone)
|
||||
|
||||
for trial in trial_runner.get_trials():
|
||||
if trial.status in [Trial.PENDING, Trial.PAUSED]:
|
||||
return TrialScheduler.PAUSE # yield time to other trials
|
||||
|
||||
# if current trial is the best trial (as in same hyperparameters),
|
||||
# do nothing
|
||||
if trial.config == best_trial.config:
|
||||
print("current trial is best trial")
|
||||
return TrialScheduler.CONTINUE
|
||||
else:
|
||||
self._exploit(self._hyperparameter_mutations, best_trial,
|
||||
trial, trial_runner, time)
|
||||
return TrialScheduler.CONTINUE
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
self._results[trial].append(result)
|
||||
self._completed_trials.add(trial)
|
||||
def _exploit(self, trial, trial_to_clone):
|
||||
"""Transfers perturbed state from trial_to_clone -> trial."""
|
||||
|
||||
def _exploit(self, hyperparameter_mutations, best_trial,
|
||||
trial, trial_runner, time):
|
||||
trial.stop()
|
||||
mutate_string = "_mutated@" + str(time)
|
||||
hyperparams = copy.deepcopy(best_trial.config)
|
||||
hyperparams = self._explore(hyperparams, hyperparameter_mutations,
|
||||
best_trial)
|
||||
print("new hyperparameter configuration: {0}".format(hyperparams))
|
||||
checkpoint = self._checkpoint_paths[best_trial]
|
||||
trial._checkpoint_path = checkpoint
|
||||
trial.config = hyperparams
|
||||
trial.experiment_tag = trial.experiment_tag + mutate_string
|
||||
trial.start()
|
||||
trial_state = self._trial_state[trial]
|
||||
new_state = self._trial_state[trial_to_clone]
|
||||
if not new_state.last_checkpoint:
|
||||
print("[pbt] warn: no checkpoint for trial, skip exploit", trial)
|
||||
return
|
||||
new_config = explore(
|
||||
trial_to_clone.config, self._hyperparam_mutations,
|
||||
self._resample_probability, self._custom_explore_fn)
|
||||
print(
|
||||
"[exploit] transferring weights from trial "
|
||||
"{} (score {}) -> {} (score {})".format(
|
||||
trial_to_clone, new_state.last_score, trial,
|
||||
trial_state.last_score))
|
||||
# TODO(ekl) restarting the trial is expensive. We should implement a
|
||||
# lighter way reset() method that can alter the trial config.
|
||||
trial.stop(stop_logger=False)
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = make_experiment_tag(
|
||||
trial_state.orig_tag, new_config, self._hyperparam_mutations)
|
||||
trial.start(new_state.last_checkpoint)
|
||||
self._num_perturbations += 1
|
||||
# Transfer over the last perturbation time as well
|
||||
trial_state.last_perturbation_time = new_state.last_perturbation_time
|
||||
|
||||
def _explore(self, hyperparams, hyperparameter_mutations, best_trial):
|
||||
if hyperparameter_mutations is not None:
|
||||
hyperparams = {
|
||||
param: random.choice(hyperparameter_mutations[param])
|
||||
for param in hyperparams
|
||||
if param != "env" and param in hyperparameter_mutations
|
||||
}
|
||||
for param in best_trial.config:
|
||||
if param not in hyperparameter_mutations and param != "env":
|
||||
hyperparams[param] = math.ceil(
|
||||
(best_trial.config[param]
|
||||
* random.choice([0.8, 1.2])/2.)) * 2
|
||||
def _quantiles(self):
|
||||
"""Returns trials in the lower and upper `quantile` of the population.
|
||||
|
||||
If there is not enough data to compute this, returns empty lists."""
|
||||
|
||||
trials = []
|
||||
for trial, state in self._trial_state.items():
|
||||
if state.last_score is not None and not trial.is_finished():
|
||||
trials.append(trial)
|
||||
trials.sort(key=lambda t: self._trial_state[t].last_score)
|
||||
|
||||
if len(trials) <= 1:
|
||||
return [], []
|
||||
else:
|
||||
hyperparams = {
|
||||
param: math.ceil(
|
||||
(random.choice([0.8, 1.2]) *
|
||||
hyperparams[param])/2.) * 2
|
||||
for param in hyperparams
|
||||
if param != "env"
|
||||
}
|
||||
hyperparams["env"] = best_trial.config["env"]
|
||||
return hyperparams
|
||||
return (
|
||||
trials[:int(math.ceil(len(trials)*PBT_QUANTILE))],
|
||||
trials[int(math.floor(-len(trials)*PBT_QUANTILE)):])
|
||||
|
||||
def _truncation_ready(self, result, trial, time):
|
||||
# function checks if appropriate time has passed
|
||||
# and trial is in the bottom 20% of all trials, and if so, is ready
|
||||
if trial not in self._last_perturbation_time:
|
||||
print("added trial to time tracker")
|
||||
self._last_perturbation_time[trial] = (time)
|
||||
else:
|
||||
time_since_last = time - self._last_perturbation_time[trial]
|
||||
if time_since_last >= self._perturbation_interval:
|
||||
self._last_perturbation_time[trial] = time
|
||||
sorted_result_keys = sorted(
|
||||
self._results, key=lambda x:
|
||||
max(self._results.get(x) if self._results.get(x) else [0])
|
||||
)
|
||||
max_index = int(round(len(sorted_result_keys) * 0.2))
|
||||
for i in range(0, max_index):
|
||||
if trial == sorted_result_keys[i]:
|
||||
print("{0} is in the bottomn 20 percent of {1}, \
|
||||
truncation is ready".format(
|
||||
trial,
|
||||
[x.experiment_tag for x in sorted_result_keys]
|
||||
))
|
||||
return True
|
||||
print("{0} is not in the bottomn 20 percent of {1}, \
|
||||
truncation is not ready".format(
|
||||
trial,
|
||||
[x.experiment_tag for x in sorted_result_keys]
|
||||
))
|
||||
else:
|
||||
print("not enough time has passed since last mutation")
|
||||
return False
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
"""Ensures all trials get fair share of time (as defined by time_attr).
|
||||
|
||||
def _get_best_trial(self, result, time):
|
||||
results_at_time = {}
|
||||
for trial in self._results:
|
||||
results_at_time[trial] = [
|
||||
getattr(r, self._reward_attr)
|
||||
for r in self._results[trial]
|
||||
if getattr(r, self._time_attr) <= time
|
||||
]
|
||||
print("Results at {0}: {1}".format(time, results_at_time))
|
||||
return max(results_at_time, key=lambda x:
|
||||
max(results_at_time.get(x)
|
||||
if results_at_time.get(x) else [0]))
|
||||
This enables the PBT scheduler to support a greater number of
|
||||
concurrent trials than can fit in the cluster at any given time.
|
||||
"""
|
||||
|
||||
def _is_empty(self, x):
|
||||
if x:
|
||||
return False
|
||||
return True
|
||||
candidates = []
|
||||
for trial in trial_runner.get_trials():
|
||||
if trial.status in [Trial.PENDING, Trial.PAUSED] and \
|
||||
trial_runner.has_resources(trial.resources):
|
||||
candidates.append(trial)
|
||||
candidates.sort(
|
||||
key=lambda trial: self._trial_state[trial].last_perturbation_time)
|
||||
return candidates[0] if candidates else None
|
||||
|
||||
def reset_stats(self):
|
||||
self._num_perturbations = 0
|
||||
self._num_checkpoints = 0
|
||||
|
||||
def last_scores(self, trials):
|
||||
scores = []
|
||||
for trial in trials:
|
||||
state = self._trial_state[trial]
|
||||
if state.last_score is not None and not trial.is_finished():
|
||||
scores.append(state.last_score)
|
||||
return scores
|
||||
|
||||
def debug_string(self):
|
||||
|
||||
min_time = 0
|
||||
best_trial = None
|
||||
for trial in self._completed_trials:
|
||||
last_result = self._results[trial][-1]
|
||||
if (getattr(last_result, self._time_attr)
|
||||
< min_time or min_time == 0):
|
||||
min_time = getattr(last_result, self._time_attr)
|
||||
best_trial = trial
|
||||
if best_trial is not None:
|
||||
return ("The Best Trial is currently {0} finishing in {1} iterations, \
|
||||
with the hyperparameters of {2}".format(
|
||||
best_trial, min_time, best_trial.config
|
||||
)
|
||||
)
|
||||
else:
|
||||
return "PBT has started"
|
||||
return "PopulationBasedTraining: {} checkpoints, {} perturbs".format(
|
||||
self._num_checkpoints, self._num_perturbations)
|
||||
|
||||
@@ -84,10 +84,14 @@ TrainingResult = namedtuple("TrainingResult", [
|
||||
|
||||
# (Auto-filled) The hostname of the machine hosting the training process.
|
||||
"hostname",
|
||||
|
||||
# (Auto=filled) The current hyperparameter configuration.
|
||||
"config",
|
||||
])
|
||||
|
||||
|
||||
def pretty_print(result):
|
||||
result = result._replace(config=None) # drop config from pretty print
|
||||
out = {}
|
||||
for k, v in result._asdict().items():
|
||||
if v is not None:
|
||||
|
||||
@@ -0,0 +1,579 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
|
||||
from ray.tune import Trainable, TuneError
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.variant_generator import generate_trials, grid_search, \
|
||||
RecursiveDependencyError
|
||||
|
||||
|
||||
class TrainableFunctionApiTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testRegisterEnv(self):
|
||||
register_env("foo", lambda: None)
|
||||
self.assertRaises(TypeError, lambda: register_env("foo", 2))
|
||||
|
||||
def testRegisterTrainable(self):
|
||||
def train(config, reporter):
|
||||
pass
|
||||
|
||||
class A(object):
|
||||
pass
|
||||
|
||||
class B(Trainable):
|
||||
pass
|
||||
|
||||
register_trainable("foo", train)
|
||||
register_trainable("foo", B)
|
||||
self.assertRaises(TypeError, lambda: register_trainable("foo", B()))
|
||||
self.assertRaises(TypeError, lambda: register_trainable("foo", A))
|
||||
|
||||
def testRewriteEnv(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=1)
|
||||
register_trainable("f1", train)
|
||||
|
||||
[trial] = run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"env": "CartPole-v0",
|
||||
}})
|
||||
self.assertEqual(trial.config["env"], "CartPole-v0")
|
||||
|
||||
def testConfigPurity(self):
|
||||
def train(config, reporter):
|
||||
assert config == {"a": "b"}, config
|
||||
reporter(timesteps_total=1)
|
||||
register_trainable("f1", train)
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {"a": "b"},
|
||||
}})
|
||||
|
||||
def testLogdir(self):
|
||||
def train(config, reporter):
|
||||
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
|
||||
reporter(timesteps_total=1)
|
||||
register_trainable("f1", train)
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"local_dir": "/tmp/logdir",
|
||||
"config": {"a": "b"},
|
||||
}})
|
||||
|
||||
def testLongFilename(self):
|
||||
def train(config, reporter):
|
||||
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
|
||||
reporter(timesteps_total=1)
|
||||
register_trainable("f1", train)
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"local_dir": "/tmp/logdir",
|
||||
"config": {
|
||||
"a" * 50: lambda spec: 5.0 / 7,
|
||||
"b" * 50: lambda spec: "long" * 40},
|
||||
}})
|
||||
|
||||
def testBadParams(self):
|
||||
def f():
|
||||
run_experiments({"foo": {}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams2(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"bah": "this param is not allowed",
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams3(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": grid_search("invalid grid search"),
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams4(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "asdf",
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams5(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "PPO",
|
||||
"stop": {"asdf": 1}
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams6(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "PPO",
|
||||
"resources": {"asdf": 1}
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter()
|
||||
register_trainable("f1", train)
|
||||
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testEarlyReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=100, done=True)
|
||||
time.sleep(99999)
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 100)
|
||||
|
||||
def testAbruptReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=100)
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 100)
|
||||
|
||||
def testErrorReturn(self):
|
||||
def train(config, reporter):
|
||||
raise Exception("uh oh")
|
||||
register_trainable("f1", train)
|
||||
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testSuccess(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 99)
|
||||
|
||||
|
||||
class VariantGeneratorTest(unittest.TestCase):
|
||||
def testParseToTrials(self):
|
||||
trials = generate_trials({
|
||||
"run": "PPO",
|
||||
"repeat": 2,
|
||||
"config": {
|
||||
"env": "Pong-v0",
|
||||
"foo": "bar"
|
||||
},
|
||||
}, "tune-pong")
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 2)
|
||||
self.assertEqual(str(trials[0]), "PPO_Pong-v0_0")
|
||||
self.assertEqual(trials[0].config, {"foo": "bar", "env": "Pong-v0"})
|
||||
self.assertEqual(trials[0].trainable_name, "PPO")
|
||||
self.assertEqual(trials[0].experiment_tag, "0")
|
||||
self.assertEqual(
|
||||
trials[0].local_dir,
|
||||
os.path.join(DEFAULT_RESULTS_DIR, "tune-pong"))
|
||||
self.assertEqual(trials[1].experiment_tag, "1")
|
||||
|
||||
def testEval(self):
|
||||
trials = generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"foo": {
|
||||
"eval": "2 + 2"
|
||||
},
|
||||
},
|
||||
})
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 1)
|
||||
self.assertEqual(trials[0].config, {"foo": 4})
|
||||
self.assertEqual(trials[0].experiment_tag, "0_foo=4")
|
||||
|
||||
def testGridSearch(self):
|
||||
trials = generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"bar": {
|
||||
"grid_search": [True, False]
|
||||
},
|
||||
"foo": {
|
||||
"grid_search": [1, 2, 3]
|
||||
},
|
||||
},
|
||||
})
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 6)
|
||||
self.assertEqual(trials[0].config, {"bar": True, "foo": 1})
|
||||
self.assertEqual(trials[0].experiment_tag, "0_bar=True,foo=1")
|
||||
self.assertEqual(trials[1].config, {"bar": False, "foo": 1})
|
||||
self.assertEqual(trials[1].experiment_tag, "1_bar=False,foo=1")
|
||||
self.assertEqual(trials[2].config, {"bar": True, "foo": 2})
|
||||
self.assertEqual(trials[3].config, {"bar": False, "foo": 2})
|
||||
self.assertEqual(trials[4].config, {"bar": True, "foo": 3})
|
||||
self.assertEqual(trials[5].config, {"bar": False, "foo": 3})
|
||||
|
||||
def testGridSearchAndEval(self):
|
||||
trials = generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"qux": lambda spec: 2 + 2,
|
||||
"bar": grid_search([True, False]),
|
||||
"foo": grid_search([1, 2, 3]),
|
||||
},
|
||||
})
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 6)
|
||||
self.assertEqual(trials[0].config, {"bar": True, "foo": 1, "qux": 4})
|
||||
self.assertEqual(trials[0].experiment_tag, "0_bar=True,foo=1,qux=4")
|
||||
|
||||
def testConditionResolution(self):
|
||||
trials = generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"x": 1,
|
||||
"y": lambda spec: spec.config.x + 1,
|
||||
"z": lambda spec: spec.config.y + 1,
|
||||
},
|
||||
})
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 1)
|
||||
self.assertEqual(trials[0].config, {"x": 1, "y": 2, "z": 3})
|
||||
|
||||
def testDependentLambda(self):
|
||||
trials = generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"x": grid_search([1, 2]),
|
||||
"y": lambda spec: spec.config.x * 100,
|
||||
},
|
||||
})
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 2)
|
||||
self.assertEqual(trials[0].config, {"x": 1, "y": 100})
|
||||
self.assertEqual(trials[1].config, {"x": 2, "y": 200})
|
||||
|
||||
def testDependentGridSearch(self):
|
||||
trials = generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"x": grid_search([
|
||||
lambda spec: spec.config.y * 100,
|
||||
lambda spec: spec.config.y * 200
|
||||
]),
|
||||
"y": lambda spec: 1,
|
||||
},
|
||||
})
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 2)
|
||||
self.assertEqual(trials[0].config, {"x": 100, "y": 1})
|
||||
self.assertEqual(trials[1].config, {"x": 200, "y": 1})
|
||||
|
||||
def testRecursiveDep(self):
|
||||
try:
|
||||
list(generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"foo": lambda spec: spec.config.foo,
|
||||
},
|
||||
}))
|
||||
except RecursiveDependencyError as e:
|
||||
assert "`foo` recursively depends on" in str(e), e
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
class TrialRunnerTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testTrialStatus(self):
|
||||
ray.init()
|
||||
trial = Trial("__fake")
|
||||
self.assertEqual(trial.status, Trial.PENDING)
|
||||
trial.start()
|
||||
self.assertEqual(trial.status, Trial.RUNNING)
|
||||
trial.stop()
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
trial.stop(error=True)
|
||||
self.assertEqual(trial.status, Trial.ERROR)
|
||||
|
||||
def testExperimentTagTruncation(self):
|
||||
ray.init()
|
||||
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=1)
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
experiments = {"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"a" * 50: lambda spec: 5.0 / 7,
|
||||
"b" * 50: lambda spec: "long" * 40},
|
||||
}}
|
||||
|
||||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
trial.start()
|
||||
self.assertLessEqual(len(trial.logdir), 200)
|
||||
trial.stop()
|
||||
|
||||
def testTrialErrorOnStart(self):
|
||||
ray.init()
|
||||
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
trial = Trial("asdf")
|
||||
try:
|
||||
trial.start()
|
||||
except Exception as e:
|
||||
self.assertIn("a class", str(e))
|
||||
|
||||
def testResourceScheduler(self):
|
||||
ray.init(num_cpus=4, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 1},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.TERMINATED)
|
||||
|
||||
def testMultiStepRun(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 5},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
|
||||
def testErrorHandling(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 1},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
trials = [
|
||||
Trial("asdf", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
|
||||
def testCheckpointing(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 1},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
|
||||
path = trials[0].checkpoint()
|
||||
kwargs["restore_path"] = path
|
||||
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1)
|
||||
self.addCleanup(os.remove, path)
|
||||
|
||||
def testResultDone(self):
|
||||
"""Tests that last_result is marked `done` after trial is complete."""
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 2},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
self.assertNotEqual(trials[0].last_result.done, True)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].last_result.done, True)
|
||||
|
||||
def testPauseThenResume(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 2},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].runner.get_info.remote()), None)
|
||||
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
|
||||
trials[0].pause()
|
||||
self.assertEqual(trials[0].status, Trial.PAUSED)
|
||||
|
||||
trials[0].resume()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].runner.get_info.remote()), 1)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
|
||||
def testStopTrial(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 5},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||
|
||||
# Stop trial while running
|
||||
runner.stop_trial(trials[0])
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[-1].status, Trial.PENDING)
|
||||
|
||||
# Stop trial while pending
|
||||
runner.stop_trial(trials[-1])
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[-1].status, Trial.TERMINATED)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[2].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[-1].status, Trial.TERMINATED)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
@@ -0,0 +1,722 @@
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.pbt import PopulationBasedTraining
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_scheduler import TrialScheduler
|
||||
|
||||
from ray.rllib import _register_all
|
||||
_register_all()
|
||||
|
||||
|
||||
def result(t, rew):
|
||||
return TrainingResult(time_total_s=t,
|
||||
episode_reward_mean=rew,
|
||||
training_iteration=int(t))
|
||||
|
||||
|
||||
class EarlyStoppingSuite(unittest.TestCase):
|
||||
def basicSetup(self, rule):
|
||||
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
|
||||
for i in range(10):
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t1, result(i, i * 100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
for i in range(5):
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(i, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
return t1, t2
|
||||
|
||||
def testMedianStoppingConstantPerf(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(5, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(6, 0)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(10, 450)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingOnCompleteOnly(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(100, 0)),
|
||||
TrialScheduler.CONTINUE)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result(101, 0)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingGracePeriod(self):
|
||||
rule = MedianStoppingRule(grace_period=2.5, min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
t3 = Trial("PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(1, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(2, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingMinSamples(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
t3 = Trial("PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingUsesMedian(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
t3 = Trial("PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(1, 260)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(2, 260)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingSoftStop(self):
|
||||
rule = MedianStoppingRule(
|
||||
grace_period=0, min_samples_required=1, hard_stop=False)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
t3 = Trial("PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(1, 260)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(2, 260)),
|
||||
TrialScheduler.PAUSE)
|
||||
|
||||
def testAlternateMetrics(self):
|
||||
def result2(t, rew):
|
||||
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
|
||||
|
||||
rule = MedianStoppingRule(
|
||||
grace_period=0, min_samples_required=1,
|
||||
time_attr='training_iteration', reward_attr='neg_mean_loss')
|
||||
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
|
||||
for i in range(10):
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t1, result2(i, i * 100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
for i in range(5):
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result2(i, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
rule.on_trial_complete(None, t1, result2(10, 1000))
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result2(5, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t2, result2(6, 0)),
|
||||
TrialScheduler.CONTINUE)
|
||||
|
||||
|
||||
class _MockTrialRunner():
|
||||
def __init__(self, scheduler):
|
||||
self._scheduler_alg = scheduler
|
||||
self.trials = []
|
||||
|
||||
def process_action(self, trial, action):
|
||||
if action == TrialScheduler.CONTINUE:
|
||||
pass
|
||||
elif action == TrialScheduler.PAUSE:
|
||||
self._pause_trial(trial)
|
||||
elif action == TrialScheduler.STOP:
|
||||
trial.stop()
|
||||
|
||||
def stop_trial(self, trial):
|
||||
if trial.status in [Trial.ERROR, Trial.TERMINATED]:
|
||||
return
|
||||
elif trial.status in [Trial.PENDING, Trial.PAUSED]:
|
||||
self._scheduler_alg.on_trial_remove(self, trial)
|
||||
else:
|
||||
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result(100, 10))
|
||||
|
||||
def add_trial(self, trial):
|
||||
self.trials.append(trial)
|
||||
self._scheduler_alg.on_trial_add(self, trial)
|
||||
|
||||
def get_trials(self):
|
||||
return self.trials
|
||||
|
||||
def has_resources(self, resources):
|
||||
return True
|
||||
|
||||
def _pause_trial(self, trial):
|
||||
trial.status = Trial.PAUSED
|
||||
|
||||
def _launch_trial(self, trial):
|
||||
trial.status = Trial.RUNNING
|
||||
|
||||
|
||||
class HyperbandSuite(unittest.TestCase):
|
||||
|
||||
def schedulerSetup(self, num_trials):
|
||||
"""Setup a scheduler and Runner with max Iter = 9
|
||||
|
||||
Bracketing is placed as follows:
|
||||
(5, 81);
|
||||
(8, 27) -> (3, 81);
|
||||
(15, 9) -> (5, 27) -> (2, 81);
|
||||
(34, 3) -> (12, 9) -> (4, 27) -> (2, 81);
|
||||
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 81);"""
|
||||
sched = HyperBandScheduler()
|
||||
for i in range(num_trials):
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
runner = _MockTrialRunner(sched)
|
||||
return sched, runner
|
||||
|
||||
def default_statistics(self):
|
||||
"""Default statistics for HyperBand"""
|
||||
sched = HyperBandScheduler()
|
||||
res = {
|
||||
str(s): {"n": sched._get_n0(s), "r": sched._get_r0(s)}
|
||||
for s in range(sched._s_max_1)
|
||||
}
|
||||
res["max_trials"] = sum(v["n"] for v in res.values())
|
||||
res["brack_count"] = sched._s_max_1
|
||||
res["s_max"] = sched._s_max_1 - 1
|
||||
return res
|
||||
|
||||
def downscale(self, n, sched):
|
||||
return int(np.ceil(n / sched._eta))
|
||||
|
||||
def basicSetup(self):
|
||||
"""Setup and verify full band.
|
||||
"""
|
||||
stats = self.default_statistics()
|
||||
sched, _ = self.schedulerSetup(stats["max_trials"])
|
||||
|
||||
self.assertEqual(len(sched._hyperbands), 1)
|
||||
self.assertEqual(sched._cur_band_filled(), True)
|
||||
|
||||
filled_band = sched._hyperbands[0]
|
||||
for bracket in filled_band:
|
||||
self.assertEqual(bracket.filled(), True)
|
||||
return sched
|
||||
|
||||
def advancedSetup(self):
|
||||
sched = self.basicSetup()
|
||||
for i in range(4):
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
|
||||
self.assertEqual(sched._cur_band_filled(), False)
|
||||
|
||||
unfilled_band = sched._hyperbands[-1]
|
||||
self.assertEqual(len(unfilled_band), 2)
|
||||
bracket = unfilled_band[-1]
|
||||
self.assertEqual(bracket.filled(), False)
|
||||
self.assertEqual(len(bracket.current_trials()), 7)
|
||||
|
||||
return sched
|
||||
|
||||
def testConfigSameEta(self):
|
||||
sched = HyperBandScheduler()
|
||||
i = 0
|
||||
while not sched._cur_band_filled():
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
i += 1
|
||||
self.assertEqual(len(sched._hyperbands[0]), 5)
|
||||
self.assertEqual(sched._hyperbands[0][0]._n, 5)
|
||||
self.assertEqual(sched._hyperbands[0][0]._r, 81)
|
||||
self.assertEqual(sched._hyperbands[0][-1]._n, 81)
|
||||
self.assertEqual(sched._hyperbands[0][-1]._r, 1)
|
||||
|
||||
sched = HyperBandScheduler(max_t=810)
|
||||
i = 0
|
||||
while not sched._cur_band_filled():
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
i += 1
|
||||
self.assertEqual(len(sched._hyperbands[0]), 5)
|
||||
self.assertEqual(sched._hyperbands[0][0]._n, 5)
|
||||
self.assertEqual(sched._hyperbands[0][0]._r, 810)
|
||||
self.assertEqual(sched._hyperbands[0][-1]._n, 81)
|
||||
self.assertEqual(sched._hyperbands[0][-1]._r, 10)
|
||||
|
||||
def testConfigSameEtaSmall(self):
|
||||
sched = HyperBandScheduler(max_t=1)
|
||||
i = 0
|
||||
while len(sched._hyperbands) < 2:
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
i += 1
|
||||
self.assertEqual(len(sched._hyperbands[0]), 5)
|
||||
self.assertTrue(all(v is None for v in sched._hyperbands[0][1:]))
|
||||
|
||||
def testSuccessiveHalving(self):
|
||||
"""Setup full band, then iterate through last bracket (n=81)
|
||||
to make sure successive halving is correct."""
|
||||
stats = self.default_statistics()
|
||||
sched, mock_runner = self.schedulerSetup(stats["max_trials"])
|
||||
big_bracket = sched._state["bracket"]
|
||||
cur_units = stats[str(stats["s_max"])]["r"]
|
||||
# The last bracket will downscale 4 times
|
||||
for x in range(stats["brack_count"] - 1):
|
||||
trials = big_bracket.current_trials()
|
||||
current_length = len(trials)
|
||||
for trl in trials:
|
||||
mock_runner._launch_trial(trl)
|
||||
|
||||
# Provides results from 0 to 8 in order, keeping last one running
|
||||
for i, trl in enumerate(trials):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
if i < current_length - 1:
|
||||
self.assertEqual(action, TrialScheduler.PAUSE)
|
||||
mock_runner.process_action(trl, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
new_length = len(big_bracket.current_trials())
|
||||
self.assertEqual(new_length, self.downscale(current_length, sched))
|
||||
cur_units += int(cur_units * sched._eta)
|
||||
self.assertEqual(len(big_bracket.current_trials()), 1)
|
||||
|
||||
def testHalvingStop(self):
|
||||
stats = self.default_statistics()
|
||||
num_trials = stats[str(0)]["n"] + stats[str(1)]["n"]
|
||||
sched, mock_runner = self.schedulerSetup(num_trials)
|
||||
big_bracket = sched._state["bracket"]
|
||||
for trl in big_bracket.current_trials():
|
||||
mock_runner._launch_trial(trl)
|
||||
|
||||
# # Provides result in reverse order, killing the last one
|
||||
cur_units = stats[str(1)]["r"]
|
||||
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
mock_runner.process_action(trl, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.STOP)
|
||||
|
||||
def testContinueLastOne(self):
|
||||
stats = self.default_statistics()
|
||||
num_trials = stats[str(0)]["n"]
|
||||
sched, mock_runner = self.schedulerSetup(num_trials)
|
||||
big_bracket = sched._state["bracket"]
|
||||
for trl in big_bracket.current_trials():
|
||||
mock_runner._launch_trial(trl)
|
||||
|
||||
# # Provides result in reverse order, killing the last one
|
||||
cur_units = stats[str(0)]["r"]
|
||||
for i, trl in enumerate(big_bracket.current_trials()):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
mock_runner.process_action(trl, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
|
||||
for x in range(100):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units + x, 10))
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
|
||||
def testTrialErrored(self):
|
||||
"""If a trial errored, make sure successive halving still happens"""
|
||||
stats = self.default_statistics()
|
||||
trial_count = stats[str(0)]["n"] + 3
|
||||
sched, mock_runner = self.schedulerSetup(trial_count)
|
||||
t1, t2, t3 = sched._state["bracket"].current_trials()
|
||||
for t in [t1, t2, t3]:
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
sched.on_trial_error(mock_runner, t3)
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t1, result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t2, result(stats[str(1)]["r"], 10)))
|
||||
|
||||
def testTrialErrored2(self):
|
||||
"""Check successive halving happened even when last trial failed"""
|
||||
stats = self.default_statistics()
|
||||
trial_count = stats[str(0)]["n"] + stats[str(1)]["n"]
|
||||
sched, mock_runner = self.schedulerSetup(trial_count)
|
||||
trials = sched._state["bracket"].current_trials()
|
||||
for t in trials[:-1]:
|
||||
mock_runner._launch_trial(t)
|
||||
sched.on_trial_result(
|
||||
mock_runner, t, result(stats[str(1)]["r"], 10))
|
||||
|
||||
mock_runner._launch_trial(trials[-1])
|
||||
sched.on_trial_error(mock_runner, trials[-1])
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()),
|
||||
self.downscale(stats[str(1)]["n"], sched))
|
||||
|
||||
def testTrialEndedEarly(self):
|
||||
"""Check successive halving happened even when one trial failed"""
|
||||
stats = self.default_statistics()
|
||||
trial_count = stats[str(0)]["n"] + 3
|
||||
sched, mock_runner = self.schedulerSetup(trial_count)
|
||||
|
||||
t1, t2, t3 = sched._state["bracket"].current_trials()
|
||||
for t in [t1, t2, t3]:
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
sched.on_trial_complete(mock_runner, t3, result(1, 12))
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t1, result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t2, result(stats[str(1)]["r"], 10)))
|
||||
|
||||
def testTrialEndedEarly2(self):
|
||||
"""Check successive halving happened even when last trial failed"""
|
||||
stats = self.default_statistics()
|
||||
trial_count = stats[str(0)]["n"] + stats[str(1)]["n"]
|
||||
sched, mock_runner = self.schedulerSetup(trial_count)
|
||||
trials = sched._state["bracket"].current_trials()
|
||||
for t in trials[:-1]:
|
||||
mock_runner._launch_trial(t)
|
||||
sched.on_trial_result(
|
||||
mock_runner, t, result(stats[str(1)]["r"], 10))
|
||||
|
||||
mock_runner._launch_trial(trials[-1])
|
||||
sched.on_trial_complete(mock_runner, trials[-1], result(100, 12))
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()),
|
||||
self.downscale(stats[str(1)]["n"], sched))
|
||||
|
||||
def testAddAfterHalving(self):
|
||||
stats = self.default_statistics()
|
||||
trial_count = stats[str(0)]["n"] + 1
|
||||
sched, mock_runner = self.schedulerSetup(trial_count)
|
||||
bracket_trials = sched._state["bracket"].current_trials()
|
||||
init_units = stats[str(1)]["r"]
|
||||
|
||||
for t in bracket_trials:
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
for i, t in enumerate(bracket_trials):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, t, result(init_units, i))
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
mock_runner._launch_trial(t)
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()), 2)
|
||||
|
||||
# Make sure that newly added trial gets fair computation (not just 1)
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, t, result(init_units, 12)))
|
||||
new_units = init_units + int(init_units * sched._eta)
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(mock_runner, t, result(new_units, 12)))
|
||||
|
||||
def testAlternateMetrics(self):
|
||||
"""Checking that alternate metrics will pass."""
|
||||
|
||||
def result2(t, rew):
|
||||
return TrainingResult(time_total_s=t, neg_mean_loss=rew)
|
||||
|
||||
sched = HyperBandScheduler(
|
||||
time_attr='time_total_s', reward_attr='neg_mean_loss')
|
||||
stats = self.default_statistics()
|
||||
|
||||
for i in range(stats["max_trials"]):
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
runner = _MockTrialRunner(sched)
|
||||
|
||||
big_bracket = sched._hyperbands[0][-1]
|
||||
|
||||
for trl in big_bracket.current_trials():
|
||||
runner._launch_trial(trl)
|
||||
current_length = len(big_bracket.current_trials())
|
||||
|
||||
# Provides results from 0 to 8 in order, keeping the last one running
|
||||
for i, trl in enumerate(big_bracket.current_trials()):
|
||||
action = sched.on_trial_result(runner, trl, result2(1, i))
|
||||
runner.process_action(trl, action)
|
||||
|
||||
new_length = len(big_bracket.current_trials())
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
self.assertEqual(new_length, self.downscale(current_length, sched))
|
||||
|
||||
def testJumpingTime(self):
|
||||
sched, mock_runner = self.schedulerSetup(81)
|
||||
big_bracket = sched._hyperbands[0][-1]
|
||||
|
||||
for trl in big_bracket.current_trials():
|
||||
mock_runner._launch_trial(trl)
|
||||
|
||||
# Provides results from 0 to 8 in order, keeping the last one running
|
||||
main_trials = big_bracket.current_trials()[:-1]
|
||||
jump = big_bracket.current_trials()[-1]
|
||||
for i, trl in enumerate(main_trials):
|
||||
action = sched.on_trial_result(mock_runner, trl, result(1, i))
|
||||
mock_runner.process_action(trl, action)
|
||||
|
||||
action = sched.on_trial_result(mock_runner, jump, result(4, i))
|
||||
self.assertEqual(action, TrialScheduler.PAUSE)
|
||||
|
||||
current_length = len(big_bracket.current_trials())
|
||||
self.assertLess(current_length, 27)
|
||||
|
||||
def testRemove(self):
|
||||
"""Test with 4: start 1, remove 1 pending, add 2, remove 1 pending"""
|
||||
sched, runner = self.schedulerSetup(4)
|
||||
trials = sorted(list(sched._trial_info), key=lambda t: t.trial_id)
|
||||
runner._launch_trial(trials[0])
|
||||
sched.on_trial_result(runner, trials[0], result(1, 5))
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||
|
||||
bracket, _ = sched._trial_info[trials[1]]
|
||||
self.assertTrue(trials[1] in bracket._live_trials)
|
||||
sched.on_trial_remove(runner, trials[1])
|
||||
self.assertFalse(trials[1] in bracket._live_trials)
|
||||
|
||||
for i in range(2):
|
||||
trial = Trial("__fake")
|
||||
sched.on_trial_add(None, trial)
|
||||
|
||||
bracket, _ = sched._trial_info[trial]
|
||||
self.assertTrue(trial in bracket._live_trials)
|
||||
sched.on_trial_remove(runner, trial) # where trial is not running
|
||||
self.assertFalse(trial in bracket._live_trials)
|
||||
|
||||
|
||||
class _MockTrial(Trial):
|
||||
def __init__(self, i, config):
|
||||
self.trainable_name = "trial_{}".format(i)
|
||||
self.config = config
|
||||
self.experiment_tag = "tag"
|
||||
self.logger_running = False
|
||||
self.restored_checkpoint = None
|
||||
self.resources = Resources(1, 0)
|
||||
|
||||
def checkpoint(self, to_object_store=False):
|
||||
return self.trainable_name
|
||||
|
||||
def start(self, checkpoint=None):
|
||||
self.logger_running = True
|
||||
self.restored_checkpoint = checkpoint
|
||||
|
||||
def stop(self, stop_logger=False):
|
||||
if stop_logger:
|
||||
self.logger_running = False
|
||||
|
||||
|
||||
class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
|
||||
def basicSetup(self, resample_prob=0.0, explore=None):
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
perturbation_interval=10,
|
||||
resample_probability=resample_prob,
|
||||
hyperparam_mutations={
|
||||
"id_factor": [100],
|
||||
"float_factor": lambda c: 100.0,
|
||||
"int_factor": lambda c: 10,
|
||||
},
|
||||
custom_explore_fn=explore)
|
||||
runner = _MockTrialRunner(pbt)
|
||||
for i in range(5):
|
||||
trial = _MockTrial(
|
||||
i,
|
||||
{"id_factor": i, "float_factor": 2.0, "const_factor": 3,
|
||||
"int_factor": 10})
|
||||
runner.add_trial(trial)
|
||||
trial.status = Trial.RUNNING
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
|
||||
TrialScheduler.CONTINUE)
|
||||
pbt.reset_stats()
|
||||
return pbt, runner
|
||||
|
||||
def testCheckpointsMostPromisingTrials(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
|
||||
# no checkpoint: haven't hit next perturbation interval yet
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(15, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 0)
|
||||
|
||||
# checkpoint: both past interval and upper quantile
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [200, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 1)
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[1], result(30, 201)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [200, 201, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 2)
|
||||
|
||||
# not upper quantile any more
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[4], result(30, 199)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(pbt._num_checkpoints, 2)
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
|
||||
def testPerturbsLowPerformingTrials(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
|
||||
# no perturbation: haven't hit next perturbation interval
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(15, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertTrue("@perturbed" not in trials[0].experiment_tag)
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
|
||||
# perturb since it's lower quantile
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [-100, 50, 100, 150, 200])
|
||||
self.assertTrue("@perturbed" in trials[0].experiment_tag)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertEqual(pbt._num_perturbations, 1)
|
||||
|
||||
# also perturbed
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[2], result(20, 40)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [-100, 50, 40, 150, 200])
|
||||
self.assertEqual(pbt._num_perturbations, 2)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertTrue("@perturbed" in trials[2].experiment_tag)
|
||||
|
||||
def testPerturbWithoutResample(self):
|
||||
pbt, runner = self.basicSetup(resample_prob=0.0)
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertIn(trials[0].config["id_factor"], [3, 4])
|
||||
self.assertIn(trials[0].config["float_factor"], [2.4, 1.6])
|
||||
self.assertEqual(type(trials[0].config["float_factor"]), float)
|
||||
self.assertIn(trials[0].config["int_factor"], [8, 12])
|
||||
self.assertEqual(type(trials[0].config["int_factor"]), int)
|
||||
self.assertEqual(trials[0].config["const_factor"], 3)
|
||||
|
||||
def testPerturbWithResample(self):
|
||||
pbt, runner = self.basicSetup(resample_prob=1.0)
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertEqual(trials[0].config["id_factor"], 100)
|
||||
self.assertEqual(trials[0].config["float_factor"], 100.0)
|
||||
self.assertEqual(type(trials[0].config["float_factor"]), float)
|
||||
self.assertEqual(trials[0].config["int_factor"], 10)
|
||||
self.assertEqual(type(trials[0].config["int_factor"]), int)
|
||||
self.assertEqual(trials[0].config["const_factor"], 3)
|
||||
|
||||
def testYieldsTimeToOtherTrials(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
trials[0].status = Trial.PENDING # simulate not enough resources
|
||||
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[1], result(20, 1000)),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 1000, 100, 150, 200])
|
||||
self.assertEqual(pbt.choose_trial_to_run(runner), trials[0])
|
||||
|
||||
def testSchedulesMostBehindTrialToRun(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
pbt.on_trial_result(runner, trials[0], result(800, 1000))
|
||||
pbt.on_trial_result(runner, trials[1], result(700, 1001))
|
||||
pbt.on_trial_result(runner, trials[2], result(600, 1002))
|
||||
pbt.on_trial_result(runner, trials[3], result(500, 1003))
|
||||
pbt.on_trial_result(runner, trials[4], result(700, 1004))
|
||||
self.assertEqual(pbt.choose_trial_to_run(runner), None)
|
||||
for i in range(5):
|
||||
trials[i].status = Trial.PENDING
|
||||
self.assertEqual(pbt.choose_trial_to_run(runner), trials[3])
|
||||
|
||||
def testPerturbationResetsLastPerturbTime(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
pbt.on_trial_result(runner, trials[0], result(10000, 1005))
|
||||
pbt.on_trial_result(runner, trials[1], result(10000, 1004))
|
||||
pbt.on_trial_result(runner, trials[2], result(600, 1003))
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
pbt.on_trial_result(runner, trials[3], result(500, 1002))
|
||||
self.assertEqual(pbt._num_perturbations, 1)
|
||||
pbt.on_trial_result(runner, trials[3], result(600, 100))
|
||||
self.assertEqual(pbt._num_perturbations, 1)
|
||||
pbt.on_trial_result(runner, trials[3], result(11000, 100))
|
||||
self.assertEqual(pbt._num_perturbations, 2)
|
||||
|
||||
def testPostprocessingHook(self):
|
||||
def explore(new_config):
|
||||
new_config["id_factor"] = 42
|
||||
new_config["float_factor"] = 43
|
||||
return new_config
|
||||
pbt, runner = self.basicSetup(resample_prob=0.0, explore=explore)
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(trials[0].config["id_factor"], 42)
|
||||
self.assertEqual(trials[0].config["float_factor"], 43)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
@@ -0,0 +1,103 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import socket
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.web_server import TuneClient
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
|
||||
|
||||
def get_valid_port():
|
||||
port = 4321
|
||||
while True:
|
||||
try:
|
||||
print("Trying port", port)
|
||||
port_test_socket = socket.socket()
|
||||
port_test_socket.bind(("127.0.0.1", port))
|
||||
port_test_socket.close()
|
||||
break
|
||||
except socket.error:
|
||||
port += 1
|
||||
return port
|
||||
|
||||
|
||||
class TuneServerSuite(unittest.TestCase):
|
||||
def basicSetup(self):
|
||||
ray.init(num_cpus=4, num_gpus=1)
|
||||
port = get_valid_port()
|
||||
self.runner = TrialRunner(
|
||||
launch_web_server=True, server_port=port)
|
||||
runner = self.runner
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 3},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
client = TuneClient("localhost:{}".format(port))
|
||||
return runner, client
|
||||
|
||||
def tearDown(self):
|
||||
print("Tearing down....")
|
||||
try:
|
||||
self.runner._server.shutdown()
|
||||
self.runner = None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
ray.worker.cleanup()
|
||||
_register_all()
|
||||
|
||||
def testAddTrial(self):
|
||||
runner, client = self.basicSetup()
|
||||
for i in range(3):
|
||||
runner.step()
|
||||
spec = {
|
||||
"run": "__fake",
|
||||
"stop": {"training_iteration": 3},
|
||||
"resources": dict(cpu=1, gpu=1),
|
||||
}
|
||||
client.add_trial("test", spec)
|
||||
runner.step()
|
||||
all_trials = client.get_all_trials()["trials"]
|
||||
runner.step()
|
||||
self.assertEqual(len(all_trials), 3)
|
||||
|
||||
def testGetTrials(self):
|
||||
runner, client = self.basicSetup()
|
||||
for i in range(3):
|
||||
runner.step()
|
||||
all_trials = client.get_all_trials()["trials"]
|
||||
self.assertEqual(len(all_trials), 2)
|
||||
tid = all_trials[0]["id"]
|
||||
client.get_trial(tid)
|
||||
runner.step()
|
||||
self.assertEqual(len(all_trials), 2)
|
||||
|
||||
def testStopTrial(self):
|
||||
"""Check if Stop Trial works"""
|
||||
runner, client = self.basicSetup()
|
||||
for i in range(2):
|
||||
runner.step()
|
||||
all_trials = client.get_all_trials()["trials"]
|
||||
self.assertEqual(
|
||||
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 1)
|
||||
|
||||
tid = [t for t in all_trials if t["status"] == Trial.RUNNING][0]["id"]
|
||||
client.stop_trial(tid)
|
||||
runner.step()
|
||||
|
||||
all_trials = client.get_all_trials()["trials"]
|
||||
self.assertEqual(
|
||||
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
@@ -135,7 +135,8 @@ class Trainable(object):
|
||||
time_total_s=self._time_total,
|
||||
neg_mean_loss=neg_loss,
|
||||
pid=os.getpid(),
|
||||
hostname=os.uname()[1])
|
||||
hostname=os.uname()[1],
|
||||
config=self.config)
|
||||
|
||||
self._result_logger.on_result(result)
|
||||
|
||||
@@ -185,8 +186,8 @@ class Trainable(object):
|
||||
"checkpoint_name": os.path.basename(checkpoint_prefix),
|
||||
"data": data,
|
||||
})
|
||||
print("Saving checkpoint to object store, {} bytes".format(
|
||||
len(compressed)))
|
||||
if len(compressed) > 10e6: # getting pretty large
|
||||
print("Checkpoint size is {} bytes".format(len(compressed)))
|
||||
f.write(compressed)
|
||||
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
+35
-16
@@ -96,7 +96,7 @@ class Trial(object):
|
||||
"Stopping condition key `{}` must be one of {}".format(
|
||||
k, TrainingResult._fields))
|
||||
|
||||
# Immutable config
|
||||
# Trial config
|
||||
self.trainable_name = trainable_name
|
||||
self.config = config or {}
|
||||
self.local_dir = local_dir
|
||||
@@ -105,6 +105,7 @@ class Trial(object):
|
||||
self.stopping_criterion = stopping_criterion or {}
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.upload_dir = upload_dir
|
||||
self.verbose = True
|
||||
|
||||
# Local trial state that is updated during the run
|
||||
self.last_result = None
|
||||
@@ -117,16 +118,22 @@ class Trial(object):
|
||||
self.result_logger = None
|
||||
self.last_debug = 0
|
||||
self.trial_id = binary_to_hex(random_string())[:8]
|
||||
self.error_file = None
|
||||
|
||||
def start(self):
|
||||
def start(self, checkpoint_obj=None):
|
||||
"""Starts this trial.
|
||||
|
||||
If an error is encountered when starting the trial, an exception will
|
||||
be thrown.
|
||||
|
||||
Args:
|
||||
checkpoint_obj (obj): Optional checkpoint to resume from.
|
||||
"""
|
||||
|
||||
self._setup_runner()
|
||||
if self._checkpoint_path:
|
||||
if checkpoint_obj:
|
||||
self.restore_from_obj(checkpoint_obj)
|
||||
elif self._checkpoint_path:
|
||||
self.restore_from_path(self._checkpoint_path)
|
||||
elif self._checkpoint_obj:
|
||||
self.restore_from_obj(self._checkpoint_obj)
|
||||
@@ -155,6 +162,7 @@ class Trial(object):
|
||||
self.logdir, "error_{}.txt".format(date_str()))
|
||||
with open(error_file, "w") as f:
|
||||
f.write(error_msg)
|
||||
self.error_file = error_file
|
||||
if self.runner:
|
||||
stop_tasks = []
|
||||
stop_tasks.append(self.runner.stop.remote())
|
||||
@@ -163,9 +171,6 @@ class Trial(object):
|
||||
# TODO(ekl) seems like wait hangs when killing actors
|
||||
_, unfinished = ray.wait(
|
||||
stop_tasks, num_returns=2, timeout=250)
|
||||
if unfinished:
|
||||
print(("Stopping %s Actor timed out, "
|
||||
"but moving on...") % self)
|
||||
except Exception:
|
||||
print("Error stopping runner:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
@@ -230,7 +235,7 @@ class Trial(object):
|
||||
"""Returns a progress message for printing out to the console."""
|
||||
|
||||
if self.last_result is None:
|
||||
return self.status
|
||||
return self._status_string()
|
||||
|
||||
def location_string(hostname, pid):
|
||||
if hostname == os.uname()[1]:
|
||||
@@ -240,7 +245,8 @@ class Trial(object):
|
||||
|
||||
pieces = [
|
||||
'{} [{}]'.format(
|
||||
self.status, location_string(
|
||||
self._status_string(),
|
||||
location_string(
|
||||
self.last_result.hostname, self.last_result.pid)),
|
||||
'{} s'.format(int(self.last_result.time_total_s)),
|
||||
'{} ts'.format(int(self.last_result.timesteps_total))]
|
||||
@@ -259,6 +265,11 @@ class Trial(object):
|
||||
|
||||
return ', '.join(pieces)
|
||||
|
||||
def _status_string(self):
|
||||
return "{}{}".format(
|
||||
self.status,
|
||||
" => {}".format(self.error_file) if self.error_file else "")
|
||||
|
||||
def checkpoint(self, to_object_store=False):
|
||||
"""Checkpoints the state of this trial.
|
||||
|
||||
@@ -276,7 +287,8 @@ class Trial(object):
|
||||
self._checkpoint_path = path
|
||||
self._checkpoint_obj = obj
|
||||
|
||||
print("Saved checkpoint to:", path or obj)
|
||||
if self.verbose:
|
||||
print("Saved checkpoint for {} to {}".format(self, path or obj))
|
||||
return path or obj
|
||||
|
||||
def restore_from_path(self, path):
|
||||
@@ -310,7 +322,9 @@ class Trial(object):
|
||||
def update_last_result(self, result, terminate=False):
|
||||
if terminate:
|
||||
result = result._replace(done=True)
|
||||
if terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL:
|
||||
if terminate or (
|
||||
self.verbose and
|
||||
time.time() - self.last_debug > DEBUG_PRINT_INTERVAL):
|
||||
print("TrainingResult for {}:".format(self))
|
||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||
self.last_debug = time.time()
|
||||
@@ -348,12 +362,17 @@ class Trial(object):
|
||||
config=self.config, registry=get_registry(),
|
||||
logger_creator=logger_creator)
|
||||
|
||||
def __str__(self):
|
||||
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``.
|
||||
def set_verbose(self, verbose):
|
||||
self.verbose = verbose
|
||||
|
||||
Truncates to MAX_LEN_IDENTIFIER (default is 130) to avoid problems
|
||||
when creating logging directories.
|
||||
"""
|
||||
def is_finished(self):
|
||||
return self.status in [Trial.TERMINATED, Trial.ERROR]
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __str__(self):
|
||||
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``."""
|
||||
if "env" in self.config:
|
||||
identifier = "{}_{}".format(
|
||||
self.trainable_name, self.config["env"])
|
||||
@@ -361,4 +380,4 @@ class Trial(object):
|
||||
identifier = self.trainable_name
|
||||
if self.experiment_tag:
|
||||
identifier += "_" + self.experiment_tag
|
||||
return identifier[:MAX_LEN_IDENTIFIER]
|
||||
return identifier
|
||||
|
||||
@@ -31,7 +31,7 @@ def _make_scheduler(args):
|
||||
|
||||
|
||||
def run_experiments(experiments, scheduler=None, with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT):
|
||||
server_port=TuneServer.DEFAULT_PORT, verbose=True):
|
||||
|
||||
# Make sure rllib agents are registered
|
||||
from ray import rllib # noqa # pylint: disable=unused-import
|
||||
@@ -44,6 +44,7 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
||||
|
||||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
trial.set_verbose(verbose)
|
||||
runner.add_trial(trial)
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user