mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 16:54:21 +08:00
Expand local_dir in Trial init (#2013)
* Fix the case where Trial logs into wrong paths when `local_dir` argument starts with tilde (~), by expanding the `local_dir` argument * Add test case for checking that the tilde gets expanded
This commit is contained in:
committed by
Richard Liaw
parent
b1e32ca6c2
commit
2048b546ff
@@ -161,6 +161,26 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
}
|
||||
})
|
||||
|
||||
def testLogdirStartingWithTilde(self):
|
||||
local_dir = '~/ray_results/local_dir'
|
||||
|
||||
def train(config, reporter):
|
||||
cwd = os.getcwd()
|
||||
assert cwd.startswith(os.path.expanduser(local_dir)), cwd
|
||||
assert not cwd.startswith('~'), cwd
|
||||
reporter(timesteps_total=1)
|
||||
|
||||
register_trainable('f1', train)
|
||||
run_experiments({
|
||||
'foo': {
|
||||
'run': 'f1',
|
||||
'local_dir': local_dir,
|
||||
'config': {
|
||||
'a': 'b'
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
def testLongFilename(self):
|
||||
def train(config, reporter):
|
||||
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
|
||||
|
||||
@@ -110,7 +110,7 @@ class Trial(object):
|
||||
# Trial config
|
||||
self.trainable_name = trainable_name
|
||||
self.config = config or {}
|
||||
self.local_dir = local_dir
|
||||
self.local_dir = os.path.expanduser(local_dir)
|
||||
self.experiment_tag = experiment_tag
|
||||
self.resources = (
|
||||
resources
|
||||
|
||||
Reference in New Issue
Block a user