[tune] Stop-gap fix for PBT checkpointing (#7794)

* Fix PBT

* lint

* reset

* rm

* tests

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Ujval Misra
2020-04-20 15:10:36 -07:00
committed by GitHub
parent 213d3894ca
commit 708dff6d8f
8 changed files with 157 additions and 31 deletions
+11 -3
View File
@@ -1,8 +1,11 @@
import logging
import os
from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.syncer import get_cloud_sync_client
logger = logging.getLogger(__name__)
class DurableTrainable(Trainable):
"""Abstract class for a remote-storage backed fault-tolerant Trainable.
@@ -57,7 +60,6 @@ class DurableTrainable(Trainable):
if checkpoint_dir.starts_with(os.path.abspath(self.logdir)):
raise ValueError("`checkpoint_dir` must be `self.logdir`, or "
"a sub-directory.")
checkpoint_path = super(DurableTrainable, self).save(checkpoint_dir)
self.storage_client.sync_up(self.logdir, self.remote_checkpoint_dir)
self.storage_client.wait()
@@ -81,9 +83,15 @@ class DurableTrainable(Trainable):
Args:
checkpoint_path (str): Local path to checkpoint.
"""
try:
local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path)
except FileNotFoundError:
logger.warning(
"Trial %s: checkpoint path not found during "
"garbage collection. See issue #6697.", self.trial_id)
else:
self.storage_client.delete(self._storage_path(local_dirpath))
super(DurableTrainable, self).delete_checkpoint(checkpoint_path)
local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path)
self.storage_client.delete(self._storage_path(local_dirpath))
def _create_storage_client(self):
"""Returns a storage client."""
+1
View File
@@ -108,6 +108,7 @@ if __name__ == "__main__":
name="pbt_test",
scheduler=pbt,
reuse_actors=True,
checkpoint_freq=20,
verbose=False,
stop={
"training_iteration": 200,
+16 -11
View File
@@ -213,7 +213,7 @@ class RayTrialExecutor(TrialExecutor):
trial_item = self._find_item(self._running, trial)
assert len(trial_item) < 2, trial_item
def _start_trial(self, trial, checkpoint=None, runner=None):
def _start_trial(self, trial, checkpoint=None, runner=None, train=True):
"""Starts trial and restores last result if trial was paused.
Args:
@@ -223,6 +223,7 @@ class RayTrialExecutor(TrialExecutor):
from the beginning.
runner (Trainable): The remote runner to use. This can be the
cached actor. If None, a new runner is created.
train (bool): Whether or not to start training.
See `RayTrialExecutor.restore` for possible errors raised.
"""
@@ -239,7 +240,7 @@ class RayTrialExecutor(TrialExecutor):
# If Trial was in flight when paused, self._paused stores result.
self._paused.pop(previous_run[0])
self._running[previous_run[0]] = trial
elif not trial.is_restoring:
elif train and not trial.is_restoring:
self._train(trial)
def _stop_trial(self, trial, error=False, error_msg=None,
@@ -278,7 +279,7 @@ class RayTrialExecutor(TrialExecutor):
finally:
trial.set_runner(None)
def start_trial(self, trial, checkpoint=None):
def start_trial(self, trial, checkpoint=None, train=True):
"""Starts the trial.
Will not return resources if trial repeatedly fails on start.
@@ -287,10 +288,11 @@ class RayTrialExecutor(TrialExecutor):
trial (Trial): Trial to be started.
checkpoint (Checkpoint): A Python object or path storing the state
of trial.
train (bool): Whether or not to start training.
"""
self._commit_resources(trial.resources)
try:
self._start_trial(trial, checkpoint)
self._start_trial(trial, checkpoint, train=train)
except AbortTrialExecution:
logger.exception("Trial %s: Error starting runner, aborting!",
trial)
@@ -342,10 +344,8 @@ class RayTrialExecutor(TrialExecutor):
Args:
trial (Trial): Trial to be reset.
new_config (dict): New configuration for Trial
trainable.
new_experiment_tag (str): New experiment name
for trial.
new_config (dict): New configuration for Trial trainable.
new_experiment_tag (str): New experiment name for trial.
Returns:
True if `reset_config` is successful else False.
@@ -633,7 +633,7 @@ class RayTrialExecutor(TrialExecutor):
self._running[value] = trial
return checkpoint
def restore(self, trial, checkpoint=None):
def restore(self, trial, checkpoint=None, block=False):
"""Restores training state from a given model checkpoint.
Args:
@@ -641,6 +641,7 @@ class RayTrialExecutor(TrialExecutor):
checkpoint (Checkpoint): The checkpoint to restore from. If None,
the most recent PERSISTENT checkpoint is used. Defaults to
None.
block (bool): Whether or not to block on restore before returning.
Raises:
RuntimeError: This error is raised if no runner is found.
@@ -680,8 +681,12 @@ class RayTrialExecutor(TrialExecutor):
"restoration. Pass in an `upload_dir` and a Trainable "
"extending `DurableTrainable` for remote storage-based "
"restoration")
self._running[remote] = trial
trial.restoring_from = checkpoint
if block:
ray.get(remote)
else:
self._running[remote] = trial
trial.restoring_from = checkpoint
def export_trial_if_needed(self, trial):
"""Exports model of this trial based on trial.export_formats.
+9 -9
View File
@@ -270,7 +270,6 @@ class PopulationBasedTraining(FIFOScheduler):
For each step, logs: [target trial tag, clone trial tag, target trial
iteration, clone trial iteration, old config, new config].
"""
trial_name, trial_to_clone_name = (trial_state.orig_tag,
new_state.orig_tag)
@@ -301,9 +300,7 @@ class PopulationBasedTraining(FIFOScheduler):
"""Transfers perturbed state from trial_to_clone -> trial.
If specified, also logs the updated hyperparam state.
"""
trial_state = self._trial_state[trial]
new_state = self._trial_state[trial_to_clone]
if not new_state.last_checkpoint:
@@ -326,13 +323,20 @@ class PopulationBasedTraining(FIFOScheduler):
self._hyperparam_mutations)
reset_successful = trial_executor.reset_trial(trial, new_config,
new_tag)
# TODO(ujvl): Refactor Scheduler abstraction to abstract
# mechanism for trial restart away. We block on restore
# and suppress train on start as a stop-gap fix to
# https://github.com/ray-project/ray/issues/7258.
if reset_successful:
trial_executor.restore(trial, new_state.last_checkpoint)
trial_executor.restore(
trial, new_state.last_checkpoint, block=True)
else:
trial_executor.stop_trial(trial, stop_logger=False)
trial.config = new_config
trial.experiment_tag = new_tag
trial_executor.start_trial(trial, new_state.last_checkpoint)
trial_executor.start_trial(
trial, new_state.last_checkpoint, train=False)
self._num_perturbations += 1
# Transfer over the last perturbation time as well
@@ -342,9 +346,7 @@ class PopulationBasedTraining(FIFOScheduler):
"""Returns trials in the lower and upper `quantile` of the population.
If there is not enough data to compute this, returns empty lists.
"""
trials = []
for trial, state in self._trial_state.items():
if state.last_score is not None and not trial.is_finished():
@@ -366,9 +368,7 @@ class PopulationBasedTraining(FIFOScheduler):
This enables the PBT scheduler to support a greater number of
concurrent trials than can fit in the cluster at any given time.
"""
candidates = []
for trial in trial_runner.get_trials():
if trial.status in [Trial.PENDING, Trial.PAUSED] and \
+103 -2
View File
@@ -9,6 +9,7 @@ import shutil
from unittest.mock import MagicMock
import ray
from ray import tune
from ray.tune.result import TRAINING_ITERATION
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
PopulationBasedTraining, MedianStoppingRule,
@@ -186,7 +187,7 @@ class EarlyStoppingSuite(unittest.TestCase):
class _MockTrialExecutor(TrialExecutor):
def start_trial(self, trial, checkpoint_obj=None):
def start_trial(self, trial, checkpoint_obj=None, train=True):
trial.logger_running = True
trial.restored_checkpoint = checkpoint_obj.value
trial.status = Trial.RUNNING
@@ -196,7 +197,7 @@ class _MockTrialExecutor(TrialExecutor):
if stop_logger:
trial.logger_running = False
def restore(self, trial, checkpoint=None):
def restore(self, trial, checkpoint=None, block=False):
pass
def save(self, trial, type=Checkpoint.PERSISTENT, result=None):
@@ -1102,6 +1103,106 @@ class PopulationBasedTestingSuite(unittest.TestCase):
shutil.rmtree(tmpdir)
class E2EPopulationBasedTestingSuite(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4)
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
def basicSetup(self,
resample_prob=0.0,
explore=None,
perturbation_interval=10,
log_config=False,
hyperparams=None,
hyperparam_mutations=None,
step_once=True):
hyperparam_mutations = hyperparam_mutations or {
"float_factor": lambda: 100.0,
"int_factor": lambda: 10,
"id_factor": [100]
}
pbt = PopulationBasedTraining(
metric="mean_accuracy",
time_attr="training_iteration",
perturbation_interval=perturbation_interval,
resample_probability=resample_prob,
quantile_fraction=0.25,
hyperparam_mutations=hyperparam_mutations,
custom_explore_fn=explore,
log_config=log_config)
return pbt
def testCheckpointing(self):
pbt = self.basicSetup(perturbation_interval=2)
class train(tune.Trainable):
def _train(self):
return {"mean_accuracy": self.training_iteration}
def _save(self, path):
checkpoint = path + "/checkpoint"
with open(checkpoint, "w") as f:
f.write("OK")
return checkpoint
trial_hyperparams = {
"float_factor": 2.0,
"const_factor": 3,
"int_factor": 10,
"id_factor": 0
}
analysis = tune.run(
train,
num_samples=3,
scheduler=pbt,
checkpoint_freq=3,
config=trial_hyperparams,
stop={"training_iteration": 30})
for trial in analysis.trials:
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertTrue(trial.has_checkpoint())
def testCheckpointDict(self):
pbt = self.basicSetup(perturbation_interval=2)
class train_dict(tune.Trainable):
def _setup(self, config):
self.state = {"hi": 1}
def _train(self):
return {"mean_accuracy": self.training_iteration}
def _save(self, path):
return self.state
def _restore(self, state):
self.state = state
trial_hyperparams = {
"float_factor": 2.0,
"const_factor": 3,
"int_factor": 10,
"id_factor": 0
}
analysis = tune.run(
train_dict,
num_samples=3,
scheduler=pbt,
checkpoint_freq=3,
config=trial_hyperparams,
stop={"training_iteration": 30})
for trial in analysis.trials:
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertTrue(trial.has_checkpoint())
class AsyncHyperBandSuite(unittest.TestCase):
def setUp(self):
ray.init()
+11 -1
View File
@@ -186,7 +186,7 @@ class Trainable:
def default_resource_request(cls, config):
"""Provides a static resource requirement for the given configuration.
This can be overriden by sub-classes to set the correct trial resource
This can be overridden by sub-classes to set the correct trial resource
allocation, so the user does not need to.
.. code-block:: python
@@ -555,6 +555,16 @@ class Trainable:
"""
return self._iteration
@property
def training_iteration(self):
"""Current training iteration (same as `self.iteration`).
This value is automatically incremented every time `train()` is called
and is automatically inserted into the training result dict.
"""
return self._iteration
def get_config(self):
"""Returns configuration passed in by Tune."""
return self.config
+2 -3
View File
@@ -108,7 +108,7 @@ class TrialInfo:
"""Serializable struct for holding information for a Trial.
Attributes:
trial_name (str): String name of the currernt trial.
trial_name (str): String name of the current trial.
trial_id (str): trial_id of the trial
"""
@@ -191,8 +191,7 @@ class Trial:
self.evaluated_params = evaluated_params or {}
self.experiment_tag = experiment_tag
trainable_cls = self.get_trainable_cls()
if trainable_cls and hasattr(trainable_cls,
"default_resource_request"):
if trainable_cls:
default_resources = trainable_cls.default_resource_request(
self.config)
if default_resources:
+4 -2
View File
@@ -72,13 +72,14 @@ class TrialExecutor:
raise NotImplementedError("Subclasses of TrialExecutor must provide "
"has_resources() method")
def start_trial(self, trial, checkpoint=None):
def start_trial(self, trial, checkpoint=None, train=True):
"""Starts the trial restoring from checkpoint if checkpoint is provided.
Args:
trial (Trial): Trial to be started.
checkpoint (Checkpoint): A Python object or path storing the state
of trial.
train (bool): Whether or not to start training.
"""
raise NotImplementedError("Subclasses of TrialExecutor must provide "
"start_trial() method")
@@ -211,7 +212,7 @@ class TrialExecutor:
"""Returns a string describing the total resources available."""
raise NotImplementedError
def restore(self, trial, checkpoint=None):
def restore(self, trial, checkpoint=None, block=False):
"""Restores training state from a checkpoint.
If checkpoint is None, try to restore from trial.checkpoint.
@@ -220,6 +221,7 @@ class TrialExecutor:
Args:
trial (Trial): Trial to be restored.
checkpoint (Checkpoint): Checkpoint to restore from.
block (bool): Whether or not to block on restore before returning.
Returns:
False if error occurred, otherwise return True.