From 708dff6d8f7dd6f7919e06c1845f1fea0cca5b89 Mon Sep 17 00:00:00 2001 From: Ujval Misra Date: Mon, 20 Apr 2020 15:10:36 -0700 Subject: [PATCH] [tune] Stop-gap fix for PBT checkpointing (#7794) * Fix PBT * lint * reset * rm * tests Co-authored-by: Richard Liaw --- python/ray/tune/durable_trainable.py | 14 ++- python/ray/tune/examples/pbt_example.py | 1 + python/ray/tune/ray_trial_executor.py | 27 +++-- python/ray/tune/schedulers/pbt.py | 18 +-- python/ray/tune/tests/test_trial_scheduler.py | 105 +++++++++++++++++- python/ray/tune/trainable.py | 12 +- python/ray/tune/trial.py | 5 +- python/ray/tune/trial_executor.py | 6 +- 8 files changed, 157 insertions(+), 31 deletions(-) diff --git a/python/ray/tune/durable_trainable.py b/python/ray/tune/durable_trainable.py index 981e1bc89..ab25cfa02 100644 --- a/python/ray/tune/durable_trainable.py +++ b/python/ray/tune/durable_trainable.py @@ -1,8 +1,11 @@ +import logging import os from ray.tune.trainable import Trainable, TrainableUtil from ray.tune.syncer import get_cloud_sync_client +logger = logging.getLogger(__name__) + class DurableTrainable(Trainable): """Abstract class for a remote-storage backed fault-tolerant Trainable. @@ -57,7 +60,6 @@ class DurableTrainable(Trainable): if checkpoint_dir.starts_with(os.path.abspath(self.logdir)): raise ValueError("`checkpoint_dir` must be `self.logdir`, or " "a sub-directory.") - checkpoint_path = super(DurableTrainable, self).save(checkpoint_dir) self.storage_client.sync_up(self.logdir, self.remote_checkpoint_dir) self.storage_client.wait() @@ -81,9 +83,15 @@ class DurableTrainable(Trainable): Args: checkpoint_path (str): Local path to checkpoint. """ + try: + local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path) + except FileNotFoundError: + logger.warning( + "Trial %s: checkpoint path not found during " + "garbage collection. See issue #6697.", self.trial_id) + else: + self.storage_client.delete(self._storage_path(local_dirpath)) super(DurableTrainable, self).delete_checkpoint(checkpoint_path) - local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path) - self.storage_client.delete(self._storage_path(local_dirpath)) def _create_storage_client(self): """Returns a storage client.""" diff --git a/python/ray/tune/examples/pbt_example.py b/python/ray/tune/examples/pbt_example.py index c906fe238..e0b9900a1 100755 --- a/python/ray/tune/examples/pbt_example.py +++ b/python/ray/tune/examples/pbt_example.py @@ -108,6 +108,7 @@ if __name__ == "__main__": name="pbt_test", scheduler=pbt, reuse_actors=True, + checkpoint_freq=20, verbose=False, stop={ "training_iteration": 200, diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index d25dc0c8e..637f5a307 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -213,7 +213,7 @@ class RayTrialExecutor(TrialExecutor): trial_item = self._find_item(self._running, trial) assert len(trial_item) < 2, trial_item - def _start_trial(self, trial, checkpoint=None, runner=None): + def _start_trial(self, trial, checkpoint=None, runner=None, train=True): """Starts trial and restores last result if trial was paused. Args: @@ -223,6 +223,7 @@ class RayTrialExecutor(TrialExecutor): from the beginning. runner (Trainable): The remote runner to use. This can be the cached actor. If None, a new runner is created. + train (bool): Whether or not to start training. See `RayTrialExecutor.restore` for possible errors raised. """ @@ -239,7 +240,7 @@ class RayTrialExecutor(TrialExecutor): # If Trial was in flight when paused, self._paused stores result. self._paused.pop(previous_run[0]) self._running[previous_run[0]] = trial - elif not trial.is_restoring: + elif train and not trial.is_restoring: self._train(trial) def _stop_trial(self, trial, error=False, error_msg=None, @@ -278,7 +279,7 @@ class RayTrialExecutor(TrialExecutor): finally: trial.set_runner(None) - def start_trial(self, trial, checkpoint=None): + def start_trial(self, trial, checkpoint=None, train=True): """Starts the trial. Will not return resources if trial repeatedly fails on start. @@ -287,10 +288,11 @@ class RayTrialExecutor(TrialExecutor): trial (Trial): Trial to be started. checkpoint (Checkpoint): A Python object or path storing the state of trial. + train (bool): Whether or not to start training. """ self._commit_resources(trial.resources) try: - self._start_trial(trial, checkpoint) + self._start_trial(trial, checkpoint, train=train) except AbortTrialExecution: logger.exception("Trial %s: Error starting runner, aborting!", trial) @@ -342,10 +344,8 @@ class RayTrialExecutor(TrialExecutor): Args: trial (Trial): Trial to be reset. - new_config (dict): New configuration for Trial - trainable. - new_experiment_tag (str): New experiment name - for trial. + new_config (dict): New configuration for Trial trainable. + new_experiment_tag (str): New experiment name for trial. Returns: True if `reset_config` is successful else False. @@ -633,7 +633,7 @@ class RayTrialExecutor(TrialExecutor): self._running[value] = trial return checkpoint - def restore(self, trial, checkpoint=None): + def restore(self, trial, checkpoint=None, block=False): """Restores training state from a given model checkpoint. Args: @@ -641,6 +641,7 @@ class RayTrialExecutor(TrialExecutor): checkpoint (Checkpoint): The checkpoint to restore from. If None, the most recent PERSISTENT checkpoint is used. Defaults to None. + block (bool): Whether or not to block on restore before returning. Raises: RuntimeError: This error is raised if no runner is found. @@ -680,8 +681,12 @@ class RayTrialExecutor(TrialExecutor): "restoration. Pass in an `upload_dir` and a Trainable " "extending `DurableTrainable` for remote storage-based " "restoration") - self._running[remote] = trial - trial.restoring_from = checkpoint + + if block: + ray.get(remote) + else: + self._running[remote] = trial + trial.restoring_from = checkpoint def export_trial_if_needed(self, trial): """Exports model of this trial based on trial.export_formats. diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index 29304483e..495055c12 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -270,7 +270,6 @@ class PopulationBasedTraining(FIFOScheduler): For each step, logs: [target trial tag, clone trial tag, target trial iteration, clone trial iteration, old config, new config]. - """ trial_name, trial_to_clone_name = (trial_state.orig_tag, new_state.orig_tag) @@ -301,9 +300,7 @@ class PopulationBasedTraining(FIFOScheduler): """Transfers perturbed state from trial_to_clone -> trial. If specified, also logs the updated hyperparam state. - """ - trial_state = self._trial_state[trial] new_state = self._trial_state[trial_to_clone] if not new_state.last_checkpoint: @@ -326,13 +323,20 @@ class PopulationBasedTraining(FIFOScheduler): self._hyperparam_mutations) reset_successful = trial_executor.reset_trial(trial, new_config, new_tag) + + # TODO(ujvl): Refactor Scheduler abstraction to abstract + # mechanism for trial restart away. We block on restore + # and suppress train on start as a stop-gap fix to + # https://github.com/ray-project/ray/issues/7258. if reset_successful: - trial_executor.restore(trial, new_state.last_checkpoint) + trial_executor.restore( + trial, new_state.last_checkpoint, block=True) else: trial_executor.stop_trial(trial, stop_logger=False) trial.config = new_config trial.experiment_tag = new_tag - trial_executor.start_trial(trial, new_state.last_checkpoint) + trial_executor.start_trial( + trial, new_state.last_checkpoint, train=False) self._num_perturbations += 1 # Transfer over the last perturbation time as well @@ -342,9 +346,7 @@ class PopulationBasedTraining(FIFOScheduler): """Returns trials in the lower and upper `quantile` of the population. If there is not enough data to compute this, returns empty lists. - """ - trials = [] for trial, state in self._trial_state.items(): if state.last_score is not None and not trial.is_finished(): @@ -366,9 +368,7 @@ class PopulationBasedTraining(FIFOScheduler): This enables the PBT scheduler to support a greater number of concurrent trials than can fit in the cluster at any given time. - """ - candidates = [] for trial in trial_runner.get_trials(): if trial.status in [Trial.PENDING, Trial.PAUSED] and \ diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index f8e102220..105206518 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -9,6 +9,7 @@ import shutil from unittest.mock import MagicMock import ray +from ray import tune from ray.tune.result import TRAINING_ITERATION from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler, PopulationBasedTraining, MedianStoppingRule, @@ -186,7 +187,7 @@ class EarlyStoppingSuite(unittest.TestCase): class _MockTrialExecutor(TrialExecutor): - def start_trial(self, trial, checkpoint_obj=None): + def start_trial(self, trial, checkpoint_obj=None, train=True): trial.logger_running = True trial.restored_checkpoint = checkpoint_obj.value trial.status = Trial.RUNNING @@ -196,7 +197,7 @@ class _MockTrialExecutor(TrialExecutor): if stop_logger: trial.logger_running = False - def restore(self, trial, checkpoint=None): + def restore(self, trial, checkpoint=None, block=False): pass def save(self, trial, type=Checkpoint.PERSISTENT, result=None): @@ -1102,6 +1103,106 @@ class PopulationBasedTestingSuite(unittest.TestCase): shutil.rmtree(tmpdir) +class E2EPopulationBasedTestingSuite(unittest.TestCase): + def setUp(self): + ray.init(num_cpus=4) + + def tearDown(self): + ray.shutdown() + _register_all() # re-register the evicted objects + + def basicSetup(self, + resample_prob=0.0, + explore=None, + perturbation_interval=10, + log_config=False, + hyperparams=None, + hyperparam_mutations=None, + step_once=True): + hyperparam_mutations = hyperparam_mutations or { + "float_factor": lambda: 100.0, + "int_factor": lambda: 10, + "id_factor": [100] + } + pbt = PopulationBasedTraining( + metric="mean_accuracy", + time_attr="training_iteration", + perturbation_interval=perturbation_interval, + resample_probability=resample_prob, + quantile_fraction=0.25, + hyperparam_mutations=hyperparam_mutations, + custom_explore_fn=explore, + log_config=log_config) + return pbt + + def testCheckpointing(self): + pbt = self.basicSetup(perturbation_interval=2) + + class train(tune.Trainable): + def _train(self): + return {"mean_accuracy": self.training_iteration} + + def _save(self, path): + checkpoint = path + "/checkpoint" + with open(checkpoint, "w") as f: + f.write("OK") + return checkpoint + + trial_hyperparams = { + "float_factor": 2.0, + "const_factor": 3, + "int_factor": 10, + "id_factor": 0 + } + + analysis = tune.run( + train, + num_samples=3, + scheduler=pbt, + checkpoint_freq=3, + config=trial_hyperparams, + stop={"training_iteration": 30}) + + for trial in analysis.trials: + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertTrue(trial.has_checkpoint()) + + def testCheckpointDict(self): + pbt = self.basicSetup(perturbation_interval=2) + + class train_dict(tune.Trainable): + def _setup(self, config): + self.state = {"hi": 1} + + def _train(self): + return {"mean_accuracy": self.training_iteration} + + def _save(self, path): + return self.state + + def _restore(self, state): + self.state = state + + trial_hyperparams = { + "float_factor": 2.0, + "const_factor": 3, + "int_factor": 10, + "id_factor": 0 + } + + analysis = tune.run( + train_dict, + num_samples=3, + scheduler=pbt, + checkpoint_freq=3, + config=trial_hyperparams, + stop={"training_iteration": 30}) + + for trial in analysis.trials: + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertTrue(trial.has_checkpoint()) + + class AsyncHyperBandSuite(unittest.TestCase): def setUp(self): ray.init() diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 1c0eeeb94..b65a47453 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -186,7 +186,7 @@ class Trainable: def default_resource_request(cls, config): """Provides a static resource requirement for the given configuration. - This can be overriden by sub-classes to set the correct trial resource + This can be overridden by sub-classes to set the correct trial resource allocation, so the user does not need to. .. code-block:: python @@ -555,6 +555,16 @@ class Trainable: """ return self._iteration + @property + def training_iteration(self): + """Current training iteration (same as `self.iteration`). + + This value is automatically incremented every time `train()` is called + and is automatically inserted into the training result dict. + + """ + return self._iteration + def get_config(self): """Returns configuration passed in by Tune.""" return self.config diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 942bdef95..802f8086d 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -108,7 +108,7 @@ class TrialInfo: """Serializable struct for holding information for a Trial. Attributes: - trial_name (str): String name of the currernt trial. + trial_name (str): String name of the current trial. trial_id (str): trial_id of the trial """ @@ -191,8 +191,7 @@ class Trial: self.evaluated_params = evaluated_params or {} self.experiment_tag = experiment_tag trainable_cls = self.get_trainable_cls() - if trainable_cls and hasattr(trainable_cls, - "default_resource_request"): + if trainable_cls: default_resources = trainable_cls.default_resource_request( self.config) if default_resources: diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index f6ebacacc..e4e581ead 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -72,13 +72,14 @@ class TrialExecutor: raise NotImplementedError("Subclasses of TrialExecutor must provide " "has_resources() method") - def start_trial(self, trial, checkpoint=None): + def start_trial(self, trial, checkpoint=None, train=True): """Starts the trial restoring from checkpoint if checkpoint is provided. Args: trial (Trial): Trial to be started. checkpoint (Checkpoint): A Python object or path storing the state of trial. + train (bool): Whether or not to start training. """ raise NotImplementedError("Subclasses of TrialExecutor must provide " "start_trial() method") @@ -211,7 +212,7 @@ class TrialExecutor: """Returns a string describing the total resources available.""" raise NotImplementedError - def restore(self, trial, checkpoint=None): + def restore(self, trial, checkpoint=None, block=False): """Restores training state from a checkpoint. If checkpoint is None, try to restore from trial.checkpoint. @@ -220,6 +221,7 @@ class TrialExecutor: Args: trial (Trial): Trial to be restored. checkpoint (Checkpoint): Checkpoint to restore from. + block (bool): Whether or not to block on restore before returning. Returns: False if error occurred, otherwise return True.