[tune] Use newest checkpoint in normal operation (#7563)

* Use persistent checkpoint for failures

* Fix test

* Add unpause test

* move test

* Fix tests

* remove debug statement

* Mark test as flaky
This commit is contained in:
Ujval Misra
2020-03-12 22:21:42 -07:00
committed by GitHub
parent f4656d8cc3
commit 6022eb53c4
7 changed files with 68 additions and 20 deletions
+11 -1
View File
@@ -2,6 +2,8 @@
import heapq
import logging
from ray.tune.result import TRAINING_ITERATION
logger = logging.getLogger(__name__)
@@ -84,6 +86,14 @@ class CheckpointManager:
self._best_checkpoints = []
self._membership = set()
@property
def newest_checkpoint(self):
"""Returns the newest checkpoint (based on training iteration)."""
newest_checkpoint = max(
[self.newest_persistent_checkpoint, self.newest_memory_checkpoint],
key=lambda c: c.result.get(TRAINING_ITERATION, -1))
return newest_checkpoint
def on_checkpoint(self, checkpoint):
"""Starts tracking checkpoint metadata on checkpoint.
@@ -127,7 +137,7 @@ class CheckpointManager:
self.delete(worst)
def best_checkpoints(self):
"""Returns best checkpoints, sorted by score."""
"""Returns best PERSISTENT checkpoints, sorted by score."""
checkpoints = sorted(self._best_checkpoints, key=lambda c: c.priority)
return [queue_item.value for queue_item in checkpoints]
@@ -4,18 +4,30 @@ import sys
import unittest
from unittest.mock import patch
from ray.tune.result import TRAINING_ITERATION
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager, logger
class CheckpointManagerTest(unittest.TestCase):
@staticmethod
def mock_result(i):
return {"i": i}
return {"i": i, TRAINING_ITERATION: i}
def checkpoint_manager(self, keep_checkpoints_num):
return CheckpointManager(
keep_checkpoints_num, "i", delete_fn=lambda c: None)
def testNewestCheckpoint(self):
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1)
memory_checkpoint = Checkpoint(Checkpoint.MEMORY, {0},
self.mock_result(0))
checkpoint_manager.on_checkpoint(memory_checkpoint)
persistent_checkpoint = Checkpoint(Checkpoint.PERSISTENT, {1},
self.mock_result(1))
checkpoint_manager.on_checkpoint(persistent_checkpoint)
self.assertEqual(checkpoint_manager.newest_persistent_checkpoint,
persistent_checkpoint)
def testOnCheckpointOrdered(self):
"""
Tests increasing priorities. Also tests that that the worst checkpoints
+1
View File
@@ -303,6 +303,7 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
runner.step()
@pytest.mark.skip(reason="Not very consistent.")
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
"""Removing a node in full cluster causes Trial to be requeued."""
@@ -7,6 +7,7 @@ from ray.rllib import _register_all
from ray.tune import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
from ray.tune.result import TRAINING_ITERATION
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.trial import Trial, Checkpoint
from ray.tune.resources import Resources
@@ -35,6 +36,7 @@ class RayTrialExecutorTest(unittest.TestCase):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(checkpoint, trial.saving_to)
self.assertEqual(trial.checkpoint.value, None)
@@ -47,6 +49,7 @@ class RayTrialExecutorTest(unittest.TestCase):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)
self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.process_trial_save(trial)
self.trial_executor.restore(trial)
@@ -65,16 +68,20 @@ class RayTrialExecutorTest(unittest.TestCase):
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testSavePauseResumeRestore(self):
def testSavePauseResumeErrorRestore(self):
"""Tests that pause checkpoint does not replace restore checkpoint."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
trial.last_result = self.trial_executor.fetch_result(trial)
# Save
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(Trial.RUNNING, trial.status)
self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT)
# Process save result (simulates trial runner)
self.process_trial_save(trial)
# Train
self.trial_executor.continue_training(trial)
trial.last_result = self.trial_executor.fetch_result(trial)
# Pause
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
@@ -82,7 +89,8 @@ class RayTrialExecutorTest(unittest.TestCase):
# Resume
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.assertEqual(trial.checkpoint, checkpoint)
# Error
trial.set_status(Trial.ERROR)
# Restore
self.trial_executor.restore(trial)
self.trial_executor.stop_trial(trial)
@@ -107,6 +115,24 @@ class RayTrialExecutorTest(unittest.TestCase):
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testPauseUnpause(self):
"""Tests that unpausing works for trials being processed."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), 1)
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.unpause_trial(trial)
self.assertEqual(Trial.PENDING, trial.status)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), 2)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testNoResetTrial(self):
"""Tests that reset handles NotImplemented properly."""
trial = Trial("__fake")
@@ -147,11 +173,10 @@ class RayTrialExecutorTest(unittest.TestCase):
suggester.add_configurations({name: spec})
return suggester.next_trials()
@staticmethod
def process_trial_save(trial):
def process_trial_save(self, trial):
"""Simulates trial runner save."""
checkpoint = trial.saving_to
checkpoint_value = ray.get(checkpoint.value)
checkpoint_value = self.trial_executor.fetch_result(trial)
checkpoint.value = checkpoint_value
trial.on_checkpoint(checkpoint)
+3 -5
View File
@@ -315,7 +315,8 @@ class TrialRunnerTest2(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
runner.step() # Start trial
runner.step() # Process result
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.get_info.remote()), None)
@@ -326,12 +327,9 @@ class TrialRunnerTest2(unittest.TestCase):
runner.trial_executor.resume_trial(trials[0])
self.assertEqual(trials[0].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.get_info.remote()), 1)
runner.step()
runner.step() # Process result
self.assertEqual(trials[0].status, Trial.TERMINATED)
+6 -6
View File
@@ -236,13 +236,13 @@ class Trial:
def checkpoint(self):
"""Returns the most recent checkpoint.
If the trial is PAUSED, this is the most recent MEMORY checkpoint.
Otherwise, it is the most recent PERSISTENT checkpoint.
If the trial is in ERROR state, the most recent PERSISTENT checkpoint
is returned.
"""
if self.status == Trial.PAUSED:
assert self.checkpoint_manager.newest_memory_checkpoint.value
return self.checkpoint_manager.newest_memory_checkpoint
checkpoint = self.checkpoint_manager.newest_persistent_checkpoint
if self.status == Trial.ERROR:
checkpoint = self.checkpoint_manager.newest_persistent_checkpoint
else:
checkpoint = self.checkpoint_manager.newest_checkpoint
if checkpoint.value is None:
checkpoint = Checkpoint(Checkpoint.PERSISTENT, self.restore_path)
return checkpoint