mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 05:43:03 +08:00
Add large scale regression test for RLlib (#6093)
This commit is contained in:
+61
-35
@@ -213,41 +213,50 @@ def run(run_or_experiment,
|
||||
queue_trials=queue_trials,
|
||||
reuse_actors=reuse_actors,
|
||||
ray_auto_init=ray_auto_init)
|
||||
experiment = run_or_experiment
|
||||
if not isinstance(run_or_experiment, Experiment):
|
||||
run_identifier = Experiment._register_if_needed(run_or_experiment)
|
||||
experiment = Experiment(
|
||||
name=name,
|
||||
run=run_identifier,
|
||||
stop=stop,
|
||||
config=config,
|
||||
resources_per_trial=resources_per_trial,
|
||||
num_samples=num_samples,
|
||||
local_dir=local_dir,
|
||||
upload_dir=upload_dir,
|
||||
sync_to_driver=sync_to_driver,
|
||||
trial_name_creator=trial_name_creator,
|
||||
loggers=loggers,
|
||||
checkpoint_freq=checkpoint_freq,
|
||||
checkpoint_at_end=checkpoint_at_end,
|
||||
keep_checkpoints_num=keep_checkpoints_num,
|
||||
checkpoint_score_attr=checkpoint_score_attr,
|
||||
export_formats=export_formats,
|
||||
max_failures=max_failures,
|
||||
restore=restore,
|
||||
sync_function=sync_function)
|
||||
if isinstance(run_or_experiment, list):
|
||||
experiments = run_or_experiment
|
||||
else:
|
||||
experiments = [run_or_experiment]
|
||||
if len(experiments) > 1:
|
||||
logger.info(
|
||||
"Running multiple concurrent experiments is experimental and may "
|
||||
"not work with certain features.")
|
||||
for i, exp in enumerate(experiments):
|
||||
if not isinstance(exp, Experiment):
|
||||
run_identifier = Experiment._register_if_needed(exp)
|
||||
experiments[i] = Experiment(
|
||||
name=name,
|
||||
run=run_identifier,
|
||||
stop=stop,
|
||||
config=config,
|
||||
resources_per_trial=resources_per_trial,
|
||||
num_samples=num_samples,
|
||||
local_dir=local_dir,
|
||||
upload_dir=upload_dir,
|
||||
sync_to_driver=sync_to_driver,
|
||||
trial_name_creator=trial_name_creator,
|
||||
loggers=loggers,
|
||||
checkpoint_freq=checkpoint_freq,
|
||||
checkpoint_at_end=checkpoint_at_end,
|
||||
keep_checkpoints_num=keep_checkpoints_num,
|
||||
checkpoint_score_attr=checkpoint_score_attr,
|
||||
export_formats=export_formats,
|
||||
max_failures=max_failures,
|
||||
restore=restore,
|
||||
sync_function=sync_function)
|
||||
else:
|
||||
logger.debug("Ignoring some parameters passed into tune.run.")
|
||||
|
||||
if sync_to_cloud:
|
||||
assert experiment.remote_checkpoint_dir, (
|
||||
"Need `upload_dir` if `sync_to_cloud` given.")
|
||||
for exp in experiments:
|
||||
assert exp.remote_checkpoint_dir, (
|
||||
"Need `upload_dir` if `sync_to_cloud` given.")
|
||||
|
||||
runner = TrialRunner(
|
||||
search_alg=search_alg or BasicVariantGenerator(),
|
||||
scheduler=scheduler or FIFOScheduler(),
|
||||
local_checkpoint_dir=experiment.checkpoint_dir,
|
||||
remote_checkpoint_dir=experiment.remote_checkpoint_dir,
|
||||
local_checkpoint_dir=experiments[0].checkpoint_dir,
|
||||
remote_checkpoint_dir=experiments[0].remote_checkpoint_dir,
|
||||
sync_to_cloud=sync_to_cloud,
|
||||
checkpoint_period=global_checkpoint_period,
|
||||
resume=resume,
|
||||
@@ -256,7 +265,8 @@ def run(run_or_experiment,
|
||||
verbose=bool(verbose > 1),
|
||||
trial_executor=trial_executor)
|
||||
|
||||
runner.add_experiment(experiment)
|
||||
for exp in experiments:
|
||||
runner.add_experiment(exp)
|
||||
|
||||
if IS_NOTEBOOK:
|
||||
reporter = JupyterNotebookReporter(overwrite=verbose < 2)
|
||||
@@ -269,7 +279,7 @@ def run(run_or_experiment,
|
||||
dict) and "gpu" in resources_per_trial:
|
||||
# "gpu" is manually set.
|
||||
pass
|
||||
elif _check_default_resources_override(experiment.run_identifier):
|
||||
elif _check_default_resources_override(experiments[0].run_identifier):
|
||||
# "default_resources" is manually overriden.
|
||||
pass
|
||||
else:
|
||||
@@ -329,7 +339,8 @@ def run_experiments(experiments,
|
||||
queue_trials=False,
|
||||
reuse_actors=False,
|
||||
trial_executor=None,
|
||||
raise_on_failed_trial=True):
|
||||
raise_on_failed_trial=True,
|
||||
concurrent=False):
|
||||
"""Runs and blocks until all trials finish.
|
||||
|
||||
Examples:
|
||||
@@ -357,10 +368,9 @@ def run_experiments(experiments,
|
||||
# and it conducts the implicit registration.
|
||||
experiments = convert_to_experiment_list(experiments)
|
||||
|
||||
trials = []
|
||||
for exp in experiments:
|
||||
trials += run(
|
||||
exp,
|
||||
if concurrent:
|
||||
return run(
|
||||
experiments,
|
||||
search_alg=search_alg,
|
||||
scheduler=scheduler,
|
||||
with_server=with_server,
|
||||
@@ -372,4 +382,20 @@ def run_experiments(experiments,
|
||||
trial_executor=trial_executor,
|
||||
raise_on_failed_trial=raise_on_failed_trial,
|
||||
return_trials=True)
|
||||
return trials
|
||||
else:
|
||||
trials = []
|
||||
for exp in experiments:
|
||||
trials += run(
|
||||
exp,
|
||||
search_alg=search_alg,
|
||||
scheduler=scheduler,
|
||||
with_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=verbose,
|
||||
resume=resume,
|
||||
queue_trials=queue_trials,
|
||||
reuse_actors=reuse_actors,
|
||||
trial_executor=trial_executor,
|
||||
raise_on_failed_trial=raise_on_failed_trial,
|
||||
return_trials=True)
|
||||
return trials
|
||||
|
||||
Reference in New Issue
Block a user