From e6ee39a6a347e0c035753f16165002e250f43b4f Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sat, 20 Jun 2020 17:56:24 -0700 Subject: [PATCH] [tune] checkpoint_dir test (#8024) --- .../ray/tune/tests/test_tune_save_restore.py | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/python/ray/tune/tests/test_tune_save_restore.py b/python/ray/tune/tests/test_tune_save_restore.py index 5b476ecdf..dd4d8fe11 100644 --- a/python/ray/tune/tests/test_tune_save_restore.py +++ b/python/ray/tune/tests/test_tune_save_restore.py @@ -9,6 +9,7 @@ import ray from ray import tune from ray.rllib import _register_all from ray.tune import Trainable +from ray.tune.utils import validate_save_restore class SerialTuneRelativeLocalDirTest(unittest.TestCase): @@ -37,11 +38,13 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase): self.state.update(extra_data) def setUp(self): + self.absolute_local_dir = None ray.init(num_cpus=1, num_gpus=0, local_mode=self.local_mode) def tearDown(self): - shutil.rmtree(self.absolute_local_dir, ignore_errors=True) - self.absolute_local_dir = None + if self.absolute_local_dir is not None: + shutil.rmtree(self.absolute_local_dir, ignore_errors=True) + self.absolute_local_dir = None ray.shutdown() # Without this line, test_tune_server.testAddTrial would fail. _register_all() @@ -147,6 +150,31 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase): self._train(exp_name, local_dir, local_dir) self._restore(exp_name, local_dir, local_dir) + def testCheckpointWithNoop(self): + """Tests that passing the checkpoint_dir right back works.""" + + class MockTrainable(Trainable): + def _setup(self, config): + pass + + def _train(self): + return {"score": 1} + + def _save(self, checkpoint_dir): + with open(os.path.join(checkpoint_dir, "test.txt"), "wb") as f: + pickle.dump("test", f) + return checkpoint_dir + + def _restore(self, checkpoint_dir): + with open(os.path.join(checkpoint_dir, "test.txt"), "rb") as f: + x = pickle.load(f) + + assert x == "test" + return checkpoint_dir + + validate_save_restore(MockTrainable) + validate_save_restore(MockTrainable, use_object_store=True) + if __name__ == "__main__": import pytest