mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:53:20 +08:00
[tune] fix restore error at tune.run() (#4733)
This commit is contained in:
committed by
Richard Liaw
parent
36b71d1446
commit
897b35ce36
@@ -0,0 +1,57 @@
|
||||
# coding: utf-8
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.util import recursive_fnmatch
|
||||
from ray.rllib import _register_all
|
||||
|
||||
|
||||
class TuneRestoreTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=1, num_gpus=0, local_mode=True)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
test_name = "TuneRestoreTest"
|
||||
tune.run(
|
||||
"PG",
|
||||
name=test_name,
|
||||
stop={"training_iteration": 1},
|
||||
checkpoint_freq=1,
|
||||
local_dir=tmpdir,
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
},
|
||||
)
|
||||
|
||||
logdir = os.path.expanduser(os.path.join(tmpdir, test_name))
|
||||
self.logdir = logdir
|
||||
self.checkpoint_path = recursive_fnmatch(logdir, "checkpoint-1")[0]
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.logdir)
|
||||
ray.shutdown()
|
||||
_register_all()
|
||||
|
||||
def testTuneRestore(self):
|
||||
self.assertTrue(os.path.isfile(self.checkpoint_path))
|
||||
tune.run(
|
||||
"PG",
|
||||
name="TuneRestoreTest",
|
||||
stop={"training_iteration": 2}, # train one more iteration.
|
||||
checkpoint_freq=1,
|
||||
restore=self.checkpoint_path, # Restore the checkpoint
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
+16
-4
@@ -190,10 +190,22 @@ def run(run_or_experiment,
|
||||
experiment = run_or_experiment
|
||||
if not isinstance(run_or_experiment, Experiment):
|
||||
experiment = Experiment(
|
||||
name, run_or_experiment, stop, config, resources_per_trial,
|
||||
num_samples, local_dir, upload_dir, trial_name_creator, loggers,
|
||||
sync_function, checkpoint_freq, checkpoint_at_end, export_formats,
|
||||
max_failures, restore)
|
||||
name=name,
|
||||
run=run_or_experiment,
|
||||
stop=stop,
|
||||
config=config,
|
||||
resources_per_trial=resources_per_trial,
|
||||
num_samples=num_samples,
|
||||
local_dir=local_dir,
|
||||
upload_dir=upload_dir,
|
||||
trial_name_creator=trial_name_creator,
|
||||
loggers=loggers,
|
||||
sync_function=sync_function,
|
||||
checkpoint_freq=checkpoint_freq,
|
||||
checkpoint_at_end=checkpoint_at_end,
|
||||
export_formats=export_formats,
|
||||
max_failures=max_failures,
|
||||
restore=restore)
|
||||
else:
|
||||
logger.debug("Ignoring some parameters passed into tune.run.")
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import base64
|
||||
import fnmatch
|
||||
import os
|
||||
import copy
|
||||
import numpy as np
|
||||
import time
|
||||
@@ -128,6 +130,18 @@ def _from_pinnable(obj):
|
||||
return obj[0]
|
||||
|
||||
|
||||
def recursive_fnmatch(dirpath, pattern):
|
||||
"""Looks at a file directory subtree for a filename pattern.
|
||||
|
||||
Similar to glob.glob(..., recursive=True) but also supports 2.7
|
||||
"""
|
||||
matches = []
|
||||
for root, dirnames, filenames in os.walk(dirpath):
|
||||
for filename in fnmatch.filter(filenames, pattern):
|
||||
matches.append(os.path.join(root, filename))
|
||||
return matches
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
X = pin_in_object_store("hello")
|
||||
|
||||
Reference in New Issue
Block a user