[tune] clean up population based training prototype (#1478)

* patch up pbt

* Sat Jan 27 01:00:03 PST 2018

* Sat Jan 27 01:04:14 PST 2018

* Sat Jan 27 01:04:21 PST 2018

* Sat Jan 27 01:15:15 PST 2018

* Sat Jan 27 01:15:42 PST 2018

* Sat Jan 27 01:16:14 PST 2018

* Sat Jan 27 01:38:42 PST 2018

* Sat Jan 27 01:39:21 PST 2018

* add pbt

* Sat Jan 27 01:41:19 PST 2018

* Sat Jan 27 01:44:21 PST 2018

* Sat Jan 27 01:45:46 PST 2018

* Sat Jan 27 16:54:42 PST 2018

* Sat Jan 27 16:57:53 PST 2018

* clean up test

* Sat Jan 27 18:01:15 PST 2018

* Sat Jan 27 18:02:54 PST 2018

* Sat Jan 27 18:11:18 PST 2018

* Sat Jan 27 18:11:55 PST 2018

* Sat Jan 27 18:14:09 PST 2018

* review

* try out a ppo example

* some tweaks to ppo example

* add postprocess hook

* Sun Jan 28 15:00:40 PST 2018

* clean up custom explore fn

* Sun Jan 28 15:10:21 PST 2018

* Sun Jan 28 15:14:53 PST 2018

* Sun Jan 28 15:17:04 PST 2018

* Sun Jan 28 15:33:13 PST 2018

* Sun Jan 28 15:56:40 PST 2018

* Sun Jan 28 15:57:36 PST 2018

* Sun Jan 28 16:00:35 PST 2018

* Sun Jan 28 16:02:58 PST 2018

* Sun Jan 28 16:29:50 PST 2018

* Sun Jan 28 16:30:36 PST 2018

* Sun Jan 28 16:31:44 PST 2018

* improve tune doc

* concepts

* update humanoid

* Fri Feb  2 18:03:33 PST 2018

* fix example

* show error file
This commit is contained in:
Eric Liang
2018-02-02 23:03:12 -08:00
committed by GitHub
parent a936468f99
commit b948405532
22 changed files with 698 additions and 288 deletions
+3 -3
View File
@@ -76,13 +76,13 @@ class PPOEvaluator(Evaluator):
# Value function predictions before the policy update.
self.prev_vf_preds = tf.placeholder(tf.float32, shape=(None,))
assert config["sgd_batchsize"] % len(devices) == 0, \
"Batch size must be evenly divisible by devices"
if is_remote:
self.batch_size = config["rollout_batchsize"]
self.per_device_batch_size = config["rollout_batchsize"]
else:
self.batch_size = config["sgd_batchsize"]
self.batch_size = int(
config["sgd_batchsize"] / len(devices)) * len(devices)
assert self.batch_size % len(devices) == 0
self.per_device_batch_size = int(self.batch_size / len(devices))
def build_loss(obs, vtargets, advs, acts, plog, pvf_preds):
View File
@@ -4,6 +4,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
import random
@@ -49,6 +50,10 @@ class MyTrainableClass(Trainable):
register_trainable("my_class", MyTrainableClass)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
# Hyperband early stopping, configured with `episode_reward_mean` as the
@@ -60,7 +65,8 @@ if __name__ == "__main__":
run_experiments({
"hyperband_test": {
"run": "my_class",
"repeat": 100,
"stop": {"training_iteration": 1 if args.smoke_test else 99999},
"repeat": 20,
"resources": {"cpu": 1, "gpu": 0},
"config": {
"width": lambda spec: 10 + int(90 * random.random()),
+88
View File
@@ -0,0 +1,88 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
import random
import time
import ray
from ray.tune import Trainable, TrainingResult, register_trainable, \
run_experiments
from ray.tune.pbt import PopulationBasedTraining
class MyTrainableClass(Trainable):
"""Fake agent whose learning rate is determined by dummy factors."""
def _setup(self):
self.timestep = 0
self.current_value = 0.0
def _train(self):
time.sleep(0.1)
# Reward increase is parabolic as a function of factor_2, with a
# maxima around factor_1=10.0.
self.current_value += max(
0.0, random.gauss(5.0 - (self.config["factor_1"] - 10.0)**2, 2.0))
# Flat increase by factor_2
self.current_value += random.gauss(self.config["factor_2"], 1.0)
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy (see tune/result.py).
return TrainingResult(
episode_reward_mean=self.current_value, timesteps_this_iter=1)
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps(
{"timestep": self.timestep, "value": self.current_value}))
return path
def _restore(self, checkpoint_path):
with open(checkpoint_path) as f:
data = json.loads(f.read())
self.timestep = data["timestep"]
self.current_value = data["value"]
register_trainable("my_class", MyTrainableClass)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
pbt = PopulationBasedTraining(
time_attr="training_iteration", reward_attr="episode_reward_mean",
perturbation_interval=10,
hyperparam_mutations={
# Allow for scaling-based perturbations, with a uniform backing
# distribution for resampling.
"factor_1": lambda config: random.uniform(0.0, 20.0),
# Only allows resampling from this list as a perturbation.
"factor_2": [1, 2],
})
# Try to find the best factor 1 and factor 2
run_experiments({
"pbt_test": {
"run": "my_class",
"stop": {"training_iteration": 2 if args.smoke_test else 99999},
"repeat": 10,
"resources": {"cpu": 1, "gpu": 0},
"config": {
"factor_1": 4.0,
"factor_2": 1.0,
},
}
}, scheduler=pbt, verbose=False)
+71
View File
@@ -0,0 +1,71 @@
#!/usr/bin/env python
"""Example of using PBT with RLlib.
Note that this requires a cluster with at least 8 GPUs in order for all trials
to run concurrently, otherwise PBT will round-robin train the trials which
is less efficient (or you can set {"gpu": 0} to use CPUs for SGD instead).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import ray
from ray.tune import run_experiments
from ray.tune.pbt import PopulationBasedTraining
if __name__ == "__main__":
# Postprocess the perturbed config to ensure it's still valid
def explore(config):
# ensure we collect enough timesteps to do sgd
if config["timesteps_per_batch"] < config["sgd_batchsize"] * 2:
config["timesteps_per_batch"] = config["sgd_batchsize"] * 2
# ensure we run at least one sgd iter
if config["num_sgd_iter"] < 1:
config["num_sgd_iter"] = 1
return config
pbt = PopulationBasedTraining(
time_attr="time_total_s", reward_attr="episode_reward_mean",
perturbation_interval=120,
resample_probability=0.25,
# Specifies the resampling distributions of these hyperparams
hyperparam_mutations={
"lambda": lambda config: random.uniform(0.9, 1.0),
"clip_param": lambda config: random.uniform(0.01, 0.5),
"sgd_stepsize": lambda config: random.uniform(.00001, .001),
"num_sgd_iter": lambda config: random.randint(1, 30),
"sgd_batchsize": lambda config: random.randint(128, 16384),
"timesteps_per_batch":
lambda config: random.randint(2000, 160000),
},
custom_explore_fn=explore)
ray.init()
run_experiments({
"pbt_humanoid_test": {
"run": "PPO",
"env": "Humanoid-v1",
"repeat": 8,
"resources": {"cpu": 4, "gpu": 1},
"config": {
"kl_coeff": 1.0,
"num_workers": 8,
"devices": ["/gpu:0"],
"model": {"free_log_std": True},
# These params are tuned from their starting value
"lambda": 0.95,
"clip_param": 0.2,
# Start off with several random variations
"sgd_stepsize": lambda spec: random.uniform(.00001, .001),
"num_sgd_iter": lambda spec: random.choice([10, 20, 30]),
"sgd_batchsize": lambda spec: random.choice([128, 512, 2048]),
"timesteps_per_batch":
lambda spec: random.choice([10000, 20000, 40000])
},
},
}, scheduler=pbt)
+2 -2
View File
@@ -205,7 +205,7 @@ def train(config={'activation': 'relu'}, reporter=None):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--fast', action='store_true', help='Finish quickly for testing')
'--smoke-test', action='store_true', help='Finish quickly for testing')
args, _ = parser.parse_known_args()
register_trainable('train_mnist', train)
@@ -220,7 +220,7 @@ if __name__ == '__main__':
},
}
if args.fast:
if args.smoke_test:
mnist_spec['stop']['training_iteration'] = 2
ray.init()
+1 -1
View File
@@ -207,7 +207,7 @@ class HyperBandScheduler(FIFOScheduler):
"""Cleans up trial info from bracket if trial errored early."""
self.on_trial_remove(trial_runner, trial)
def choose_trial_to_run(self, trial_runner, *args):
def choose_trial_to_run(self, trial_runner):
"""Fair scheduling within iteration by completion percentage.
List of trials not used since all trials are tracked as state
-1
View File
@@ -63,7 +63,6 @@ class UnifiedLogger(Logger):
print("TF not installed - cannot log with {}...".format(cls))
continue
self._loggers.append(cls(self.config, self.logdir, self.uri))
print("Unified logger created with logdir '{}'".format(self.logdir))
def on_result(self, result):
for logger in self._loggers:
+1 -1
View File
@@ -31,7 +31,7 @@ class MedianStoppingRule(FIFOScheduler):
"""
def __init__(
self, time_attr='time_total_s', reward_attr='episode_reward_mean',
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
grace_period=60.0, min_samples_required=3, hard_stop=True):
FIFOScheduler.__init__(self)
self._stopped_trials = set()
+234 -154
View File
@@ -2,189 +2,269 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import random
import math
import copy
from ray.tune.error import TuneError
from ray.tune.trial import Trial
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.tune.variant_generator import _format_vars
# Parameters are transferred from the top PBT_QUANTILE fraction of trials to
# the bottom PBT_QUANTILE fraction.
PBT_QUANTILE = 0.25
class PBTTrialState(object):
"""Internal PBT state tracked per-trial."""
def __init__(self, trial):
self.orig_tag = trial.experiment_tag
self.last_score = None
self.last_checkpoint = None
self.last_perturbation_time = 0
def __repr__(self):
return str((
self.last_score, self.last_checkpoint,
self.last_perturbation_time))
def explore(config, mutations, resample_probability, custom_explore_fn):
"""Return a config perturbed as specified.
Args:
config (dict): Original hyperparameter configuration.
mutations (dict): Specification of mutations to perform as documented
in the PopulationBasedTraining scheduler.
resample_probability (float): Probability of allowing resampling of a
particular variable.
custom_explore_fn (func): Custom explore fn applied after built-in
config perturbations are.
"""
new_config = copy.deepcopy(config)
for key, distribution in mutations.items():
if isinstance(distribution, list):
if random.random() < resample_probability:
new_config[key] = random.choice(distribution)
else:
if random.random() < resample_probability:
new_config[key] = distribution(config)
elif random.random() > 0.5:
new_config[key] = config[key] * 1.2
else:
new_config[key] = config[key] * 0.8
if type(config[key]) is int:
new_config[key] = int(new_config[key])
if custom_explore_fn:
new_config = custom_explore_fn(new_config)
assert new_config is not None, \
"Custom explore fn failed to return new config"
print(
"[explore] perturbed config from {} -> {}".format(config, new_config))
return new_config
def make_experiment_tag(orig_tag, config, mutations):
"""Appends perturbed params to the trial name to show in the console."""
resolved_vars = {}
for k in mutations.keys():
resolved_vars[("config", k)] = config[k]
return "{}@perturbed[{}]".format(orig_tag, _format_vars(resolved_vars))
class PopulationBasedTraining(FIFOScheduler):
"""Implements the Population Based Training algorithm as described in the
PBT paper (https://arxiv.org/abs/1711.09846)(Experimental):
"""Implements the Population Based Training (PBT) algorithm.
https://deepmind.com/blog/population-based-training-neural-networks
PBT trains a group of models (or agents) in parallel. Periodically, poorly
performing models clone the state of the top performers, and a random
mutation is applied to their hyperparameters in the hopes of
outperforming the current top models.
Unlike other hyperparameter search algorithms, PBT mutates hyperparameters
during training time. This enables very fast hyperparameter discovery and
also automatically discovers good annealing schedules.
This Ray Tune PBT implementation considers all trials added as part of the
PBT population. If the number of trials exceeds the cluster capacity,
they will be time-multiplexed as to balance training progress across the
population.
Args:
time_attr (str): The TrainingResult attr to use for documenting length
of time since last ready() call. Attribute only has to increase
monotonically.
time_attr (str): The TrainingResult attr to use for comparing time.
Note that you can pass in something non-temporal such as
`training_iteration` as a measure of progress, the only requirement
is that the attribute should increase monotonically.
reward_attr (str): The TrainingResult objective value attribute. As
with 'time_attr'. this may refer to any objective value that
is supposed to increase with time.
grace_period (float): Period of time, in which algorithm will not
compare model to other models.
perturbation_interval (float): Used in the truncation ready function to
determine if enough time has passed so that a agent can be tested
for readiness.
hyperparameter_mutations (dict); Possible values that each
hyperparameter can mutate to, as certain hyperparameters
only work with certain values.
with `time_attr`, this may refer to any objective value. Stopping
procedures will use this attribute.
perturbation_interval (float): Models will be considered for
perturbation at this interval of `time_attr`. Note that
perturbation incurs checkpoint overhead, so you shouldn't set this
to be too frequent.
hyperparam_mutations (dict): Hyperparams to mutate. The format is
as follows: for each key, either a list or function can be
provided. A list specifies values for a discrete parameter.
A function specifies the distribution of a continuous parameter.
You must specify at least one of `hyperparam_mutations` or
`custom_explore_fn`.
resample_probability (float): The probability of resampling from the
original distribution when applying `hyperparam_mutations`. If not
resampled, the value will be perturbed by a factor of 1.2 or 0.8
if continuous, or left unchanged if discrete.
custom_explore_fn (func): You can also specify a custom exploration
function. This function is invoked as `f(config)` after built-in
perturbations from `hyperparam_mutations` are applied, and should
return `config` updated as needed. You must specify at least one of
`hyperparam_mutations` or `custom_explore_fn`.
Example:
>>> pbt = PopulationBasedTraining(
>>> time_attr="training_iteration",
>>> reward_attr="episode_reward_mean",
>>> perturbation_interval=10, # every 10 `time_attr` units
>>> # (training_iterations in this case)
>>> hyperparam_mutations={
>>> # Allow for scaling-based perturbations, with a uniform
>>> # backing distribution for resampling.
>>> "factor_1": lambda config: random.uniform(0.0, 20.0),
>>> # Only allows resampling from this list as a perturbation.
>>> "factor_2": [1, 2],
>>> })
>>> run_experiments({...}, scheduler=pbt)
"""
def __init__(
self, time_attr='training_iteration',
reward_attr='episode_reward_mean',
grace_period=10.0, perturbation_interval=6.0,
hyperparameter_mutations=None):
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
perturbation_interval=60.0, hyperparam_mutations={},
resample_probability=0.25, custom_explore_fn=None):
if not hyperparam_mutations and not custom_explore_fn:
raise TuneError(
"You must specify at least one of `hyperparam_mutations` or "
"`custom_explore_fn` to use PBT.")
FIFOScheduler.__init__(self)
self._completed_trials = set()
self._results = collections.defaultdict(list)
self._last_perturbation_time = {}
self._grace_period = grace_period
self._reward_attr = reward_attr
self._time_attr = time_attr
self._hyperparameter_mutations = hyperparameter_mutations
self._perturbation_interval = perturbation_interval
self._checkpoint_paths = {}
self._hyperparam_mutations = hyperparam_mutations
self._resample_probability = resample_probability
self._trial_state = {}
self._custom_explore_fn = custom_explore_fn
# Metrics
self._num_checkpoints = 0
self._num_perturbations = 0
def on_trial_add(self, trial_runner, trial):
self._trial_state[trial] = PBTTrialState(trial)
def on_trial_result(self, trial_runner, trial, result):
self._results[trial].append(result)
time = getattr(result, self._time_attr)
# check model is ready to undergo mutation, based on user
# function or default function
self._checkpoint_paths[trial] = trial.checkpoint()
if time > self._grace_period:
ready = self._truncation_ready(result, trial, time)
else:
ready = False
if ready:
print("ready to undergo mutation")
print("----")
print("Current Trial is: {0}".format(trial))
# get best trial for current time
best_trial = self._get_best_trial(result, time)
print("Best Trial is: {0}".format(best_trial))
print(best_trial.config)
state = self._trial_state[trial]
if time - state.last_perturbation_time < self._perturbation_interval:
return TrialScheduler.CONTINUE # avoid checkpoint overhead
score = getattr(result, self._reward_attr)
state.last_score = score
state.last_perturbation_time = time
lower_quantile, upper_quantile = self._quantiles()
if trial in upper_quantile:
state.last_checkpoint = trial.checkpoint(to_object_store=True)
self._num_checkpoints += 1
else:
state.last_checkpoint = None # not a top trial
if trial in lower_quantile:
trial_to_clone = random.choice(upper_quantile)
assert trial is not trial_to_clone
self._exploit(trial, trial_to_clone)
for trial in trial_runner.get_trials():
if trial.status in [Trial.PENDING, Trial.PAUSED]:
return TrialScheduler.PAUSE # yield time to other trials
# if current trial is the best trial (as in same hyperparameters),
# do nothing
if trial.config == best_trial.config:
print("current trial is best trial")
return TrialScheduler.CONTINUE
else:
self._exploit(self._hyperparameter_mutations, best_trial,
trial, trial_runner, time)
return TrialScheduler.CONTINUE
return TrialScheduler.CONTINUE
def on_trial_complete(self, trial_runner, trial, result):
self._results[trial].append(result)
self._completed_trials.add(trial)
def _exploit(self, trial, trial_to_clone):
"""Transfers perturbed state from trial_to_clone -> trial."""
def _exploit(self, hyperparameter_mutations, best_trial,
trial, trial_runner, time):
trial.stop()
mutate_string = "_mutated@" + str(time)
hyperparams = copy.deepcopy(best_trial.config)
hyperparams = self._explore(hyperparams, hyperparameter_mutations,
best_trial)
print("new hyperparameter configuration: {0}".format(hyperparams))
checkpoint = self._checkpoint_paths[best_trial]
trial._checkpoint_path = checkpoint
trial.config = hyperparams
trial.experiment_tag = trial.experiment_tag + mutate_string
trial.start()
trial_state = self._trial_state[trial]
new_state = self._trial_state[trial_to_clone]
if not new_state.last_checkpoint:
print("[pbt] warn: no checkpoint for trial, skip exploit", trial)
return
new_config = explore(
trial_to_clone.config, self._hyperparam_mutations,
self._resample_probability, self._custom_explore_fn)
print(
"[exploit] transferring weights from trial "
"{} (score {}) -> {} (score {})".format(
trial_to_clone, new_state.last_score, trial,
trial_state.last_score))
# TODO(ekl) restarting the trial is expensive. We should implement a
# lighter way reset() method that can alter the trial config.
trial.stop(stop_logger=False)
trial.config = new_config
trial.experiment_tag = make_experiment_tag(
trial_state.orig_tag, new_config, self._hyperparam_mutations)
trial.start(new_state.last_checkpoint)
self._num_perturbations += 1
# Transfer over the last perturbation time as well
trial_state.last_perturbation_time = new_state.last_perturbation_time
def _explore(self, hyperparams, hyperparameter_mutations, best_trial):
if hyperparameter_mutations is not None:
hyperparams = {
param: random.choice(hyperparameter_mutations[param])
for param in hyperparams
if param != "env" and param in hyperparameter_mutations
}
for param in best_trial.config:
if param not in hyperparameter_mutations and param != "env":
hyperparams[param] = math.ceil(
(best_trial.config[param]
* random.choice([0.8, 1.2])/2.)) * 2
def _quantiles(self):
"""Returns trials in the lower and upper `quantile` of the population.
If there is not enough data to compute this, returns empty lists."""
trials = []
for trial, state in self._trial_state.items():
if state.last_score is not None and not trial.is_finished():
trials.append(trial)
trials.sort(key=lambda t: self._trial_state[t].last_score)
if len(trials) <= 1:
return [], []
else:
hyperparams = {
param: math.ceil(
(random.choice([0.8, 1.2]) *
hyperparams[param])/2.) * 2
for param in hyperparams
if param != "env"
}
hyperparams["env"] = best_trial.config["env"]
return hyperparams
return (
trials[:int(math.ceil(len(trials)*PBT_QUANTILE))],
trials[int(math.floor(-len(trials)*PBT_QUANTILE)):])
def _truncation_ready(self, result, trial, time):
# function checks if appropriate time has passed
# and trial is in the bottom 20% of all trials, and if so, is ready
if trial not in self._last_perturbation_time:
print("added trial to time tracker")
self._last_perturbation_time[trial] = (time)
else:
time_since_last = time - self._last_perturbation_time[trial]
if time_since_last >= self._perturbation_interval:
self._last_perturbation_time[trial] = time
sorted_result_keys = sorted(
self._results, key=lambda x:
max(self._results.get(x) if self._results.get(x) else [0])
)
max_index = int(round(len(sorted_result_keys) * 0.2))
for i in range(0, max_index):
if trial == sorted_result_keys[i]:
print("{0} is in the bottomn 20 percent of {1}, \
truncation is ready".format(
trial,
[x.experiment_tag for x in sorted_result_keys]
))
return True
print("{0} is not in the bottomn 20 percent of {1}, \
truncation is not ready".format(
trial,
[x.experiment_tag for x in sorted_result_keys]
))
else:
print("not enough time has passed since last mutation")
return False
def choose_trial_to_run(self, trial_runner):
"""Ensures all trials get fair share of time (as defined by time_attr).
def _get_best_trial(self, result, time):
results_at_time = {}
for trial in self._results:
results_at_time[trial] = [
getattr(r, self._reward_attr)
for r in self._results[trial]
if getattr(r, self._time_attr) <= time
]
print("Results at {0}: {1}".format(time, results_at_time))
return max(results_at_time, key=lambda x:
max(results_at_time.get(x)
if results_at_time.get(x) else [0]))
This enables the PBT scheduler to support a greater number of
concurrent trials than can fit in the cluster at any given time.
"""
def _is_empty(self, x):
if x:
return False
return True
candidates = []
for trial in trial_runner.get_trials():
if trial.status in [Trial.PENDING, Trial.PAUSED] and \
trial_runner.has_resources(trial.resources):
candidates.append(trial)
candidates.sort(
key=lambda trial: self._trial_state[trial].last_perturbation_time)
return candidates[0] if candidates else None
def reset_stats(self):
self._num_perturbations = 0
self._num_checkpoints = 0
def last_scores(self, trials):
scores = []
for trial in trials:
state = self._trial_state[trial]
if state.last_score is not None and not trial.is_finished():
scores.append(state.last_score)
return scores
def debug_string(self):
min_time = 0
best_trial = None
for trial in self._completed_trials:
last_result = self._results[trial][-1]
if (getattr(last_result, self._time_attr)
< min_time or min_time == 0):
min_time = getattr(last_result, self._time_attr)
best_trial = trial
if best_trial is not None:
return ("The Best Trial is currently {0} finishing in {1} iterations, \
with the hyperparameters of {2}".format(
best_trial, min_time, best_trial.config
)
)
else:
return "PBT has started"
return "PopulationBasedTraining: {} checkpoints, {} perturbs".format(
self._num_checkpoints, self._num_perturbations)
+4
View File
@@ -84,10 +84,14 @@ TrainingResult = namedtuple("TrainingResult", [
# (Auto-filled) The hostname of the machine hosting the training process.
"hostname",
# (Auto=filled) The current hyperparameter configuration.
"config",
])
def pretty_print(result):
result = result._replace(config=None) # drop config from pretty print
out = {}
for k, v in result._asdict().items():
if v is not None:
+579
View File
@@ -0,0 +1,579 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import unittest
import ray
from ray.rllib import _register_all
from ray.tune import Trainable, TuneError
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.trial import Trial, Resources
from ray.tune.trial_runner import TrialRunner
from ray.tune.variant_generator import generate_trials, grid_search, \
RecursiveDependencyError
class TrainableFunctionApiTest(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.worker.cleanup()
_register_all() # re-register the evicted objects
def testRegisterEnv(self):
register_env("foo", lambda: None)
self.assertRaises(TypeError, lambda: register_env("foo", 2))
def testRegisterTrainable(self):
def train(config, reporter):
pass
class A(object):
pass
class B(Trainable):
pass
register_trainable("foo", train)
register_trainable("foo", B)
self.assertRaises(TypeError, lambda: register_trainable("foo", B()))
self.assertRaises(TypeError, lambda: register_trainable("foo", A))
def testRewriteEnv(self):
def train(config, reporter):
reporter(timesteps_total=1)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"env": "CartPole-v0",
}})
self.assertEqual(trial.config["env"], "CartPole-v0")
def testConfigPurity(self):
def train(config, reporter):
assert config == {"a": "b"}, config
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({"foo": {
"run": "f1",
"config": {"a": "b"},
}})
def testLogdir(self):
def train(config, reporter):
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({"foo": {
"run": "f1",
"local_dir": "/tmp/logdir",
"config": {"a": "b"},
}})
def testLongFilename(self):
def train(config, reporter):
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({"foo": {
"run": "f1",
"local_dir": "/tmp/logdir",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40},
}})
def testBadParams(self):
def f():
run_experiments({"foo": {}})
self.assertRaises(TuneError, f)
def testBadParams2(self):
def f():
run_experiments({"foo": {
"bah": "this param is not allowed",
}})
self.assertRaises(TuneError, f)
def testBadParams3(self):
def f():
run_experiments({"foo": {
"run": grid_search("invalid grid search"),
}})
self.assertRaises(TuneError, f)
def testBadParams4(self):
def f():
run_experiments({"foo": {
"run": "asdf",
}})
self.assertRaises(TuneError, f)
def testBadParams5(self):
def f():
run_experiments({"foo": {
"run": "PPO",
"stop": {"asdf": 1}
}})
self.assertRaises(TuneError, f)
def testBadParams6(self):
def f():
run_experiments({"foo": {
"run": "PPO",
"resources": {"asdf": 1}
}})
self.assertRaises(TuneError, f)
def testBadReturn(self):
def train(config, reporter):
reporter()
register_trainable("f1", train)
def f():
run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
self.assertRaises(TuneError, f)
def testEarlyReturn(self):
def train(config, reporter):
reporter(timesteps_total=100, done=True)
time.sleep(99999)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 100)
def testAbruptReturn(self):
def train(config, reporter):
reporter(timesteps_total=100)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 100)
def testErrorReturn(self):
def train(config, reporter):
raise Exception("uh oh")
register_trainable("f1", train)
def f():
run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
self.assertRaises(TuneError, f)
def testSuccess(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 99)
class VariantGeneratorTest(unittest.TestCase):
def testParseToTrials(self):
trials = generate_trials({
"run": "PPO",
"repeat": 2,
"config": {
"env": "Pong-v0",
"foo": "bar"
},
}, "tune-pong")
trials = list(trials)
self.assertEqual(len(trials), 2)
self.assertEqual(str(trials[0]), "PPO_Pong-v0_0")
self.assertEqual(trials[0].config, {"foo": "bar", "env": "Pong-v0"})
self.assertEqual(trials[0].trainable_name, "PPO")
self.assertEqual(trials[0].experiment_tag, "0")
self.assertEqual(
trials[0].local_dir,
os.path.join(DEFAULT_RESULTS_DIR, "tune-pong"))
self.assertEqual(trials[1].experiment_tag, "1")
def testEval(self):
trials = generate_trials({
"run": "PPO",
"config": {
"foo": {
"eval": "2 + 2"
},
},
})
trials = list(trials)
self.assertEqual(len(trials), 1)
self.assertEqual(trials[0].config, {"foo": 4})
self.assertEqual(trials[0].experiment_tag, "0_foo=4")
def testGridSearch(self):
trials = generate_trials({
"run": "PPO",
"config": {
"bar": {
"grid_search": [True, False]
},
"foo": {
"grid_search": [1, 2, 3]
},
},
})
trials = list(trials)
self.assertEqual(len(trials), 6)
self.assertEqual(trials[0].config, {"bar": True, "foo": 1})
self.assertEqual(trials[0].experiment_tag, "0_bar=True,foo=1")
self.assertEqual(trials[1].config, {"bar": False, "foo": 1})
self.assertEqual(trials[1].experiment_tag, "1_bar=False,foo=1")
self.assertEqual(trials[2].config, {"bar": True, "foo": 2})
self.assertEqual(trials[3].config, {"bar": False, "foo": 2})
self.assertEqual(trials[4].config, {"bar": True, "foo": 3})
self.assertEqual(trials[5].config, {"bar": False, "foo": 3})
def testGridSearchAndEval(self):
trials = generate_trials({
"run": "PPO",
"config": {
"qux": lambda spec: 2 + 2,
"bar": grid_search([True, False]),
"foo": grid_search([1, 2, 3]),
},
})
trials = list(trials)
self.assertEqual(len(trials), 6)
self.assertEqual(trials[0].config, {"bar": True, "foo": 1, "qux": 4})
self.assertEqual(trials[0].experiment_tag, "0_bar=True,foo=1,qux=4")
def testConditionResolution(self):
trials = generate_trials({
"run": "PPO",
"config": {
"x": 1,
"y": lambda spec: spec.config.x + 1,
"z": lambda spec: spec.config.y + 1,
},
})
trials = list(trials)
self.assertEqual(len(trials), 1)
self.assertEqual(trials[0].config, {"x": 1, "y": 2, "z": 3})
def testDependentLambda(self):
trials = generate_trials({
"run": "PPO",
"config": {
"x": grid_search([1, 2]),
"y": lambda spec: spec.config.x * 100,
},
})
trials = list(trials)
self.assertEqual(len(trials), 2)
self.assertEqual(trials[0].config, {"x": 1, "y": 100})
self.assertEqual(trials[1].config, {"x": 2, "y": 200})
def testDependentGridSearch(self):
trials = generate_trials({
"run": "PPO",
"config": {
"x": grid_search([
lambda spec: spec.config.y * 100,
lambda spec: spec.config.y * 200
]),
"y": lambda spec: 1,
},
})
trials = list(trials)
self.assertEqual(len(trials), 2)
self.assertEqual(trials[0].config, {"x": 100, "y": 1})
self.assertEqual(trials[1].config, {"x": 200, "y": 1})
def testRecursiveDep(self):
try:
list(generate_trials({
"run": "PPO",
"config": {
"foo": lambda spec: spec.config.foo,
},
}))
except RecursiveDependencyError as e:
assert "`foo` recursively depends on" in str(e), e
else:
assert False
class TrialRunnerTest(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
_register_all() # re-register the evicted objects
def testTrialStatus(self):
ray.init()
trial = Trial("__fake")
self.assertEqual(trial.status, Trial.PENDING)
trial.start()
self.assertEqual(trial.status, Trial.RUNNING)
trial.stop()
self.assertEqual(trial.status, Trial.TERMINATED)
trial.stop(error=True)
self.assertEqual(trial.status, Trial.ERROR)
def testExperimentTagTruncation(self):
ray.init()
def train(config, reporter):
reporter(timesteps_total=1)
register_trainable("f1", train)
experiments = {"foo": {
"run": "f1",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40},
}}
for name, spec in experiments.items():
for trial in generate_trials(spec, name):
trial.start()
self.assertLessEqual(len(trial.logdir), 200)
trial.stop()
def testTrialErrorOnStart(self):
ray.init()
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
trial = Trial("asdf")
try:
trial.start()
except Exception as e:
self.assertIn("a class", str(e))
def testResourceScheduler(self):
ray.init(num_cpus=4, num_gpus=1)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.PENDING)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.TERMINATED)
def testMultiStepRun(self):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 5},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.RUNNING)
def testErrorHandling(self):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"resources": Resources(cpu=1, gpu=1),
}
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
trials = [
Trial("asdf", **kwargs),
Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
runner.step()
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[1].status, Trial.PENDING)
runner.step()
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[1].status, Trial.RUNNING)
def testCheckpointing(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
path = trials[0].checkpoint()
kwargs["restore_path"] = path
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.PENDING)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1)
self.addCleanup(os.remove, path)
def testResultDone(self):
"""Tests that last_result is marked `done` after trial is complete."""
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 2},
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertNotEqual(trials[0].last_result.done, True)
runner.step()
self.assertEqual(trials[0].last_result.done, True)
def testPauseThenResume(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 2},
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.get_info.remote()), None)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
trials[0].pause()
self.assertEqual(trials[0].status, Trial.PAUSED)
trials[0].resume()
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.get_info.remote()), 1)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
def testStopTrial(self):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 5},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs),
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING)
# Stop trial while running
runner.stop_trial(trials[0])
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.PENDING)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
self.assertEqual(trials[-1].status, Trial.PENDING)
# Stop trial while pending
runner.stop_trial(trials[-1])
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
self.assertEqual(trials[-1].status, Trial.TERMINATED)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
self.assertEqual(trials[2].status, Trial.RUNNING)
self.assertEqual(trials[-1].status, Trial.TERMINATED)
if __name__ == "__main__":
unittest.main(verbosity=2)
@@ -0,0 +1,722 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import numpy as np
from ray.tune.hyperband import HyperBandScheduler
from ray.tune.pbt import PopulationBasedTraining
from ray.tune.median_stopping_rule import MedianStoppingRule
from ray.tune.result import TrainingResult
from ray.tune.trial import Trial, Resources
from ray.tune.trial_scheduler import TrialScheduler
from ray.rllib import _register_all
_register_all()
def result(t, rew):
return TrainingResult(time_total_s=t,
episode_reward_mean=rew,
training_iteration=int(t))
class EarlyStoppingSuite(unittest.TestCase):
def basicSetup(self, rule):
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
for i in range(10):
self.assertEqual(
rule.on_trial_result(None, t1, result(i, i * 100)),
TrialScheduler.CONTINUE)
for i in range(5):
self.assertEqual(
rule.on_trial_result(None, t2, result(i, 450)),
TrialScheduler.CONTINUE)
return t1, t2
def testMedianStoppingConstantPerf(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
self.assertEqual(
rule.on_trial_result(None, t2, result(5, 450)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t2, result(6, 0)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t2, result(10, 450)),
TrialScheduler.STOP)
def testMedianStoppingOnCompleteOnly(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
t1, t2 = self.basicSetup(rule)
self.assertEqual(
rule.on_trial_result(None, t2, result(100, 0)),
TrialScheduler.CONTINUE)
rule.on_trial_complete(None, t1, result(10, 1000))
self.assertEqual(
rule.on_trial_result(None, t2, result(101, 0)),
TrialScheduler.STOP)
def testMedianStoppingGracePeriod(self):
rule = MedianStoppingRule(grace_period=2.5, min_samples_required=1)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
rule.on_trial_complete(None, t2, result(10, 1000))
t3 = Trial("PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(1, 10)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t3, result(2, 10)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)),
TrialScheduler.STOP)
def testMedianStoppingMinSamples(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
t3 = Trial("PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)),
TrialScheduler.CONTINUE)
rule.on_trial_complete(None, t2, result(10, 1000))
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)),
TrialScheduler.STOP)
def testMedianStoppingUsesMedian(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
rule.on_trial_complete(None, t2, result(10, 1000))
t3 = Trial("PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(1, 260)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t3, result(2, 260)),
TrialScheduler.STOP)
def testMedianStoppingSoftStop(self):
rule = MedianStoppingRule(
grace_period=0, min_samples_required=1, hard_stop=False)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
rule.on_trial_complete(None, t2, result(10, 1000))
t3 = Trial("PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(1, 260)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t3, result(2, 260)),
TrialScheduler.PAUSE)
def testAlternateMetrics(self):
def result2(t, rew):
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
rule = MedianStoppingRule(
grace_period=0, min_samples_required=1,
time_attr='training_iteration', reward_attr='neg_mean_loss')
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
for i in range(10):
self.assertEqual(
rule.on_trial_result(None, t1, result2(i, i * 100)),
TrialScheduler.CONTINUE)
for i in range(5):
self.assertEqual(
rule.on_trial_result(None, t2, result2(i, 450)),
TrialScheduler.CONTINUE)
rule.on_trial_complete(None, t1, result2(10, 1000))
self.assertEqual(
rule.on_trial_result(None, t2, result2(5, 450)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t2, result2(6, 0)),
TrialScheduler.CONTINUE)
class _MockTrialRunner():
def __init__(self, scheduler):
self._scheduler_alg = scheduler
self.trials = []
def process_action(self, trial, action):
if action == TrialScheduler.CONTINUE:
pass
elif action == TrialScheduler.PAUSE:
self._pause_trial(trial)
elif action == TrialScheduler.STOP:
trial.stop()
def stop_trial(self, trial):
if trial.status in [Trial.ERROR, Trial.TERMINATED]:
return
elif trial.status in [Trial.PENDING, Trial.PAUSED]:
self._scheduler_alg.on_trial_remove(self, trial)
else:
self._scheduler_alg.on_trial_complete(self, trial, result(100, 10))
def add_trial(self, trial):
self.trials.append(trial)
self._scheduler_alg.on_trial_add(self, trial)
def get_trials(self):
return self.trials
def has_resources(self, resources):
return True
def _pause_trial(self, trial):
trial.status = Trial.PAUSED
def _launch_trial(self, trial):
trial.status = Trial.RUNNING
class HyperbandSuite(unittest.TestCase):
def schedulerSetup(self, num_trials):
"""Setup a scheduler and Runner with max Iter = 9
Bracketing is placed as follows:
(5, 81);
(8, 27) -> (3, 81);
(15, 9) -> (5, 27) -> (2, 81);
(34, 3) -> (12, 9) -> (4, 27) -> (2, 81);
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 81);"""
sched = HyperBandScheduler()
for i in range(num_trials):
t = Trial("__fake")
sched.on_trial_add(None, t)
runner = _MockTrialRunner(sched)
return sched, runner
def default_statistics(self):
"""Default statistics for HyperBand"""
sched = HyperBandScheduler()
res = {
str(s): {"n": sched._get_n0(s), "r": sched._get_r0(s)}
for s in range(sched._s_max_1)
}
res["max_trials"] = sum(v["n"] for v in res.values())
res["brack_count"] = sched._s_max_1
res["s_max"] = sched._s_max_1 - 1
return res
def downscale(self, n, sched):
return int(np.ceil(n / sched._eta))
def basicSetup(self):
"""Setup and verify full band.
"""
stats = self.default_statistics()
sched, _ = self.schedulerSetup(stats["max_trials"])
self.assertEqual(len(sched._hyperbands), 1)
self.assertEqual(sched._cur_band_filled(), True)
filled_band = sched._hyperbands[0]
for bracket in filled_band:
self.assertEqual(bracket.filled(), True)
return sched
def advancedSetup(self):
sched = self.basicSetup()
for i in range(4):
t = Trial("__fake")
sched.on_trial_add(None, t)
self.assertEqual(sched._cur_band_filled(), False)
unfilled_band = sched._hyperbands[-1]
self.assertEqual(len(unfilled_band), 2)
bracket = unfilled_band[-1]
self.assertEqual(bracket.filled(), False)
self.assertEqual(len(bracket.current_trials()), 7)
return sched
def testConfigSameEta(self):
sched = HyperBandScheduler()
i = 0
while not sched._cur_band_filled():
t = Trial("__fake")
sched.on_trial_add(None, t)
i += 1
self.assertEqual(len(sched._hyperbands[0]), 5)
self.assertEqual(sched._hyperbands[0][0]._n, 5)
self.assertEqual(sched._hyperbands[0][0]._r, 81)
self.assertEqual(sched._hyperbands[0][-1]._n, 81)
self.assertEqual(sched._hyperbands[0][-1]._r, 1)
sched = HyperBandScheduler(max_t=810)
i = 0
while not sched._cur_band_filled():
t = Trial("__fake")
sched.on_trial_add(None, t)
i += 1
self.assertEqual(len(sched._hyperbands[0]), 5)
self.assertEqual(sched._hyperbands[0][0]._n, 5)
self.assertEqual(sched._hyperbands[0][0]._r, 810)
self.assertEqual(sched._hyperbands[0][-1]._n, 81)
self.assertEqual(sched._hyperbands[0][-1]._r, 10)
def testConfigSameEtaSmall(self):
sched = HyperBandScheduler(max_t=1)
i = 0
while len(sched._hyperbands) < 2:
t = Trial("__fake")
sched.on_trial_add(None, t)
i += 1
self.assertEqual(len(sched._hyperbands[0]), 5)
self.assertTrue(all(v is None for v in sched._hyperbands[0][1:]))
def testSuccessiveHalving(self):
"""Setup full band, then iterate through last bracket (n=81)
to make sure successive halving is correct."""
stats = self.default_statistics()
sched, mock_runner = self.schedulerSetup(stats["max_trials"])
big_bracket = sched._state["bracket"]
cur_units = stats[str(stats["s_max"])]["r"]
# The last bracket will downscale 4 times
for x in range(stats["brack_count"] - 1):
trials = big_bracket.current_trials()
current_length = len(trials)
for trl in trials:
mock_runner._launch_trial(trl)
# Provides results from 0 to 8 in order, keeping last one running
for i, trl in enumerate(trials):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
if i < current_length - 1:
self.assertEqual(action, TrialScheduler.PAUSE)
mock_runner.process_action(trl, action)
self.assertEqual(action, TrialScheduler.CONTINUE)
new_length = len(big_bracket.current_trials())
self.assertEqual(new_length, self.downscale(current_length, sched))
cur_units += int(cur_units * sched._eta)
self.assertEqual(len(big_bracket.current_trials()), 1)
def testHalvingStop(self):
stats = self.default_statistics()
num_trials = stats[str(0)]["n"] + stats[str(1)]["n"]
sched, mock_runner = self.schedulerSetup(num_trials)
big_bracket = sched._state["bracket"]
for trl in big_bracket.current_trials():
mock_runner._launch_trial(trl)
# # Provides result in reverse order, killing the last one
cur_units = stats[str(1)]["r"]
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
mock_runner.process_action(trl, action)
self.assertEqual(action, TrialScheduler.STOP)
def testContinueLastOne(self):
stats = self.default_statistics()
num_trials = stats[str(0)]["n"]
sched, mock_runner = self.schedulerSetup(num_trials)
big_bracket = sched._state["bracket"]
for trl in big_bracket.current_trials():
mock_runner._launch_trial(trl)
# # Provides result in reverse order, killing the last one
cur_units = stats[str(0)]["r"]
for i, trl in enumerate(big_bracket.current_trials()):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
mock_runner.process_action(trl, action)
self.assertEqual(action, TrialScheduler.CONTINUE)
for x in range(100):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units + x, 10))
self.assertEqual(action, TrialScheduler.CONTINUE)
def testTrialErrored(self):
"""If a trial errored, make sure successive halving still happens"""
stats = self.default_statistics()
trial_count = stats[str(0)]["n"] + 3
sched, mock_runner = self.schedulerSetup(trial_count)
t1, t2, t3 = sched._state["bracket"].current_trials()
for t in [t1, t2, t3]:
mock_runner._launch_trial(t)
sched.on_trial_error(mock_runner, t3)
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(
mock_runner, t1, result(stats[str(1)]["r"], 10)))
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(
mock_runner, t2, result(stats[str(1)]["r"], 10)))
def testTrialErrored2(self):
"""Check successive halving happened even when last trial failed"""
stats = self.default_statistics()
trial_count = stats[str(0)]["n"] + stats[str(1)]["n"]
sched, mock_runner = self.schedulerSetup(trial_count)
trials = sched._state["bracket"].current_trials()
for t in trials[:-1]:
mock_runner._launch_trial(t)
sched.on_trial_result(
mock_runner, t, result(stats[str(1)]["r"], 10))
mock_runner._launch_trial(trials[-1])
sched.on_trial_error(mock_runner, trials[-1])
self.assertEqual(len(sched._state["bracket"].current_trials()),
self.downscale(stats[str(1)]["n"], sched))
def testTrialEndedEarly(self):
"""Check successive halving happened even when one trial failed"""
stats = self.default_statistics()
trial_count = stats[str(0)]["n"] + 3
sched, mock_runner = self.schedulerSetup(trial_count)
t1, t2, t3 = sched._state["bracket"].current_trials()
for t in [t1, t2, t3]:
mock_runner._launch_trial(t)
sched.on_trial_complete(mock_runner, t3, result(1, 12))
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(
mock_runner, t1, result(stats[str(1)]["r"], 10)))
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(
mock_runner, t2, result(stats[str(1)]["r"], 10)))
def testTrialEndedEarly2(self):
"""Check successive halving happened even when last trial failed"""
stats = self.default_statistics()
trial_count = stats[str(0)]["n"] + stats[str(1)]["n"]
sched, mock_runner = self.schedulerSetup(trial_count)
trials = sched._state["bracket"].current_trials()
for t in trials[:-1]:
mock_runner._launch_trial(t)
sched.on_trial_result(
mock_runner, t, result(stats[str(1)]["r"], 10))
mock_runner._launch_trial(trials[-1])
sched.on_trial_complete(mock_runner, trials[-1], result(100, 12))
self.assertEqual(len(sched._state["bracket"].current_trials()),
self.downscale(stats[str(1)]["n"], sched))
def testAddAfterHalving(self):
stats = self.default_statistics()
trial_count = stats[str(0)]["n"] + 1
sched, mock_runner = self.schedulerSetup(trial_count)
bracket_trials = sched._state["bracket"].current_trials()
init_units = stats[str(1)]["r"]
for t in bracket_trials:
mock_runner._launch_trial(t)
for i, t in enumerate(bracket_trials):
action = sched.on_trial_result(
mock_runner, t, result(init_units, i))
self.assertEqual(action, TrialScheduler.CONTINUE)
t = Trial("__fake")
sched.on_trial_add(None, t)
mock_runner._launch_trial(t)
self.assertEqual(len(sched._state["bracket"].current_trials()), 2)
# Make sure that newly added trial gets fair computation (not just 1)
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t, result(init_units, 12)))
new_units = init_units + int(init_units * sched._eta)
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t, result(new_units, 12)))
def testAlternateMetrics(self):
"""Checking that alternate metrics will pass."""
def result2(t, rew):
return TrainingResult(time_total_s=t, neg_mean_loss=rew)
sched = HyperBandScheduler(
time_attr='time_total_s', reward_attr='neg_mean_loss')
stats = self.default_statistics()
for i in range(stats["max_trials"]):
t = Trial("__fake")
sched.on_trial_add(None, t)
runner = _MockTrialRunner(sched)
big_bracket = sched._hyperbands[0][-1]
for trl in big_bracket.current_trials():
runner._launch_trial(trl)
current_length = len(big_bracket.current_trials())
# Provides results from 0 to 8 in order, keeping the last one running
for i, trl in enumerate(big_bracket.current_trials()):
action = sched.on_trial_result(runner, trl, result2(1, i))
runner.process_action(trl, action)
new_length = len(big_bracket.current_trials())
self.assertEqual(action, TrialScheduler.CONTINUE)
self.assertEqual(new_length, self.downscale(current_length, sched))
def testJumpingTime(self):
sched, mock_runner = self.schedulerSetup(81)
big_bracket = sched._hyperbands[0][-1]
for trl in big_bracket.current_trials():
mock_runner._launch_trial(trl)
# Provides results from 0 to 8 in order, keeping the last one running
main_trials = big_bracket.current_trials()[:-1]
jump = big_bracket.current_trials()[-1]
for i, trl in enumerate(main_trials):
action = sched.on_trial_result(mock_runner, trl, result(1, i))
mock_runner.process_action(trl, action)
action = sched.on_trial_result(mock_runner, jump, result(4, i))
self.assertEqual(action, TrialScheduler.PAUSE)
current_length = len(big_bracket.current_trials())
self.assertLess(current_length, 27)
def testRemove(self):
"""Test with 4: start 1, remove 1 pending, add 2, remove 1 pending"""
sched, runner = self.schedulerSetup(4)
trials = sorted(list(sched._trial_info), key=lambda t: t.trial_id)
runner._launch_trial(trials[0])
sched.on_trial_result(runner, trials[0], result(1, 5))
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[1].status, Trial.PENDING)
bracket, _ = sched._trial_info[trials[1]]
self.assertTrue(trials[1] in bracket._live_trials)
sched.on_trial_remove(runner, trials[1])
self.assertFalse(trials[1] in bracket._live_trials)
for i in range(2):
trial = Trial("__fake")
sched.on_trial_add(None, trial)
bracket, _ = sched._trial_info[trial]
self.assertTrue(trial in bracket._live_trials)
sched.on_trial_remove(runner, trial) # where trial is not running
self.assertFalse(trial in bracket._live_trials)
class _MockTrial(Trial):
def __init__(self, i, config):
self.trainable_name = "trial_{}".format(i)
self.config = config
self.experiment_tag = "tag"
self.logger_running = False
self.restored_checkpoint = None
self.resources = Resources(1, 0)
def checkpoint(self, to_object_store=False):
return self.trainable_name
def start(self, checkpoint=None):
self.logger_running = True
self.restored_checkpoint = checkpoint
def stop(self, stop_logger=False):
if stop_logger:
self.logger_running = False
class PopulationBasedTestingSuite(unittest.TestCase):
def basicSetup(self, resample_prob=0.0, explore=None):
pbt = PopulationBasedTraining(
time_attr="training_iteration",
perturbation_interval=10,
resample_probability=resample_prob,
hyperparam_mutations={
"id_factor": [100],
"float_factor": lambda c: 100.0,
"int_factor": lambda c: 10,
},
custom_explore_fn=explore)
runner = _MockTrialRunner(pbt)
for i in range(5):
trial = _MockTrial(
i,
{"id_factor": i, "float_factor": 2.0, "const_factor": 3,
"int_factor": 10})
runner.add_trial(trial)
trial.status = Trial.RUNNING
self.assertEqual(
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
TrialScheduler.CONTINUE)
pbt.reset_stats()
return pbt, runner
def testCheckpointsMostPromisingTrials(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
# no checkpoint: haven't hit next perturbation interval yet
self.assertEqual(
pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(15, 200)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 0)
# checkpoint: both past interval and upper quantile
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, 200)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [200, 50, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 1)
self.assertEqual(
pbt.on_trial_result(runner, trials[1], result(30, 201)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [200, 201, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 2)
# not upper quantile any more
self.assertEqual(
pbt.on_trial_result(runner, trials[4], result(30, 199)),
TrialScheduler.CONTINUE)
self.assertEqual(pbt._num_checkpoints, 2)
self.assertEqual(pbt._num_perturbations, 0)
def testPerturbsLowPerformingTrials(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
# no perturbation: haven't hit next perturbation interval
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(15, -100)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertTrue("@perturbed" not in trials[0].experiment_tag)
self.assertEqual(pbt._num_perturbations, 0)
# perturb since it's lower quantile
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [-100, 50, 100, 150, 200])
self.assertTrue("@perturbed" in trials[0].experiment_tag)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertEqual(pbt._num_perturbations, 1)
# also perturbed
self.assertEqual(
pbt.on_trial_result(runner, trials[2], result(20, 40)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [-100, 50, 40, 150, 200])
self.assertEqual(pbt._num_perturbations, 2)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertTrue("@perturbed" in trials[2].experiment_tag)
def testPerturbWithoutResample(self):
pbt, runner = self.basicSetup(resample_prob=0.0)
trials = runner.get_trials()
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertIn(trials[0].config["id_factor"], [3, 4])
self.assertIn(trials[0].config["float_factor"], [2.4, 1.6])
self.assertEqual(type(trials[0].config["float_factor"]), float)
self.assertIn(trials[0].config["int_factor"], [8, 12])
self.assertEqual(type(trials[0].config["int_factor"]), int)
self.assertEqual(trials[0].config["const_factor"], 3)
def testPerturbWithResample(self):
pbt, runner = self.basicSetup(resample_prob=1.0)
trials = runner.get_trials()
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertEqual(trials[0].config["id_factor"], 100)
self.assertEqual(trials[0].config["float_factor"], 100.0)
self.assertEqual(type(trials[0].config["float_factor"]), float)
self.assertEqual(trials[0].config["int_factor"], 10)
self.assertEqual(type(trials[0].config["int_factor"]), int)
self.assertEqual(trials[0].config["const_factor"], 3)
def testYieldsTimeToOtherTrials(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
trials[0].status = Trial.PENDING # simulate not enough resources
self.assertEqual(
pbt.on_trial_result(runner, trials[1], result(20, 1000)),
TrialScheduler.PAUSE)
self.assertEqual(
pbt.last_scores(trials), [0, 1000, 100, 150, 200])
self.assertEqual(pbt.choose_trial_to_run(runner), trials[0])
def testSchedulesMostBehindTrialToRun(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
pbt.on_trial_result(runner, trials[0], result(800, 1000))
pbt.on_trial_result(runner, trials[1], result(700, 1001))
pbt.on_trial_result(runner, trials[2], result(600, 1002))
pbt.on_trial_result(runner, trials[3], result(500, 1003))
pbt.on_trial_result(runner, trials[4], result(700, 1004))
self.assertEqual(pbt.choose_trial_to_run(runner), None)
for i in range(5):
trials[i].status = Trial.PENDING
self.assertEqual(pbt.choose_trial_to_run(runner), trials[3])
def testPerturbationResetsLastPerturbTime(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
pbt.on_trial_result(runner, trials[0], result(10000, 1005))
pbt.on_trial_result(runner, trials[1], result(10000, 1004))
pbt.on_trial_result(runner, trials[2], result(600, 1003))
self.assertEqual(pbt._num_perturbations, 0)
pbt.on_trial_result(runner, trials[3], result(500, 1002))
self.assertEqual(pbt._num_perturbations, 1)
pbt.on_trial_result(runner, trials[3], result(600, 100))
self.assertEqual(pbt._num_perturbations, 1)
pbt.on_trial_result(runner, trials[3], result(11000, 100))
self.assertEqual(pbt._num_perturbations, 2)
def testPostprocessingHook(self):
def explore(new_config):
new_config["id_factor"] = 42
new_config["float_factor"] = 43
return new_config
pbt, runner = self.basicSetup(resample_prob=0.0, explore=explore)
trials = runner.get_trials()
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.assertEqual(trials[0].config["id_factor"], 42)
self.assertEqual(trials[0].config["float_factor"], 43)
if __name__ == "__main__":
unittest.main(verbosity=2)
+103
View File
@@ -0,0 +1,103 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import socket
import ray
from ray.rllib import _register_all
from ray.tune.trial import Trial, Resources
from ray.tune.web_server import TuneClient
from ray.tune.trial_runner import TrialRunner
def get_valid_port():
port = 4321
while True:
try:
print("Trying port", port)
port_test_socket = socket.socket()
port_test_socket.bind(("127.0.0.1", port))
port_test_socket.close()
break
except socket.error:
port += 1
return port
class TuneServerSuite(unittest.TestCase):
def basicSetup(self):
ray.init(num_cpus=4, num_gpus=1)
port = get_valid_port()
self.runner = TrialRunner(
launch_web_server=True, server_port=port)
runner = self.runner
kwargs = {
"stopping_criterion": {"training_iteration": 3},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
client = TuneClient("localhost:{}".format(port))
return runner, client
def tearDown(self):
print("Tearing down....")
try:
self.runner._server.shutdown()
self.runner = None
except Exception as e:
print(e)
ray.worker.cleanup()
_register_all()
def testAddTrial(self):
runner, client = self.basicSetup()
for i in range(3):
runner.step()
spec = {
"run": "__fake",
"stop": {"training_iteration": 3},
"resources": dict(cpu=1, gpu=1),
}
client.add_trial("test", spec)
runner.step()
all_trials = client.get_all_trials()["trials"]
runner.step()
self.assertEqual(len(all_trials), 3)
def testGetTrials(self):
runner, client = self.basicSetup()
for i in range(3):
runner.step()
all_trials = client.get_all_trials()["trials"]
self.assertEqual(len(all_trials), 2)
tid = all_trials[0]["id"]
client.get_trial(tid)
runner.step()
self.assertEqual(len(all_trials), 2)
def testStopTrial(self):
"""Check if Stop Trial works"""
runner, client = self.basicSetup()
for i in range(2):
runner.step()
all_trials = client.get_all_trials()["trials"]
self.assertEqual(
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 1)
tid = [t for t in all_trials if t["status"] == Trial.RUNNING][0]["id"]
client.stop_trial(tid)
runner.step()
all_trials = client.get_all_trials()["trials"]
self.assertEqual(
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0)
if __name__ == "__main__":
unittest.main(verbosity=2)
+4 -3
View File
@@ -135,7 +135,8 @@ class Trainable(object):
time_total_s=self._time_total,
neg_mean_loss=neg_loss,
pid=os.getpid(),
hostname=os.uname()[1])
hostname=os.uname()[1],
config=self.config)
self._result_logger.on_result(result)
@@ -185,8 +186,8 @@ class Trainable(object):
"checkpoint_name": os.path.basename(checkpoint_prefix),
"data": data,
})
print("Saving checkpoint to object store, {} bytes".format(
len(compressed)))
if len(compressed) > 10e6: # getting pretty large
print("Checkpoint size is {} bytes".format(len(compressed)))
f.write(compressed)
shutil.rmtree(tmpdir)
+35 -16
View File
@@ -96,7 +96,7 @@ class Trial(object):
"Stopping condition key `{}` must be one of {}".format(
k, TrainingResult._fields))
# Immutable config
# Trial config
self.trainable_name = trainable_name
self.config = config or {}
self.local_dir = local_dir
@@ -105,6 +105,7 @@ class Trial(object):
self.stopping_criterion = stopping_criterion or {}
self.checkpoint_freq = checkpoint_freq
self.upload_dir = upload_dir
self.verbose = True
# Local trial state that is updated during the run
self.last_result = None
@@ -117,16 +118,22 @@ class Trial(object):
self.result_logger = None
self.last_debug = 0
self.trial_id = binary_to_hex(random_string())[:8]
self.error_file = None
def start(self):
def start(self, checkpoint_obj=None):
"""Starts this trial.
If an error is encountered when starting the trial, an exception will
be thrown.
Args:
checkpoint_obj (obj): Optional checkpoint to resume from.
"""
self._setup_runner()
if self._checkpoint_path:
if checkpoint_obj:
self.restore_from_obj(checkpoint_obj)
elif self._checkpoint_path:
self.restore_from_path(self._checkpoint_path)
elif self._checkpoint_obj:
self.restore_from_obj(self._checkpoint_obj)
@@ -155,6 +162,7 @@ class Trial(object):
self.logdir, "error_{}.txt".format(date_str()))
with open(error_file, "w") as f:
f.write(error_msg)
self.error_file = error_file
if self.runner:
stop_tasks = []
stop_tasks.append(self.runner.stop.remote())
@@ -163,9 +171,6 @@ class Trial(object):
# TODO(ekl) seems like wait hangs when killing actors
_, unfinished = ray.wait(
stop_tasks, num_returns=2, timeout=250)
if unfinished:
print(("Stopping %s Actor timed out, "
"but moving on...") % self)
except Exception:
print("Error stopping runner:", traceback.format_exc())
self.status = Trial.ERROR
@@ -230,7 +235,7 @@ class Trial(object):
"""Returns a progress message for printing out to the console."""
if self.last_result is None:
return self.status
return self._status_string()
def location_string(hostname, pid):
if hostname == os.uname()[1]:
@@ -240,7 +245,8 @@ class Trial(object):
pieces = [
'{} [{}]'.format(
self.status, location_string(
self._status_string(),
location_string(
self.last_result.hostname, self.last_result.pid)),
'{} s'.format(int(self.last_result.time_total_s)),
'{} ts'.format(int(self.last_result.timesteps_total))]
@@ -259,6 +265,11 @@ class Trial(object):
return ', '.join(pieces)
def _status_string(self):
return "{}{}".format(
self.status,
" => {}".format(self.error_file) if self.error_file else "")
def checkpoint(self, to_object_store=False):
"""Checkpoints the state of this trial.
@@ -276,7 +287,8 @@ class Trial(object):
self._checkpoint_path = path
self._checkpoint_obj = obj
print("Saved checkpoint to:", path or obj)
if self.verbose:
print("Saved checkpoint for {} to {}".format(self, path or obj))
return path or obj
def restore_from_path(self, path):
@@ -310,7 +322,9 @@ class Trial(object):
def update_last_result(self, result, terminate=False):
if terminate:
result = result._replace(done=True)
if terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL:
if terminate or (
self.verbose and
time.time() - self.last_debug > DEBUG_PRINT_INTERVAL):
print("TrainingResult for {}:".format(self))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
self.last_debug = time.time()
@@ -348,12 +362,17 @@ class Trial(object):
config=self.config, registry=get_registry(),
logger_creator=logger_creator)
def __str__(self):
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``.
def set_verbose(self, verbose):
self.verbose = verbose
Truncates to MAX_LEN_IDENTIFIER (default is 130) to avoid problems
when creating logging directories.
"""
def is_finished(self):
return self.status in [Trial.TERMINATED, Trial.ERROR]
def __repr__(self):
return str(self)
def __str__(self):
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``."""
if "env" in self.config:
identifier = "{}_{}".format(
self.trainable_name, self.config["env"])
@@ -361,4 +380,4 @@ class Trial(object):
identifier = self.trainable_name
if self.experiment_tag:
identifier += "_" + self.experiment_tag
return identifier[:MAX_LEN_IDENTIFIER]
return identifier
+2 -1
View File
@@ -31,7 +31,7 @@ def _make_scheduler(args):
def run_experiments(experiments, scheduler=None, with_server=False,
server_port=TuneServer.DEFAULT_PORT):
server_port=TuneServer.DEFAULT_PORT, verbose=True):
# Make sure rllib agents are registered
from ray import rllib # noqa # pylint: disable=unused-import
@@ -44,6 +44,7 @@ def run_experiments(experiments, scheduler=None, with_server=False,
for name, spec in experiments.items():
for trial in generate_trials(spec, name):
trial.set_verbose(verbose)
runner.add_trial(trial)
print(runner.debug_string(max_debug=99999))