[tune] Tune Documentation and expose better API (#1681)

This commit is contained in:
Richard Liaw
2018-03-19 12:55:10 -07:00
committed by GitHub
parent 7b493aa4a1
commit 23954e7ce2
9 changed files with 315 additions and 89 deletions
+2 -1
View File
@@ -3,7 +3,7 @@ from __future__ import division
from __future__ import print_function
from ray.tune.error import TuneError
from ray.tune.tune import run_experiments
from ray.tune.tune import run_experiments, Experiment
from ray.tune.registry import register_env, register_trainable
from ray.tune.result import TrainingResult
from ray.tune.trainable import Trainable
@@ -18,4 +18,5 @@ __all__ = [
"register_env",
"register_trainable",
"run_experiments",
"Experiment"
]
+12 -13
View File
@@ -13,7 +13,7 @@ import numpy as np
import ray
from ray.tune import Trainable, TrainingResult, register_trainable, \
run_experiments
run_experiments, Experiment
from ray.tune.hyperband import HyperBandScheduler
@@ -62,15 +62,14 @@ if __name__ == "__main__":
time_attr="timesteps_total", reward_attr="episode_reward_mean",
max_t=100)
run_experiments({
"hyperband_test": {
"run": "my_class",
"stop": {"training_iteration": 1 if args.smoke_test else 99999},
"repeat": 20,
"resources": {"cpu": 1, "gpu": 0},
"config": {
"width": lambda spec: 10 + int(90 * random.random()),
"height": lambda spec: int(100 * random.random()),
},
}
}, scheduler=hyperband)
exp = Experiment(
name="hyperband_test",
run="my_class",
repeat=20,
stop={"training_iteration": 1 if args.smoke_test else 99999},
config={
"width": lambda spec: 10 + int(90 * random.random()),
"height": lambda spec: int(100 * random.random())
})
run_experiments(exp, scheduler=hyperband)
+56
View File
@@ -0,0 +1,56 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.tune.variant_generator import generate_trials
from ray.tune.result import DEFAULT_RESULTS_DIR
class Experiment(object):
"""Tracks experiment specifications.
Parameters:
name (str): Name of experiment.
run (str): The algorithm or model to train. This may refer to the
name of a built-on algorithm (e.g. RLLib's DQN or PPO), or a
user-defined trainable function or class
registered in the tune registry.
stop (dict): The stopping criteria. The keys may be any field in
TrainingResult, whichever is reached first. Defaults to
empty dict.
config (dict): Algorithm-specific configuration
(e.g. env, hyperparams). Defaults to empty dict.
resources (dict): Machine resources to allocate per trial,
e.g. ``{"cpu": 64, "gpu": 8}``. Note that GPUs will not be
assigned unless you specify them here. Defaults to 1 CPU and 0
GPUs.
repeat (int): Number of times to repeat each trial. Defaults to 1.
local_dir (str): Local dir to save training results to.
Defaults to ``~/ray_results``.
upload_dir (str): Optional URI to sync training results
to (e.g. ``s3://bucket``).
checkpoint_freq (int): How many training iterations between
checkpoints. A value of 0 (default) disables checkpointing.
max_failures (int): Try to recover a trial from its last
checkpoint at least this many times. Only applies if
checkpointing is enabled. Defaults to 3.
"""
def __init__(self, name, run, stop=None, config=None,
resources=None, repeat=1, local_dir=None,
upload_dir="", checkpoint_freq=0, max_failures=3):
spec = {
"run": run,
"stop": stop or {},
"config": config or {},
"resources": resources or {"cpu": 1, "gpu": 0},
"repeat": repeat,
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
"upload_dir": upload_dir,
"checkpoint_freq": checkpoint_freq,
"max_failures": max_failures
}
self._trials = generate_trials(spec, name)
def trials(self):
for trial in self._trials:
yield trial
+74
View File
@@ -13,6 +13,7 @@ from ray.tune import Trainable, TuneError
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial, Resources
from ray.tune.trial_runner import TrialRunner
from ray.tune.variant_generator import generate_trials, grid_search, \
@@ -203,6 +204,79 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertEqual(trial.last_result.timesteps_total, 99)
class RunExperimentTest(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.worker.cleanup()
_register_all() # re-register the evicted objects
def testDict(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
trials = run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
},
"bar": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
}
})
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 99)
def testExperiment(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
exp1 = Experiment(**{
"name": "foo",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
})
[trial] = run_experiments(exp1)
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 99)
def testExperimentList(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
exp1 = Experiment(**{
"name": "foo",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
})
exp2 = Experiment(**{
"name": "bar",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
})
trials = run_experiments([exp1, exp2])
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 99)
class VariantGeneratorTest(unittest.TestCase):
def testParseToTrials(self):
trials = generate_trials({
+27 -2
View File
@@ -14,6 +14,7 @@ from ray.tune.trial_runner import TrialRunner
from ray.tune.trial_scheduler import FIFOScheduler
from ray.tune.web_server import TuneServer
from ray.tune.variant_generator import generate_trials
from ray.tune.experiment import Experiment
_SCHEDULERS = {
@@ -35,6 +36,18 @@ def _make_scheduler(args):
def run_experiments(experiments, scheduler=None, with_server=False,
server_port=TuneServer.DEFAULT_PORT, verbose=True):
"""Tunes experiments.
Args:
experiments (Experiment | list | dict): Experiments to run.
scheduler (TrialScheduler): Scheduler for executing
the experiment. Choose among FIFO (default), MedianStopping,
AsyncHyperBand, or HyperBand.
with_server (bool): Starts a background Tune server. Needed for
using the Client API.
server_port (int): Port number for launching TuneServer.
verbose (bool): How much output should be printed for each trial.
"""
# Make sure rllib agents are registered
from ray import rllib # noqa # pylint: disable=unused-import
@@ -45,10 +58,22 @@ def run_experiments(experiments, scheduler=None, with_server=False,
runner = TrialRunner(
scheduler, launch_web_server=with_server, server_port=server_port)
for name, spec in experiments.items():
for trial in generate_trials(spec, name):
if type(experiments) is dict:
for name, spec in experiments.items():
for trial in generate_trials(spec, name):
trial.set_verbose(verbose)
runner.add_trial(trial)
elif (type(experiments) is list and
all(isinstance(exp, Experiment) for exp in experiments)):
for experiment in experiments:
for trial in experiment.trials():
trial.set_verbose(verbose)
runner.add_trial(trial)
elif isinstance(experiments, Experiment):
for trial in experiments.trials():
trial.set_verbose(verbose)
runner.add_trial(trial)
print(runner.debug_string(max_debug=99999))
last_debug = 0