mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 09:22:57 +08:00
[tune] Stop-gap fix for PBT checkpointing (#7794)
* Fix PBT * lint * reset * rm * tests Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ray.tune.trainable import Trainable, TrainableUtil
|
||||
from ray.tune.syncer import get_cloud_sync_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DurableTrainable(Trainable):
|
||||
"""Abstract class for a remote-storage backed fault-tolerant Trainable.
|
||||
@@ -57,7 +60,6 @@ class DurableTrainable(Trainable):
|
||||
if checkpoint_dir.starts_with(os.path.abspath(self.logdir)):
|
||||
raise ValueError("`checkpoint_dir` must be `self.logdir`, or "
|
||||
"a sub-directory.")
|
||||
|
||||
checkpoint_path = super(DurableTrainable, self).save(checkpoint_dir)
|
||||
self.storage_client.sync_up(self.logdir, self.remote_checkpoint_dir)
|
||||
self.storage_client.wait()
|
||||
@@ -81,9 +83,15 @@ class DurableTrainable(Trainable):
|
||||
Args:
|
||||
checkpoint_path (str): Local path to checkpoint.
|
||||
"""
|
||||
try:
|
||||
local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path)
|
||||
except FileNotFoundError:
|
||||
logger.warning(
|
||||
"Trial %s: checkpoint path not found during "
|
||||
"garbage collection. See issue #6697.", self.trial_id)
|
||||
else:
|
||||
self.storage_client.delete(self._storage_path(local_dirpath))
|
||||
super(DurableTrainable, self).delete_checkpoint(checkpoint_path)
|
||||
local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path)
|
||||
self.storage_client.delete(self._storage_path(local_dirpath))
|
||||
|
||||
def _create_storage_client(self):
|
||||
"""Returns a storage client."""
|
||||
|
||||
@@ -108,6 +108,7 @@ if __name__ == "__main__":
|
||||
name="pbt_test",
|
||||
scheduler=pbt,
|
||||
reuse_actors=True,
|
||||
checkpoint_freq=20,
|
||||
verbose=False,
|
||||
stop={
|
||||
"training_iteration": 200,
|
||||
|
||||
@@ -213,7 +213,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
trial_item = self._find_item(self._running, trial)
|
||||
assert len(trial_item) < 2, trial_item
|
||||
|
||||
def _start_trial(self, trial, checkpoint=None, runner=None):
|
||||
def _start_trial(self, trial, checkpoint=None, runner=None, train=True):
|
||||
"""Starts trial and restores last result if trial was paused.
|
||||
|
||||
Args:
|
||||
@@ -223,6 +223,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
from the beginning.
|
||||
runner (Trainable): The remote runner to use. This can be the
|
||||
cached actor. If None, a new runner is created.
|
||||
train (bool): Whether or not to start training.
|
||||
|
||||
See `RayTrialExecutor.restore` for possible errors raised.
|
||||
"""
|
||||
@@ -239,7 +240,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
# If Trial was in flight when paused, self._paused stores result.
|
||||
self._paused.pop(previous_run[0])
|
||||
self._running[previous_run[0]] = trial
|
||||
elif not trial.is_restoring:
|
||||
elif train and not trial.is_restoring:
|
||||
self._train(trial)
|
||||
|
||||
def _stop_trial(self, trial, error=False, error_msg=None,
|
||||
@@ -278,7 +279,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
finally:
|
||||
trial.set_runner(None)
|
||||
|
||||
def start_trial(self, trial, checkpoint=None):
|
||||
def start_trial(self, trial, checkpoint=None, train=True):
|
||||
"""Starts the trial.
|
||||
|
||||
Will not return resources if trial repeatedly fails on start.
|
||||
@@ -287,10 +288,11 @@ class RayTrialExecutor(TrialExecutor):
|
||||
trial (Trial): Trial to be started.
|
||||
checkpoint (Checkpoint): A Python object or path storing the state
|
||||
of trial.
|
||||
train (bool): Whether or not to start training.
|
||||
"""
|
||||
self._commit_resources(trial.resources)
|
||||
try:
|
||||
self._start_trial(trial, checkpoint)
|
||||
self._start_trial(trial, checkpoint, train=train)
|
||||
except AbortTrialExecution:
|
||||
logger.exception("Trial %s: Error starting runner, aborting!",
|
||||
trial)
|
||||
@@ -342,10 +344,8 @@ class RayTrialExecutor(TrialExecutor):
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to be reset.
|
||||
new_config (dict): New configuration for Trial
|
||||
trainable.
|
||||
new_experiment_tag (str): New experiment name
|
||||
for trial.
|
||||
new_config (dict): New configuration for Trial trainable.
|
||||
new_experiment_tag (str): New experiment name for trial.
|
||||
|
||||
Returns:
|
||||
True if `reset_config` is successful else False.
|
||||
@@ -633,7 +633,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
self._running[value] = trial
|
||||
return checkpoint
|
||||
|
||||
def restore(self, trial, checkpoint=None):
|
||||
def restore(self, trial, checkpoint=None, block=False):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
|
||||
Args:
|
||||
@@ -641,6 +641,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
checkpoint (Checkpoint): The checkpoint to restore from. If None,
|
||||
the most recent PERSISTENT checkpoint is used. Defaults to
|
||||
None.
|
||||
block (bool): Whether or not to block on restore before returning.
|
||||
|
||||
Raises:
|
||||
RuntimeError: This error is raised if no runner is found.
|
||||
@@ -680,8 +681,12 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"restoration. Pass in an `upload_dir` and a Trainable "
|
||||
"extending `DurableTrainable` for remote storage-based "
|
||||
"restoration")
|
||||
self._running[remote] = trial
|
||||
trial.restoring_from = checkpoint
|
||||
|
||||
if block:
|
||||
ray.get(remote)
|
||||
else:
|
||||
self._running[remote] = trial
|
||||
trial.restoring_from = checkpoint
|
||||
|
||||
def export_trial_if_needed(self, trial):
|
||||
"""Exports model of this trial based on trial.export_formats.
|
||||
|
||||
@@ -270,7 +270,6 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
|
||||
For each step, logs: [target trial tag, clone trial tag, target trial
|
||||
iteration, clone trial iteration, old config, new config].
|
||||
|
||||
"""
|
||||
trial_name, trial_to_clone_name = (trial_state.orig_tag,
|
||||
new_state.orig_tag)
|
||||
@@ -301,9 +300,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
"""Transfers perturbed state from trial_to_clone -> trial.
|
||||
|
||||
If specified, also logs the updated hyperparam state.
|
||||
|
||||
"""
|
||||
|
||||
trial_state = self._trial_state[trial]
|
||||
new_state = self._trial_state[trial_to_clone]
|
||||
if not new_state.last_checkpoint:
|
||||
@@ -326,13 +323,20 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
self._hyperparam_mutations)
|
||||
reset_successful = trial_executor.reset_trial(trial, new_config,
|
||||
new_tag)
|
||||
|
||||
# TODO(ujvl): Refactor Scheduler abstraction to abstract
|
||||
# mechanism for trial restart away. We block on restore
|
||||
# and suppress train on start as a stop-gap fix to
|
||||
# https://github.com/ray-project/ray/issues/7258.
|
||||
if reset_successful:
|
||||
trial_executor.restore(trial, new_state.last_checkpoint)
|
||||
trial_executor.restore(
|
||||
trial, new_state.last_checkpoint, block=True)
|
||||
else:
|
||||
trial_executor.stop_trial(trial, stop_logger=False)
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = new_tag
|
||||
trial_executor.start_trial(trial, new_state.last_checkpoint)
|
||||
trial_executor.start_trial(
|
||||
trial, new_state.last_checkpoint, train=False)
|
||||
|
||||
self._num_perturbations += 1
|
||||
# Transfer over the last perturbation time as well
|
||||
@@ -342,9 +346,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
"""Returns trials in the lower and upper `quantile` of the population.
|
||||
|
||||
If there is not enough data to compute this, returns empty lists.
|
||||
|
||||
"""
|
||||
|
||||
trials = []
|
||||
for trial, state in self._trial_state.items():
|
||||
if state.last_score is not None and not trial.is_finished():
|
||||
@@ -366,9 +368,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
|
||||
This enables the PBT scheduler to support a greater number of
|
||||
concurrent trials than can fit in the cluster at any given time.
|
||||
|
||||
"""
|
||||
|
||||
candidates = []
|
||||
for trial in trial_runner.get_trials():
|
||||
if trial.status in [Trial.PENDING, Trial.PAUSED] and \
|
||||
|
||||
@@ -9,6 +9,7 @@ import shutil
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
PopulationBasedTraining, MedianStoppingRule,
|
||||
@@ -186,7 +187,7 @@ class EarlyStoppingSuite(unittest.TestCase):
|
||||
|
||||
|
||||
class _MockTrialExecutor(TrialExecutor):
|
||||
def start_trial(self, trial, checkpoint_obj=None):
|
||||
def start_trial(self, trial, checkpoint_obj=None, train=True):
|
||||
trial.logger_running = True
|
||||
trial.restored_checkpoint = checkpoint_obj.value
|
||||
trial.status = Trial.RUNNING
|
||||
@@ -196,7 +197,7 @@ class _MockTrialExecutor(TrialExecutor):
|
||||
if stop_logger:
|
||||
trial.logger_running = False
|
||||
|
||||
def restore(self, trial, checkpoint=None):
|
||||
def restore(self, trial, checkpoint=None, block=False):
|
||||
pass
|
||||
|
||||
def save(self, trial, type=Checkpoint.PERSISTENT, result=None):
|
||||
@@ -1102,6 +1103,106 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
class E2EPopulationBasedTestingSuite(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def basicSetup(self,
|
||||
resample_prob=0.0,
|
||||
explore=None,
|
||||
perturbation_interval=10,
|
||||
log_config=False,
|
||||
hyperparams=None,
|
||||
hyperparam_mutations=None,
|
||||
step_once=True):
|
||||
hyperparam_mutations = hyperparam_mutations or {
|
||||
"float_factor": lambda: 100.0,
|
||||
"int_factor": lambda: 10,
|
||||
"id_factor": [100]
|
||||
}
|
||||
pbt = PopulationBasedTraining(
|
||||
metric="mean_accuracy",
|
||||
time_attr="training_iteration",
|
||||
perturbation_interval=perturbation_interval,
|
||||
resample_probability=resample_prob,
|
||||
quantile_fraction=0.25,
|
||||
hyperparam_mutations=hyperparam_mutations,
|
||||
custom_explore_fn=explore,
|
||||
log_config=log_config)
|
||||
return pbt
|
||||
|
||||
def testCheckpointing(self):
|
||||
pbt = self.basicSetup(perturbation_interval=2)
|
||||
|
||||
class train(tune.Trainable):
|
||||
def _train(self):
|
||||
return {"mean_accuracy": self.training_iteration}
|
||||
|
||||
def _save(self, path):
|
||||
checkpoint = path + "/checkpoint"
|
||||
with open(checkpoint, "w") as f:
|
||||
f.write("OK")
|
||||
return checkpoint
|
||||
|
||||
trial_hyperparams = {
|
||||
"float_factor": 2.0,
|
||||
"const_factor": 3,
|
||||
"int_factor": 10,
|
||||
"id_factor": 0
|
||||
}
|
||||
|
||||
analysis = tune.run(
|
||||
train,
|
||||
num_samples=3,
|
||||
scheduler=pbt,
|
||||
checkpoint_freq=3,
|
||||
config=trial_hyperparams,
|
||||
stop={"training_iteration": 30})
|
||||
|
||||
for trial in analysis.trials:
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertTrue(trial.has_checkpoint())
|
||||
|
||||
def testCheckpointDict(self):
|
||||
pbt = self.basicSetup(perturbation_interval=2)
|
||||
|
||||
class train_dict(tune.Trainable):
|
||||
def _setup(self, config):
|
||||
self.state = {"hi": 1}
|
||||
|
||||
def _train(self):
|
||||
return {"mean_accuracy": self.training_iteration}
|
||||
|
||||
def _save(self, path):
|
||||
return self.state
|
||||
|
||||
def _restore(self, state):
|
||||
self.state = state
|
||||
|
||||
trial_hyperparams = {
|
||||
"float_factor": 2.0,
|
||||
"const_factor": 3,
|
||||
"int_factor": 10,
|
||||
"id_factor": 0
|
||||
}
|
||||
|
||||
analysis = tune.run(
|
||||
train_dict,
|
||||
num_samples=3,
|
||||
scheduler=pbt,
|
||||
checkpoint_freq=3,
|
||||
config=trial_hyperparams,
|
||||
stop={"training_iteration": 30})
|
||||
|
||||
for trial in analysis.trials:
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertTrue(trial.has_checkpoint())
|
||||
|
||||
|
||||
class AsyncHyperBandSuite(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
||||
@@ -186,7 +186,7 @@ class Trainable:
|
||||
def default_resource_request(cls, config):
|
||||
"""Provides a static resource requirement for the given configuration.
|
||||
|
||||
This can be overriden by sub-classes to set the correct trial resource
|
||||
This can be overridden by sub-classes to set the correct trial resource
|
||||
allocation, so the user does not need to.
|
||||
|
||||
.. code-block:: python
|
||||
@@ -555,6 +555,16 @@ class Trainable:
|
||||
"""
|
||||
return self._iteration
|
||||
|
||||
@property
|
||||
def training_iteration(self):
|
||||
"""Current training iteration (same as `self.iteration`).
|
||||
|
||||
This value is automatically incremented every time `train()` is called
|
||||
and is automatically inserted into the training result dict.
|
||||
|
||||
"""
|
||||
return self._iteration
|
||||
|
||||
def get_config(self):
|
||||
"""Returns configuration passed in by Tune."""
|
||||
return self.config
|
||||
|
||||
@@ -108,7 +108,7 @@ class TrialInfo:
|
||||
"""Serializable struct for holding information for a Trial.
|
||||
|
||||
Attributes:
|
||||
trial_name (str): String name of the currernt trial.
|
||||
trial_name (str): String name of the current trial.
|
||||
trial_id (str): trial_id of the trial
|
||||
"""
|
||||
|
||||
@@ -191,8 +191,7 @@ class Trial:
|
||||
self.evaluated_params = evaluated_params or {}
|
||||
self.experiment_tag = experiment_tag
|
||||
trainable_cls = self.get_trainable_cls()
|
||||
if trainable_cls and hasattr(trainable_cls,
|
||||
"default_resource_request"):
|
||||
if trainable_cls:
|
||||
default_resources = trainable_cls.default_resource_request(
|
||||
self.config)
|
||||
if default_resources:
|
||||
|
||||
@@ -72,13 +72,14 @@ class TrialExecutor:
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"has_resources() method")
|
||||
|
||||
def start_trial(self, trial, checkpoint=None):
|
||||
def start_trial(self, trial, checkpoint=None, train=True):
|
||||
"""Starts the trial restoring from checkpoint if checkpoint is provided.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to be started.
|
||||
checkpoint (Checkpoint): A Python object or path storing the state
|
||||
of trial.
|
||||
train (bool): Whether or not to start training.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"start_trial() method")
|
||||
@@ -211,7 +212,7 @@ class TrialExecutor:
|
||||
"""Returns a string describing the total resources available."""
|
||||
raise NotImplementedError
|
||||
|
||||
def restore(self, trial, checkpoint=None):
|
||||
def restore(self, trial, checkpoint=None, block=False):
|
||||
"""Restores training state from a checkpoint.
|
||||
|
||||
If checkpoint is None, try to restore from trial.checkpoint.
|
||||
@@ -220,6 +221,7 @@ class TrialExecutor:
|
||||
Args:
|
||||
trial (Trial): Trial to be restored.
|
||||
checkpoint (Checkpoint): Checkpoint to restore from.
|
||||
block (bool): Whether or not to block on restore before returning.
|
||||
|
||||
Returns:
|
||||
False if error occurred, otherwise return True.
|
||||
|
||||
Reference in New Issue
Block a user