[tune] pass trainable function name when using tune.with_parameters (#14009)

This commit is contained in:
Kai Fricke
2021-02-09 17:51:14 +01:00
committed by GitHub
parent d7301a51f4
commit 3c8b164882
2 changed files with 8 additions and 0 deletions
+6
View File
@@ -644,16 +644,22 @@ def with_parameters(fn, **kwargs):
fn_kwargs[k] = parameter_registry.get(prefix + k)
fn(config, **fn_kwargs)
fn_name = getattr(fn, "__name__", "tune_with_parameters")
inner.__name__ = fn_name
# Use correct function signature if no `checkpoint_dir` parameter is set
if not use_checkpoint:
def _inner(config):
inner(config, checkpoint_dir=None)
_inner.__name__ = fn_name
if hasattr(fn, "__mixins__"):
_inner.__mixins__ = fn.__mixins__
return _inner
if hasattr(fn, "__mixins__"):
inner.__mixins__ = fn.__mixins__
return inner
@@ -455,6 +455,7 @@ class FunctionApiTest(unittest.TestCase):
self.assertEquals(trial_1.last_result["hundred"], 1)
self.assertEquals(trial_2.last_result["metric"], 500_000)
self.assertEquals(trial_2.last_result["hundred"], 1)
self.assertTrue(str(trial_1).startswith("train_"))
# With checkpoint dir parameter
def train(config, checkpoint_dir="DIR", data=None):
@@ -469,6 +470,7 @@ class FunctionApiTest(unittest.TestCase):
self.assertEquals(trial_1.last_result["cp"], "DIR")
self.assertEquals(trial_2.last_result["metric"], 500_000)
self.assertEquals(trial_2.last_result["cp"], "DIR")
self.assertTrue(str(trial_1).startswith("train_"))
def testWithParameters2(self):
class Data: