mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 20:07:41 +08:00
[tune] More robust resolution/detection of signature (#10365)
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,8 @@ from ray.tune import TuneError, session
|
||||
from ray.tune.trainable import Trainable, TrainableUtil
|
||||
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||
SHOULD_CHECKPOINT)
|
||||
from ray.tune.utils import (detect_checkpoint_function, detect_config_single,
|
||||
detect_reporter)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -459,24 +461,6 @@ class FunctionRunner(Trainable):
|
||||
pass
|
||||
|
||||
|
||||
def detect_checkpoint_function(train_func, abort=False):
|
||||
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
|
||||
argspec = inspect.getfullargspec(train_func)
|
||||
func_args = argspec.args
|
||||
func_kwargs = argspec.kwonlyargs
|
||||
validated = len(func_args) == 2 and any("checkpoint_dir" in arg
|
||||
for arg in func_args)
|
||||
validated = validated or (len(func_args) == 1) and any(
|
||||
"checkpoint_dir" in arg for arg in func_kwargs)
|
||||
if abort and not validated:
|
||||
raise ValueError(
|
||||
"Provided training function must have 2 args "
|
||||
"in the signature, and the latter arg must "
|
||||
"contain `checkpoint_dir`. For example: "
|
||||
"`func(config, checkpoint_dir=None)`. Got {}".format(func_args))
|
||||
return validated
|
||||
|
||||
|
||||
def wrap_function(train_func, warn=True):
|
||||
if hasattr(train_func, "__mixins__"):
|
||||
inherit_from = train_func.__mixins__ + (FunctionRunner, )
|
||||
@@ -485,16 +469,18 @@ def wrap_function(train_func, warn=True):
|
||||
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_checkpoint = detect_checkpoint_function(train_func)
|
||||
if len(func_args) > 1: # more arguments than just the config
|
||||
if "reporter" not in func_args and not use_checkpoint:
|
||||
raise ValueError(
|
||||
"Unknown argument found in the Trainable function. "
|
||||
"Arguments other than the 'config' arg must be one "
|
||||
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
|
||||
func_args))
|
||||
use_config_single = detect_config_single(train_func)
|
||||
use_reporter = detect_reporter(train_func)
|
||||
|
||||
use_reporter = "reporter" in func_args
|
||||
if not use_checkpoint and not use_reporter:
|
||||
if not any([use_checkpoint, use_config_single, use_reporter]):
|
||||
# use_reporter is hidden
|
||||
raise ValueError(
|
||||
"Unknown argument found in the Trainable function. "
|
||||
"The function args must include a 'config' positional "
|
||||
"parameter. Any other args must be 'checkpoint_dir'. "
|
||||
"Found: {}".format(func_args))
|
||||
|
||||
if use_config_single and not use_checkpoint:
|
||||
if log_once("tune_function_checkpoint") and warn:
|
||||
logger.warning(
|
||||
"Function checkpointing is disabled. This may result in "
|
||||
|
||||
Reference in New Issue
Block a user