mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 00:35:15 +08:00
[tune] checkpoint_dir test (#8024)
This commit is contained in:
@@ -9,6 +9,7 @@ import ray
|
||||
from ray import tune
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.utils import validate_save_restore
|
||||
|
||||
|
||||
class SerialTuneRelativeLocalDirTest(unittest.TestCase):
|
||||
@@ -37,11 +38,13 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
|
||||
self.state.update(extra_data)
|
||||
|
||||
def setUp(self):
|
||||
self.absolute_local_dir = None
|
||||
ray.init(num_cpus=1, num_gpus=0, local_mode=self.local_mode)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.absolute_local_dir, ignore_errors=True)
|
||||
self.absolute_local_dir = None
|
||||
if self.absolute_local_dir is not None:
|
||||
shutil.rmtree(self.absolute_local_dir, ignore_errors=True)
|
||||
self.absolute_local_dir = None
|
||||
ray.shutdown()
|
||||
# Without this line, test_tune_server.testAddTrial would fail.
|
||||
_register_all()
|
||||
@@ -147,6 +150,31 @@ class SerialTuneRelativeLocalDirTest(unittest.TestCase):
|
||||
self._train(exp_name, local_dir, local_dir)
|
||||
self._restore(exp_name, local_dir, local_dir)
|
||||
|
||||
def testCheckpointWithNoop(self):
|
||||
"""Tests that passing the checkpoint_dir right back works."""
|
||||
|
||||
class MockTrainable(Trainable):
|
||||
def _setup(self, config):
|
||||
pass
|
||||
|
||||
def _train(self):
|
||||
return {"score": 1}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
with open(os.path.join(checkpoint_dir, "test.txt"), "wb") as f:
|
||||
pickle.dump("test", f)
|
||||
return checkpoint_dir
|
||||
|
||||
def _restore(self, checkpoint_dir):
|
||||
with open(os.path.join(checkpoint_dir, "test.txt"), "rb") as f:
|
||||
x = pickle.load(f)
|
||||
|
||||
assert x == "test"
|
||||
return checkpoint_dir
|
||||
|
||||
validate_save_restore(MockTrainable)
|
||||
validate_save_restore(MockTrainable, use_object_store=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
Reference in New Issue
Block a user