[tune] Fixed bug in PBT where initial trial result is empty. (#6351)

* Fixed bug in tune pbt where initial result is empty.

* Updated mock trial executor in test suite.

* Added comment.
This commit is contained in:
visatish
2019-12-06 15:30:27 -08:00
committed by Richard Liaw
parent 53d62d3eec
commit e2ba8c1898
4 changed files with 13 additions and 6 deletions
+5 -3
View File
@@ -552,15 +552,17 @@ class RayTrialExecutor(TrialExecutor):
"""Before step() called, update the available resources."""
self._update_avail_resources()
def save(self, trial, storage=Checkpoint.DISK):
def save(self, trial, storage=Checkpoint.DISK, result=None):
"""Saves the trial's state to a checkpoint."""
result = result or trial.last_result
if storage == Checkpoint.MEMORY:
value = trial.runner.save_to_object.remote()
checkpoint = Checkpoint(storage, value, trial.last_result)
checkpoint = Checkpoint(storage, value, result)
else:
with warn_if_slow("save_checkpoint_to_disk"):
value = ray.get(trial.runner.save.remote())
checkpoint = Checkpoint(storage, value, trial.last_result)
checkpoint = Checkpoint(storage, value, result)
with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile:
try:
+3 -1
View File
@@ -240,8 +240,10 @@ class PopulationBasedTraining(FIFOScheduler):
lower_quantile, upper_quantile = self._quantiles()
if trial in upper_quantile:
# The trial last result is only updated after the scheduler
# callback. So, we override with the current result.
state.last_checkpoint = trial_runner.trial_executor.save(
trial, Checkpoint.MEMORY)
trial, Checkpoint.MEMORY, result=result)
self._num_checkpoints += 1
else:
state.last_checkpoint = None # not a top trial
@@ -207,7 +207,7 @@ class _MockTrialExecutor(TrialExecutor):
def restore(self, trial, checkpoint=None):
pass
def save(self, trial, type=Checkpoint.DISK):
def save(self, trial, type=Checkpoint.DISK, result=None):
return trial.trainable_name
def reset_trial(self, trial, new_config, new_experiment_tag):
+4 -1
View File
@@ -226,12 +226,15 @@ class TrialExecutor(object):
raise NotImplementedError("Subclasses of TrialExecutor must provide "
"restore() method")
def save(self, trial, storage=Checkpoint.DISK):
def save(self, trial, storage=Checkpoint.DISK, result=None):
"""Saves training state of this trial to a checkpoint.
If result is None, this trial's last result will be used.
Args:
trial (Trial): The state of this trial to be saved.
storage (str): Where to store the checkpoint. Defaults to DISK.
result (dict): The state of this trial as a dictionary to be saved.
Return:
A Python object if storage==Checkpoint.MEMORY otherwise