mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 14:31:32 +08:00
[tune/rllib] Add checkpoint eraser (#4490)
This commit is contained in:
committed by
Eric Liang
parent
7746d20d30
commit
820c71b7d0
@@ -104,6 +104,8 @@ def run(args, parser):
|
||||
args.experiment_name: { # i.e. log to ~/ray_results/default
|
||||
"run": args.run,
|
||||
"checkpoint_freq": args.checkpoint_freq,
|
||||
"keep_checkpoints_num": args.keep_checkpoints_num,
|
||||
"checkpoint_score_attr": args.checkpoint_score_attr,
|
||||
"local_dir": args.local_dir,
|
||||
"resources_per_trial": (
|
||||
args.resources_per_trial and
|
||||
|
||||
@@ -103,6 +103,20 @@ def make_parser(parser_creator=None, **kwargs):
|
||||
action="store_true",
|
||||
help="Whether to checkpoint at the end of the experiment. "
|
||||
"Default is False.")
|
||||
parser.add_argument(
|
||||
"--keep-checkpoints-num",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Number of last checkpoints to keep. Others get "
|
||||
"deleted. Default (None) keeps all checkpoints.")
|
||||
parser.add_argument(
|
||||
"--checkpoint-score-attr",
|
||||
default="training_iteration",
|
||||
type=str,
|
||||
help="Specifies by which attribute to rank the best checkpoint. "
|
||||
"Default is increasing order. If attribute starts with min- it "
|
||||
"will rank attribute in decreasing order. Example: "
|
||||
"min-validation_loss")
|
||||
parser.add_argument(
|
||||
"--export-formats",
|
||||
default=None,
|
||||
@@ -143,6 +157,8 @@ def to_argv(config):
|
||||
for k, v in config.items():
|
||||
if "-" in k:
|
||||
raise ValueError("Use '_' instead of '-' in `{}`".format(k))
|
||||
if v is None:
|
||||
continue
|
||||
if not isinstance(v, bool) or v: # for argparse flags
|
||||
argv.append("--{}".format(k.replace("_", "-")))
|
||||
if isinstance(v, string_types):
|
||||
@@ -188,6 +204,8 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||
stopping_criterion=spec.get("stop", {}),
|
||||
checkpoint_freq=args.checkpoint_freq,
|
||||
checkpoint_at_end=args.checkpoint_at_end,
|
||||
keep_checkpoints_num=args.keep_checkpoints_num,
|
||||
checkpoint_score_attr=args.checkpoint_score_attr,
|
||||
export_formats=spec.get("export_formats", []),
|
||||
# str(None) doesn't create None
|
||||
restore_path=spec.get("restore"),
|
||||
|
||||
@@ -71,6 +71,8 @@ class Experiment(object):
|
||||
sync_function=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
keep_checkpoints_num=None,
|
||||
checkpoint_score_attr=None,
|
||||
export_formats=None,
|
||||
max_failures=3,
|
||||
restore=None,
|
||||
@@ -102,6 +104,8 @@ class Experiment(object):
|
||||
"sync_function": sync_function,
|
||||
"checkpoint_freq": checkpoint_freq,
|
||||
"checkpoint_at_end": checkpoint_at_end,
|
||||
"keep_checkpoints_num": keep_checkpoints_num,
|
||||
"checkpoint_score_attr": checkpoint_score_attr,
|
||||
"export_formats": export_formats or [],
|
||||
"max_failures": max_failures,
|
||||
"restore": restore
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
@@ -469,10 +470,44 @@ class RayTrialExecutor(TrialExecutor):
|
||||
if storage == Checkpoint.MEMORY:
|
||||
trial._checkpoint.value = trial.runner.save_to_object.remote()
|
||||
else:
|
||||
with warn_if_slow("save_to_disk"):
|
||||
trial._checkpoint.value = ray.get(trial.runner.save.remote())
|
||||
# Keeps only highest performing checkpoints if enabled
|
||||
if trial.keep_checkpoints_num:
|
||||
try:
|
||||
last_attr_val = trial.last_result[
|
||||
trial.checkpoint_score_attr]
|
||||
if (trial.compare_checkpoints(last_attr_val)
|
||||
and not math.isnan(last_attr_val)):
|
||||
trial.best_checkpoint_attr_value = last_attr_val
|
||||
self._checkpoint_and_erase(trial)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"Result dict has no key: {}. keep"
|
||||
"_checkpoints_num flag will not work".format(
|
||||
trial.checkpoint_score_attr))
|
||||
else:
|
||||
with warn_if_slow("save_to_disk"):
|
||||
trial._checkpoint.value = ray.get(
|
||||
trial.runner.save.remote())
|
||||
|
||||
return trial._checkpoint.value
|
||||
|
||||
def _checkpoint_and_erase(self, trial):
|
||||
"""Checkpoints the model and erases old checkpoints
|
||||
if needed.
|
||||
Parameters
|
||||
----------
|
||||
trial : trial to save
|
||||
"""
|
||||
|
||||
with warn_if_slow("save_to_disk"):
|
||||
trial._checkpoint.value = ray.get(trial.runner.save.remote())
|
||||
|
||||
if len(trial.history) >= trial.keep_checkpoints_num:
|
||||
ray.get(trial.runner.delete_checkpoint.remote(trial.history[-1]))
|
||||
trial.history.pop()
|
||||
|
||||
trial.history.insert(0, trial._checkpoint.value)
|
||||
|
||||
def restore(self, trial, checkpoint=None):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
|
||||
|
||||
@@ -210,6 +210,17 @@ class Trainable(object):
|
||||
|
||||
return result
|
||||
|
||||
def delete_checkpoint(self, checkpoint_dir):
|
||||
"""Removes subdirectory within checkpoint_folder
|
||||
Parameters
|
||||
----------
|
||||
checkpoint_dir : path to checkpoint
|
||||
"""
|
||||
if os.path.isfile(checkpoint_dir):
|
||||
shutil.rmtree(os.path.dirname(checkpoint_dir))
|
||||
else:
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
|
||||
def save(self, checkpoint_dir=None):
|
||||
"""Saves the current model state to a checkpoint.
|
||||
|
||||
|
||||
@@ -253,6 +253,8 @@ class Trial(object):
|
||||
stopping_criterion=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
keep_checkpoints_num=None,
|
||||
checkpoint_score_attr="",
|
||||
export_formats=None,
|
||||
restore_path=None,
|
||||
upload_dir=None,
|
||||
@@ -288,6 +290,16 @@ class Trial(object):
|
||||
self.last_update_time = -float("inf")
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.checkpoint_at_end = checkpoint_at_end
|
||||
|
||||
self.history = []
|
||||
self.keep_checkpoints_num = keep_checkpoints_num
|
||||
self._cmp_greater = not checkpoint_score_attr.startswith("min-")
|
||||
self.best_checkpoint_attr_value = -float("inf") \
|
||||
if self._cmp_greater else float("inf")
|
||||
# Strip off "min-" from checkpoint attribute
|
||||
self.checkpoint_score_attr = checkpoint_score_attr \
|
||||
if self._cmp_greater else checkpoint_score_attr[4:]
|
||||
|
||||
self._checkpoint = Checkpoint(
|
||||
storage=Checkpoint.DISK, value=restore_path)
|
||||
self.export_formats = export_formats
|
||||
@@ -299,7 +311,6 @@ class Trial(object):
|
||||
self.trial_id = Trial.generate_id() if trial_id is None else trial_id
|
||||
self.error_file = None
|
||||
self.num_failures = 0
|
||||
|
||||
self.custom_trial_name = None
|
||||
|
||||
# AutoML fields
|
||||
@@ -495,6 +506,28 @@ class Trial(object):
|
||||
self.last_update_time = time.time()
|
||||
self.result_logger.on_result(self.last_result)
|
||||
|
||||
def compare_checkpoints(self, attr_mean):
|
||||
"""Compares two checkpoints based on the attribute attr_mean param.
|
||||
Greater than is used by default. If command-line parameter
|
||||
checkpoint_score_attr starts with "min-" less than is used.
|
||||
|
||||
Arguments:
|
||||
attr_mean: mean of attribute value for the current checkpoint
|
||||
|
||||
Returns:
|
||||
True: when attr_mean is greater than previous checkpoint attr_mean
|
||||
and greater than function is selected
|
||||
when attr_mean is less than previous checkpoint attr_mean and
|
||||
less than function is selected
|
||||
False: when attr_mean is not in alignment with selected cmp fn
|
||||
"""
|
||||
if self._cmp_greater and attr_mean > self.best_checkpoint_attr_value:
|
||||
return True
|
||||
elif (not self._cmp_greater
|
||||
and attr_mean < self.best_checkpoint_attr_value):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_trainable_cls(self):
|
||||
return ray.tune.registry._global_registry.get(
|
||||
ray.tune.registry.TRAINABLE_CLASS, self.trainable_name)
|
||||
|
||||
Reference in New Issue
Block a user