diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index d0c126227..e19457437 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -49,6 +49,8 @@ def run(run_or_experiment, sync_to_driver=None, checkpoint_freq=0, checkpoint_at_end=False, + keep_checkpoints_num=None, + checkpoint_score_attr=None, global_checkpoint_period=10, export_formats=None, max_failures=3, @@ -114,6 +116,13 @@ def run(run_or_experiment, checkpoints. A value of 0 (default) disables checkpointing. checkpoint_at_end (bool): Whether to checkpoint at the end of the experiment regardless of the checkpoint_freq. Default is False. + keep_checkpoints_num (int): Number of checkpoints to keep. A value of + `None` keeps all checkpoints. Defaults to `None`. If set, need + to provide `checkpoint_score_attr`. + checkpoint_score_attr (str): Specifies by which attribute to rank the + best checkpoint. Default is increasing order. If attribute starts + with `min-` it will rank attribute in decreasing order, i.e. + `min-validation_loss`. global_checkpoint_period (int): Seconds between global checkpointing. This does not affect `checkpoint_freq`, which specifies frequency for individual trials. @@ -199,6 +208,8 @@ def run(run_or_experiment, loggers=loggers, checkpoint_freq=checkpoint_freq, checkpoint_at_end=checkpoint_at_end, + keep_checkpoints_num=keep_checkpoints_num, + checkpoint_score_attr=checkpoint_score_attr, export_formats=export_formats, max_failures=max_failures, restore=restore,