mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
[tune] support rerunning failed trials (#10060)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from collections import Counter
|
||||
import shutil
|
||||
|
||||
import tempfile
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
@@ -569,6 +570,38 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
print(trial.last_result)
|
||||
self.assertEqual(trial.last_result[DONE], True)
|
||||
|
||||
def testRerun(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmpdir))
|
||||
|
||||
def test(config):
|
||||
tid = config["id"]
|
||||
fail = config["fail"]
|
||||
marker = os.path.join(tmpdir, f"t{tid}-{fail}.log")
|
||||
if not os.path.exists(marker) and fail:
|
||||
open(marker, "w").close()
|
||||
raise ValueError
|
||||
for i in range(10):
|
||||
time.sleep(0.1)
|
||||
tune.report(hello=123)
|
||||
|
||||
config = dict(
|
||||
name="hi-2",
|
||||
config={
|
||||
"fail": tune.grid_search([True, False]),
|
||||
"id": tune.grid_search(list(range(5)))
|
||||
},
|
||||
verbose=1,
|
||||
local_dir=tmpdir,
|
||||
loggers=None)
|
||||
trials = tune.run(test, raise_on_failed_trial=False, **config).trials
|
||||
self.assertEqual(Counter(t.status for t in trials)["ERROR"], 5)
|
||||
new_trials = tune.run(
|
||||
test, resume=True, run_errored_only=True, **config).trials
|
||||
self.assertEqual(Counter(t.status for t in new_trials)["ERROR"], 0)
|
||||
self.assertTrue(
|
||||
all(t.last_result.get("hello") == 123 for t in new_trials))
|
||||
|
||||
def testErrorReturn(self):
|
||||
def train(config, reporter):
|
||||
raise Exception("uh oh")
|
||||
|
||||
@@ -22,11 +22,15 @@ from ray.tune.suggest.search_generator import SearchGenerator
|
||||
|
||||
|
||||
class TrialRunnerTest3(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
shutil.rmtree(self.tmpdir)
|
||||
|
||||
def testStepHook(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
@@ -125,7 +129,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
|
||||
def testSearchAlgFinished(self):
|
||||
"""Checks that SearchAlg is Finished before all trials are done."""
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
ray.init(num_cpus=4, local_mode=True, include_dashboard=False)
|
||||
experiment_spec = {"run": "__fake", "stop": {"training_iteration": 1}}
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
searcher = _MockSuggestionAlgorithm()
|
||||
@@ -150,7 +154,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
def on_trial_result(self, *args, **kwargs):
|
||||
return TrialScheduler.STOP
|
||||
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
ray.init(num_cpus=4, local_mode=True, include_dashboard=False)
|
||||
experiment_spec = {"run": "__fake", "stop": {"training_iteration": 2}}
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
searcher = _MockSuggestionAlgorithm()
|
||||
@@ -241,7 +245,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
def suggest(self, trial_id):
|
||||
return {}
|
||||
|
||||
ray.init(num_cpus=2)
|
||||
ray.init(num_cpus=2, local_mode=True, include_dashboard=False)
|
||||
experiment_spec = {
|
||||
"run": "__fake",
|
||||
"num_samples": 2,
|
||||
@@ -271,7 +275,6 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
|
||||
def testSearcherSaveRestore(self):
|
||||
ray.init(num_cpus=8, local_mode=True)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
def create_searcher():
|
||||
class TestSuggestion(Searcher):
|
||||
@@ -313,7 +316,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
searcher = create_searcher()
|
||||
runner = TrialRunner(
|
||||
search_alg=searcher,
|
||||
local_checkpoint_dir=tmpdir,
|
||||
local_checkpoint_dir=self.tmpdir,
|
||||
checkpoint_period=-1)
|
||||
for i in range(6):
|
||||
runner.step()
|
||||
@@ -331,7 +334,9 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
|
||||
searcher = create_searcher()
|
||||
runner2 = TrialRunner(
|
||||
search_alg=searcher, local_checkpoint_dir=tmpdir, resume="LOCAL")
|
||||
search_alg=searcher,
|
||||
local_checkpoint_dir=self.tmpdir,
|
||||
resume="LOCAL")
|
||||
assert len(runner2.get_trials()) == 6, [
|
||||
t.config for t in runner2.get_trials()
|
||||
]
|
||||
@@ -355,12 +360,87 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
count = Counter(evaluated)
|
||||
assert all(v <= 3 for v in count.values())
|
||||
|
||||
def testTrialErrorResumeFalse(self):
|
||||
ray.init(num_cpus=3, local_mode=True, include_dashboard=False)
|
||||
runner = TrialRunner(local_checkpoint_dir=self.tmpdir)
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 4
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=0),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", config={"mock_error": True}, **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
while not runner.is_finished():
|
||||
runner.step()
|
||||
|
||||
runner.checkpoint(force=True)
|
||||
|
||||
assert trials[0].status == Trial.ERROR
|
||||
del runner
|
||||
|
||||
new_runner = TrialRunner(
|
||||
run_errored_only=False,
|
||||
resume=True,
|
||||
local_checkpoint_dir=self.tmpdir)
|
||||
assert len(new_runner.get_trials()) == 3
|
||||
assert Trial.ERROR in (t.status for t in new_runner.get_trials())
|
||||
|
||||
def testTrialErrorResumeTrue(self):
|
||||
ray.init(num_cpus=3, local_mode=True, include_dashboard=False)
|
||||
runner = TrialRunner(local_checkpoint_dir=self.tmpdir)
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 4
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=0),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", config={"mock_error": True}, **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
while not runner.is_finished():
|
||||
runner.step()
|
||||
|
||||
runner.checkpoint(force=True)
|
||||
|
||||
assert trials[0].status == Trial.ERROR
|
||||
del runner
|
||||
|
||||
new_runner = TrialRunner(
|
||||
run_errored_only=True,
|
||||
resume=True,
|
||||
local_checkpoint_dir=self.tmpdir)
|
||||
assert len(new_runner.get_trials()) == 3
|
||||
assert Trial.ERROR not in (t.status for t in new_runner.get_trials())
|
||||
# The below is just a check for standard behavior.
|
||||
disable_error = False
|
||||
for t in new_runner.get_trials():
|
||||
if t.config.get("mock_error"):
|
||||
t.config["mock_error"] = False
|
||||
disable_error = True
|
||||
assert disable_error
|
||||
|
||||
while not new_runner.is_finished():
|
||||
new_runner.step()
|
||||
assert Trial.ERROR not in (t.status for t in new_runner.get_trials())
|
||||
|
||||
def testTrialSaveRestore(self):
|
||||
"""Creates different trials to test runner.checkpoint/restore."""
|
||||
ray.init(num_cpus=3)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
|
||||
runner = TrialRunner(
|
||||
local_checkpoint_dir=self.tmpdir, checkpoint_period=0)
|
||||
trials = [
|
||||
Trial(
|
||||
"__fake",
|
||||
@@ -401,7 +481,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3)
|
||||
self.assertEquals(trials[2].status, Trial.RUNNING)
|
||||
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir)
|
||||
for tid in ["trial_terminate", "trial_fail"]:
|
||||
original_trial = runner.get_trial(tid)
|
||||
restored_trial = runner2.get_trial(tid)
|
||||
@@ -416,14 +496,13 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
runner2.step() # Process result, dispatch save
|
||||
runner2.step() # Process save
|
||||
self.assertRaises(TuneError, runner2.step)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testTrialNoSave(self):
|
||||
"""Check that non-checkpointing trials are not saved."""
|
||||
ray.init(num_cpus=3)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
|
||||
runner = TrialRunner(
|
||||
local_checkpoint_dir=self.tmpdir, checkpoint_period=0)
|
||||
runner.add_trial(
|
||||
Trial(
|
||||
"__fake",
|
||||
@@ -454,7 +533,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
runner.step()
|
||||
runner.step()
|
||||
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir)
|
||||
new_trials = runner2.get_trials()
|
||||
self.assertEquals(len(new_trials), 3)
|
||||
self.assertTrue(
|
||||
@@ -464,7 +543,6 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING)
|
||||
self.assertTrue(not runner2.get_trial("pending").last_result)
|
||||
runner2.step()
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testCheckpointWithFunction(self):
|
||||
ray.init()
|
||||
@@ -474,18 +552,17 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
"on_episode_start": lambda i: i,
|
||||
}},
|
||||
checkpoint_freq=1)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
|
||||
runner = TrialRunner(
|
||||
local_checkpoint_dir=self.tmpdir, checkpoint_period=0)
|
||||
runner.add_trial(trial)
|
||||
for _ in range(5):
|
||||
runner.step()
|
||||
# force checkpoint
|
||||
runner.checkpoint()
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir)
|
||||
new_trial = runner2.get_trials()[0]
|
||||
self.assertTrue("callbacks" in new_trial.config)
|
||||
self.assertTrue("on_episode_start" in new_trial.config["callbacks"])
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testCheckpointOverwrite(self):
|
||||
def count_checkpoints(cdir):
|
||||
|
||||
@@ -243,6 +243,7 @@ class Trial:
|
||||
self.last_debug = 0
|
||||
self.error_file = None
|
||||
self.error_msg = None
|
||||
self.trial_name_creator = trial_name_creator
|
||||
self.custom_trial_name = None
|
||||
|
||||
# Checkpointing fields
|
||||
@@ -253,6 +254,8 @@ class Trial:
|
||||
self.remote_checkpoint_dir_prefix = None
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.checkpoint_at_end = checkpoint_at_end
|
||||
self.keep_checkpoints_num = keep_checkpoints_num
|
||||
self.checkpoint_score_attr = checkpoint_score_attr
|
||||
self.sync_on_checkpoint = sync_on_checkpoint
|
||||
self.checkpoint_manager = CheckpointManager(
|
||||
keep_checkpoints_num, checkpoint_score_attr,
|
||||
@@ -319,6 +322,31 @@ class Trial:
|
||||
prefix="{}_{}".format(identifier[:MAX_LEN_IDENTIFIER], date_str()),
|
||||
dir=local_dir)
|
||||
|
||||
def reset(self):
|
||||
return Trial(
|
||||
self.trainable_name,
|
||||
config=self.config,
|
||||
trial_id=None,
|
||||
local_dir=self.local_dir,
|
||||
evaluated_params=self.evaluated_params,
|
||||
experiment_tag=self.experiment_tag,
|
||||
resources=self.resources,
|
||||
stopping_criterion=self.stopping_criterion,
|
||||
remote_checkpoint_dir=self.remote_checkpoint_dir,
|
||||
checkpoint_freq=self.checkpoint_freq,
|
||||
checkpoint_at_end=self.checkpoint_at_end,
|
||||
sync_on_checkpoint=self.sync_on_checkpoint,
|
||||
keep_checkpoints_num=self.keep_checkpoints_num,
|
||||
checkpoint_score_attr=self.checkpoint_score_attr,
|
||||
export_formats=self.export_formats,
|
||||
restore_path=self.restore_path,
|
||||
trial_name_creator=self.trial_name_creator,
|
||||
loggers=self.loggers,
|
||||
log_to_file=self.log_to_file,
|
||||
sync_to_driver_fn=self.sync_to_driver_fn,
|
||||
max_failures=self.max_failures,
|
||||
)
|
||||
|
||||
def init_logger(self):
|
||||
"""Init logger."""
|
||||
if not self.result_logger:
|
||||
@@ -579,6 +607,7 @@ class Trial:
|
||||
state[key] = binary_to_hex(cloudpickle.dumps(state.get(key)))
|
||||
|
||||
state["runner"] = None
|
||||
state["location"] = Location()
|
||||
state["result_logger"] = None
|
||||
# Avoid waiting for events that will never occur on resume.
|
||||
state["resuming_from"] = None
|
||||
|
||||
@@ -108,6 +108,10 @@ class TrialRunner:
|
||||
If fail_fast='raise' provided, Tune will automatically
|
||||
raise the exception received by the Trainable. fail_fast='raise'
|
||||
can easily leak resources and should be used with caution.
|
||||
run_errored_only (bool): Resets and reruns failed trials, assuming
|
||||
the provided Trainable is the same. Previous trial artifacts
|
||||
will be left untouched. Only to be used with
|
||||
`resume` enabled. Raises ValueError otherwise.
|
||||
verbose (bool): Flag for verbosity. If False, trial results
|
||||
will not be output.
|
||||
checkpoint_period (int): Trial runner checkpoint periodicity in
|
||||
@@ -130,6 +134,7 @@ class TrialRunner:
|
||||
resume=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
fail_fast=False,
|
||||
run_errored_only=False,
|
||||
verbose=True,
|
||||
checkpoint_period=10,
|
||||
trial_executor=None):
|
||||
@@ -181,8 +186,7 @@ class TrialRunner:
|
||||
|
||||
if self._validate_resume(resume_type=resume):
|
||||
try:
|
||||
self.resume()
|
||||
logger.info("Resuming trial.")
|
||||
self.resume(run_errored_only=run_errored_only)
|
||||
self._resumed = True
|
||||
except Exception as e:
|
||||
if self._verbose:
|
||||
@@ -192,6 +196,11 @@ class TrialRunner:
|
||||
raise
|
||||
logger.info("Restarting experiment.")
|
||||
else:
|
||||
if run_errored_only:
|
||||
raise ValueError(
|
||||
"'run_errored_only' should only be used with 'resume'. "
|
||||
f"Got: resume={resume}, "
|
||||
f"run_errored_only={run_errored_only}")
|
||||
logger.debug("Starting a new experiment.")
|
||||
|
||||
self._start_time = time.time()
|
||||
@@ -307,7 +316,7 @@ class TrialRunner:
|
||||
self._syncer.sync_up_if_needed()
|
||||
return self._local_checkpoint_dir
|
||||
|
||||
def resume(self):
|
||||
def resume(self, run_errored_only=False):
|
||||
"""Resumes all checkpointed trials from previous run.
|
||||
|
||||
Requires user to manually re-register their objects. Also stops
|
||||
@@ -335,7 +344,11 @@ class TrialRunner:
|
||||
trials += [new_trial]
|
||||
for trial in sorted(
|
||||
trials, key=lambda t: t.last_update_time, reverse=True):
|
||||
self.add_trial(trial)
|
||||
if run_errored_only and trial.status == Trial.ERROR:
|
||||
new_trial = trial.reset()
|
||||
self.add_trial(new_trial)
|
||||
else:
|
||||
self.add_trial(trial)
|
||||
|
||||
def is_finished(self):
|
||||
"""Returns whether all trials have finished running."""
|
||||
|
||||
+33
-16
@@ -95,6 +95,7 @@ def run(run_or_experiment,
|
||||
verbose=2,
|
||||
progress_reporter=None,
|
||||
resume=False,
|
||||
run_errored_only=False,
|
||||
queue_trials=False,
|
||||
reuse_actors=False,
|
||||
trial_executor=None,
|
||||
@@ -103,6 +104,32 @@ def run(run_or_experiment,
|
||||
ray_auto_init=True):
|
||||
"""Executes training.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Run 10 trials (each trial is one instance of a Trainable). Tune runs
|
||||
# in parallel and automatically determines concurrency.
|
||||
tune.run(trainable, num_samples=10)
|
||||
|
||||
# Run 1 trial, stop when trial has reached 10 iterations
|
||||
tune.run(my_trainable, stop={"training_iteration": 10})
|
||||
|
||||
# automatically retry failed trials up to 3 times
|
||||
tune.run(my_trainable, stop={"training_iteration": 10}, max_failures=3)
|
||||
|
||||
# Run 1 trial, search over hyperparameters, stop after 10 iterations.
|
||||
space = {"lr": tune.uniform(0, 1), "momentum": tune.uniform(0, 1)}
|
||||
tune.run(my_trainable, config=space, stop={"training_iteration": 10})
|
||||
|
||||
# Resumes training if a previous machine crashed
|
||||
tune.run(my_trainable, config=space,
|
||||
local_dir=<path/to/dir>, resume=True)
|
||||
|
||||
# Rerun ONLY failed trials after an experiment is finished.
|
||||
tune.run(my_trainable, config=space,
|
||||
local_dir=<path/to/dir>, resume=True, run_errored_only=True)
|
||||
|
||||
Args:
|
||||
run_or_experiment (function | class | str | :class:`Experiment`): If
|
||||
function|class|str, this is the algorithm or model to train.
|
||||
@@ -217,6 +244,11 @@ def run(run_or_experiment,
|
||||
PROMPT provides CLI feedback. False forces a new
|
||||
experiment. If resume is set but checkpoint does not exist,
|
||||
ValueError will be thrown.
|
||||
run_errored_only (bool): Only to be used with `resume` enabled.
|
||||
Resets and reruns ERRORED trials upon resume.
|
||||
Experiment location is determined
|
||||
by `name` and `local_dir`. Previous trial artifacts will
|
||||
be left untouched.
|
||||
queue_trials (bool): Whether to queue trials when the cluster does
|
||||
not currently have enough resources to launch one. This should
|
||||
be set to True when running on an autoscaling cluster to enable
|
||||
@@ -233,27 +265,11 @@ def run(run_or_experiment,
|
||||
if Ray is not initialized. Defaults to True.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
ExperimentAnalysis: Object for experiment analysis.
|
||||
|
||||
Raises:
|
||||
TuneError: Any trials failed and `raise_on_failed_trial` is True.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Run 10 trials (each trial is one instance of a Trainable). Tune runs
|
||||
# in parallel and automatically determines concurrency.
|
||||
tune.run(trainable, num_samples=10)
|
||||
|
||||
# Run 1 trial, stop when trial has reached 10 iterations
|
||||
tune.run(my_trainable, stop={"training_iteration": 10})
|
||||
|
||||
# Run 1 trial, search over hyperparameters, stop after 10 iterations.
|
||||
space = {"lr": tune.uniform(0, 1), "momentum": tune.uniform(0, 1)}
|
||||
tune.run(my_trainable, config=space, stop={"training_iteration": 10})
|
||||
"""
|
||||
config = config or {}
|
||||
|
||||
@@ -315,6 +331,7 @@ def run(run_or_experiment,
|
||||
stopper=experiments[0].stopper,
|
||||
checkpoint_period=global_checkpoint_period,
|
||||
resume=resume,
|
||||
run_errored_only=run_errored_only,
|
||||
launch_web_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=bool(verbose > 1),
|
||||
|
||||
Reference in New Issue
Block a user