mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:48:54 +08:00
[tune] Asynchronous saves (#6912)
* Support asynchronous saves * Fix merge issues * Add test, fix existing tests * More informative warning * Lint, remove print statements * Address comments, add checkpoint.is_resolved fn * Add more detailed comments
This commit is contained in:
@@ -267,8 +267,7 @@ class FunctionActorManager:
|
||||
))
|
||||
self._num_task_executions[job_id][function_id] = 0
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to load function {}.".format(function_name))
|
||||
logger.exception("Failed to load function %s.", function_name)
|
||||
raise Exception(
|
||||
"Function {} failed to be loaded from local code.".format(
|
||||
function_descriptor))
|
||||
@@ -428,8 +427,7 @@ class FunctionActorManager:
|
||||
else:
|
||||
return actor_class
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to load actor_class %s.".format(class_name))
|
||||
logger.exception("Failed to load actor_class %s.", class_name)
|
||||
raise Exception(
|
||||
"Actor {} failed to be imported from local code.".format(
|
||||
class_name))
|
||||
@@ -475,8 +473,7 @@ class FunctionActorManager:
|
||||
with self.lock:
|
||||
actor_class = pickle.loads(pickled_class)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to load actor class %s.".format(class_name))
|
||||
logger.exception("Failed to load actor class %s.", class_name)
|
||||
# The actor class failed to be unpickled, create a fake actor
|
||||
# class instead (just to produce error messages and to prevent
|
||||
# the driver from hanging).
|
||||
|
||||
@@ -13,7 +13,8 @@ class Checkpoint:
|
||||
Attributes:
|
||||
storage (str): Storage type.
|
||||
value (str): If storage==MEMORY, it is a Python object.
|
||||
If storage==PERSISTENT, it is a path to persistent storage.
|
||||
If storage==PERSISTENT, it is a path to persistent storage,
|
||||
or a future that will be resolved to such a path.
|
||||
"""
|
||||
|
||||
MEMORY = "memory"
|
||||
@@ -29,6 +30,18 @@ class Checkpoint:
|
||||
"""Creates a checkpoint from a Python object."""
|
||||
return Checkpoint(Checkpoint.MEMORY, value)
|
||||
|
||||
@property
|
||||
def is_ready(self):
|
||||
"""Returns whether the checkpoint is ready to be used for restoration.
|
||||
|
||||
A PERSISTENT checkpoint is considered ready once its value is resolved
|
||||
to an actual path. MEMORY checkpoints are always considered ready since
|
||||
they are transient.
|
||||
"""
|
||||
if self.storage == Checkpoint.PERSISTENT:
|
||||
return isinstance(self.value, str)
|
||||
return self.storage == Checkpoint.MEMORY
|
||||
|
||||
|
||||
class QueueItem:
|
||||
def __init__(self, priority, value):
|
||||
|
||||
@@ -554,7 +554,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
self._update_avail_resources()
|
||||
|
||||
def save(self, trial, storage=Checkpoint.PERSISTENT, result=None):
|
||||
"""Saves the trial's state to a checkpoint.
|
||||
"""Saves the trial's state to a checkpoint asynchronously.
|
||||
|
||||
Args:
|
||||
trial (Trial): The trial to be saved.
|
||||
@@ -567,29 +567,16 @@ class RayTrialExecutor(TrialExecutor):
|
||||
Checkpoint object, or None if an Exception occurs.
|
||||
"""
|
||||
result = result or trial.last_result
|
||||
|
||||
with self._change_working_directory(trial):
|
||||
if storage == Checkpoint.MEMORY:
|
||||
value = trial.runner.save_to_object.remote()
|
||||
checkpoint = Checkpoint(storage, value, result)
|
||||
else:
|
||||
with warn_if_slow("save_checkpoint_to_storage"):
|
||||
# TODO(ujvl): Make this asynchronous.
|
||||
value = ray.get(trial.runner.save.remote())
|
||||
checkpoint = Checkpoint(storage, value, result)
|
||||
with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile:
|
||||
try:
|
||||
trial.on_checkpoint(checkpoint)
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error handling checkpoint %s",
|
||||
trial, checkpoint.value)
|
||||
return None
|
||||
if profile.too_slow and trial.sync_on_checkpoint:
|
||||
logger.warning(
|
||||
"Consider turning off forced head-worker trial checkpoint "
|
||||
"syncs by setting sync_on_checkpoint=False. Note that this "
|
||||
"might result in faulty trial restoration for some worker "
|
||||
"failure modes.")
|
||||
else:
|
||||
value = trial.runner.save.remote()
|
||||
checkpoint = Checkpoint(storage, value, result)
|
||||
trial.saving_to = checkpoint
|
||||
self._running[value] = trial
|
||||
return checkpoint
|
||||
|
||||
def restore(self, trial, checkpoint=None):
|
||||
|
||||
@@ -146,14 +146,15 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
|
||||
trial = Trial("__fake", **kwargs)
|
||||
runner.add_trial(trial)
|
||||
|
||||
runner.step() # run 1
|
||||
runner.step() # Start trial
|
||||
assert trial.status == Trial.RUNNING
|
||||
cluster.remove_node(node)
|
||||
cluster.add_node(num_cpus=1)
|
||||
cluster.wait_for_nodes()
|
||||
assert ray.cluster_resources()["CPU"] == 1
|
||||
|
||||
for i in range(3):
|
||||
# Process result (x2), process save, process result.
|
||||
for _ in range(4):
|
||||
runner.step()
|
||||
assert trial.status == Trial.TERMINATED
|
||||
|
||||
@@ -237,39 +238,45 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
||||
# Test recovery of trial that hasn't been checkpointed
|
||||
t = Trial(trainable_id, **kwargs)
|
||||
runner.add_trial(t)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
runner.step() # Start trial
|
||||
runner.step() # Process result
|
||||
assert t.last_result
|
||||
node2 = cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
# TODO(ujvl): Node failure does not propagate until a step after it
|
||||
# actually should. This is possibly a problem with `Cluster`.
|
||||
runner.step()
|
||||
runner.step() # Recovery step
|
||||
|
||||
# TODO(rliaw): This assertion is not critical but will not pass
|
||||
# because checkpoint handling is messy and should be refactored
|
||||
# rather than hotfixed.
|
||||
# assert t.last_result is None, "Trial result not restored correctly."
|
||||
for i in range(4):
|
||||
|
||||
# Process result (x2), process save, process result (x2), process save
|
||||
for _ in range(6):
|
||||
runner.step()
|
||||
|
||||
assert t.status == Trial.TERMINATED
|
||||
assert t.status == Trial.TERMINATED, runner.debug_string()
|
||||
|
||||
# Test recovery of trial that has been checkpointed
|
||||
t2 = Trial(trainable_id, **kwargs)
|
||||
runner.add_trial(t2)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
runner.step() # 2 result and checkpoint
|
||||
# Start trial, process result (x2), process save
|
||||
for _ in range(4):
|
||||
runner.step()
|
||||
assert t2.has_checkpoint()
|
||||
node3 = cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node2)
|
||||
cluster.wait_for_nodes()
|
||||
runner.step() # 3 result + start and fail 4 result
|
||||
runner.step() # Recovery step
|
||||
runner.step() # Process recovery
|
||||
runner.step() # result
|
||||
runner.step() # Process result 3 + start and fail 4 result
|
||||
runner.step() # Dispatch restore
|
||||
runner.step() # Process restore
|
||||
runner.step() # Process result 5
|
||||
if t2.status != Trial.TERMINATED:
|
||||
runner.step()
|
||||
runner.step() # Process result 6, dispatch save
|
||||
runner.step() # Process save
|
||||
assert t2.status == Trial.TERMINATED, runner.debug_string()
|
||||
|
||||
# Test recovery of trial that won't be checkpointed
|
||||
@@ -282,8 +289,8 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
||||
}
|
||||
t3 = Trial(trainable_id, **kwargs)
|
||||
runner.add_trial(t3)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
runner.step() # Start trial
|
||||
runner.step() # Process result 1
|
||||
cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node3)
|
||||
cluster.wait_for_nodes()
|
||||
@@ -318,13 +325,16 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
runner.step() # Start trial
|
||||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save
|
||||
|
||||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
runner.step()
|
||||
assert all(t.status == Trial.PENDING for t in trials)
|
||||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save (detect error), requeue trial
|
||||
assert all(
|
||||
t.status == Trial.PENDING for t in trials), runner.debug_string()
|
||||
|
||||
with pytest.raises(TuneError):
|
||||
runner.step()
|
||||
@@ -374,19 +384,21 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
|
||||
# Test recovery of trial that has been checkpointed
|
||||
t1 = Trial(trainable_id, **kwargs)
|
||||
runner.add_trial(t1)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
runner.step() # 2 result and checkpoint
|
||||
|
||||
# Start trial, process result (x2), process save
|
||||
for _ in range(4):
|
||||
runner.step()
|
||||
assert t1.has_checkpoint()
|
||||
|
||||
cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
|
||||
|
||||
runner.step() # collect result 3, kick off + fail result 4
|
||||
runner.step() # Recovery step
|
||||
runner.step() # Process Recovery + step 4
|
||||
for i in range(3):
|
||||
runner.step() # Collect result 3, kick off + fail result 4
|
||||
runner.step() # Dispatch restore
|
||||
runner.step() # Process restore + step 4
|
||||
for _ in range(3):
|
||||
if t1.status != Trial.TERMINATED:
|
||||
runner.step()
|
||||
assert t1.status == Trial.TERMINATED, runner.debug_string()
|
||||
@@ -414,9 +426,9 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step() # start
|
||||
runner.step() # start2
|
||||
runner.step() # step
|
||||
# Start trial (x2), process result, process save
|
||||
for _ in range(4):
|
||||
runner.step()
|
||||
assert all(t.status == Trial.RUNNING for t in runner.get_trials())
|
||||
runner.checkpoint()
|
||||
|
||||
@@ -425,11 +437,12 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
|
||||
|
||||
cluster = _start_new_cluster()
|
||||
runner = TrialRunner(resume="LOCAL", local_checkpoint_dir=dirpath)
|
||||
runner.step() # start
|
||||
runner.step() # process restore
|
||||
runner.step() # start2
|
||||
# Start trial, process restore, process result, process save
|
||||
for _ in range(4):
|
||||
runner.step()
|
||||
|
||||
for i in range(3):
|
||||
# Start trial 2, process result, process save, process result, process save
|
||||
for i in range(5):
|
||||
runner.step()
|
||||
|
||||
with pytest.raises(TuneError):
|
||||
|
||||
@@ -30,11 +30,25 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
self.assertEqual(1, len(running))
|
||||
self.trial_executor.stop_trial(trial)
|
||||
|
||||
def testAsyncSave(self):
|
||||
"""Tests that saved checkpoint value not immediately set."""
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
|
||||
self.assertEqual(checkpoint, trial.saving_to)
|
||||
self.assertEqual(trial.checkpoint.value, None)
|
||||
self.process_trial_save(trial)
|
||||
self.assertEqual(checkpoint, trial.checkpoint)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
def testSaveRestore(self):
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.save(trial, Checkpoint.PERSISTENT)
|
||||
self.process_trial_save(trial)
|
||||
self.trial_executor.restore(trial)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
@@ -59,6 +73,8 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT)
|
||||
# Process save result (simulates trial runner)
|
||||
self.process_trial_save(trial)
|
||||
# Pause
|
||||
self.trial_executor.pause_trial(trial)
|
||||
self.assertEqual(Trial.PAUSED, trial.status)
|
||||
@@ -125,11 +141,20 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
self.assertEqual(trial.experiment_tag, "modified_mock")
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
|
||||
def generate_trials(self, spec, name):
|
||||
@staticmethod
|
||||
def generate_trials(spec, name):
|
||||
suggester = BasicVariantGenerator()
|
||||
suggester.add_configurations({name: spec})
|
||||
return suggester.next_trials()
|
||||
|
||||
@staticmethod
|
||||
def process_trial_save(trial):
|
||||
"""Simulates trial runner save."""
|
||||
checkpoint = trial.saving_to
|
||||
checkpoint_value = ray.get(checkpoint.value)
|
||||
checkpoint.value = checkpoint_value
|
||||
trial.on_checkpoint(checkpoint)
|
||||
|
||||
|
||||
class RayExecutorQueueTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
@@ -84,11 +84,12 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
runner.step() # Start trial
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
runner.step() # Process result, dispatch save
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
runner.step() # Process save
|
||||
runner.step() # Error
|
||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||
self.assertEqual(trials[0].num_failures, 1)
|
||||
self.assertEqual(len(searchalg.errored_trials), 1)
|
||||
@@ -111,14 +112,15 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
runner.step() # Start trial
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
runner.step() # Process result, dispatch save
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
runner.step() # Process save
|
||||
runner.step() # Error (transient), dispatch restore
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[0].num_failures, 1)
|
||||
runner.step()
|
||||
runner.step() # Process restore
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(len(searchalg.errored_trials), 0)
|
||||
self.assertEqual(len(scheduler.errored_trials), 0)
|
||||
@@ -142,15 +144,16 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
|
||||
with patch("ray.cluster_resources") as resource_mock:
|
||||
resource_mock.return_value = {"CPU": 1, "GPU": 1}
|
||||
runner.step()
|
||||
runner.step() # Start trial
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
|
||||
runner.step()
|
||||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
|
||||
# Mimic a node failure
|
||||
resource_mock.return_value = {"CPU": 0, "GPU": 0}
|
||||
runner.step()
|
||||
runner.step() # Detect node failure
|
||||
self.assertEqual(trials[0].status, Trial.PENDING)
|
||||
self.assertEqual(trials[0].num_failures, 1)
|
||||
self.assertEqual(len(searchalg.errored_trials), 0)
|
||||
@@ -171,19 +174,20 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
runner.step() # Start trial
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
runner.step() # Process result, dispatch save
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
runner.step() # Process save
|
||||
runner.step() # Error (transient), dispatch restore
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[0].num_failures, 1)
|
||||
runner.step() # Restore step
|
||||
runner.step()
|
||||
runner.step() # Process restore
|
||||
runner.step() # Error (transient), dispatch restore
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[0].num_failures, 2)
|
||||
runner.step() # Restore step
|
||||
runner.step()
|
||||
runner.step() # Process restore
|
||||
runner.step() # Error (terminal)
|
||||
self.assertEqual(trials[0].status, Trial.ERROR)
|
||||
self.assertEqual(trials[0].num_failures, 3)
|
||||
|
||||
@@ -195,61 +199,69 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
"training_iteration": 1
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
"checkpoint_freq": 1,
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
runner.step() # Start trial
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
checkpoint = runner.trial_executor.save(trials[0])
|
||||
kwargs["restore_path"] = checkpoint.value
|
||||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save, stop trial
|
||||
kwargs["restore_path"] = trials[0].checkpoint.value
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||
|
||||
runner.step()
|
||||
runner.step() # Start trial, dispatch restore
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
|
||||
runner.step() # Process restore
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1)
|
||||
self.addCleanup(os.remove, checkpoint.value)
|
||||
self.addCleanup(os.remove, trials[0].checkpoint.value)
|
||||
|
||||
def testRestoreMetricsAfterCheckpointing(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
"checkpoint_freq": 1,
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
runner.step() # Start trial
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
checkpoint = runner.trial_executor.save(trials[0])
|
||||
# checkpoint = runner.trial_executor.save(trials[0])
|
||||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save
|
||||
runner.trial_executor.stop_trial(trials[0])
|
||||
kwargs["restore_path"] = checkpoint.value
|
||||
kwargs["restore_path"] = trials[0].checkpoint.value
|
||||
|
||||
kwargs.pop("checkpoint_freq") # No checkpointing for next trial
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
runner.step() # Start trial, dispatch restore
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
runner.step() # Restore step
|
||||
runner.step()
|
||||
runner.step() # Process restore
|
||||
runner.step() # Process result
|
||||
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 10)
|
||||
self.assertEqual(trials[1].last_result["iterations_since_restore"], 1)
|
||||
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
|
||||
runner.step()
|
||||
runner.step() # Process restore
|
||||
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20)
|
||||
self.assertEqual(trials[1].last_result["iterations_since_restore"], 2)
|
||||
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
|
||||
self.addCleanup(os.remove, checkpoint.value)
|
||||
self.addCleanup(os.remove, trials[0].checkpoint.value)
|
||||
|
||||
def testCheckpointingAtEnd(self):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
@@ -264,11 +276,12 @@ class TrialRunnerTest2(unittest.TestCase):
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
runner.step() # Start trial
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
runner.step()
|
||||
runner.step()
|
||||
runner.step() # Process result
|
||||
runner.step() # Process result, dispatch save
|
||||
self.assertEqual(trials[0].last_result[DONE], True)
|
||||
runner.step() # Process save
|
||||
self.assertEqual(trials[0].has_checkpoint(), True)
|
||||
|
||||
def testResultDone(self):
|
||||
|
||||
@@ -297,8 +297,9 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
checkpoint_freq=1)
|
||||
]
|
||||
runner.add_trial(trials[0])
|
||||
runner.step() # start
|
||||
runner.step()
|
||||
runner.step() # Start trial
|
||||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save
|
||||
self.assertEquals(trials[0].status, Trial.TERMINATED)
|
||||
|
||||
trials += [
|
||||
@@ -310,9 +311,10 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
config={"mock_error": True})
|
||||
]
|
||||
runner.add_trial(trials[1])
|
||||
runner.step()
|
||||
runner.step()
|
||||
runner.step()
|
||||
runner.step() # Start trial
|
||||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save
|
||||
runner.step() # Error
|
||||
self.assertEquals(trials[1].status, Trial.ERROR)
|
||||
|
||||
trials += [
|
||||
@@ -323,7 +325,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
checkpoint_freq=1)
|
||||
]
|
||||
runner.add_trial(trials[2])
|
||||
runner.step()
|
||||
runner.step() # Start trial
|
||||
self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3)
|
||||
self.assertEquals(trials[2].status, Trial.RUNNING)
|
||||
|
||||
@@ -336,9 +338,11 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
restored_trial = runner2.get_trial("trial_succ")
|
||||
self.assertEqual(Trial.PENDING, restored_trial.status)
|
||||
|
||||
runner2.step()
|
||||
runner2.step()
|
||||
runner2.step()
|
||||
runner2.step() # Start trial
|
||||
runner2.step() # Process result, dispatch save
|
||||
runner2.step() # Process save
|
||||
runner2.step() # Process result, dispatch save
|
||||
runner2.step() # Process save
|
||||
self.assertRaises(TuneError, runner2.step)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
@@ -444,18 +448,19 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 2}))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
runner.step() # Start trial
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
runner.step() # 0
|
||||
runner.step() # Process result
|
||||
self.assertFalse(trials[0].has_checkpoint())
|
||||
runner.step() # 1
|
||||
runner.step() # Process result
|
||||
self.assertFalse(trials[0].has_checkpoint())
|
||||
runner.step() # 2
|
||||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save
|
||||
self.assertTrue(trials[0].has_checkpoint())
|
||||
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
|
||||
runner2.step()
|
||||
runner2.step() # 5: Start trial and dispatch restore
|
||||
trials2 = runner2.get_trials()
|
||||
self.assertEqual(ray.get(trials2[0].runner.get_info.remote()), 1)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
@@ -194,6 +194,7 @@ class Trial:
|
||||
self.custom_trial_name = None
|
||||
|
||||
# Checkpointing fields
|
||||
self.saving_to = None
|
||||
if remote_checkpoint_dir:
|
||||
self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
|
||||
else:
|
||||
@@ -210,7 +211,6 @@ class Trial:
|
||||
# Restoration fields
|
||||
self.restoring_from = None
|
||||
self.num_failures = 0
|
||||
self.num_consecutive_start_attempts = 0
|
||||
|
||||
# AutoML fields
|
||||
self.results = None
|
||||
@@ -460,6 +460,10 @@ class Trial:
|
||||
def is_restoring(self):
|
||||
return self.restoring_from is not None
|
||||
|
||||
@property
|
||||
def is_saving(self):
|
||||
return self.saving_to is not None
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
@@ -497,6 +501,9 @@ class Trial:
|
||||
|
||||
state["runner"] = None
|
||||
state["result_logger"] = None
|
||||
# Avoid waiting for events that will never occur on resume.
|
||||
state["resuming_from"] = None
|
||||
state["saving_to"] = None
|
||||
if self.result_logger:
|
||||
self.result_logger.flush(sync_down=False)
|
||||
state["__logger_started__"] = True
|
||||
|
||||
@@ -45,7 +45,7 @@ class TrialExecutor:
|
||||
self.try_checkpoint_metadata(trial)
|
||||
|
||||
def try_checkpoint_metadata(self, trial):
|
||||
"""Checkpoints metadata.
|
||||
"""Checkpoints trial metadata.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to checkpoint.
|
||||
|
||||
@@ -124,6 +124,8 @@ class TrialRunner:
|
||||
server_port (int): Port number for launching TuneServer.
|
||||
verbose (bool): Flag for verbosity. If False, trial results
|
||||
will not be output.
|
||||
checkpoint_period (int): Trial runner checkpoint periodicity in
|
||||
seconds. Defaults to 10.
|
||||
trial_executor (TrialExecutor): Defaults to RayTrialExecutor.
|
||||
"""
|
||||
self._search_alg = search_alg or BasicVariantGenerator()
|
||||
@@ -144,6 +146,7 @@ class TrialRunner:
|
||||
self._server = TuneServer(self, self._server_port)
|
||||
|
||||
self._trials = []
|
||||
self._cached_trial_decisions = {}
|
||||
self._stop_queue = []
|
||||
self._local_checkpoint_dir = local_checkpoint_dir
|
||||
|
||||
@@ -281,7 +284,6 @@ class TrialRunner:
|
||||
Requires user to manually re-register their objects. Also stops
|
||||
all ongoing trials.
|
||||
"""
|
||||
|
||||
newest_ckpt_path = _find_newest_ckpt(self._local_checkpoint_dir)
|
||||
with open(newest_ckpt_path, "r") as f:
|
||||
runner_state = json.load(f, cls=_TuneFunctionDecoder)
|
||||
@@ -307,7 +309,6 @@ class TrialRunner:
|
||||
|
||||
def is_finished(self):
|
||||
"""Returns whether all trials have finished running."""
|
||||
|
||||
if self._total_time > self._global_time_limit:
|
||||
logger.warning("Exceeded global time limit {} / {}".format(
|
||||
self._total_time, self._global_time_limit))
|
||||
@@ -362,7 +363,6 @@ class TrialRunner:
|
||||
|
||||
Note that the caller usually should not mutate trial state directly.
|
||||
"""
|
||||
|
||||
return self._trials
|
||||
|
||||
def add_trial(self, trial):
|
||||
@@ -427,12 +427,34 @@ class TrialRunner:
|
||||
if trial.is_restoring:
|
||||
with warn_if_slow("process_trial_restore"):
|
||||
self._process_trial_restore(trial)
|
||||
elif trial.is_saving:
|
||||
with warn_if_slow("process_trial_save") as profile:
|
||||
self._process_trial_save(trial)
|
||||
if profile.too_slow and trial.sync_on_checkpoint:
|
||||
# TODO(ujvl): Suggest using DurableTrainable once
|
||||
# API has converged.
|
||||
logger.warning(
|
||||
"Consider turning off forced head-worker trial "
|
||||
"checkpoint syncs by setting sync_on_checkpoint=False"
|
||||
". Note that this may result in faulty trial "
|
||||
"restoration if a failure occurs while the checkpoint "
|
||||
"is being synced from the worker to the head node.")
|
||||
else:
|
||||
with warn_if_slow("process_trial"):
|
||||
self._process_trial(trial)
|
||||
|
||||
def _process_trial(self, trial):
|
||||
"""Processes a trial result."""
|
||||
"""Processes a trial result.
|
||||
|
||||
Fetches the trial's latest result and makes a scheduling decision
|
||||
regarding its next action. If a checkpoint is taken, the decided
|
||||
action is cached and acted on only after the checkpoint is later
|
||||
processed (see `_process_trial_save`). Otherwise the decision is
|
||||
acted on immediately.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial with a result ready to be processed.
|
||||
"""
|
||||
try:
|
||||
result = self.trial_executor.fetch_result(trial)
|
||||
|
||||
@@ -480,25 +502,53 @@ class TrialRunner:
|
||||
self._checkpoint_trial_if_needed(
|
||||
trial, force=result.get(SHOULD_CHECKPOINT, False))
|
||||
|
||||
if decision == TrialScheduler.CONTINUE:
|
||||
self.trial_executor.continue_training(trial)
|
||||
elif decision == TrialScheduler.PAUSE:
|
||||
self.trial_executor.pause_trial(trial)
|
||||
elif decision == TrialScheduler.STOP:
|
||||
self.trial_executor.export_trial_if_needed(trial)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
if trial.is_saving:
|
||||
# Cache decision to execute on after the save is processed.
|
||||
# This prevents changing the trial's state or kicking off
|
||||
# another training step prematurely.
|
||||
self._cached_trial_decisions[trial.trial_id] = decision
|
||||
else:
|
||||
assert False, "Invalid scheduling decision: {}".format(
|
||||
decision)
|
||||
self._execute_action(trial, decision)
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error processing event.", trial)
|
||||
self._process_trial_failure(trial, traceback.format_exc())
|
||||
|
||||
def _process_trial_save(self, trial):
|
||||
"""Processes a trial save.
|
||||
|
||||
Acts on the decision cached during the last `_process_trial` call.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial being saved.
|
||||
"""
|
||||
logger.debug("Trial %s: Processing trial save.", trial)
|
||||
checkpoint_value = None
|
||||
|
||||
try:
|
||||
checkpoint_value = self.trial_executor.fetch_result(trial)
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error processing result.", trial)
|
||||
self._process_trial_failure(trial, traceback.format_exc())
|
||||
|
||||
if checkpoint_value:
|
||||
try:
|
||||
trial.saving_to.value = checkpoint_value
|
||||
trial.on_checkpoint(trial.saving_to)
|
||||
self.trial_executor.try_checkpoint_metadata(trial)
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error handling checkpoint %s",
|
||||
trial, checkpoint_value)
|
||||
|
||||
trial.saving_to = None
|
||||
decision = self._cached_trial_decisions.pop(trial.trial_id, None)
|
||||
if decision and checkpoint_value:
|
||||
self._execute_action(trial, decision)
|
||||
|
||||
def _process_trial_restore(self, trial):
|
||||
"""Processes a trial restore.
|
||||
|
||||
Args:
|
||||
trial: Trial being restored.
|
||||
trial (Trial): Trial being restored.
|
||||
"""
|
||||
logger.debug("Trial %s: Processing trial restore.", trial)
|
||||
try:
|
||||
@@ -529,13 +579,29 @@ class TrialRunner:
|
||||
self.trial_executor.stop_trial(
|
||||
trial, error=True, error_msg=error_msg)
|
||||
|
||||
def _execute_action(self, trial, decision):
|
||||
"""Executes action based on decision.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to act on.
|
||||
decision (str): Scheduling decision to undertake.
|
||||
"""
|
||||
if decision == TrialScheduler.CONTINUE:
|
||||
self.trial_executor.continue_training(trial)
|
||||
elif decision == TrialScheduler.PAUSE:
|
||||
self.trial_executor.pause_trial(trial)
|
||||
elif decision == TrialScheduler.STOP:
|
||||
self.trial_executor.export_trial_if_needed(trial)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
else:
|
||||
raise ValueError("Invalid decision: {}".format(decision))
|
||||
|
||||
def _checkpoint_trial_if_needed(self, trial, force=False):
|
||||
"""Checkpoints trial based off trial.last_result."""
|
||||
if trial.should_checkpoint() or force:
|
||||
# Save trial runtime if possible
|
||||
# Save trial runtime if possible.
|
||||
if trial.runner:
|
||||
self.trial_executor.save(trial, storage=Checkpoint.PERSISTENT)
|
||||
self.trial_executor.try_checkpoint_metadata(trial)
|
||||
|
||||
def _try_recover(self, trial, error_msg):
|
||||
"""Tries to recover trial.
|
||||
|
||||
Reference in New Issue
Block a user