diff --git a/python/ray/tune/tests/test_tune_restore.py b/python/ray/tune/tests/test_tune_restore.py new file mode 100644 index 000000000..3742cf598 --- /dev/null +++ b/python/ray/tune/tests/test_tune_restore.py @@ -0,0 +1,57 @@ +# coding: utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile +import unittest + +import ray +from ray import tune +from ray.tune.util import recursive_fnmatch +from ray.rllib import _register_all + + +class TuneRestoreTest(unittest.TestCase): + def setUp(self): + ray.init(num_cpus=1, num_gpus=0, local_mode=True) + tmpdir = tempfile.mkdtemp() + test_name = "TuneRestoreTest" + tune.run( + "PG", + name=test_name, + stop={"training_iteration": 1}, + checkpoint_freq=1, + local_dir=tmpdir, + config={ + "env": "CartPole-v0", + }, + ) + + logdir = os.path.expanduser(os.path.join(tmpdir, test_name)) + self.logdir = logdir + self.checkpoint_path = recursive_fnmatch(logdir, "checkpoint-1")[0] + + def tearDown(self): + shutil.rmtree(self.logdir) + ray.shutdown() + _register_all() + + def testTuneRestore(self): + self.assertTrue(os.path.isfile(self.checkpoint_path)) + tune.run( + "PG", + name="TuneRestoreTest", + stop={"training_iteration": 2}, # train one more iteration. + checkpoint_freq=1, + restore=self.checkpoint_path, # Restore the checkpoint + config={ + "env": "CartPole-v0", + }, + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 03ab3cd32..51f5dcdf2 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -190,10 +190,22 @@ def run(run_or_experiment, experiment = run_or_experiment if not isinstance(run_or_experiment, Experiment): experiment = Experiment( - name, run_or_experiment, stop, config, resources_per_trial, - num_samples, local_dir, upload_dir, trial_name_creator, loggers, - sync_function, checkpoint_freq, checkpoint_at_end, export_formats, - max_failures, restore) + name=name, + run=run_or_experiment, + stop=stop, + config=config, + resources_per_trial=resources_per_trial, + num_samples=num_samples, + local_dir=local_dir, + upload_dir=upload_dir, + trial_name_creator=trial_name_creator, + loggers=loggers, + sync_function=sync_function, + checkpoint_freq=checkpoint_freq, + checkpoint_at_end=checkpoint_at_end, + export_formats=export_formats, + max_failures=max_failures, + restore=restore) else: logger.debug("Ignoring some parameters passed into tune.run.") diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py index a9a9ace94..7440ba49f 100644 --- a/python/ray/tune/util.py +++ b/python/ray/tune/util.py @@ -4,6 +4,8 @@ from __future__ import print_function import logging import base64 +import fnmatch +import os import copy import numpy as np import time @@ -128,6 +130,18 @@ def _from_pinnable(obj): return obj[0] +def recursive_fnmatch(dirpath, pattern): + """Looks at a file directory subtree for a filename pattern. + + Similar to glob.glob(..., recursive=True) but also supports 2.7 + """ + matches = [] + for root, dirnames, filenames in os.walk(dirpath): + for filename in fnmatch.filter(filenames, pattern): + matches.append(os.path.join(root, filename)) + return matches + + if __name__ == "__main__": ray.init() X = pin_in_object_store("hello")