From 2048b546ff93d41a1d3f644a56297b615c3fd9ac Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Mon, 7 May 2018 21:44:28 -0700 Subject: [PATCH] Expand local_dir in Trial init (#2013) * Fix the case where Trial logs into wrong paths when `local_dir` argument starts with tilde (~), by expanding the `local_dir` argument * Add test case for checking that the tilde gets expanded --- python/ray/tune/test/trial_runner_test.py | 20 ++++++++++++++++++++ python/ray/tune/trial.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index d51f9ec6f..2eba2693d 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -161,6 +161,26 @@ class TrainableFunctionApiTest(unittest.TestCase): } }) + def testLogdirStartingWithTilde(self): + local_dir = '~/ray_results/local_dir' + + def train(config, reporter): + cwd = os.getcwd() + assert cwd.startswith(os.path.expanduser(local_dir)), cwd + assert not cwd.startswith('~'), cwd + reporter(timesteps_total=1) + + register_trainable('f1', train) + run_experiments({ + 'foo': { + 'run': 'f1', + 'local_dir': local_dir, + 'config': { + 'a': 'b' + }, + } + }) + def testLongFilename(self): def train(config, reporter): assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd() diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 9d12e768c..f94c09b60 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -110,7 +110,7 @@ class Trial(object): # Trial config self.trainable_name = trainable_name self.config = config or {} - self.local_dir = local_dir + self.local_dir = os.path.expanduser(local_dir) self.experiment_tag = experiment_tag self.resources = ( resources