From a96ddec3582e1f1007f41c4bbaf31761ba3851d4 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Mon, 3 Aug 2020 12:36:17 -0700 Subject: [PATCH] [tune] Fix restoration for function API PBT (#9853) --- python/ray/tune/function_runner.py | 150 ++++++++++-- python/ray/tune/session.py | 11 +- python/ray/tune/tests/test_function_api.py | 265 +++++++++++++++++++-- python/ray/tune/trainable.py | 6 +- 4 files changed, 394 insertions(+), 38 deletions(-) diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index 7408f780d..2fd1f9fa1 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -5,6 +5,7 @@ import inspect import shutil import threading import traceback +import uuid from six.moves import queue @@ -22,6 +23,84 @@ RESULT_FETCH_TIMEOUT = 0.2 ERROR_REPORT_TIMEOUT = 10 ERROR_FETCH_TIMEOUT = 1 +NULL_MARKER = ".null_marker" +TEMP_MARKER = ".temp_marker" + + +class FuncCheckpointUtil: + """Utility class holding various function-checkpointing mechanisms. + + The two special modes are "null" and "temporary" checkpoints. + + *Null Checkpoints* + ------------------- + + Null checkpoints are generated when a trial is being saved + but a checkpoint has not been created. In this case, + a marker is set, indicating that the checkpoint is null. + + When restoring from an null checkpoint, the FunctionRunner + will detect this and *not* restore from any checkpoint at all. + + *Temporary Checkpoints* + ----------------------- + + Temporary checkpoints are generated when a trial is being + restored from a prior in-memory checkpoint. In this case, a marker + will be set indicating that a checkpoint is temporary. + + Upon termination of the trial, temporary checkpoints + will be removed. We cannot remove them any earlier because + the loading of checkpoints is non-deterministic. + + If "save" is called on a trial whose most recent checkpoint + is temporary, "create_perm_checkpoint" will be called. This + copies the temporary checkpoint to a permanent checkpoint directory. + """ + + @staticmethod + def mk_null_checkpoint_dir(logdir): + """Indicate that the given checkpoint doesn't have state.""" + checkpoint_dir = TrainableUtil.make_checkpoint_dir( + logdir, index=-1, override=True) + open(os.path.join(checkpoint_dir, NULL_MARKER), "a").close() + return checkpoint_dir + + @staticmethod + def mk_temp_checkpoint_dir(logdir): + """Indicate that the checkpoint is only for restoration.""" + temporary_checkpoint_dir = TrainableUtil.make_checkpoint_dir( + logdir, index="tmp" + uuid.uuid4().hex[:6], override=True) + open(os.path.join(temporary_checkpoint_dir, TEMP_MARKER), "a").close() + return temporary_checkpoint_dir + + @staticmethod + def is_temp_checkpoint_dir(checkpoint_dir): + """Checks for the temp checkpoint marker.""" + return os.path.exists(os.path.join(checkpoint_dir, TEMP_MARKER)) + + @staticmethod + def is_null_checkpoint(checkpoint_dir): + """Checks for the empty checkpoint marker.""" + return os.path.exists(os.path.join(checkpoint_dir, NULL_MARKER)) + + @staticmethod + def create_perm_checkpoint(checkpoint_dir, logdir, step): + """Copies temporary checkpoint to a permanent checkpoint directory.""" + checkpoint_dir = os.path.abspath(checkpoint_dir) + temporary_marker = os.path.join(checkpoint_dir, TEMP_MARKER) + assert os.path.exists(temporary_marker), ( + "Should not be calling this method on a permanent checkpoint.") + os.remove(temporary_marker) + perm_checkpoint_dir = TrainableUtil.make_checkpoint_dir( + logdir, index=step, override=True) + shutil.rmtree(perm_checkpoint_dir) + + shutil.copytree(checkpoint_dir, perm_checkpoint_dir) + assert not os.path.exists( + os.path.join(perm_checkpoint_dir, TEMP_MARKER)) + return perm_checkpoint_dir + class StatusReporter: """Object passed into your function that you can report status through. @@ -44,7 +123,7 @@ class StatusReporter: self._trial_name = trial_name self._trial_id = trial_id self._logdir = logdir - self._last_checkpoint = {} + self._last_checkpoint = None self._fresh_checkpoint = False def __call__(self, **kwargs): @@ -83,13 +162,18 @@ class StatusReporter: # resume training. self._continue_semaphore.acquire() - def make_checkpoint_dir(self, step=None): + def make_checkpoint_dir(self, step): checkpoint_dir = TrainableUtil.make_checkpoint_dir( self.logdir, index=step) logger.debug("Making checkpoint dir at %s", checkpoint_dir) return checkpoint_dir - def save_checkpoint(self, checkpoint): + def set_checkpoint(self, checkpoint, is_new=True): + """Sets the checkpoint to be returned upon get_checkpoint. + + If this is a "new" checkpoint, it will notify Tune + (via has_new_checkpoint). Otherwise, it will NOT notify Tune. + """ if isinstance(checkpoint, str): try: TrainableUtil.find_checkpoint_dir(checkpoint) @@ -98,7 +182,8 @@ class StatusReporter: "make_checkpoint_dir.") raise self._last_checkpoint = checkpoint - self._fresh_checkpoint = True + if is_new: + self._fresh_checkpoint = True def has_new_checkpoint(self): return self._fresh_checkpoint @@ -189,7 +274,7 @@ class FunctionRunner(Trainable): session.init(self._status_reporter) self._runner = None self._restore_tmpdir = None - self.default_checkpoint_dir = None + self.temp_checkpoint_dir = None def _trainable_func(self): """Subclasses can override this to set the trainable func.""" @@ -282,11 +367,6 @@ class FunctionRunner(Trainable): def execute(self, fn): return fn(self) - def create_default_checkpoint_dir(self): - self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir( - self.logdir, index="default") - return self.default_checkpoint_dir - def save(self, checkpoint_path=None): if checkpoint_path: raise ValueError( @@ -297,12 +377,31 @@ class FunctionRunner(Trainable): if not checkpoint: state.update(iteration=0, timesteps_total=0, episodes_total=0) - parent_dir = self.create_default_checkpoint_dir() + # We drop a marker here to indicate that the checkpoint is empty + checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir) + parent_dir = checkpoint elif isinstance(checkpoint, dict): parent_dir = TrainableUtil.make_checkpoint_dir( self.logdir, index=self.training_iteration) - else: + elif isinstance(checkpoint, str): parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint) + # When the trainable is restored, a temporary checkpoint + # is created. However, when saved, it should become permanent. + # Ideally, there are no save calls upon a temporary + # checkpoint, but certain schedulers might. + if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir): + relative_path = os.path.relpath(checkpoint, parent_dir) + parent_dir = FuncCheckpointUtil.create_perm_checkpoint( + checkpoint_dir=parent_dir, + logdir=self.logdir, + step=self.training_iteration) + checkpoint = os.path.abspath( + os.path.join(parent_dir, relative_path)) + else: + raise ValueError("Provided checkpoint was expected to have " + "type (str, dict). Got {}.".format( + type(checkpoint))) + checkpoint_path = TrainableUtil.process_checkpoint( checkpoint, parent_dir, state) return checkpoint_path @@ -316,17 +415,20 @@ class FunctionRunner(Trainable): # This should be removed once Trainables are refactored. if "tune_checkpoint_path" in checkpoint: del checkpoint["tune_checkpoint_path"] - self._status_reporter.save_checkpoint(checkpoint) + # If there does not exist a checkpoint, we will not restore + # from it and will remove the marker. + if FuncCheckpointUtil.is_null_checkpoint(checkpoint): + return + # By informing that this checkpoint is not new, + # we will not return the checkpoint path + # as a new checkpoint. + self._status_reporter.set_checkpoint(checkpoint, is_new=False) def restore_from_object(self, obj): - if self.default_checkpoint_dir is not None and os.exists( - self.default_checkpoint_dir): - shutil.rmtree(self.default_checkpoint_dir) - logger.debug("Clearing default checkpoint: %s", - self.default_checkpoint_dir) - - checkpoint_dir = self.create_default_checkpoint_dir() - checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir) + self.temp_checkpoint_dir = (FuncCheckpointUtil.mk_temp_checkpoint_dir( + self.logdir)) + checkpoint_path = TrainableUtil.create_from_pickle( + obj, self.temp_checkpoint_dir) self.restore(checkpoint_path) def cleanup(self): @@ -340,6 +442,12 @@ class FunctionRunner(Trainable): self._report_thread_runner_error() session.shutdown() + if self.temp_checkpoint_dir is not None and os.path.exists( + self.temp_checkpoint_dir): + shutil.rmtree(self.temp_checkpoint_dir) + logger.debug("Clearing temporary checkpoint: %s", + self.temp_checkpoint_dir) + def _report_thread_runner_error(self, block=False): try: err_tb_str = self._error_queue.get( diff --git a/python/ray/tune/session.py b/python/ray/tune/session.py index db1ff426d..c682e9457 100644 --- a/python/ray/tune/session.py +++ b/python/ray/tune/session.py @@ -98,12 +98,16 @@ def save_checkpoint(checkpoint): @contextmanager -def checkpoint_dir(step=None): +def checkpoint_dir(step): """Returns a checkpoint dir inside a context. Store any files related to restoring state within the provided checkpoint dir. + Args: + step (int): Index for the checkpoint. Expected to be a + monotonically increasing quantity. + .. code-block:: python import os @@ -136,6 +140,9 @@ def checkpoint_dir(step=None): """ _session = get_session() + if step is None: + raise ValueError("checkpoint_dir(step) must be provided - got None.") + if _session: _checkpoint_dir = _session.make_checkpoint_dir(step=step) else: @@ -144,7 +151,7 @@ def checkpoint_dir(step=None): yield _checkpoint_dir if _session: - _session.save_checkpoint(_checkpoint_dir) + _session.set_checkpoint(_checkpoint_dir) def get_trial_dir(): diff --git a/python/ray/tune/tests/test_function_api.py b/python/ray/tune/tests/test_function_api.py index 0ff5a7cf9..0e20849ad 100644 --- a/python/ray/tune/tests/test_function_api.py +++ b/python/ray/tune/tests/test_function_api.py @@ -1,36 +1,230 @@ import json import os +import shutil +import tempfile import unittest import ray from ray.rllib import _register_all from ray import tune -from ray.tune.function_runner import wrap_function +from ray.tune.logger import NoopLogger +from ray.tune.trainable import TrainableUtil +from ray.tune.function_runner import wrap_function, FuncCheckpointUtil from ray.tune.result import TRAINING_ITERATION -class FunctionApiTest(unittest.TestCase): +def creator_generator(logdir): + def logger_creator(config): + return NoopLogger(config, logdir) + + return logger_creator + + +class FuncCheckpointUtilTest(unittest.TestCase): def setUp(self): - ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024) + self.logdir = tempfile.mkdtemp() def tearDown(self): - ray.shutdown() - _register_all() # re-register the evicted objects + shutil.rmtree(self.logdir) + + def testEmptyCheckpoint(self): + checkpoint_dir = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir) + assert FuncCheckpointUtil.is_null_checkpoint(checkpoint_dir) + + def testTempCheckpointDir(self): + checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir) + assert FuncCheckpointUtil.is_temp_checkpoint_dir(checkpoint_dir) + + def testConvertTempToPermanent(self): + checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir) + new_checkpoint_dir = FuncCheckpointUtil.create_perm_checkpoint( + checkpoint_dir, self.logdir, step=4) + assert new_checkpoint_dir == TrainableUtil.find_checkpoint_dir( + new_checkpoint_dir) + assert os.path.exists(new_checkpoint_dir) + assert not FuncCheckpointUtil.is_temp_checkpoint_dir( + new_checkpoint_dir) + + tmp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir( + self.logdir) + assert tmp_checkpoint_dir != new_checkpoint_dir + + +class FunctionCheckpointingTest(unittest.TestCase): + def setUp(self): + self.logdir = tempfile.mkdtemp() + self.logger_creator = creator_generator(self.logdir) + + def tearDown(self): + shutil.rmtree(self.logdir) + + def testCheckpointReuse(self): + """Test that repeated save/restore never reuses same checkpoint dir.""" + + def train(config, checkpoint_dir=None): + if checkpoint_dir: + count = sum("checkpoint-" in path + for path in os.listdir(checkpoint_dir)) + assert count == 1, os.listdir(checkpoint_dir) + + for step in range(20): + with tune.checkpoint_dir(step=step) as checkpoint_dir: + path = os.path.join(checkpoint_dir, + "checkpoint-{}".format(step)) + open(path, "a").close() + tune.report(test=step) + + wrapped = wrap_function(train) + checkpoint = None + for i in range(5): + new_trainable = wrapped(logger_creator=self.logger_creator) + if checkpoint: + new_trainable.restore(checkpoint) + for i in range(2): + result = new_trainable.train() + checkpoint = new_trainable.save() + new_trainable.stop() + assert result[TRAINING_ITERATION] == 10 + + def testCheckpointReuseObject(self): + """Test that repeated save/restore never reuses same checkpoint dir.""" + + def train(config, checkpoint_dir=None): + if checkpoint_dir: + count = sum("checkpoint-" in path + for path in os.listdir(checkpoint_dir)) + assert count == 1, os.listdir(checkpoint_dir) + + for step in range(20): + with tune.checkpoint_dir(step=step) as checkpoint_dir: + path = os.path.join(checkpoint_dir, + "checkpoint-{}".format(step)) + open(path, "a").close() + tune.report(test=step) + + wrapped = wrap_function(train) + checkpoint = None + for i in range(5): + new_trainable = wrapped(logger_creator=self.logger_creator) + if checkpoint: + new_trainable.restore_from_object(checkpoint) + for i in range(2): + result = new_trainable.train() + checkpoint = new_trainable.save_to_object() + new_trainable.stop() + self.assertTrue(result[TRAINING_ITERATION] == 10) + + def testCheckpointReuseObjectWithoutTraining(self): + """Test that repeated save/restore never reuses same checkpoint dir.""" + + def train(config, checkpoint_dir=None): + if checkpoint_dir: + count = sum("checkpoint-" in path + for path in os.listdir(checkpoint_dir)) + assert count == 1, os.listdir(checkpoint_dir) + + for step in range(20): + with tune.checkpoint_dir(step=step) as checkpoint_dir: + path = os.path.join(checkpoint_dir, + "checkpoint-{}".format(step)) + open(path, "a").close() + tune.report(test=step) + + wrapped = wrap_function(train) + new_trainable = wrapped(logger_creator=self.logger_creator) + for i in range(2): + result = new_trainable.train() + checkpoint = new_trainable.save_to_object() + new_trainable.stop() + + new_trainable2 = wrapped(logger_creator=self.logger_creator) + new_trainable2.restore_from_object(checkpoint) + new_trainable2.stop() + + new_trainable2 = wrapped(logger_creator=self.logger_creator) + new_trainable2.restore_from_object(checkpoint) + result = new_trainable2.train() + new_trainable2.stop() + self.assertTrue(result[TRAINING_ITERATION] == 3) + + def testReuseNullCheckpoint(self): + def train(config, checkpoint_dir=None): + assert not checkpoint_dir + for step in range(10): + tune.report(test=step) + + # Create checkpoint + wrapped = wrap_function(train) + checkpoint = None + new_trainable = wrapped(logger_creator=self.logger_creator) + new_trainable.train() + checkpoint = new_trainable.save() + new_trainable.stop() + + # Use the checkpoint a couple of times + for i in range(3): + new_trainable = wrapped(logger_creator=self.logger_creator) + new_trainable.restore(checkpoint) + new_trainable.stop() + + # Make sure the result is still good + new_trainable = wrapped(logger_creator=self.logger_creator) + new_trainable.restore(checkpoint) + result = new_trainable.train() + checkpoint = new_trainable.save() + new_trainable.stop() + self.assertTrue(result[TRAINING_ITERATION] == 1) + + def testMultipleNullCheckpoints(self): + def train(config, checkpoint_dir=None): + assert not checkpoint_dir + for step in range(10): + tune.report(test=step) + + wrapped = wrap_function(train) + checkpoint = None + for i in range(5): + new_trainable = wrapped(logger_creator=self.logger_creator) + if checkpoint: + new_trainable.restore(checkpoint) + result = new_trainable.train() + checkpoint = new_trainable.save() + new_trainable.stop() + self.assertTrue(result[TRAINING_ITERATION] == 1) + + def testMultipleNullMemoryCheckpoints(self): + def train(config, checkpoint_dir=None): + assert not checkpoint_dir + for step in range(10): + tune.report(test=step) + + wrapped = wrap_function(train) + checkpoint = None + for i in range(5): + new_trainable = wrapped(logger_creator=self.logger_creator) + if checkpoint: + new_trainable.restore_from_object(checkpoint) + result = new_trainable.train() + checkpoint = new_trainable.save_to_object() + new_trainable.stop() + assert result[TRAINING_ITERATION] == 1 def testFunctionNoCheckpointing(self): def train(config, checkpoint_dir=None): - for i in range(10): - tune.report(test=i) + if checkpoint_dir: + assert os.path.exists(checkpoint_dir) + for step in range(10): + tune.report(test=step) wrapped = wrap_function(train) - new_trainable = wrapped() + new_trainable = wrapped(logger_creator=self.logger_creator) result = new_trainable.train() checkpoint = new_trainable.save() new_trainable.stop() - new_trainable2 = wrapped() + new_trainable2 = wrapped(logger_creator=self.logger_creator) new_trainable2.restore(checkpoint) result = new_trainable2.train() self.assertEquals(result[TRAINING_ITERATION], 1) @@ -41,6 +235,8 @@ class FunctionApiTest(unittest.TestCase): """This tests that save and restore are commutative.""" def train(config, checkpoint_dir=None): + if checkpoint_dir: + assert os.path.exists(checkpoint_dir) for step in range(10): if step % 3 == 0: with tune.checkpoint_dir(step=step) as checkpoint_dir: @@ -51,18 +247,59 @@ class FunctionApiTest(unittest.TestCase): wrapped = wrap_function(train) - new_trainable = wrapped() + new_trainable = wrapped(logger_creator=self.logger_creator) new_trainable.train() checkpoint_obj = new_trainable.save_to_object() new_trainable.restore_from_object(checkpoint_obj) checkpoint = new_trainable.save() + new_trainable.stop() - new_trainable2 = wrapped() + new_trainable2 = wrapped(logger_creator=self.logger_creator) new_trainable2.restore(checkpoint) new_trainable2.train() new_trainable2.stop() + def testFunctionImmediateSave(self): + """This tests that save and restore are commutative.""" + + def train(config, checkpoint_dir=None): + if checkpoint_dir: + assert os.path.exists(checkpoint_dir) + for step in range(10): + with tune.checkpoint_dir(step=step) as checkpoint_dir: + print(checkpoint_dir) + path = os.path.join(checkpoint_dir, + "checkpoint-{}".format(step)) + open(path, "w").close() + tune.report(test=step) + + wrapped = wrap_function(train) + new_trainable = wrapped(logger_creator=self.logger_creator) + new_trainable.train() + new_trainable.train() + checkpoint_obj = new_trainable.save_to_object() + new_trainable.stop() + + new_trainable2 = wrapped(logger_creator=self.logger_creator) + new_trainable2.restore_from_object(checkpoint_obj) + checkpoint_obj = new_trainable2.save_to_object() + new_trainable2.train() + result = new_trainable2.train() + assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1 + new_trainable2.stop() + assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0 + assert result[TRAINING_ITERATION] == 4 + + +class FunctionApiTest(unittest.TestCase): + def setUp(self): + ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024) + + def tearDown(self): + ray.shutdown() + _register_all() # re-register the evicted objects + def testCheckpointFunctionAtEnd(self): def train(config, checkpoint_dir=False): for i in range(10): @@ -90,12 +327,12 @@ class FunctionApiTest(unittest.TestCase): def testVariousCheckpointFunctionAtEnd(self): def train(config, checkpoint_dir=False): for i in range(10): - with tune.checkpoint_dir() as checkpoint_dir: + with tune.checkpoint_dir(step=i) as checkpoint_dir: checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log") with open(checkpoint_path, "w") as f: f.write("hello") tune.report(test=i) - with tune.checkpoint_dir() as checkpoint_dir: + with tune.checkpoint_dir(step=i) as checkpoint_dir: checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log2") with open(checkpoint_path, "w") as f: f.write("goodbye") @@ -164,7 +401,7 @@ class FunctionApiTest(unittest.TestCase): for i in range(itr, 10): if i == 5 and not restored: raise Exception("try to fail me") - with tune.checkpoint_dir() as checkpoint_dir: + with tune.checkpoint_dir(step=itr) as checkpoint_dir: checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log") with open(checkpoint_path, "w") as f: f.write(str(i)) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index cb8c2a35e..486a7b587 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -109,19 +109,23 @@ class TrainableUtil: return checkpoint_dir @staticmethod - def make_checkpoint_dir(checkpoint_dir, index): + def make_checkpoint_dir(checkpoint_dir, index, override=False): """Creates a checkpoint directory within the provided path. Args: checkpoint_dir (str): Path to checkpoint directory. index (str): A subdirectory will be created at the checkpoint directory named 'checkpoint_{index}'. + override (bool): Deletes checkpoint_dir before creating + a new one. """ suffix = "checkpoint" if index is not None: suffix += "_{}".format(index) checkpoint_dir = os.path.join(checkpoint_dir, suffix) + if override and os.path.exists(checkpoint_dir): + shutil.rmtree(checkpoint_dir) os.makedirs(checkpoint_dir, exist_ok=True) # Drop marker in directory to identify it as a checkpoint dir. open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()