[tune] Fix restoration for function API PBT (#9853)

This commit is contained in:
Richard Liaw
2020-08-03 12:36:17 -07:00
committed by GitHub
parent 323bc23c21
commit b5068d08bf
4 changed files with 394 additions and 38 deletions
+129 -21
View File
@@ -5,6 +5,7 @@ import inspect
import shutil
import threading
import traceback
import uuid
from six.moves import queue
@@ -22,6 +23,84 @@ RESULT_FETCH_TIMEOUT = 0.2
ERROR_REPORT_TIMEOUT = 10
ERROR_FETCH_TIMEOUT = 1
NULL_MARKER = ".null_marker"
TEMP_MARKER = ".temp_marker"
class FuncCheckpointUtil:
"""Utility class holding various function-checkpointing mechanisms.
The two special modes are "null" and "temporary" checkpoints.
*Null Checkpoints*
-------------------
Null checkpoints are generated when a trial is being saved
but a checkpoint has not been created. In this case,
a marker is set, indicating that the checkpoint is null.
When restoring from an null checkpoint, the FunctionRunner
will detect this and *not* restore from any checkpoint at all.
*Temporary Checkpoints*
-----------------------
Temporary checkpoints are generated when a trial is being
restored from a prior in-memory checkpoint. In this case, a marker
will be set indicating that a checkpoint is temporary.
Upon termination of the trial, temporary checkpoints
will be removed. We cannot remove them any earlier because
the loading of checkpoints is non-deterministic.
If "save" is called on a trial whose most recent checkpoint
is temporary, "create_perm_checkpoint" will be called. This
copies the temporary checkpoint to a permanent checkpoint directory.
"""
@staticmethod
def mk_null_checkpoint_dir(logdir):
"""Indicate that the given checkpoint doesn't have state."""
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
logdir, index=-1, override=True)
open(os.path.join(checkpoint_dir, NULL_MARKER), "a").close()
return checkpoint_dir
@staticmethod
def mk_temp_checkpoint_dir(logdir):
"""Indicate that the checkpoint is only for restoration."""
temporary_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
logdir, index="tmp" + uuid.uuid4().hex[:6], override=True)
open(os.path.join(temporary_checkpoint_dir, TEMP_MARKER), "a").close()
return temporary_checkpoint_dir
@staticmethod
def is_temp_checkpoint_dir(checkpoint_dir):
"""Checks for the temp checkpoint marker."""
return os.path.exists(os.path.join(checkpoint_dir, TEMP_MARKER))
@staticmethod
def is_null_checkpoint(checkpoint_dir):
"""Checks for the empty checkpoint marker."""
return os.path.exists(os.path.join(checkpoint_dir, NULL_MARKER))
@staticmethod
def create_perm_checkpoint(checkpoint_dir, logdir, step):
"""Copies temporary checkpoint to a permanent checkpoint directory."""
checkpoint_dir = os.path.abspath(checkpoint_dir)
temporary_marker = os.path.join(checkpoint_dir, TEMP_MARKER)
assert os.path.exists(temporary_marker), (
"Should not be calling this method on a permanent checkpoint.")
os.remove(temporary_marker)
perm_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
logdir, index=step, override=True)
shutil.rmtree(perm_checkpoint_dir)
shutil.copytree(checkpoint_dir, perm_checkpoint_dir)
assert not os.path.exists(
os.path.join(perm_checkpoint_dir, TEMP_MARKER))
return perm_checkpoint_dir
class StatusReporter:
"""Object passed into your function that you can report status through.
@@ -44,7 +123,7 @@ class StatusReporter:
self._trial_name = trial_name
self._trial_id = trial_id
self._logdir = logdir
self._last_checkpoint = {}
self._last_checkpoint = None
self._fresh_checkpoint = False
def __call__(self, **kwargs):
@@ -83,13 +162,18 @@ class StatusReporter:
# resume training.
self._continue_semaphore.acquire()
def make_checkpoint_dir(self, step=None):
def make_checkpoint_dir(self, step):
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index=step)
logger.debug("Making checkpoint dir at %s", checkpoint_dir)
return checkpoint_dir
def save_checkpoint(self, checkpoint):
def set_checkpoint(self, checkpoint, is_new=True):
"""Sets the checkpoint to be returned upon get_checkpoint.
If this is a "new" checkpoint, it will notify Tune
(via has_new_checkpoint). Otherwise, it will NOT notify Tune.
"""
if isinstance(checkpoint, str):
try:
TrainableUtil.find_checkpoint_dir(checkpoint)
@@ -98,7 +182,8 @@ class StatusReporter:
"make_checkpoint_dir.")
raise
self._last_checkpoint = checkpoint
self._fresh_checkpoint = True
if is_new:
self._fresh_checkpoint = True
def has_new_checkpoint(self):
return self._fresh_checkpoint
@@ -189,7 +274,7 @@ class FunctionRunner(Trainable):
session.init(self._status_reporter)
self._runner = None
self._restore_tmpdir = None
self.default_checkpoint_dir = None
self.temp_checkpoint_dir = None
def _trainable_func(self):
"""Subclasses can override this to set the trainable func."""
@@ -282,11 +367,6 @@ class FunctionRunner(Trainable):
def execute(self, fn):
return fn(self)
def create_default_checkpoint_dir(self):
self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index="default")
return self.default_checkpoint_dir
def save(self, checkpoint_path=None):
if checkpoint_path:
raise ValueError(
@@ -297,12 +377,31 @@ class FunctionRunner(Trainable):
if not checkpoint:
state.update(iteration=0, timesteps_total=0, episodes_total=0)
parent_dir = self.create_default_checkpoint_dir()
# We drop a marker here to indicate that the checkpoint is empty
checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir)
parent_dir = checkpoint
elif isinstance(checkpoint, dict):
parent_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index=self.training_iteration)
else:
elif isinstance(checkpoint, str):
parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
# When the trainable is restored, a temporary checkpoint
# is created. However, when saved, it should become permanent.
# Ideally, there are no save calls upon a temporary
# checkpoint, but certain schedulers might.
if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir):
relative_path = os.path.relpath(checkpoint, parent_dir)
parent_dir = FuncCheckpointUtil.create_perm_checkpoint(
checkpoint_dir=parent_dir,
logdir=self.logdir,
step=self.training_iteration)
checkpoint = os.path.abspath(
os.path.join(parent_dir, relative_path))
else:
raise ValueError("Provided checkpoint was expected to have "
"type (str, dict). Got {}.".format(
type(checkpoint)))
checkpoint_path = TrainableUtil.process_checkpoint(
checkpoint, parent_dir, state)
return checkpoint_path
@@ -316,17 +415,20 @@ class FunctionRunner(Trainable):
# This should be removed once Trainables are refactored.
if "tune_checkpoint_path" in checkpoint:
del checkpoint["tune_checkpoint_path"]
self._status_reporter.save_checkpoint(checkpoint)
# If there does not exist a checkpoint, we will not restore
# from it and will remove the marker.
if FuncCheckpointUtil.is_null_checkpoint(checkpoint):
return
# By informing that this checkpoint is not new,
# we will not return the checkpoint path
# as a new checkpoint.
self._status_reporter.set_checkpoint(checkpoint, is_new=False)
def restore_from_object(self, obj):
if self.default_checkpoint_dir is not None and os.exists(
self.default_checkpoint_dir):
shutil.rmtree(self.default_checkpoint_dir)
logger.debug("Clearing default checkpoint: %s",
self.default_checkpoint_dir)
checkpoint_dir = self.create_default_checkpoint_dir()
checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir)
self.temp_checkpoint_dir = (FuncCheckpointUtil.mk_temp_checkpoint_dir(
self.logdir))
checkpoint_path = TrainableUtil.create_from_pickle(
obj, self.temp_checkpoint_dir)
self.restore(checkpoint_path)
def cleanup(self):
@@ -340,6 +442,12 @@ class FunctionRunner(Trainable):
self._report_thread_runner_error()
session.shutdown()
if self.temp_checkpoint_dir is not None and os.path.exists(
self.temp_checkpoint_dir):
shutil.rmtree(self.temp_checkpoint_dir)
logger.debug("Clearing temporary checkpoint: %s",
self.temp_checkpoint_dir)
def _report_thread_runner_error(self, block=False):
try:
err_tb_str = self._error_queue.get(