[tune] More robust resolution/detection of signature (#10365)

Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
Richard Liaw
2020-09-08 11:38:16 -07:00
committed by GitHub
parent 39c598bab0
commit 5851e893ee
7 changed files with 70 additions and 36 deletions
+13 -27
View File
@@ -14,6 +14,8 @@ from ray.tune import TuneError, session
from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
SHOULD_CHECKPOINT)
from ray.tune.utils import (detect_checkpoint_function, detect_config_single,
detect_reporter)
logger = logging.getLogger(__name__)
@@ -459,24 +461,6 @@ class FunctionRunner(Trainable):
pass
def detect_checkpoint_function(train_func, abort=False):
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
argspec = inspect.getfullargspec(train_func)
func_args = argspec.args
func_kwargs = argspec.kwonlyargs
validated = len(func_args) == 2 and any("checkpoint_dir" in arg
for arg in func_args)
validated = validated or (len(func_args) == 1) and any(
"checkpoint_dir" in arg for arg in func_kwargs)
if abort and not validated:
raise ValueError(
"Provided training function must have 2 args "
"in the signature, and the latter arg must "
"contain `checkpoint_dir`. For example: "
"`func(config, checkpoint_dir=None)`. Got {}".format(func_args))
return validated
def wrap_function(train_func, warn=True):
if hasattr(train_func, "__mixins__"):
inherit_from = train_func.__mixins__ + (FunctionRunner, )
@@ -485,16 +469,18 @@ def wrap_function(train_func, warn=True):
func_args = inspect.getfullargspec(train_func).args
use_checkpoint = detect_checkpoint_function(train_func)
if len(func_args) > 1: # more arguments than just the config
if "reporter" not in func_args and not use_checkpoint:
raise ValueError(
"Unknown argument found in the Trainable function. "
"Arguments other than the 'config' arg must be one "
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
func_args))
use_config_single = detect_config_single(train_func)
use_reporter = detect_reporter(train_func)
use_reporter = "reporter" in func_args
if not use_checkpoint and not use_reporter:
if not any([use_checkpoint, use_config_single, use_reporter]):
# use_reporter is hidden
raise ValueError(
"Unknown argument found in the Trainable function. "
"The function args must include a 'config' positional "
"parameter. Any other args must be 'checkpoint_dir'. "
"Found: {}".format(func_args))
if use_config_single and not use_checkpoint:
if log_once("tune_function_checkpoint") and warn:
logger.warning(
"Function checkpointing is disabled. This may result in "