mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 10:46:13 +08:00
[tune] Fixed bug in PBT where initial trial result is empty. (#6351)
* Fixed bug in tune pbt where initial result is empty. * Updated mock trial executor in test suite. * Added comment.
This commit is contained in:
@@ -552,15 +552,17 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"""Before step() called, update the available resources."""
|
||||
self._update_avail_resources()
|
||||
|
||||
def save(self, trial, storage=Checkpoint.DISK):
|
||||
def save(self, trial, storage=Checkpoint.DISK, result=None):
|
||||
"""Saves the trial's state to a checkpoint."""
|
||||
result = result or trial.last_result
|
||||
|
||||
if storage == Checkpoint.MEMORY:
|
||||
value = trial.runner.save_to_object.remote()
|
||||
checkpoint = Checkpoint(storage, value, trial.last_result)
|
||||
checkpoint = Checkpoint(storage, value, result)
|
||||
else:
|
||||
with warn_if_slow("save_checkpoint_to_disk"):
|
||||
value = ray.get(trial.runner.save.remote())
|
||||
checkpoint = Checkpoint(storage, value, trial.last_result)
|
||||
checkpoint = Checkpoint(storage, value, result)
|
||||
|
||||
with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile:
|
||||
try:
|
||||
|
||||
@@ -240,8 +240,10 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
lower_quantile, upper_quantile = self._quantiles()
|
||||
|
||||
if trial in upper_quantile:
|
||||
# The trial last result is only updated after the scheduler
|
||||
# callback. So, we override with the current result.
|
||||
state.last_checkpoint = trial_runner.trial_executor.save(
|
||||
trial, Checkpoint.MEMORY)
|
||||
trial, Checkpoint.MEMORY, result=result)
|
||||
self._num_checkpoints += 1
|
||||
else:
|
||||
state.last_checkpoint = None # not a top trial
|
||||
|
||||
@@ -207,7 +207,7 @@ class _MockTrialExecutor(TrialExecutor):
|
||||
def restore(self, trial, checkpoint=None):
|
||||
pass
|
||||
|
||||
def save(self, trial, type=Checkpoint.DISK):
|
||||
def save(self, trial, type=Checkpoint.DISK, result=None):
|
||||
return trial.trainable_name
|
||||
|
||||
def reset_trial(self, trial, new_config, new_experiment_tag):
|
||||
|
||||
@@ -226,12 +226,15 @@ class TrialExecutor(object):
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"restore() method")
|
||||
|
||||
def save(self, trial, storage=Checkpoint.DISK):
|
||||
def save(self, trial, storage=Checkpoint.DISK, result=None):
|
||||
"""Saves training state of this trial to a checkpoint.
|
||||
|
||||
If result is None, this trial's last result will be used.
|
||||
|
||||
Args:
|
||||
trial (Trial): The state of this trial to be saved.
|
||||
storage (str): Where to store the checkpoint. Defaults to DISK.
|
||||
result (dict): The state of this trial as a dictionary to be saved.
|
||||
|
||||
Return:
|
||||
A Python object if storage==Checkpoint.MEMORY otherwise
|
||||
|
||||
Reference in New Issue
Block a user