mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[tune] pass trainable function name when using tune.with_parameters (#14009)
This commit is contained in:
@@ -644,16 +644,22 @@ def with_parameters(fn, **kwargs):
|
||||
fn_kwargs[k] = parameter_registry.get(prefix + k)
|
||||
fn(config, **fn_kwargs)
|
||||
|
||||
fn_name = getattr(fn, "__name__", "tune_with_parameters")
|
||||
inner.__name__ = fn_name
|
||||
|
||||
# Use correct function signature if no `checkpoint_dir` parameter is set
|
||||
if not use_checkpoint:
|
||||
|
||||
def _inner(config):
|
||||
inner(config, checkpoint_dir=None)
|
||||
|
||||
_inner.__name__ = fn_name
|
||||
|
||||
if hasattr(fn, "__mixins__"):
|
||||
_inner.__mixins__ = fn.__mixins__
|
||||
return _inner
|
||||
|
||||
if hasattr(fn, "__mixins__"):
|
||||
inner.__mixins__ = fn.__mixins__
|
||||
|
||||
return inner
|
||||
|
||||
@@ -455,6 +455,7 @@ class FunctionApiTest(unittest.TestCase):
|
||||
self.assertEquals(trial_1.last_result["hundred"], 1)
|
||||
self.assertEquals(trial_2.last_result["metric"], 500_000)
|
||||
self.assertEquals(trial_2.last_result["hundred"], 1)
|
||||
self.assertTrue(str(trial_1).startswith("train_"))
|
||||
|
||||
# With checkpoint dir parameter
|
||||
def train(config, checkpoint_dir="DIR", data=None):
|
||||
@@ -469,6 +470,7 @@ class FunctionApiTest(unittest.TestCase):
|
||||
self.assertEquals(trial_1.last_result["cp"], "DIR")
|
||||
self.assertEquals(trial_2.last_result["metric"], 500_000)
|
||||
self.assertEquals(trial_2.last_result["cp"], "DIR")
|
||||
self.assertTrue(str(trial_1).startswith("train_"))
|
||||
|
||||
def testWithParameters2(self):
|
||||
class Data:
|
||||
|
||||
Reference in New Issue
Block a user