mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 08:31:18 +08:00
[tune] Fix registering trainable twice (#2293)
* register twice * isolate * Update registry.py * Update registry.py
This commit is contained in:
@@ -17,7 +17,7 @@ def _internal_kv_get(key):
|
||||
return worker.redis_client.hget(key, "value")
|
||||
|
||||
|
||||
def _internal_kv_put(key, value):
|
||||
def _internal_kv_put(key, value, overwrite=False):
|
||||
"""Globally associates a value with a given binary key.
|
||||
|
||||
This only has an effect if the key does not already have a value.
|
||||
@@ -27,5 +27,8 @@ def _internal_kv_put(key, value):
|
||||
"""
|
||||
|
||||
worker = ray.worker.get_global_worker()
|
||||
updated = worker.redis_client.hsetnx(key, "value", value)
|
||||
if overwrite:
|
||||
updated = worker.redis_client.hset(key, "value", value)
|
||||
else:
|
||||
updated = worker.redis_client.hsetnx(key, "value", value)
|
||||
return updated == 0 # already exists
|
||||
|
||||
@@ -98,7 +98,7 @@ class _Registry(object):
|
||||
|
||||
def flush_values(self):
|
||||
for (category, key), value in self._to_flush.items():
|
||||
_internal_kv_put(_make_key(category, key), value)
|
||||
_internal_kv_put(_make_key(category, key), value, overwrite=True)
|
||||
self._to_flush.clear()
|
||||
|
||||
|
||||
|
||||
@@ -61,6 +61,26 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
register_env("foo", lambda: None)
|
||||
self.assertRaises(TypeError, lambda: register_env("foo", 2))
|
||||
|
||||
def testRegisterEnvOverwrite(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=100, done=True)
|
||||
|
||||
def train2(config, reporter):
|
||||
reporter(timesteps_total=200, done=True)
|
||||
|
||||
register_trainable("f1", train)
|
||||
register_trainable("f1", train2)
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 200)
|
||||
|
||||
def testRegisterTrainable(self):
|
||||
def train(config, reporter):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user