diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 756059395..52b8c63f5 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -552,15 +552,17 @@ class RayTrialExecutor(TrialExecutor): """Before step() called, update the available resources.""" self._update_avail_resources() - def save(self, trial, storage=Checkpoint.DISK): + def save(self, trial, storage=Checkpoint.DISK, result=None): """Saves the trial's state to a checkpoint.""" + result = result or trial.last_result + if storage == Checkpoint.MEMORY: value = trial.runner.save_to_object.remote() - checkpoint = Checkpoint(storage, value, trial.last_result) + checkpoint = Checkpoint(storage, value, result) else: with warn_if_slow("save_checkpoint_to_disk"): value = ray.get(trial.runner.save.remote()) - checkpoint = Checkpoint(storage, value, trial.last_result) + checkpoint = Checkpoint(storage, value, result) with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile: try: diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index f863ab21a..7d4046059 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -240,8 +240,10 @@ class PopulationBasedTraining(FIFOScheduler): lower_quantile, upper_quantile = self._quantiles() if trial in upper_quantile: + # The trial last result is only updated after the scheduler + # callback. So, we override with the current result. state.last_checkpoint = trial_runner.trial_executor.save( - trial, Checkpoint.MEMORY) + trial, Checkpoint.MEMORY, result=result) self._num_checkpoints += 1 else: state.last_checkpoint = None # not a top trial diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 4ca60236e..245f0b585 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -207,7 +207,7 @@ class _MockTrialExecutor(TrialExecutor): def restore(self, trial, checkpoint=None): pass - def save(self, trial, type=Checkpoint.DISK): + def save(self, trial, type=Checkpoint.DISK, result=None): return trial.trainable_name def reset_trial(self, trial, new_config, new_experiment_tag): diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index d6bd529a3..d8d35e023 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -226,12 +226,15 @@ class TrialExecutor(object): raise NotImplementedError("Subclasses of TrialExecutor must provide " "restore() method") - def save(self, trial, storage=Checkpoint.DISK): + def save(self, trial, storage=Checkpoint.DISK, result=None): """Saves training state of this trial to a checkpoint. + If result is None, this trial's last result will be used. + Args: trial (Trial): The state of this trial to be saved. storage (str): Where to store the checkpoint. Defaults to DISK. + result (dict): The state of this trial as a dictionary to be saved. Return: A Python object if storage==Checkpoint.MEMORY otherwise