diff --git a/python/ray/rllib/train.py b/python/ray/rllib/train.py index 16e49b788..539a857bc 100755 --- a/python/ray/rllib/train.py +++ b/python/ray/rllib/train.py @@ -104,6 +104,8 @@ def run(args, parser): args.experiment_name: { # i.e. log to ~/ray_results/default "run": args.run, "checkpoint_freq": args.checkpoint_freq, + "keep_checkpoints_num": args.keep_checkpoints_num, + "checkpoint_score_attr": args.checkpoint_score_attr, "local_dir": args.local_dir, "resources_per_trial": ( args.resources_per_trial and diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 667fa0061..864ed6402 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -103,6 +103,20 @@ def make_parser(parser_creator=None, **kwargs): action="store_true", help="Whether to checkpoint at the end of the experiment. " "Default is False.") + parser.add_argument( + "--keep-checkpoints-num", + default=None, + type=int, + help="Number of last checkpoints to keep. Others get " + "deleted. Default (None) keeps all checkpoints.") + parser.add_argument( + "--checkpoint-score-attr", + default="training_iteration", + type=str, + help="Specifies by which attribute to rank the best checkpoint. " + "Default is increasing order. If attribute starts with min- it " + "will rank attribute in decreasing order. Example: " + "min-validation_loss") parser.add_argument( "--export-formats", default=None, @@ -143,6 +157,8 @@ def to_argv(config): for k, v in config.items(): if "-" in k: raise ValueError("Use '_' instead of '-' in `{}`".format(k)) + if v is None: + continue if not isinstance(v, bool) or v: # for argparse flags argv.append("--{}".format(k.replace("_", "-"))) if isinstance(v, string_types): @@ -188,6 +204,8 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): stopping_criterion=spec.get("stop", {}), checkpoint_freq=args.checkpoint_freq, checkpoint_at_end=args.checkpoint_at_end, + keep_checkpoints_num=args.keep_checkpoints_num, + checkpoint_score_attr=args.checkpoint_score_attr, export_formats=spec.get("export_formats", []), # str(None) doesn't create None restore_path=spec.get("restore"), diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index d6bf12cc0..5f3e46aab 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -71,6 +71,8 @@ class Experiment(object): sync_function=None, checkpoint_freq=0, checkpoint_at_end=False, + keep_checkpoints_num=None, + checkpoint_score_attr=None, export_formats=None, max_failures=3, restore=None, @@ -102,6 +104,8 @@ class Experiment(object): "sync_function": sync_function, "checkpoint_freq": checkpoint_freq, "checkpoint_at_end": checkpoint_at_end, + "keep_checkpoints_num": keep_checkpoints_num, + "checkpoint_score_attr": checkpoint_score_attr, "export_formats": export_formats or [], "max_failures": max_failures, "restore": restore diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 326398b3a..f0f32e6fb 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -4,6 +4,7 @@ from __future__ import division from __future__ import print_function import logging +import math import os import random import time @@ -469,10 +470,44 @@ class RayTrialExecutor(TrialExecutor): if storage == Checkpoint.MEMORY: trial._checkpoint.value = trial.runner.save_to_object.remote() else: - with warn_if_slow("save_to_disk"): - trial._checkpoint.value = ray.get(trial.runner.save.remote()) + # Keeps only highest performing checkpoints if enabled + if trial.keep_checkpoints_num: + try: + last_attr_val = trial.last_result[ + trial.checkpoint_score_attr] + if (trial.compare_checkpoints(last_attr_val) + and not math.isnan(last_attr_val)): + trial.best_checkpoint_attr_value = last_attr_val + self._checkpoint_and_erase(trial) + except KeyError: + logger.warning( + "Result dict has no key: {}. keep" + "_checkpoints_num flag will not work".format( + trial.checkpoint_score_attr)) + else: + with warn_if_slow("save_to_disk"): + trial._checkpoint.value = ray.get( + trial.runner.save.remote()) + return trial._checkpoint.value + def _checkpoint_and_erase(self, trial): + """Checkpoints the model and erases old checkpoints + if needed. + Parameters + ---------- + trial : trial to save + """ + + with warn_if_slow("save_to_disk"): + trial._checkpoint.value = ray.get(trial.runner.save.remote()) + + if len(trial.history) >= trial.keep_checkpoints_num: + ray.get(trial.runner.delete_checkpoint.remote(trial.history[-1])) + trial.history.pop() + + trial.history.insert(0, trial._checkpoint.value) + def restore(self, trial, checkpoint=None): """Restores training state from a given model checkpoint. diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 31d766af3..c10934896 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -210,6 +210,17 @@ class Trainable(object): return result + def delete_checkpoint(self, checkpoint_dir): + """Removes subdirectory within checkpoint_folder + Parameters + ---------- + checkpoint_dir : path to checkpoint + """ + if os.path.isfile(checkpoint_dir): + shutil.rmtree(os.path.dirname(checkpoint_dir)) + else: + shutil.rmtree(checkpoint_dir) + def save(self, checkpoint_dir=None): """Saves the current model state to a checkpoint. diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index bb2c3ae50..83d7fd05e 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -253,6 +253,8 @@ class Trial(object): stopping_criterion=None, checkpoint_freq=0, checkpoint_at_end=False, + keep_checkpoints_num=None, + checkpoint_score_attr="", export_formats=None, restore_path=None, upload_dir=None, @@ -288,6 +290,16 @@ class Trial(object): self.last_update_time = -float("inf") self.checkpoint_freq = checkpoint_freq self.checkpoint_at_end = checkpoint_at_end + + self.history = [] + self.keep_checkpoints_num = keep_checkpoints_num + self._cmp_greater = not checkpoint_score_attr.startswith("min-") + self.best_checkpoint_attr_value = -float("inf") \ + if self._cmp_greater else float("inf") + # Strip off "min-" from checkpoint attribute + self.checkpoint_score_attr = checkpoint_score_attr \ + if self._cmp_greater else checkpoint_score_attr[4:] + self._checkpoint = Checkpoint( storage=Checkpoint.DISK, value=restore_path) self.export_formats = export_formats @@ -299,7 +311,6 @@ class Trial(object): self.trial_id = Trial.generate_id() if trial_id is None else trial_id self.error_file = None self.num_failures = 0 - self.custom_trial_name = None # AutoML fields @@ -495,6 +506,28 @@ class Trial(object): self.last_update_time = time.time() self.result_logger.on_result(self.last_result) + def compare_checkpoints(self, attr_mean): + """Compares two checkpoints based on the attribute attr_mean param. + Greater than is used by default. If command-line parameter + checkpoint_score_attr starts with "min-" less than is used. + + Arguments: + attr_mean: mean of attribute value for the current checkpoint + + Returns: + True: when attr_mean is greater than previous checkpoint attr_mean + and greater than function is selected + when attr_mean is less than previous checkpoint attr_mean and + less than function is selected + False: when attr_mean is not in alignment with selected cmp fn + """ + if self._cmp_greater and attr_mean > self.best_checkpoint_attr_value: + return True + elif (not self._cmp_greater + and attr_mean < self.best_checkpoint_attr_value): + return True + return False + def _get_trainable_cls(self): return ray.tune.registry._global_registry.get( ray.tune.registry.TRAINABLE_CLASS, self.trainable_name)