diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index 9da6b2601..c7c088293 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -644,16 +644,22 @@ def with_parameters(fn, **kwargs): fn_kwargs[k] = parameter_registry.get(prefix + k) fn(config, **fn_kwargs) + fn_name = getattr(fn, "__name__", "tune_with_parameters") + inner.__name__ = fn_name + # Use correct function signature if no `checkpoint_dir` parameter is set if not use_checkpoint: def _inner(config): inner(config, checkpoint_dir=None) + _inner.__name__ = fn_name + if hasattr(fn, "__mixins__"): _inner.__mixins__ = fn.__mixins__ return _inner if hasattr(fn, "__mixins__"): inner.__mixins__ = fn.__mixins__ + return inner diff --git a/python/ray/tune/tests/test_function_api.py b/python/ray/tune/tests/test_function_api.py index 9ee2cdc64..f7084a1fa 100644 --- a/python/ray/tune/tests/test_function_api.py +++ b/python/ray/tune/tests/test_function_api.py @@ -455,6 +455,7 @@ class FunctionApiTest(unittest.TestCase): self.assertEquals(trial_1.last_result["hundred"], 1) self.assertEquals(trial_2.last_result["metric"], 500_000) self.assertEquals(trial_2.last_result["hundred"], 1) + self.assertTrue(str(trial_1).startswith("train_")) # With checkpoint dir parameter def train(config, checkpoint_dir="DIR", data=None): @@ -469,6 +470,7 @@ class FunctionApiTest(unittest.TestCase): self.assertEquals(trial_1.last_result["cp"], "DIR") self.assertEquals(trial_2.last_result["metric"], 500_000) self.assertEquals(trial_2.last_result["cp"], "DIR") + self.assertTrue(str(trial_1).startswith("train_")) def testWithParameters2(self): class Data: