[tune] fix restore error at tune.run() (#4733)

This commit is contained in:
Peng Zhenghao
2019-05-04 14:56:15 +08:00
committed by Richard Liaw
parent 36b71d1446
commit 897b35ce36
3 changed files with 87 additions and 4 deletions
@@ -0,0 +1,57 @@
# coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import tempfile
import unittest
import ray
from ray import tune
from ray.tune.util import recursive_fnmatch
from ray.rllib import _register_all
class TuneRestoreTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=1, num_gpus=0, local_mode=True)
tmpdir = tempfile.mkdtemp()
test_name = "TuneRestoreTest"
tune.run(
"PG",
name=test_name,
stop={"training_iteration": 1},
checkpoint_freq=1,
local_dir=tmpdir,
config={
"env": "CartPole-v0",
},
)
logdir = os.path.expanduser(os.path.join(tmpdir, test_name))
self.logdir = logdir
self.checkpoint_path = recursive_fnmatch(logdir, "checkpoint-1")[0]
def tearDown(self):
shutil.rmtree(self.logdir)
ray.shutdown()
_register_all()
def testTuneRestore(self):
self.assertTrue(os.path.isfile(self.checkpoint_path))
tune.run(
"PG",
name="TuneRestoreTest",
stop={"training_iteration": 2}, # train one more iteration.
checkpoint_freq=1,
restore=self.checkpoint_path, # Restore the checkpoint
config={
"env": "CartPole-v0",
},
)
if __name__ == "__main__":
unittest.main(verbosity=2)
+16 -4
View File
@@ -190,10 +190,22 @@ def run(run_or_experiment,
experiment = run_or_experiment
if not isinstance(run_or_experiment, Experiment):
experiment = Experiment(
name, run_or_experiment, stop, config, resources_per_trial,
num_samples, local_dir, upload_dir, trial_name_creator, loggers,
sync_function, checkpoint_freq, checkpoint_at_end, export_formats,
max_failures, restore)
name=name,
run=run_or_experiment,
stop=stop,
config=config,
resources_per_trial=resources_per_trial,
num_samples=num_samples,
local_dir=local_dir,
upload_dir=upload_dir,
trial_name_creator=trial_name_creator,
loggers=loggers,
sync_function=sync_function,
checkpoint_freq=checkpoint_freq,
checkpoint_at_end=checkpoint_at_end,
export_formats=export_formats,
max_failures=max_failures,
restore=restore)
else:
logger.debug("Ignoring some parameters passed into tune.run.")
+14
View File
@@ -4,6 +4,8 @@ from __future__ import print_function
import logging
import base64
import fnmatch
import os
import copy
import numpy as np
import time
@@ -128,6 +130,18 @@ def _from_pinnable(obj):
return obj[0]
def recursive_fnmatch(dirpath, pattern):
"""Looks at a file directory subtree for a filename pattern.
Similar to glob.glob(..., recursive=True) but also supports 2.7
"""
matches = []
for root, dirnames, filenames in os.walk(dirpath):
for filename in fnmatch.filter(filenames, pattern):
matches.append(os.path.join(root, filename))
return matches
if __name__ == "__main__":
ray.init()
X = pin_in_object_store("hello")