[tune] Fix registering trainable twice (#2293)

* register twice

* isolate

* Update registry.py

* Update registry.py
This commit is contained in:
Eric Liang
2018-06-27 16:29:39 -07:00
committed by GitHub
parent 965e182384
commit 737f3e3cf2
3 changed files with 26 additions and 3 deletions
+5 -2
View File
@@ -17,7 +17,7 @@ def _internal_kv_get(key):
return worker.redis_client.hget(key, "value")
def _internal_kv_put(key, value):
def _internal_kv_put(key, value, overwrite=False):
"""Globally associates a value with a given binary key.
This only has an effect if the key does not already have a value.
@@ -27,5 +27,8 @@ def _internal_kv_put(key, value):
"""
worker = ray.worker.get_global_worker()
updated = worker.redis_client.hsetnx(key, "value", value)
if overwrite:
updated = worker.redis_client.hset(key, "value", value)
else:
updated = worker.redis_client.hsetnx(key, "value", value)
return updated == 0 # already exists
+1 -1
View File
@@ -98,7 +98,7 @@ class _Registry(object):
def flush_values(self):
for (category, key), value in self._to_flush.items():
_internal_kv_put(_make_key(category, key), value)
_internal_kv_put(_make_key(category, key), value, overwrite=True)
self._to_flush.clear()
+20
View File
@@ -61,6 +61,26 @@ class TrainableFunctionApiTest(unittest.TestCase):
register_env("foo", lambda: None)
self.assertRaises(TypeError, lambda: register_env("foo", 2))
def testRegisterEnvOverwrite(self):
def train(config, reporter):
reporter(timesteps_total=100, done=True)
def train2(config, reporter):
reporter(timesteps_total=200, done=True)
register_trainable("f1", train)
register_trainable("f1", train2)
[trial] = run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 200)
def testRegisterTrainable(self):
def train(config, reporter):
pass