[tune/rllib] Add checkpoint eraser (#4490)

This commit is contained in:
Dušan Josipović
2019-04-07 05:01:54 +02:00
committed by Eric Liang
parent 7746d20d30
commit 820c71b7d0
6 changed files with 106 additions and 3 deletions
+2
View File
@@ -104,6 +104,8 @@ def run(args, parser):
args.experiment_name: { # i.e. log to ~/ray_results/default
"run": args.run,
"checkpoint_freq": args.checkpoint_freq,
"keep_checkpoints_num": args.keep_checkpoints_num,
"checkpoint_score_attr": args.checkpoint_score_attr,
"local_dir": args.local_dir,
"resources_per_trial": (
args.resources_per_trial and
+18
View File
@@ -103,6 +103,20 @@ def make_parser(parser_creator=None, **kwargs):
action="store_true",
help="Whether to checkpoint at the end of the experiment. "
"Default is False.")
parser.add_argument(
"--keep-checkpoints-num",
default=None,
type=int,
help="Number of last checkpoints to keep. Others get "
"deleted. Default (None) keeps all checkpoints.")
parser.add_argument(
"--checkpoint-score-attr",
default="training_iteration",
type=str,
help="Specifies by which attribute to rank the best checkpoint. "
"Default is increasing order. If attribute starts with min- it "
"will rank attribute in decreasing order. Example: "
"min-validation_loss")
parser.add_argument(
"--export-formats",
default=None,
@@ -143,6 +157,8 @@ def to_argv(config):
for k, v in config.items():
if "-" in k:
raise ValueError("Use '_' instead of '-' in `{}`".format(k))
if v is None:
continue
if not isinstance(v, bool) or v: # for argparse flags
argv.append("--{}".format(k.replace("_", "-")))
if isinstance(v, string_types):
@@ -188,6 +204,8 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
stopping_criterion=spec.get("stop", {}),
checkpoint_freq=args.checkpoint_freq,
checkpoint_at_end=args.checkpoint_at_end,
keep_checkpoints_num=args.keep_checkpoints_num,
checkpoint_score_attr=args.checkpoint_score_attr,
export_formats=spec.get("export_formats", []),
# str(None) doesn't create None
restore_path=spec.get("restore"),
+4
View File
@@ -71,6 +71,8 @@ class Experiment(object):
sync_function=None,
checkpoint_freq=0,
checkpoint_at_end=False,
keep_checkpoints_num=None,
checkpoint_score_attr=None,
export_formats=None,
max_failures=3,
restore=None,
@@ -102,6 +104,8 @@ class Experiment(object):
"sync_function": sync_function,
"checkpoint_freq": checkpoint_freq,
"checkpoint_at_end": checkpoint_at_end,
"keep_checkpoints_num": keep_checkpoints_num,
"checkpoint_score_attr": checkpoint_score_attr,
"export_formats": export_formats or [],
"max_failures": max_failures,
"restore": restore
+37 -2
View File
@@ -4,6 +4,7 @@ from __future__ import division
from __future__ import print_function
import logging
import math
import os
import random
import time
@@ -469,10 +470,44 @@ class RayTrialExecutor(TrialExecutor):
if storage == Checkpoint.MEMORY:
trial._checkpoint.value = trial.runner.save_to_object.remote()
else:
with warn_if_slow("save_to_disk"):
trial._checkpoint.value = ray.get(trial.runner.save.remote())
# Keeps only highest performing checkpoints if enabled
if trial.keep_checkpoints_num:
try:
last_attr_val = trial.last_result[
trial.checkpoint_score_attr]
if (trial.compare_checkpoints(last_attr_val)
and not math.isnan(last_attr_val)):
trial.best_checkpoint_attr_value = last_attr_val
self._checkpoint_and_erase(trial)
except KeyError:
logger.warning(
"Result dict has no key: {}. keep"
"_checkpoints_num flag will not work".format(
trial.checkpoint_score_attr))
else:
with warn_if_slow("save_to_disk"):
trial._checkpoint.value = ray.get(
trial.runner.save.remote())
return trial._checkpoint.value
def _checkpoint_and_erase(self, trial):
"""Checkpoints the model and erases old checkpoints
if needed.
Parameters
----------
trial : trial to save
"""
with warn_if_slow("save_to_disk"):
trial._checkpoint.value = ray.get(trial.runner.save.remote())
if len(trial.history) >= trial.keep_checkpoints_num:
ray.get(trial.runner.delete_checkpoint.remote(trial.history[-1]))
trial.history.pop()
trial.history.insert(0, trial._checkpoint.value)
def restore(self, trial, checkpoint=None):
"""Restores training state from a given model checkpoint.
+11
View File
@@ -210,6 +210,17 @@ class Trainable(object):
return result
def delete_checkpoint(self, checkpoint_dir):
"""Removes subdirectory within checkpoint_folder
Parameters
----------
checkpoint_dir : path to checkpoint
"""
if os.path.isfile(checkpoint_dir):
shutil.rmtree(os.path.dirname(checkpoint_dir))
else:
shutil.rmtree(checkpoint_dir)
def save(self, checkpoint_dir=None):
"""Saves the current model state to a checkpoint.
+34 -1
View File
@@ -253,6 +253,8 @@ class Trial(object):
stopping_criterion=None,
checkpoint_freq=0,
checkpoint_at_end=False,
keep_checkpoints_num=None,
checkpoint_score_attr="",
export_formats=None,
restore_path=None,
upload_dir=None,
@@ -288,6 +290,16 @@ class Trial(object):
self.last_update_time = -float("inf")
self.checkpoint_freq = checkpoint_freq
self.checkpoint_at_end = checkpoint_at_end
self.history = []
self.keep_checkpoints_num = keep_checkpoints_num
self._cmp_greater = not checkpoint_score_attr.startswith("min-")
self.best_checkpoint_attr_value = -float("inf") \
if self._cmp_greater else float("inf")
# Strip off "min-" from checkpoint attribute
self.checkpoint_score_attr = checkpoint_score_attr \
if self._cmp_greater else checkpoint_score_attr[4:]
self._checkpoint = Checkpoint(
storage=Checkpoint.DISK, value=restore_path)
self.export_formats = export_formats
@@ -299,7 +311,6 @@ class Trial(object):
self.trial_id = Trial.generate_id() if trial_id is None else trial_id
self.error_file = None
self.num_failures = 0
self.custom_trial_name = None
# AutoML fields
@@ -495,6 +506,28 @@ class Trial(object):
self.last_update_time = time.time()
self.result_logger.on_result(self.last_result)
def compare_checkpoints(self, attr_mean):
"""Compares two checkpoints based on the attribute attr_mean param.
Greater than is used by default. If command-line parameter
checkpoint_score_attr starts with "min-" less than is used.
Arguments:
attr_mean: mean of attribute value for the current checkpoint
Returns:
True: when attr_mean is greater than previous checkpoint attr_mean
and greater than function is selected
when attr_mean is less than previous checkpoint attr_mean and
less than function is selected
False: when attr_mean is not in alignment with selected cmp fn
"""
if self._cmp_greater and attr_mean > self.best_checkpoint_attr_value:
return True
elif (not self._cmp_greater
and attr_mean < self.best_checkpoint_attr_value):
return True
return False
def _get_trainable_cls(self):
return ray.tune.registry._global_registry.get(
ray.tune.registry.TRAINABLE_CLASS, self.trainable_name)