[tune] checkpoint_dir test (#8024)

This commit is contained in:
Richard Liaw
2020-06-20 17:56:24 -07:00
committed by GitHub
parent 8fa584a445
commit e6ee39a6a3
@@ -9,6 +9,7 @@ import ray
from ray import tune
from ray.rllib import _register_all
from ray.tune import Trainable
from ray.tune.utils import validate_save_restore
class SerialTuneRelativeLocalDirTest(unittest.TestCase):
@@ -37,11 +38,13 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
self.state.update(extra_data)
def setUp(self):
self.absolute_local_dir = None
ray.init(num_cpus=1, num_gpus=0, local_mode=self.local_mode)
def tearDown(self):
shutil.rmtree(self.absolute_local_dir, ignore_errors=True)
self.absolute_local_dir = None
if self.absolute_local_dir is not None:
shutil.rmtree(self.absolute_local_dir, ignore_errors=True)
self.absolute_local_dir = None
ray.shutdown()
# Without this line, test_tune_server.testAddTrial would fail.
_register_all()
@@ -147,6 +150,31 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
self._train(exp_name, local_dir, local_dir)
self._restore(exp_name, local_dir, local_dir)
def testCheckpointWithNoop(self):
"""Tests that passing the checkpoint_dir right back works."""
class MockTrainable(Trainable):
def _setup(self, config):
pass
def _train(self):
return {"score": 1}
def _save(self, checkpoint_dir):
with open(os.path.join(checkpoint_dir, "test.txt"), "wb") as f:
pickle.dump("test", f)
return checkpoint_dir
def _restore(self, checkpoint_dir):
with open(os.path.join(checkpoint_dir, "test.txt"), "rb") as f:
x = pickle.load(f)
assert x == "test"
return checkpoint_dir
validate_save_restore(MockTrainable)
validate_save_restore(MockTrainable, use_object_store=True)
if __name__ == "__main__":
import pytest