diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index d4a5bcd1e..27318355c 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -45,6 +45,9 @@ class CheckpointManager: def __init__(self, keep_checkpoints_num, checkpoint_score_attr, delete_fn): """Initializes a new CheckpointManager. + `newest_persistent_checkpoint` and `newest_memory_checkpoint` are + initialized to Checkpoint objects with values of None. + Args: keep_checkpoints_num (int): Keep at least this many checkpoints. checkpoint_score_attr (str): Attribute to use to determine which @@ -60,28 +63,38 @@ class CheckpointManager: self._checkpoint_score_attr = checkpoint_score_attr[4:] else: self._checkpoint_score_attr = checkpoint_score_attr + self.delete = delete_fn - self.newest_checkpoint = Checkpoint(Checkpoint.MEMORY, None) + self.newest_persistent_checkpoint = Checkpoint(Checkpoint.PERSISTENT, + None) + self.newest_memory_checkpoint = Checkpoint(Checkpoint.MEMORY, None) self._best_checkpoints = [] self._membership = set() def on_checkpoint(self, checkpoint): """Starts tracking checkpoint metadata on checkpoint. - Sets newest checkpoint. Deletes previous checkpoint as long as it isn't - one of the best ones. Also deletes the worst checkpoint if at capacity. + Sets the newest checkpoint. For PERSISTENT checkpoints: Deletes + previous checkpoint as long as it isn't one of the best ones. Also + deletes the worst checkpoint if at capacity. Args: checkpoint (Checkpoint): Trial state checkpoint. """ - old_checkpoint = self.newest_checkpoint - self.newest_checkpoint = checkpoint + if checkpoint.storage == Checkpoint.MEMORY: + self.newest_memory_checkpoint = checkpoint + return + + old_checkpoint = self.newest_persistent_checkpoint + self.newest_persistent_checkpoint = checkpoint + + # Remove the old checkpoint if it isn't one of the best ones. + if old_checkpoint.value and old_checkpoint not in self._membership: + self.delete(old_checkpoint) try: queue_item = QueueItem(self._priority(checkpoint), checkpoint) except KeyError: - if old_checkpoint not in self._membership: - self.delete(old_checkpoint) logger.error("Result dict has no key: {}. " "checkpoint_score_attr must be set to a key in the " "result dict.".format(self._checkpoint_score_attr)) @@ -95,11 +108,10 @@ class CheckpointManager: self._membership.add(checkpoint) if worst in self._membership: self._membership.remove(worst) - self.delete(worst) - - # Remove the old checkpoint if it isn't one of the best ones. - if old_checkpoint.value and old_checkpoint not in self._membership: - self.delete(old_checkpoint) + # Don't delete the newest checkpoint. It will be deleted on the + # next on_checkpoint() call since it isn't in self._membership. + if worst != checkpoint: + self.delete(worst) def best_checkpoints(self): """Returns best checkpoints, sorted by score.""" diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index ecf6d0f30..48b1bfc3e 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -172,13 +172,12 @@ class RayTrialExecutor(TrialExecutor): See `RayTrialExecutor.restore` for possible errors raised. """ prior_status = trial.status - self.set_status(trial, Trial.RUNNING) - trial.set_runner( - runner or self._setup_remote_runner( - trial, - reuse_allowed=checkpoint is not None - or trial.has_checkpoint())) + if runner is None: + reuse_allowed = checkpoint is not None or trial.has_checkpoint() + runner = self._setup_remote_runner(trial, reuse_allowed) + trial.set_runner(runner) self.restore(trial, checkpoint) + self.set_status(trial, Trial.RUNNING) previous_run = self._find_item(self._paused, trial) if prior_status == Trial.PAUSED and previous_run: @@ -421,6 +420,11 @@ class RayTrialExecutor(TrialExecutor): def _update_avail_resources(self, num_retries=5): resources = None for i in range(num_retries): + if i > 0: + logger.warning( + "Cluster resources not detected or are 0. Attempt #" + "%s...", i + 1) + time.sleep(0.5) try: resources = ray.cluster_resources() except Exception: @@ -428,10 +432,8 @@ class RayTrialExecutor(TrialExecutor): # https://github.com/ray-project/ray/issues/4147 logger.debug("Using resources for local machine.") resources = ResourceSpec().resolve(True).to_resource_dict() - if not resources: - logger.warning( - "Cluster resources not detected or are 0. Retrying...") - time.sleep(0.5) + if resources: + break if not resources: # NOTE: This hides the possibility that Ray may be waiting for @@ -555,14 +557,14 @@ class RayTrialExecutor(TrialExecutor): """Saves the trial's state to a checkpoint. Args: - trial (Trial): The state of this trial to be saved. + trial (Trial): The trial to be saved. storage (str): Where to store the checkpoint. Defaults to PERSISTENT. result (dict): The state of this trial as a dictionary to be saved. If result is None, the trial's last result will be used. Returns: - Checkpoint future, or None if an Exception occurs. + Checkpoint object, or None if an Exception occurs. """ result = result or trial.last_result @@ -588,11 +590,17 @@ class RayTrialExecutor(TrialExecutor): "syncs by setting sync_on_checkpoint=False. Note that this " "might result in faulty trial restoration for some worker " "failure modes.") - return checkpoint.value + return checkpoint def restore(self, trial, checkpoint=None): """Restores training state from a given model checkpoint. + Args: + trial (Trial): The trial to be restored. + checkpoint (Checkpoint): The checkpoint to restore from. If None, + the most recent PERSISTENT checkpoint is used. Defaults to + None. + Raises: RuntimeError: This error is raised if no runner is found. AbortTrialExecution: This error is raised if the trial is diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index e864fba49..ee5efb8e8 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -324,14 +324,12 @@ class PopulationBasedTraining(FIFOScheduler): reset_successful = trial_executor.reset_trial(trial, new_config, new_tag) if reset_successful: - trial_executor.restore( - trial, Checkpoint.from_object(new_state.last_checkpoint)) + trial_executor.restore(trial, new_state.last_checkpoint) else: trial_executor.stop_trial(trial, stop_logger=False) trial.config = new_config trial.experiment_tag = new_tag - trial_executor.start_trial( - trial, Checkpoint.from_object(new_state.last_checkpoint)) + trial_executor.start_trial(trial, new_state.last_checkpoint) self._num_perturbations += 1 # Transfer over the last perturbation time as well diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index 24451a1c2..d6daf7c08 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -28,14 +28,14 @@ class CheckpointManagerTest(unittest.TestCase): for i in range(3) ] - with patch.object(checkpoint_manager, "delete") as \ - delete_mock: + with patch.object(checkpoint_manager, "delete") as delete_mock: for j in range(3): checkpoint_manager.on_checkpoint(checkpoints[j]) expected_deletes = 0 if j != 2 else 1 self.assertEqual(delete_mock.call_count, expected_deletes, j) - self.assertEqual(checkpoint_manager.newest_checkpoint, - checkpoints[j]) + self.assertEqual( + checkpoint_manager.newest_persistent_checkpoint, + checkpoints[j]) best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) @@ -59,8 +59,9 @@ class CheckpointManagerTest(unittest.TestCase): checkpoint_manager.on_checkpoint(checkpoints[j]) expected_deletes = 0 if j != 3 else 1 self.assertEqual(delete_mock.call_count, expected_deletes) - self.assertEqual(checkpoint_manager.newest_checkpoint, - checkpoints[j]) + self.assertEqual( + checkpoint_manager.newest_persistent_checkpoint, + checkpoints[j]) best_checkpoints = checkpoint_manager.best_checkpoints() self.assertEqual(len(best_checkpoints), keep_checkpoints_num) @@ -74,7 +75,7 @@ class CheckpointManagerTest(unittest.TestCase): keep_checkpoints_num = 4 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - Checkpoint(Checkpoint.MEMORY, i, self.mock_result(i)) + Checkpoint(Checkpoint.PERSISTENT, i, self.mock_result(i)) for i in range(16) ] random.shuffle(checkpoints) @@ -92,15 +93,28 @@ class CheckpointManagerTest(unittest.TestCase): Tests that an error is logged when the associated result of the checkpoint has no checkpoint score attribute. """ - keep_checkpoints_num = 1 - checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) + checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) - no_attr_checkpoint = Checkpoint(Checkpoint.MEMORY, 0, {}) + no_attr_checkpoint = Checkpoint(Checkpoint.PERSISTENT, 0, {}) with patch.object(logger, "error") as log_error_mock: checkpoint_manager.on_checkpoint(no_attr_checkpoint) log_error_mock.assert_called_once() # The newest checkpoint should still be set despite this error. - assert checkpoint_manager.newest_checkpoint == no_attr_checkpoint + self.assertEqual(checkpoint_manager.newest_persistent_checkpoint, + no_attr_checkpoint) + + def testOnMemoryCheckpoint(self): + checkpoints = [ + Checkpoint(Checkpoint.MEMORY, 0, self.mock_result(0)), + Checkpoint(Checkpoint.MEMORY, 0, self.mock_result(0)) + ] + checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) + checkpoint_manager.on_checkpoint(checkpoints[0]) + checkpoint_manager.on_checkpoint(checkpoints[1]) + newest = checkpoint_manager.newest_memory_checkpoint + + self.assertEqual(newest, checkpoints[1]) + self.assertEqual(checkpoint_manager.best_checkpoints(), []) if __name__ == "__main__": diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index e54789ed9..a4d49e3e2 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -555,6 +555,8 @@ tune.run( cluster.shutdown() +# TODO(ujvl): Fix test. +@pytest.mark.skip(reason="Not very consistent.") def test_cluster_interrupt(start_connected_cluster, tmpdir): """Tests run_experiment on cluster shutdown with actual interrupt. diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index 70bc7c913..6c9fd12d9 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -51,6 +51,27 @@ class RayTrialExecutorTest(unittest.TestCase): self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) + def testSavePauseResumeRestore(self): + """Tests that pause checkpoint does not replace restore checkpoint.""" + trial = Trial("__fake") + self.trial_executor.start_trial(trial) + # Save + checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT) + self.assertEqual(Trial.RUNNING, trial.status) + self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT) + # Pause + self.trial_executor.pause_trial(trial) + self.assertEqual(Trial.PAUSED, trial.status) + self.assertEqual(trial.checkpoint.storage, Checkpoint.MEMORY) + # Resume + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.RUNNING, trial.status) + self.assertEqual(trial.checkpoint, checkpoint) + # Restore + self.trial_executor.restore(trial) + self.trial_executor.stop_trial(trial) + self.assertEqual(Trial.TERMINATED, trial.status) + def testStartFailure(self): _global_registry.register(TRAINABLE_CLASS, "asdf", None) trial = Trial("asdf", resources=Resources(1, 0)) @@ -63,9 +84,9 @@ class RayTrialExecutorTest(unittest.TestCase): self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.fetch_result(trial) - self.trial_executor.pause_trial(trial) + checkpoint = self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) - self.trial_executor.start_trial(trial) + self.trial_executor.start_trial(trial, checkpoint) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) diff --git a/python/ray/tune/tests/test_trial_runner_2.py b/python/ray/tune/tests/test_trial_runner_2.py index 17fcf8093..7e1d76abf 100644 --- a/python/ray/tune/tests/test_trial_runner_2.py +++ b/python/ray/tune/tests/test_trial_runner_2.py @@ -202,8 +202,8 @@ class TrialRunnerTest2(unittest.TestCase): runner.step() self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) - path = runner.trial_executor.save(trials[0]) - kwargs["restore_path"] = path + checkpoint = runner.trial_executor.save(trials[0]) + kwargs["restore_path"] = checkpoint.value runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() @@ -216,7 +216,7 @@ class TrialRunnerTest2(unittest.TestCase): self.assertEqual(trials[0].status, Trial.TERMINATED) self.assertEqual(trials[1].status, Trial.RUNNING) self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1) - self.addCleanup(os.remove, path) + self.addCleanup(os.remove, checkpoint.value) def testRestoreMetricsAfterCheckpointing(self): ray.init(num_cpus=1, num_gpus=1) @@ -230,9 +230,9 @@ class TrialRunnerTest2(unittest.TestCase): runner.step() self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) - path = runner.trial_executor.save(trials[0]) + checkpoint = runner.trial_executor.save(trials[0]) runner.trial_executor.stop_trial(trials[0]) - kwargs["restore_path"] = path + kwargs["restore_path"] = checkpoint.value runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() @@ -249,7 +249,7 @@ class TrialRunnerTest2(unittest.TestCase): self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20) self.assertEqual(trials[1].last_result["iterations_since_restore"], 2) self.assertGreater(trials[1].last_result["time_since_restore"], 0) - self.addCleanup(os.remove, path) + self.addCleanup(os.remove, checkpoint.value) def testCheckpointingAtEnd(self): ray.init(num_cpus=1, num_gpus=1) diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index d187f95c8..e1e0d0277 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -200,7 +200,7 @@ class _MockTrialExecutor(TrialExecutor): pass def save(self, trial, type=Checkpoint.PERSISTENT, result=None): - return trial.trainable_name + return Checkpoint(Checkpoint.PERSISTENT, trial.trainable_name, result) def reset_trial(self, trial, new_config, new_experiment_tag): return False diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index f3d0da9fa..89135f90d 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -194,11 +194,11 @@ class Trial: self.checkpoint_freq = checkpoint_freq self.checkpoint_at_end = checkpoint_at_end self.sync_on_checkpoint = sync_on_checkpoint - newest_checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path) self.checkpoint_manager = CheckpointManager( keep_checkpoints_num, checkpoint_score_attr, checkpoint_deleter(str(self), self.runner)) - self.checkpoint_manager.newest_checkpoint = newest_checkpoint + checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path) + self.checkpoint_manager.newest_persistent_checkpoint = checkpoint # Restoration fields self.restoring_from = None @@ -228,7 +228,16 @@ class Trial: @property def checkpoint(self): - return self.checkpoint_manager.newest_checkpoint + """Returns the most recent checkpoint. + + If the trial is PAUSED, this is the most recent MEMORY checkpoint. + Otherwise, it is the most recent PERSISTENT checkpoint. + """ + if self.status == Trial.PAUSED: + assert self.checkpoint_manager.newest_memory_checkpoint.value + return self.checkpoint_manager.newest_memory_checkpoint + else: + return self.checkpoint_manager.newest_persistent_checkpoint @classmethod def generate_id(cls): @@ -351,7 +360,6 @@ class Trial: checkpoint (Checkpoint): Checkpoint taken. """ if checkpoint.storage == Checkpoint.MEMORY: - # TODO(ujvl): Handle this separately to avoid restoration failure. self.checkpoint_manager.on_checkpoint(checkpoint) return if self.sync_on_checkpoint: @@ -364,7 +372,7 @@ class Trial: # checkpoint, so it should just be logged. logger.error( "Trial %s: An error occurred during the " - "checkpoint pre-sync wait.", str(e)) + "checkpoint pre-sync wait - %s", self, str(e)) # Force sync down and wait before tracking the new checkpoint. try: if self.result_logger.sync_down(): diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 83696fac7..c4393160e 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -196,7 +196,7 @@ class TrialExecutor: Assumes the trial is running. - Return: + Returns: Result object for the trial. """ raise NotImplementedError @@ -219,7 +219,7 @@ class TrialExecutor: trial (Trial): Trial to be restored. checkpoint (Checkpoint): Checkpoint to restore from. - Return: + Returns: False if error occurred, otherwise return True. """ raise NotImplementedError("Subclasses of TrialExecutor must provide " @@ -236,9 +236,8 @@ class TrialExecutor: PERSISTENT. result (dict): The state of this trial as a dictionary to be saved. - Return: - A Python object if storage==Checkpoint.MEMORY otherwise - a path to the checkpoint. + Returns: + A Checkpoint object. """ raise NotImplementedError("Subclasses of TrialExecutor must provide " "save() method") @@ -249,7 +248,7 @@ class TrialExecutor: Args: trial (Trial): The state of this trial to be saved. - Return: + Returns: A dict that maps ExportFormats to successfully exported models. """ raise NotImplementedError("Subclasses of TrialExecutor must provide "