[tune] Fix restoration for function API PBT (#9853)

This commit is contained in:
Richard Liaw
2020-08-03 12:36:17 -07:00
committed by SangBin Cho
parent ea1ac15da0
commit a96ddec358
4 changed files with 394 additions and 38 deletions
+129 -21
View File
@@ -5,6 +5,7 @@ import inspect
import shutil
import threading
import traceback
import uuid
from six.moves import queue
@@ -22,6 +23,84 @@ RESULT_FETCH_TIMEOUT = 0.2
ERROR_REPORT_TIMEOUT = 10
ERROR_FETCH_TIMEOUT = 1
NULL_MARKER = ".null_marker"
TEMP_MARKER = ".temp_marker"
class FuncCheckpointUtil:
"""Utility class holding various function-checkpointing mechanisms.
The two special modes are "null" and "temporary" checkpoints.
*Null Checkpoints*
-------------------
Null checkpoints are generated when a trial is being saved
but a checkpoint has not been created. In this case,
a marker is set, indicating that the checkpoint is null.
When restoring from an null checkpoint, the FunctionRunner
will detect this and *not* restore from any checkpoint at all.
*Temporary Checkpoints*
-----------------------
Temporary checkpoints are generated when a trial is being
restored from a prior in-memory checkpoint. In this case, a marker
will be set indicating that a checkpoint is temporary.
Upon termination of the trial, temporary checkpoints
will be removed. We cannot remove them any earlier because
the loading of checkpoints is non-deterministic.
If "save" is called on a trial whose most recent checkpoint
is temporary, "create_perm_checkpoint" will be called. This
copies the temporary checkpoint to a permanent checkpoint directory.
"""
@staticmethod
def mk_null_checkpoint_dir(logdir):
"""Indicate that the given checkpoint doesn't have state."""
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
logdir, index=-1, override=True)
open(os.path.join(checkpoint_dir, NULL_MARKER), "a").close()
return checkpoint_dir
@staticmethod
def mk_temp_checkpoint_dir(logdir):
"""Indicate that the checkpoint is only for restoration."""
temporary_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
logdir, index="tmp" + uuid.uuid4().hex[:6], override=True)
open(os.path.join(temporary_checkpoint_dir, TEMP_MARKER), "a").close()
return temporary_checkpoint_dir
@staticmethod
def is_temp_checkpoint_dir(checkpoint_dir):
"""Checks for the temp checkpoint marker."""
return os.path.exists(os.path.join(checkpoint_dir, TEMP_MARKER))
@staticmethod
def is_null_checkpoint(checkpoint_dir):
"""Checks for the empty checkpoint marker."""
return os.path.exists(os.path.join(checkpoint_dir, NULL_MARKER))
@staticmethod
def create_perm_checkpoint(checkpoint_dir, logdir, step):
"""Copies temporary checkpoint to a permanent checkpoint directory."""
checkpoint_dir = os.path.abspath(checkpoint_dir)
temporary_marker = os.path.join(checkpoint_dir, TEMP_MARKER)
assert os.path.exists(temporary_marker), (
"Should not be calling this method on a permanent checkpoint.")
os.remove(temporary_marker)
perm_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
logdir, index=step, override=True)
shutil.rmtree(perm_checkpoint_dir)
shutil.copytree(checkpoint_dir, perm_checkpoint_dir)
assert not os.path.exists(
os.path.join(perm_checkpoint_dir, TEMP_MARKER))
return perm_checkpoint_dir
class StatusReporter:
"""Object passed into your function that you can report status through.
@@ -44,7 +123,7 @@ class StatusReporter:
self._trial_name = trial_name
self._trial_id = trial_id
self._logdir = logdir
self._last_checkpoint = {}
self._last_checkpoint = None
self._fresh_checkpoint = False
def __call__(self, **kwargs):
@@ -83,13 +162,18 @@ class StatusReporter:
# resume training.
self._continue_semaphore.acquire()
def make_checkpoint_dir(self, step=None):
def make_checkpoint_dir(self, step):
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index=step)
logger.debug("Making checkpoint dir at %s", checkpoint_dir)
return checkpoint_dir
def save_checkpoint(self, checkpoint):
def set_checkpoint(self, checkpoint, is_new=True):
"""Sets the checkpoint to be returned upon get_checkpoint.
If this is a "new" checkpoint, it will notify Tune
(via has_new_checkpoint). Otherwise, it will NOT notify Tune.
"""
if isinstance(checkpoint, str):
try:
TrainableUtil.find_checkpoint_dir(checkpoint)
@@ -98,7 +182,8 @@ class StatusReporter:
"make_checkpoint_dir.")
raise
self._last_checkpoint = checkpoint
self._fresh_checkpoint = True
if is_new:
self._fresh_checkpoint = True
def has_new_checkpoint(self):
return self._fresh_checkpoint
@@ -189,7 +274,7 @@ class FunctionRunner(Trainable):
session.init(self._status_reporter)
self._runner = None
self._restore_tmpdir = None
self.default_checkpoint_dir = None
self.temp_checkpoint_dir = None
def _trainable_func(self):
"""Subclasses can override this to set the trainable func."""
@@ -282,11 +367,6 @@ class FunctionRunner(Trainable):
def execute(self, fn):
return fn(self)
def create_default_checkpoint_dir(self):
self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index="default")
return self.default_checkpoint_dir
def save(self, checkpoint_path=None):
if checkpoint_path:
raise ValueError(
@@ -297,12 +377,31 @@ class FunctionRunner(Trainable):
if not checkpoint:
state.update(iteration=0, timesteps_total=0, episodes_total=0)
parent_dir = self.create_default_checkpoint_dir()
# We drop a marker here to indicate that the checkpoint is empty
checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir)
parent_dir = checkpoint
elif isinstance(checkpoint, dict):
parent_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index=self.training_iteration)
else:
elif isinstance(checkpoint, str):
parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
# When the trainable is restored, a temporary checkpoint
# is created. However, when saved, it should become permanent.
# Ideally, there are no save calls upon a temporary
# checkpoint, but certain schedulers might.
if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir):
relative_path = os.path.relpath(checkpoint, parent_dir)
parent_dir = FuncCheckpointUtil.create_perm_checkpoint(
checkpoint_dir=parent_dir,
logdir=self.logdir,
step=self.training_iteration)
checkpoint = os.path.abspath(
os.path.join(parent_dir, relative_path))
else:
raise ValueError("Provided checkpoint was expected to have "
"type (str, dict). Got {}.".format(
type(checkpoint)))
checkpoint_path = TrainableUtil.process_checkpoint(
checkpoint, parent_dir, state)
return checkpoint_path
@@ -316,17 +415,20 @@ class FunctionRunner(Trainable):
# This should be removed once Trainables are refactored.
if "tune_checkpoint_path" in checkpoint:
del checkpoint["tune_checkpoint_path"]
self._status_reporter.save_checkpoint(checkpoint)
# If there does not exist a checkpoint, we will not restore
# from it and will remove the marker.
if FuncCheckpointUtil.is_null_checkpoint(checkpoint):
return
# By informing that this checkpoint is not new,
# we will not return the checkpoint path
# as a new checkpoint.
self._status_reporter.set_checkpoint(checkpoint, is_new=False)
def restore_from_object(self, obj):
if self.default_checkpoint_dir is not None and os.exists(
self.default_checkpoint_dir):
shutil.rmtree(self.default_checkpoint_dir)
logger.debug("Clearing default checkpoint: %s",
self.default_checkpoint_dir)
checkpoint_dir = self.create_default_checkpoint_dir()
checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir)
self.temp_checkpoint_dir = (FuncCheckpointUtil.mk_temp_checkpoint_dir(
self.logdir))
checkpoint_path = TrainableUtil.create_from_pickle(
obj, self.temp_checkpoint_dir)
self.restore(checkpoint_path)
def cleanup(self):
@@ -340,6 +442,12 @@ class FunctionRunner(Trainable):
self._report_thread_runner_error()
session.shutdown()
if self.temp_checkpoint_dir is not None and os.path.exists(
self.temp_checkpoint_dir):
shutil.rmtree(self.temp_checkpoint_dir)
logger.debug("Clearing temporary checkpoint: %s",
self.temp_checkpoint_dir)
def _report_thread_runner_error(self, block=False):
try:
err_tb_str = self._error_queue.get(
+9 -2
View File
@@ -98,12 +98,16 @@ def save_checkpoint(checkpoint):
@contextmanager
def checkpoint_dir(step=None):
def checkpoint_dir(step):
"""Returns a checkpoint dir inside a context.
Store any files related to restoring state within the
provided checkpoint dir.
Args:
step (int): Index for the checkpoint. Expected to be a
monotonically increasing quantity.
.. code-block:: python
import os
@@ -136,6 +140,9 @@ def checkpoint_dir(step=None):
"""
_session = get_session()
if step is None:
raise ValueError("checkpoint_dir(step) must be provided - got None.")
if _session:
_checkpoint_dir = _session.make_checkpoint_dir(step=step)
else:
@@ -144,7 +151,7 @@ def checkpoint_dir(step=None):
yield _checkpoint_dir
if _session:
_session.save_checkpoint(_checkpoint_dir)
_session.set_checkpoint(_checkpoint_dir)
def get_trial_dir():
+251 -14
View File
@@ -1,36 +1,230 @@
import json
import os
import shutil
import tempfile
import unittest
import ray
from ray.rllib import _register_all
from ray import tune
from ray.tune.function_runner import wrap_function
from ray.tune.logger import NoopLogger
from ray.tune.trainable import TrainableUtil
from ray.tune.function_runner import wrap_function, FuncCheckpointUtil
from ray.tune.result import TRAINING_ITERATION
class FunctionApiTest(unittest.TestCase):
def creator_generator(logdir):
def logger_creator(config):
return NoopLogger(config, logdir)
return logger_creator
class FuncCheckpointUtilTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024)
self.logdir = tempfile.mkdtemp()
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
shutil.rmtree(self.logdir)
def testEmptyCheckpoint(self):
checkpoint_dir = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir)
assert FuncCheckpointUtil.is_null_checkpoint(checkpoint_dir)
def testTempCheckpointDir(self):
checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir)
assert FuncCheckpointUtil.is_temp_checkpoint_dir(checkpoint_dir)
def testConvertTempToPermanent(self):
checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir)
new_checkpoint_dir = FuncCheckpointUtil.create_perm_checkpoint(
checkpoint_dir, self.logdir, step=4)
assert new_checkpoint_dir == TrainableUtil.find_checkpoint_dir(
new_checkpoint_dir)
assert os.path.exists(new_checkpoint_dir)
assert not FuncCheckpointUtil.is_temp_checkpoint_dir(
new_checkpoint_dir)
tmp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
self.logdir)
assert tmp_checkpoint_dir != new_checkpoint_dir
class FunctionCheckpointingTest(unittest.TestCase):
def setUp(self):
self.logdir = tempfile.mkdtemp()
self.logger_creator = creator_generator(self.logdir)
def tearDown(self):
shutil.rmtree(self.logdir)
def testCheckpointReuse(self):
"""Test that repeated save/restore never reuses same checkpoint dir."""
def train(config, checkpoint_dir=None):
if checkpoint_dir:
count = sum("checkpoint-" in path
for path in os.listdir(checkpoint_dir))
assert count == 1, os.listdir(checkpoint_dir)
for step in range(20):
with tune.checkpoint_dir(step=step) as checkpoint_dir:
path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(step))
open(path, "a").close()
tune.report(test=step)
wrapped = wrap_function(train)
checkpoint = None
for i in range(5):
new_trainable = wrapped(logger_creator=self.logger_creator)
if checkpoint:
new_trainable.restore(checkpoint)
for i in range(2):
result = new_trainable.train()
checkpoint = new_trainable.save()
new_trainable.stop()
assert result[TRAINING_ITERATION] == 10
def testCheckpointReuseObject(self):
"""Test that repeated save/restore never reuses same checkpoint dir."""
def train(config, checkpoint_dir=None):
if checkpoint_dir:
count = sum("checkpoint-" in path
for path in os.listdir(checkpoint_dir))
assert count == 1, os.listdir(checkpoint_dir)
for step in range(20):
with tune.checkpoint_dir(step=step) as checkpoint_dir:
path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(step))
open(path, "a").close()
tune.report(test=step)
wrapped = wrap_function(train)
checkpoint = None
for i in range(5):
new_trainable = wrapped(logger_creator=self.logger_creator)
if checkpoint:
new_trainable.restore_from_object(checkpoint)
for i in range(2):
result = new_trainable.train()
checkpoint = new_trainable.save_to_object()
new_trainable.stop()
self.assertTrue(result[TRAINING_ITERATION] == 10)
def testCheckpointReuseObjectWithoutTraining(self):
"""Test that repeated save/restore never reuses same checkpoint dir."""
def train(config, checkpoint_dir=None):
if checkpoint_dir:
count = sum("checkpoint-" in path
for path in os.listdir(checkpoint_dir))
assert count == 1, os.listdir(checkpoint_dir)
for step in range(20):
with tune.checkpoint_dir(step=step) as checkpoint_dir:
path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(step))
open(path, "a").close()
tune.report(test=step)
wrapped = wrap_function(train)
new_trainable = wrapped(logger_creator=self.logger_creator)
for i in range(2):
result = new_trainable.train()
checkpoint = new_trainable.save_to_object()
new_trainable.stop()
new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint)
new_trainable2.stop()
new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint)
result = new_trainable2.train()
new_trainable2.stop()
self.assertTrue(result[TRAINING_ITERATION] == 3)
def testReuseNullCheckpoint(self):
def train(config, checkpoint_dir=None):
assert not checkpoint_dir
for step in range(10):
tune.report(test=step)
# Create checkpoint
wrapped = wrap_function(train)
checkpoint = None
new_trainable = wrapped(logger_creator=self.logger_creator)
new_trainable.train()
checkpoint = new_trainable.save()
new_trainable.stop()
# Use the checkpoint a couple of times
for i in range(3):
new_trainable = wrapped(logger_creator=self.logger_creator)
new_trainable.restore(checkpoint)
new_trainable.stop()
# Make sure the result is still good
new_trainable = wrapped(logger_creator=self.logger_creator)
new_trainable.restore(checkpoint)
result = new_trainable.train()
checkpoint = new_trainable.save()
new_trainable.stop()
self.assertTrue(result[TRAINING_ITERATION] == 1)
def testMultipleNullCheckpoints(self):
def train(config, checkpoint_dir=None):
assert not checkpoint_dir
for step in range(10):
tune.report(test=step)
wrapped = wrap_function(train)
checkpoint = None
for i in range(5):
new_trainable = wrapped(logger_creator=self.logger_creator)
if checkpoint:
new_trainable.restore(checkpoint)
result = new_trainable.train()
checkpoint = new_trainable.save()
new_trainable.stop()
self.assertTrue(result[TRAINING_ITERATION] == 1)
def testMultipleNullMemoryCheckpoints(self):
def train(config, checkpoint_dir=None):
assert not checkpoint_dir
for step in range(10):
tune.report(test=step)
wrapped = wrap_function(train)
checkpoint = None
for i in range(5):
new_trainable = wrapped(logger_creator=self.logger_creator)
if checkpoint:
new_trainable.restore_from_object(checkpoint)
result = new_trainable.train()
checkpoint = new_trainable.save_to_object()
new_trainable.stop()
assert result[TRAINING_ITERATION] == 1
def testFunctionNoCheckpointing(self):
def train(config, checkpoint_dir=None):
for i in range(10):
tune.report(test=i)
if checkpoint_dir:
assert os.path.exists(checkpoint_dir)
for step in range(10):
tune.report(test=step)
wrapped = wrap_function(train)
new_trainable = wrapped()
new_trainable = wrapped(logger_creator=self.logger_creator)
result = new_trainable.train()
checkpoint = new_trainable.save()
new_trainable.stop()
new_trainable2 = wrapped()
new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore(checkpoint)
result = new_trainable2.train()
self.assertEquals(result[TRAINING_ITERATION], 1)
@@ -41,6 +235,8 @@ class FunctionApiTest(unittest.TestCase):
"""This tests that save and restore are commutative."""
def train(config, checkpoint_dir=None):
if checkpoint_dir:
assert os.path.exists(checkpoint_dir)
for step in range(10):
if step % 3 == 0:
with tune.checkpoint_dir(step=step) as checkpoint_dir:
@@ -51,18 +247,59 @@ class FunctionApiTest(unittest.TestCase):
wrapped = wrap_function(train)
new_trainable = wrapped()
new_trainable = wrapped(logger_creator=self.logger_creator)
new_trainable.train()
checkpoint_obj = new_trainable.save_to_object()
new_trainable.restore_from_object(checkpoint_obj)
checkpoint = new_trainable.save()
new_trainable.stop()
new_trainable2 = wrapped()
new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore(checkpoint)
new_trainable2.train()
new_trainable2.stop()
def testFunctionImmediateSave(self):
"""This tests that save and restore are commutative."""
def train(config, checkpoint_dir=None):
if checkpoint_dir:
assert os.path.exists(checkpoint_dir)
for step in range(10):
with tune.checkpoint_dir(step=step) as checkpoint_dir:
print(checkpoint_dir)
path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(step))
open(path, "w").close()
tune.report(test=step)
wrapped = wrap_function(train)
new_trainable = wrapped(logger_creator=self.logger_creator)
new_trainable.train()
new_trainable.train()
checkpoint_obj = new_trainable.save_to_object()
new_trainable.stop()
new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint_obj)
checkpoint_obj = new_trainable2.save_to_object()
new_trainable2.train()
result = new_trainable2.train()
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
new_trainable2.stop()
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0
assert result[TRAINING_ITERATION] == 4
class FunctionApiTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024)
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
def testCheckpointFunctionAtEnd(self):
def train(config, checkpoint_dir=False):
for i in range(10):
@@ -90,12 +327,12 @@ class FunctionApiTest(unittest.TestCase):
def testVariousCheckpointFunctionAtEnd(self):
def train(config, checkpoint_dir=False):
for i in range(10):
with tune.checkpoint_dir() as checkpoint_dir:
with tune.checkpoint_dir(step=i) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write("hello")
tune.report(test=i)
with tune.checkpoint_dir() as checkpoint_dir:
with tune.checkpoint_dir(step=i) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log2")
with open(checkpoint_path, "w") as f:
f.write("goodbye")
@@ -164,7 +401,7 @@ class FunctionApiTest(unittest.TestCase):
for i in range(itr, 10):
if i == 5 and not restored:
raise Exception("try to fail me")
with tune.checkpoint_dir() as checkpoint_dir:
with tune.checkpoint_dir(step=itr) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write(str(i))
+5 -1
View File
@@ -109,19 +109,23 @@ class TrainableUtil:
return checkpoint_dir
@staticmethod
def make_checkpoint_dir(checkpoint_dir, index):
def make_checkpoint_dir(checkpoint_dir, index, override=False):
"""Creates a checkpoint directory within the provided path.
Args:
checkpoint_dir (str): Path to checkpoint directory.
index (str): A subdirectory will be created
at the checkpoint directory named 'checkpoint_{index}'.
override (bool): Deletes checkpoint_dir before creating
a new one.
"""
suffix = "checkpoint"
if index is not None:
suffix += "_{}".format(index)
checkpoint_dir = os.path.join(checkpoint_dir, suffix)
if override and os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
os.makedirs(checkpoint_dir, exist_ok=True)
# Drop marker in directory to identify it as a checkpoint dir.
open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()