From 737f3e3cf20849babf429273a9ff8e903f205303 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 27 Jun 2018 16:29:39 -0700 Subject: [PATCH] [tune] Fix registering trainable twice (#2293) * register twice * isolate * Update registry.py * Update registry.py --- python/ray/experimental/internal_kv.py | 7 +++++-- python/ray/tune/registry.py | 2 +- python/ray/tune/test/trial_runner_test.py | 20 ++++++++++++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/python/ray/experimental/internal_kv.py b/python/ray/experimental/internal_kv.py index 99b2d73b5..573669d7d 100644 --- a/python/ray/experimental/internal_kv.py +++ b/python/ray/experimental/internal_kv.py @@ -17,7 +17,7 @@ def _internal_kv_get(key): return worker.redis_client.hget(key, "value") -def _internal_kv_put(key, value): +def _internal_kv_put(key, value, overwrite=False): """Globally associates a value with a given binary key. This only has an effect if the key does not already have a value. @@ -27,5 +27,8 @@ def _internal_kv_put(key, value): """ worker = ray.worker.get_global_worker() - updated = worker.redis_client.hsetnx(key, "value", value) + if overwrite: + updated = worker.redis_client.hset(key, "value", value) + else: + updated = worker.redis_client.hsetnx(key, "value", value) return updated == 0 # already exists diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index 3441fc793..204d79c93 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -98,7 +98,7 @@ class _Registry(object): def flush_values(self): for (category, key), value in self._to_flush.items(): - _internal_kv_put(_make_key(category, key), value) + _internal_kv_put(_make_key(category, key), value, overwrite=True) self._to_flush.clear() diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 832bcde28..c72ceab68 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -61,6 +61,26 @@ class TrainableFunctionApiTest(unittest.TestCase): register_env("foo", lambda: None) self.assertRaises(TypeError, lambda: register_env("foo", 2)) + def testRegisterEnvOverwrite(self): + def train(config, reporter): + reporter(timesteps_total=100, done=True) + + def train2(config, reporter): + reporter(timesteps_total=200, done=True) + + register_trainable("f1", train) + register_trainable("f1", train2) + [trial] = run_experiments({ + "foo": { + "run": "f1", + "config": { + "script_min_iter_time_s": 0, + }, + } + }) + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertEqual(trial.last_result.timesteps_total, 200) + def testRegisterTrainable(self): def train(config, reporter): pass