mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 01:55:25 +08:00
[tune] Tune Documentation and expose better API (#1681)
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.tune import run_experiments
|
||||
from ray.tune.tune import run_experiments, Experiment
|
||||
from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trainable import Trainable
|
||||
@@ -18,4 +18,5 @@ __all__ = [
|
||||
"register_env",
|
||||
"register_trainable",
|
||||
"run_experiments",
|
||||
"Experiment"
|
||||
]
|
||||
|
||||
@@ -13,7 +13,7 @@ import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.tune import Trainable, TrainingResult, register_trainable, \
|
||||
run_experiments
|
||||
run_experiments, Experiment
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
|
||||
|
||||
@@ -62,15 +62,14 @@ if __name__ == "__main__":
|
||||
time_attr="timesteps_total", reward_attr="episode_reward_mean",
|
||||
max_t=100)
|
||||
|
||||
run_experiments({
|
||||
"hyperband_test": {
|
||||
"run": "my_class",
|
||||
"stop": {"training_iteration": 1 if args.smoke_test else 99999},
|
||||
"repeat": 20,
|
||||
"resources": {"cpu": 1, "gpu": 0},
|
||||
"config": {
|
||||
"width": lambda spec: 10 + int(90 * random.random()),
|
||||
"height": lambda spec: int(100 * random.random()),
|
||||
},
|
||||
}
|
||||
}, scheduler=hyperband)
|
||||
exp = Experiment(
|
||||
name="hyperband_test",
|
||||
run="my_class",
|
||||
repeat=20,
|
||||
stop={"training_iteration": 1 if args.smoke_test else 99999},
|
||||
config={
|
||||
"width": lambda spec: 10 + int(90 * random.random()),
|
||||
"height": lambda spec: int(100 * random.random())
|
||||
})
|
||||
|
||||
run_experiments(exp, scheduler=hyperband)
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.tune.variant_generator import generate_trials
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
|
||||
class Experiment(object):
|
||||
"""Tracks experiment specifications.
|
||||
|
||||
Parameters:
|
||||
name (str): Name of experiment.
|
||||
run (str): The algorithm or model to train. This may refer to the
|
||||
name of a built-on algorithm (e.g. RLLib's DQN or PPO), or a
|
||||
user-defined trainable function or class
|
||||
registered in the tune registry.
|
||||
stop (dict): The stopping criteria. The keys may be any field in
|
||||
TrainingResult, whichever is reached first. Defaults to
|
||||
empty dict.
|
||||
config (dict): Algorithm-specific configuration
|
||||
(e.g. env, hyperparams). Defaults to empty dict.
|
||||
resources (dict): Machine resources to allocate per trial,
|
||||
e.g. ``{"cpu": 64, "gpu": 8}``. Note that GPUs will not be
|
||||
assigned unless you specify them here. Defaults to 1 CPU and 0
|
||||
GPUs.
|
||||
repeat (int): Number of times to repeat each trial. Defaults to 1.
|
||||
local_dir (str): Local dir to save training results to.
|
||||
Defaults to ``~/ray_results``.
|
||||
upload_dir (str): Optional URI to sync training results
|
||||
to (e.g. ``s3://bucket``).
|
||||
checkpoint_freq (int): How many training iterations between
|
||||
checkpoints. A value of 0 (default) disables checkpointing.
|
||||
max_failures (int): Try to recover a trial from its last
|
||||
checkpoint at least this many times. Only applies if
|
||||
checkpointing is enabled. Defaults to 3.
|
||||
"""
|
||||
def __init__(self, name, run, stop=None, config=None,
|
||||
resources=None, repeat=1, local_dir=None,
|
||||
upload_dir="", checkpoint_freq=0, max_failures=3):
|
||||
spec = {
|
||||
"run": run,
|
||||
"stop": stop or {},
|
||||
"config": config or {},
|
||||
"resources": resources or {"cpu": 1, "gpu": 0},
|
||||
"repeat": repeat,
|
||||
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
|
||||
"upload_dir": upload_dir,
|
||||
"checkpoint_freq": checkpoint_freq,
|
||||
"max_failures": max_failures
|
||||
}
|
||||
self._trials = generate_trials(spec, name)
|
||||
|
||||
def trials(self):
|
||||
for trial in self._trials:
|
||||
yield trial
|
||||
@@ -13,6 +13,7 @@ from ray.tune import Trainable, TuneError
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.variant_generator import generate_trials, grid_search, \
|
||||
@@ -203,6 +204,79 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
self.assertEqual(trial.last_result.timesteps_total, 99)
|
||||
|
||||
|
||||
class RunExperimentTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testDict(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
register_trainable("f1", train)
|
||||
trials = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
},
|
||||
"bar": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
}
|
||||
})
|
||||
for trial in trials:
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 99)
|
||||
|
||||
def testExperiment(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
register_trainable("f1", train)
|
||||
exp1 = Experiment(**{
|
||||
"name": "foo",
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
})
|
||||
[trial] = run_experiments(exp1)
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 99)
|
||||
|
||||
def testExperimentList(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
register_trainable("f1", train)
|
||||
exp1 = Experiment(**{
|
||||
"name": "foo",
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
})
|
||||
exp2 = Experiment(**{
|
||||
"name": "bar",
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
})
|
||||
trials = run_experiments([exp1, exp2])
|
||||
for trial in trials:
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 99)
|
||||
|
||||
|
||||
class VariantGeneratorTest(unittest.TestCase):
|
||||
def testParseToTrials(self):
|
||||
trials = generate_trials({
|
||||
|
||||
+27
-2
@@ -14,6 +14,7 @@ from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.trial_scheduler import FIFOScheduler
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.tune.variant_generator import generate_trials
|
||||
from ray.tune.experiment import Experiment
|
||||
|
||||
|
||||
_SCHEDULERS = {
|
||||
@@ -35,6 +36,18 @@ def _make_scheduler(args):
|
||||
|
||||
def run_experiments(experiments, scheduler=None, with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT, verbose=True):
|
||||
"""Tunes experiments.
|
||||
|
||||
Args:
|
||||
experiments (Experiment | list | dict): Experiments to run.
|
||||
scheduler (TrialScheduler): Scheduler for executing
|
||||
the experiment. Choose among FIFO (default), MedianStopping,
|
||||
AsyncHyperBand, or HyperBand.
|
||||
with_server (bool): Starts a background Tune server. Needed for
|
||||
using the Client API.
|
||||
server_port (int): Port number for launching TuneServer.
|
||||
verbose (bool): How much output should be printed for each trial.
|
||||
"""
|
||||
|
||||
# Make sure rllib agents are registered
|
||||
from ray import rllib # noqa # pylint: disable=unused-import
|
||||
@@ -45,10 +58,22 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
||||
runner = TrialRunner(
|
||||
scheduler, launch_web_server=with_server, server_port=server_port)
|
||||
|
||||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
if type(experiments) is dict:
|
||||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
trial.set_verbose(verbose)
|
||||
runner.add_trial(trial)
|
||||
elif (type(experiments) is list and
|
||||
all(isinstance(exp, Experiment) for exp in experiments)):
|
||||
for experiment in experiments:
|
||||
for trial in experiment.trials():
|
||||
trial.set_verbose(verbose)
|
||||
runner.add_trial(trial)
|
||||
elif isinstance(experiments, Experiment):
|
||||
for trial in experiments.trials():
|
||||
trial.set_verbose(verbose)
|
||||
runner.add_trial(trial)
|
||||
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
last_debug = 0
|
||||
|
||||
Reference in New Issue
Block a user