[tune] Asynchronous saves (#6912)

* Support asynchronous saves

* Fix merge issues

* Add test, fix existing tests

* More informative warning

* Lint, remove print statements

* Address comments, add checkpoint.is_resolved fn

* Add more detailed comments
This commit is contained in:
Ujval Misra
2020-02-09 12:17:45 -08:00
committed by GitHub
parent 0648bd28ef
commit 98a07fe37e
10 changed files with 254 additions and 128 deletions
+3 -6
View File
@@ -267,8 +267,7 @@ class FunctionActorManager:
))
self._num_task_executions[job_id][function_id] = 0
except Exception:
logger.exception(
"Failed to load function {}.".format(function_name))
logger.exception("Failed to load function %s.", function_name)
raise Exception(
"Function {} failed to be loaded from local code.".format(
function_descriptor))
@@ -428,8 +427,7 @@ class FunctionActorManager:
else:
return actor_class
except Exception:
logger.exception(
"Failed to load actor_class %s.".format(class_name))
logger.exception("Failed to load actor_class %s.", class_name)
raise Exception(
"Actor {} failed to be imported from local code.".format(
class_name))
@@ -475,8 +473,7 @@ class FunctionActorManager:
with self.lock:
actor_class = pickle.loads(pickled_class)
except Exception:
logger.exception(
"Failed to load actor class %s.".format(class_name))
logger.exception("Failed to load actor class %s.", class_name)
# The actor class failed to be unpickled, create a fake actor
# class instead (just to produce error messages and to prevent
# the driver from hanging).
+14 -1
View File
@@ -13,7 +13,8 @@ class Checkpoint:
Attributes:
storage (str): Storage type.
value (str): If storage==MEMORY, it is a Python object.
If storage==PERSISTENT, it is a path to persistent storage.
If storage==PERSISTENT, it is a path to persistent storage,
or a future that will be resolved to such a path.
"""
MEMORY = "memory"
@@ -29,6 +30,18 @@ class Checkpoint:
"""Creates a checkpoint from a Python object."""
return Checkpoint(Checkpoint.MEMORY, value)
@property
def is_ready(self):
"""Returns whether the checkpoint is ready to be used for restoration.
A PERSISTENT checkpoint is considered ready once its value is resolved
to an actual path. MEMORY checkpoints are always considered ready since
they are transient.
"""
if self.storage == Checkpoint.PERSISTENT:
return isinstance(self.value, str)
return self.storage == Checkpoint.MEMORY
class QueueItem:
def __init__(self, priority, value):
+6 -19
View File
@@ -554,7 +554,7 @@ class RayTrialExecutor(TrialExecutor):
self._update_avail_resources()
def save(self, trial, storage=Checkpoint.PERSISTENT, result=None):
"""Saves the trial's state to a checkpoint.
"""Saves the trial's state to a checkpoint asynchronously.
Args:
trial (Trial): The trial to be saved.
@@ -567,29 +567,16 @@ class RayTrialExecutor(TrialExecutor):
Checkpoint object, or None if an Exception occurs.
"""
result = result or trial.last_result
with self._change_working_directory(trial):
if storage == Checkpoint.MEMORY:
value = trial.runner.save_to_object.remote()
checkpoint = Checkpoint(storage, value, result)
else:
with warn_if_slow("save_checkpoint_to_storage"):
# TODO(ujvl): Make this asynchronous.
value = ray.get(trial.runner.save.remote())
checkpoint = Checkpoint(storage, value, result)
with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile:
try:
trial.on_checkpoint(checkpoint)
except Exception:
logger.exception("Trial %s: Error handling checkpoint %s",
trial, checkpoint.value)
return None
if profile.too_slow and trial.sync_on_checkpoint:
logger.warning(
"Consider turning off forced head-worker trial checkpoint "
"syncs by setting sync_on_checkpoint=False. Note that this "
"might result in faulty trial restoration for some worker "
"failure modes.")
else:
value = trial.runner.save.remote()
checkpoint = Checkpoint(storage, value, result)
trial.saving_to = checkpoint
self._running[value] = trial
return checkpoint
def restore(self, trial, checkpoint=None):
+47 -34
View File
@@ -146,14 +146,15 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
trial = Trial("__fake", **kwargs)
runner.add_trial(trial)
runner.step() # run 1
runner.step() # Start trial
assert trial.status == Trial.RUNNING
cluster.remove_node(node)
cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()
assert ray.cluster_resources()["CPU"] == 1
for i in range(3):
# Process result (x2), process save, process result.
for _ in range(4):
runner.step()
assert trial.status == Trial.TERMINATED
@@ -237,39 +238,45 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
# Test recovery of trial that hasn't been checkpointed
t = Trial(trainable_id, **kwargs)
runner.add_trial(t)
runner.step() # start
runner.step() # 1 result
runner.step() # Start trial
runner.step() # Process result
assert t.last_result
node2 = cluster.add_node(num_cpus=1)
cluster.remove_node(node)
cluster.wait_for_nodes()
# TODO(ujvl): Node failure does not propagate until a step after it
# actually should. This is possibly a problem with `Cluster`.
runner.step()
runner.step() # Recovery step
# TODO(rliaw): This assertion is not critical but will not pass
# because checkpoint handling is messy and should be refactored
# rather than hotfixed.
# assert t.last_result is None, "Trial result not restored correctly."
for i in range(4):
# Process result (x2), process save, process result (x2), process save
for _ in range(6):
runner.step()
assert t.status == Trial.TERMINATED
assert t.status == Trial.TERMINATED, runner.debug_string()
# Test recovery of trial that has been checkpointed
t2 = Trial(trainable_id, **kwargs)
runner.add_trial(t2)
runner.step() # start
runner.step() # 1 result
runner.step() # 2 result and checkpoint
# Start trial, process result (x2), process save
for _ in range(4):
runner.step()
assert t2.has_checkpoint()
node3 = cluster.add_node(num_cpus=1)
cluster.remove_node(node2)
cluster.wait_for_nodes()
runner.step() # 3 result + start and fail 4 result
runner.step() # Recovery step
runner.step() # Process recovery
runner.step() # result
runner.step() # Process result 3 + start and fail 4 result
runner.step() # Dispatch restore
runner.step() # Process restore
runner.step() # Process result 5
if t2.status != Trial.TERMINATED:
runner.step()
runner.step() # Process result 6, dispatch save
runner.step() # Process save
assert t2.status == Trial.TERMINATED, runner.debug_string()
# Test recovery of trial that won't be checkpointed
@@ -282,8 +289,8 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
}
t3 = Trial(trainable_id, **kwargs)
runner.add_trial(t3)
runner.step() # start
runner.step() # 1 result
runner.step() # Start trial
runner.step() # Process result 1
cluster.add_node(num_cpus=1)
cluster.remove_node(node3)
cluster.wait_for_nodes()
@@ -318,13 +325,16 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
for t in trials:
runner.add_trial(t)
runner.step() # start
runner.step() # 1 result
runner.step() # Start trial
runner.step() # Process result, dispatch save
runner.step() # Process save
cluster.remove_node(node)
cluster.wait_for_nodes()
runner.step()
assert all(t.status == Trial.PENDING for t in trials)
runner.step() # Process result, dispatch save
runner.step() # Process save (detect error), requeue trial
assert all(
t.status == Trial.PENDING for t in trials), runner.debug_string()
with pytest.raises(TuneError):
runner.step()
@@ -374,19 +384,21 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
# Test recovery of trial that has been checkpointed
t1 = Trial(trainable_id, **kwargs)
runner.add_trial(t1)
runner.step() # start
runner.step() # 1 result
runner.step() # 2 result and checkpoint
# Start trial, process result (x2), process save
for _ in range(4):
runner.step()
assert t1.has_checkpoint()
cluster.add_node(num_cpus=1)
cluster.remove_node(node)
cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
runner.step() # collect result 3, kick off + fail result 4
runner.step() # Recovery step
runner.step() # Process Recovery + step 4
for i in range(3):
runner.step() # Collect result 3, kick off + fail result 4
runner.step() # Dispatch restore
runner.step() # Process restore + step 4
for _ in range(3):
if t1.status != Trial.TERMINATED:
runner.step()
assert t1.status == Trial.TERMINATED, runner.debug_string()
@@ -414,9 +426,9 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
for t in trials:
runner.add_trial(t)
runner.step() # start
runner.step() # start2
runner.step() # step
# Start trial (x2), process result, process save
for _ in range(4):
runner.step()
assert all(t.status == Trial.RUNNING for t in runner.get_trials())
runner.checkpoint()
@@ -425,11 +437,12 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
cluster = _start_new_cluster()
runner = TrialRunner(resume="LOCAL", local_checkpoint_dir=dirpath)
runner.step() # start
runner.step() # process restore
runner.step() # start2
# Start trial, process restore, process result, process save
for _ in range(4):
runner.step()
for i in range(3):
# Start trial 2, process result, process save, process result, process save
for i in range(5):
runner.step()
with pytest.raises(TuneError):
@@ -30,11 +30,25 @@ class RayTrialExecutorTest(unittest.TestCase):
self.assertEqual(1, len(running))
self.trial_executor.stop_trial(trial)
def testAsyncSave(self):
"""Tests that saved checkpoint value not immediately set."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(checkpoint, trial.saving_to)
self.assertEqual(trial.checkpoint.value, None)
self.process_trial_save(trial)
self.assertEqual(checkpoint, trial.checkpoint)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testSaveRestore(self):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.process_trial_save(trial)
self.trial_executor.restore(trial)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
@@ -59,6 +73,8 @@ class RayTrialExecutorTest(unittest.TestCase):
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(Trial.RUNNING, trial.status)
self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT)
# Process save result (simulates trial runner)
self.process_trial_save(trial)
# Pause
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
@@ -125,11 +141,20 @@ class RayTrialExecutorTest(unittest.TestCase):
self.assertEqual(trial.experiment_tag, "modified_mock")
self.assertEqual(Trial.RUNNING, trial.status)
def generate_trials(self, spec, name):
@staticmethod
def generate_trials(spec, name):
suggester = BasicVariantGenerator()
suggester.add_configurations({name: spec})
return suggester.next_trials()
@staticmethod
def process_trial_save(trial):
"""Simulates trial runner save."""
checkpoint = trial.saving_to
checkpoint_value = ray.get(checkpoint.value)
checkpoint.value = checkpoint_value
trial.on_checkpoint(checkpoint)
class RayExecutorQueueTest(unittest.TestCase):
def setUp(self):
+48 -35
View File
@@ -84,11 +84,12 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
runner.step() # Process save
runner.step() # Error
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[0].num_failures, 1)
self.assertEqual(len(searchalg.errored_trials), 1)
@@ -111,14 +112,15 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
runner.step() # Process save
runner.step() # Error (transient), dispatch restore
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[0].num_failures, 1)
runner.step()
runner.step() # Process restore
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(len(searchalg.errored_trials), 0)
self.assertEqual(len(scheduler.errored_trials), 0)
@@ -142,15 +144,16 @@ class TrialRunnerTest2(unittest.TestCase):
with patch("ray.cluster_resources") as resource_mock:
resource_mock.return_value = {"CPU": 1, "GPU": 1}
runner.step()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
runner.step() # Process result, dispatch save
runner.step() # Process save
self.assertEqual(trials[0].status, Trial.RUNNING)
# Mimic a node failure
resource_mock.return_value = {"CPU": 0, "GPU": 0}
runner.step()
runner.step() # Detect node failure
self.assertEqual(trials[0].status, Trial.PENDING)
self.assertEqual(trials[0].num_failures, 1)
self.assertEqual(len(searchalg.errored_trials), 0)
@@ -171,19 +174,20 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
runner.step() # Process save
runner.step() # Error (transient), dispatch restore
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[0].num_failures, 1)
runner.step() # Restore step
runner.step()
runner.step() # Process restore
runner.step() # Error (transient), dispatch restore
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(trials[0].num_failures, 2)
runner.step() # Restore step
runner.step()
runner.step() # Process restore
runner.step() # Error (terminal)
self.assertEqual(trials[0].status, Trial.ERROR)
self.assertEqual(trials[0].num_failures, 3)
@@ -195,61 +199,69 @@ class TrialRunnerTest2(unittest.TestCase):
"training_iteration": 1
},
"resources": Resources(cpu=1, gpu=1),
"checkpoint_freq": 1,
}
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
checkpoint = runner.trial_executor.save(trials[0])
kwargs["restore_path"] = checkpoint.value
runner.step() # Process result, dispatch save
runner.step() # Process save, stop trial
kwargs["restore_path"] = trials[0].checkpoint.value
self.assertEqual(trials[0].status, Trial.TERMINATED)
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.PENDING)
runner.step()
runner.step() # Start trial, dispatch restore
self.assertEqual(trials[1].status, Trial.RUNNING)
runner.step() # Process restore
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1)
self.addCleanup(os.remove, checkpoint.value)
self.addCleanup(os.remove, trials[0].checkpoint.value)
def testRestoreMetricsAfterCheckpointing(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"resources": Resources(cpu=1, gpu=1),
"checkpoint_freq": 1,
}
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
checkpoint = runner.trial_executor.save(trials[0])
# checkpoint = runner.trial_executor.save(trials[0])
runner.step() # Process result, dispatch save
runner.step() # Process save
runner.trial_executor.stop_trial(trials[0])
kwargs["restore_path"] = checkpoint.value
kwargs["restore_path"] = trials[0].checkpoint.value
kwargs.pop("checkpoint_freq") # No checkpointing for next trial
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
runner.step() # Start trial, dispatch restore
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
runner.step() # Restore step
runner.step()
runner.step() # Process restore
runner.step() # Process result
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 10)
self.assertEqual(trials[1].last_result["iterations_since_restore"], 1)
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
runner.step()
runner.step() # Process restore
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20)
self.assertEqual(trials[1].last_result["iterations_since_restore"], 2)
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
self.addCleanup(os.remove, checkpoint.value)
self.addCleanup(os.remove, trials[0].checkpoint.value)
def testCheckpointingAtEnd(self):
ray.init(num_cpus=1, num_gpus=1)
@@ -264,11 +276,12 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
runner.step()
runner.step() # Process result
runner.step() # Process result, dispatch save
self.assertEqual(trials[0].last_result[DONE], True)
runner.step() # Process save
self.assertEqual(trials[0].has_checkpoint(), True)
def testResultDone(self):
+19 -14
View File
@@ -297,8 +297,9 @@ class TrialRunnerTest3(unittest.TestCase):
checkpoint_freq=1)
]
runner.add_trial(trials[0])
runner.step() # start
runner.step()
runner.step() # Start trial
runner.step() # Process result, dispatch save
runner.step() # Process save
self.assertEquals(trials[0].status, Trial.TERMINATED)
trials += [
@@ -310,9 +311,10 @@ class TrialRunnerTest3(unittest.TestCase):
config={"mock_error": True})
]
runner.add_trial(trials[1])
runner.step()
runner.step()
runner.step()
runner.step() # Start trial
runner.step() # Process result, dispatch save
runner.step() # Process save
runner.step() # Error
self.assertEquals(trials[1].status, Trial.ERROR)
trials += [
@@ -323,7 +325,7 @@ class TrialRunnerTest3(unittest.TestCase):
checkpoint_freq=1)
]
runner.add_trial(trials[2])
runner.step()
runner.step() # Start trial
self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3)
self.assertEquals(trials[2].status, Trial.RUNNING)
@@ -336,9 +338,11 @@ class TrialRunnerTest3(unittest.TestCase):
restored_trial = runner2.get_trial("trial_succ")
self.assertEqual(Trial.PENDING, restored_trial.status)
runner2.step()
runner2.step()
runner2.step()
runner2.step() # Start trial
runner2.step() # Process result, dispatch save
runner2.step() # Process save
runner2.step() # Process result, dispatch save
runner2.step() # Process save
self.assertRaises(TuneError, runner2.step)
shutil.rmtree(tmpdir)
@@ -444,18 +448,19 @@ class TrialRunnerTest3(unittest.TestCase):
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 2}))
trials = runner.get_trials()
runner.step()
runner.step() # Start trial
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
runner.step() # 0
runner.step() # Process result
self.assertFalse(trials[0].has_checkpoint())
runner.step() # 1
runner.step() # Process result
self.assertFalse(trials[0].has_checkpoint())
runner.step() # 2
runner.step() # Process result, dispatch save
runner.step() # Process save
self.assertTrue(trials[0].has_checkpoint())
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
runner2.step()
runner2.step() # 5: Start trial and dispatch restore
trials2 = runner2.get_trials()
self.assertEqual(ray.get(trials2[0].runner.get_info.remote()), 1)
shutil.rmtree(tmpdir)
+8 -1
View File
@@ -194,6 +194,7 @@ class Trial:
self.custom_trial_name = None
# Checkpointing fields
self.saving_to = None
if remote_checkpoint_dir:
self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
else:
@@ -210,7 +211,6 @@ class Trial:
# Restoration fields
self.restoring_from = None
self.num_failures = 0
self.num_consecutive_start_attempts = 0
# AutoML fields
self.results = None
@@ -460,6 +460,10 @@ class Trial:
def is_restoring(self):
return self.restoring_from is not None
@property
def is_saving(self):
return self.saving_to is not None
def __repr__(self):
return str(self)
@@ -497,6 +501,9 @@ class Trial:
state["runner"] = None
state["result_logger"] = None
# Avoid waiting for events that will never occur on resume.
state["resuming_from"] = None
state["saving_to"] = None
if self.result_logger:
self.result_logger.flush(sync_down=False)
state["__logger_started__"] = True
+1 -1
View File
@@ -45,7 +45,7 @@ class TrialExecutor:
self.try_checkpoint_metadata(trial)
def try_checkpoint_metadata(self, trial):
"""Checkpoints metadata.
"""Checkpoints trial metadata.
Args:
trial (Trial): Trial to checkpoint.
+82 -16
View File
@@ -124,6 +124,8 @@ class TrialRunner:
server_port (int): Port number for launching TuneServer.
verbose (bool): Flag for verbosity. If False, trial results
will not be output.
checkpoint_period (int): Trial runner checkpoint periodicity in
seconds. Defaults to 10.
trial_executor (TrialExecutor): Defaults to RayTrialExecutor.
"""
self._search_alg = search_alg or BasicVariantGenerator()
@@ -144,6 +146,7 @@ class TrialRunner:
self._server = TuneServer(self, self._server_port)
self._trials = []
self._cached_trial_decisions = {}
self._stop_queue = []
self._local_checkpoint_dir = local_checkpoint_dir
@@ -281,7 +284,6 @@ class TrialRunner:
Requires user to manually re-register their objects. Also stops
all ongoing trials.
"""
newest_ckpt_path = _find_newest_ckpt(self._local_checkpoint_dir)
with open(newest_ckpt_path, "r") as f:
runner_state = json.load(f, cls=_TuneFunctionDecoder)
@@ -307,7 +309,6 @@ class TrialRunner:
def is_finished(self):
"""Returns whether all trials have finished running."""
if self._total_time > self._global_time_limit:
logger.warning("Exceeded global time limit {} / {}".format(
self._total_time, self._global_time_limit))
@@ -362,7 +363,6 @@ class TrialRunner:
Note that the caller usually should not mutate trial state directly.
"""
return self._trials
def add_trial(self, trial):
@@ -427,12 +427,34 @@ class TrialRunner:
if trial.is_restoring:
with warn_if_slow("process_trial_restore"):
self._process_trial_restore(trial)
elif trial.is_saving:
with warn_if_slow("process_trial_save") as profile:
self._process_trial_save(trial)
if profile.too_slow and trial.sync_on_checkpoint:
# TODO(ujvl): Suggest using DurableTrainable once
# API has converged.
logger.warning(
"Consider turning off forced head-worker trial "
"checkpoint syncs by setting sync_on_checkpoint=False"
". Note that this may result in faulty trial "
"restoration if a failure occurs while the checkpoint "
"is being synced from the worker to the head node.")
else:
with warn_if_slow("process_trial"):
self._process_trial(trial)
def _process_trial(self, trial):
"""Processes a trial result."""
"""Processes a trial result.
Fetches the trial's latest result and makes a scheduling decision
regarding its next action. If a checkpoint is taken, the decided
action is cached and acted on only after the checkpoint is later
processed (see `_process_trial_save`). Otherwise the decision is
acted on immediately.
Args:
trial (Trial): Trial with a result ready to be processed.
"""
try:
result = self.trial_executor.fetch_result(trial)
@@ -480,25 +502,53 @@ class TrialRunner:
self._checkpoint_trial_if_needed(
trial, force=result.get(SHOULD_CHECKPOINT, False))
if decision == TrialScheduler.CONTINUE:
self.trial_executor.continue_training(trial)
elif decision == TrialScheduler.PAUSE:
self.trial_executor.pause_trial(trial)
elif decision == TrialScheduler.STOP:
self.trial_executor.export_trial_if_needed(trial)
self.trial_executor.stop_trial(trial)
if trial.is_saving:
# Cache decision to execute on after the save is processed.
# This prevents changing the trial's state or kicking off
# another training step prematurely.
self._cached_trial_decisions[trial.trial_id] = decision
else:
assert False, "Invalid scheduling decision: {}".format(
decision)
self._execute_action(trial, decision)
except Exception:
logger.exception("Trial %s: Error processing event.", trial)
self._process_trial_failure(trial, traceback.format_exc())
def _process_trial_save(self, trial):
"""Processes a trial save.
Acts on the decision cached during the last `_process_trial` call.
Args:
trial (Trial): Trial being saved.
"""
logger.debug("Trial %s: Processing trial save.", trial)
checkpoint_value = None
try:
checkpoint_value = self.trial_executor.fetch_result(trial)
except Exception:
logger.exception("Trial %s: Error processing result.", trial)
self._process_trial_failure(trial, traceback.format_exc())
if checkpoint_value:
try:
trial.saving_to.value = checkpoint_value
trial.on_checkpoint(trial.saving_to)
self.trial_executor.try_checkpoint_metadata(trial)
except Exception:
logger.exception("Trial %s: Error handling checkpoint %s",
trial, checkpoint_value)
trial.saving_to = None
decision = self._cached_trial_decisions.pop(trial.trial_id, None)
if decision and checkpoint_value:
self._execute_action(trial, decision)
def _process_trial_restore(self, trial):
"""Processes a trial restore.
Args:
trial: Trial being restored.
trial (Trial): Trial being restored.
"""
logger.debug("Trial %s: Processing trial restore.", trial)
try:
@@ -529,13 +579,29 @@ class TrialRunner:
self.trial_executor.stop_trial(
trial, error=True, error_msg=error_msg)
def _execute_action(self, trial, decision):
"""Executes action based on decision.
Args:
trial (Trial): Trial to act on.
decision (str): Scheduling decision to undertake.
"""
if decision == TrialScheduler.CONTINUE:
self.trial_executor.continue_training(trial)
elif decision == TrialScheduler.PAUSE:
self.trial_executor.pause_trial(trial)
elif decision == TrialScheduler.STOP:
self.trial_executor.export_trial_if_needed(trial)
self.trial_executor.stop_trial(trial)
else:
raise ValueError("Invalid decision: {}".format(decision))
def _checkpoint_trial_if_needed(self, trial, force=False):
"""Checkpoints trial based off trial.last_result."""
if trial.should_checkpoint() or force:
# Save trial runtime if possible
# Save trial runtime if possible.
if trial.runner:
self.trial_executor.save(trial, storage=Checkpoint.PERSISTENT)
self.trial_executor.try_checkpoint_metadata(trial)
def _try_recover(self, trial, error_msg):
"""Tries to recover trial.