diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index a765ca23e..d485dcdd0 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -254,6 +254,20 @@ Additionally, checkpointing can be used to provide fault-tolerance for experimen }, }) +The checkpoint_freq may not coincide with the exact end of an experiment. If you want a checkpoint to be created at the end +of a trial, you can additionally set the checkpoint_at_end to True. An example is shown below: + +.. code-block:: python + :emphasize-lines: 5 + + run_experiments({ + "my_experiment_name": { + "run": my_trainable + "checkpoint_freq": 10, + "checkpoint_at_end": True, + "max_failures": 5, + }, + }) Handling Large Datasets ----------------------- diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 91d1c63e7..8fa754ae5 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -112,6 +112,12 @@ def make_parser(parser_creator=None, **kwargs): type=int, help="How many training iterations between checkpoints. " "A value of 0 (default) disables checkpointing.") + parser.add_argument( + "--checkpoint-at-end", + default=False, + type=bool, + help="Whether to checkpoint at the end of the experiment. " + "Default is False.") parser.add_argument( "--max-failures", default=3, @@ -149,6 +155,8 @@ def to_argv(config): argv.append("--{}".format(k.replace("_", "-"))) if isinstance(v, string_types): argv.append(v) + elif isinstance(v, bool): + argv.append(v) else: argv.append(json.dumps(v, cls=_SafeFallbackEncoder)) return argv @@ -186,6 +194,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): # json.load leads to str -> unicode in py2.7 stopping_criterion=spec.get("stop", {}), checkpoint_freq=args.checkpoint_freq, + checkpoint_at_end=args.checkpoint_at_end, # str(None) doesn't create None restore_path=spec.get("restore"), upload_dir=args.upload_dir, diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 80767515d..ed5e742b1 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -43,6 +43,8 @@ class Experiment(object): to (e.g. ``s3://bucket``). checkpoint_freq (int): How many training iterations between checkpoints. A value of 0 (default) disables checkpointing. + checkpoint_at_end (bool): Whether to checkpoint at the end of the + experiment regardless of the checkpoint_freq. Default is False. max_failures (int): Try to recover a trial from its last checkpoint at least this many times. Only applies if checkpointing is enabled. Defaults to 3. @@ -82,6 +84,7 @@ class Experiment(object): local_dir=None, upload_dir="", checkpoint_freq=0, + checkpoint_at_end=False, max_failures=3, restore=None): spec = { @@ -93,6 +96,7 @@ class Experiment(object): "local_dir": local_dir or DEFAULT_RESULTS_DIR, "upload_dir": upload_dir, "checkpoint_freq": checkpoint_freq, + "checkpoint_at_end": checkpoint_at_end, "max_failures": max_failures, "restore": restore } diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index d2930a0b3..f5771f081 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -938,6 +938,26 @@ class TrialRunnerTest(unittest.TestCase): self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1) self.addCleanup(os.remove, path) + def testCheckpointingAtEnd(self): + ray.init(num_cpus=1, num_gpus=1) + runner = TrialRunner(BasicVariantGenerator()) + kwargs = { + "stopping_criterion": { + "training_iteration": 2 + }, + "checkpoint_at_end": True, + "resources": Resources(cpu=1, gpu=1), + } + runner.add_trial(Trial("__fake", **kwargs)) + trials = runner.get_trials() + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + runner.step() + runner.step() + self.assertEqual(trials[0].last_result[DONE], True) + self.assertEqual(trials[0].has_checkpoint(), True) + def testResultDone(self): """Tests that last_result is marked `done` after trial is complete.""" ray.init(num_cpus=1, num_gpus=1) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index e764369f9..9f677532d 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -112,6 +112,7 @@ class Trial(object): resources=None, stopping_criterion=None, checkpoint_freq=0, + checkpoint_at_end=False, restore_path=None, upload_dir=None, max_failures=0): @@ -142,6 +143,7 @@ class Trial(object): # Local trial state that is updated during the run self.last_result = None self.checkpoint_freq = checkpoint_freq + self.checkpoint_at_end = checkpoint_at_end self._checkpoint = Checkpoint( storage=Checkpoint.DISK, value=restore_path) self.status = Trial.PENDING @@ -203,9 +205,12 @@ class Trial(object): return False - def should_checkpoint(self): + def should_checkpoint(self, result): """Whether this trial is due for checkpointing.""" + if result.get(DONE) and self.checkpoint_at_end: + return True + if not self.checkpoint_freq: return False diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index c69b11835..c4ed36afe 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -223,6 +223,7 @@ class TrialRunner(object): self._search_alg.on_trial_complete( trial.trial_id, result=result) decision = TrialScheduler.STOP + else: decision = self._scheduler_alg.on_trial_result( self, trial, result) @@ -234,13 +235,17 @@ class TrialRunner(object): result, terminate=(decision == TrialScheduler.STOP)) if decision == TrialScheduler.CONTINUE: - if trial.should_checkpoint(): + if trial.should_checkpoint(result): # TODO(rliaw): This is a blocking call self.trial_executor.save(trial) self.trial_executor.continue_training(trial) elif decision == TrialScheduler.PAUSE: self.trial_executor.pause_trial(trial) elif decision == TrialScheduler.STOP: + # Checkpoint before ending the trial + # if checkpoint_at_end experiment option is set to True + if trial.should_checkpoint(result): + self.trial_executor.save(trial) self.trial_executor.stop_trial(trial) else: assert False, "Invalid scheduling decision: {}".format(