Add large scale regression test for RLlib (#6093)

This commit is contained in:
Eric Liang
2019-11-13 12:22:55 -08:00
committed by GitHub
parent f3f86385d6
commit b924299833
4 changed files with 210 additions and 37 deletions
+61 -35
View File
@@ -213,41 +213,50 @@ def run(run_or_experiment,
queue_trials=queue_trials,
reuse_actors=reuse_actors,
ray_auto_init=ray_auto_init)
experiment = run_or_experiment
if not isinstance(run_or_experiment, Experiment):
run_identifier = Experiment._register_if_needed(run_or_experiment)
experiment = Experiment(
name=name,
run=run_identifier,
stop=stop,
config=config,
resources_per_trial=resources_per_trial,
num_samples=num_samples,
local_dir=local_dir,
upload_dir=upload_dir,
sync_to_driver=sync_to_driver,
trial_name_creator=trial_name_creator,
loggers=loggers,
checkpoint_freq=checkpoint_freq,
checkpoint_at_end=checkpoint_at_end,
keep_checkpoints_num=keep_checkpoints_num,
checkpoint_score_attr=checkpoint_score_attr,
export_formats=export_formats,
max_failures=max_failures,
restore=restore,
sync_function=sync_function)
if isinstance(run_or_experiment, list):
experiments = run_or_experiment
else:
experiments = [run_or_experiment]
if len(experiments) > 1:
logger.info(
"Running multiple concurrent experiments is experimental and may "
"not work with certain features.")
for i, exp in enumerate(experiments):
if not isinstance(exp, Experiment):
run_identifier = Experiment._register_if_needed(exp)
experiments[i] = Experiment(
name=name,
run=run_identifier,
stop=stop,
config=config,
resources_per_trial=resources_per_trial,
num_samples=num_samples,
local_dir=local_dir,
upload_dir=upload_dir,
sync_to_driver=sync_to_driver,
trial_name_creator=trial_name_creator,
loggers=loggers,
checkpoint_freq=checkpoint_freq,
checkpoint_at_end=checkpoint_at_end,
keep_checkpoints_num=keep_checkpoints_num,
checkpoint_score_attr=checkpoint_score_attr,
export_formats=export_formats,
max_failures=max_failures,
restore=restore,
sync_function=sync_function)
else:
logger.debug("Ignoring some parameters passed into tune.run.")
if sync_to_cloud:
assert experiment.remote_checkpoint_dir, (
"Need `upload_dir` if `sync_to_cloud` given.")
for exp in experiments:
assert exp.remote_checkpoint_dir, (
"Need `upload_dir` if `sync_to_cloud` given.")
runner = TrialRunner(
search_alg=search_alg or BasicVariantGenerator(),
scheduler=scheduler or FIFOScheduler(),
local_checkpoint_dir=experiment.checkpoint_dir,
remote_checkpoint_dir=experiment.remote_checkpoint_dir,
local_checkpoint_dir=experiments[0].checkpoint_dir,
remote_checkpoint_dir=experiments[0].remote_checkpoint_dir,
sync_to_cloud=sync_to_cloud,
checkpoint_period=global_checkpoint_period,
resume=resume,
@@ -256,7 +265,8 @@ def run(run_or_experiment,
verbose=bool(verbose > 1),
trial_executor=trial_executor)
runner.add_experiment(experiment)
for exp in experiments:
runner.add_experiment(exp)
if IS_NOTEBOOK:
reporter = JupyterNotebookReporter(overwrite=verbose < 2)
@@ -269,7 +279,7 @@ def run(run_or_experiment,
dict) and "gpu" in resources_per_trial:
# "gpu" is manually set.
pass
elif _check_default_resources_override(experiment.run_identifier):
elif _check_default_resources_override(experiments[0].run_identifier):
# "default_resources" is manually overriden.
pass
else:
@@ -329,7 +339,8 @@ def run_experiments(experiments,
queue_trials=False,
reuse_actors=False,
trial_executor=None,
raise_on_failed_trial=True):
raise_on_failed_trial=True,
concurrent=False):
"""Runs and blocks until all trials finish.
Examples:
@@ -357,10 +368,9 @@ def run_experiments(experiments,
# and it conducts the implicit registration.
experiments = convert_to_experiment_list(experiments)
trials = []
for exp in experiments:
trials += run(
exp,
if concurrent:
return run(
experiments,
search_alg=search_alg,
scheduler=scheduler,
with_server=with_server,
@@ -372,4 +382,20 @@ def run_experiments(experiments,
trial_executor=trial_executor,
raise_on_failed_trial=raise_on_failed_trial,
return_trials=True)
return trials
else:
trials = []
for exp in experiments:
trials += run(
exp,
search_alg=search_alg,
scheduler=scheduler,
with_server=with_server,
server_port=server_port,
verbose=verbose,
resume=resume,
queue_trials=queue_trials,
reuse_actors=reuse_actors,
trial_executor=trial_executor,
raise_on_failed_trial=raise_on_failed_trial,
return_trials=True)
return trials