[tune/sgd] Document func_trainable and add checkpoint context (#9739)

Co-authored-by: krfricke <krfricke@users.noreply.github.com>
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
Richard Liaw
2020-07-30 09:46:37 -07:00
committed by GitHub
parent e540e425e4
commit 0c3b9ebeef
23 changed files with 619 additions and 452 deletions
+22 -10
View File
@@ -350,35 +350,47 @@ class FunctionRunner(Trainable):
pass
def detect_checkpoint_function(train_func):
func_args = inspect.getfullargspec(train_func).args
use_checkpoint = "checkpoint" in func_args
return use_checkpoint
def detect_checkpoint_function(train_func, abort=False):
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
argspec = inspect.getfullargspec(train_func)
func_args = argspec.args
func_kwargs = argspec.kwonlyargs
validated = len(func_args) == 2 and any("checkpoint_dir" in arg
for arg in func_args)
validated = validated or (len(func_args) == 1) and any(
"checkpoint_dir" in arg for arg in func_kwargs)
if abort and not validated:
raise ValueError(
"Provided training function must have 2 args "
"in the signature, and the latter arg must "
"contain `checkpoint_dir`. For example: "
"`func(config, checkpoint_dir=None)`. Got {}".format(func_args))
return validated
def wrap_function(train_func):
class ImplicitFunc(FunctionRunner):
def _trainable_func(self, config, reporter, checkpoint):
def _trainable_func(self, config, reporter, checkpoint_dir):
func_args = inspect.getfullargspec(train_func).args
if len(func_args) > 1: # more arguments than just the config
if "reporter" not in func_args and (
"checkpoint" not in func_args):
not detect_checkpoint_function(train_func)):
raise ValueError(
"Unknown argument found in the Trainable function. "
"Arguments other than the 'config' arg must be one "
"of ['reporter', 'checkpoint']. Found: {}".format(
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
func_args))
use_reporter = "reporter" in func_args
use_checkpoint = "checkpoint" in func_args
use_checkpoint = detect_checkpoint_function(train_func)
if not use_checkpoint and not use_reporter:
logger.warning(
"Function checkpointing is disabled. This may result in "
"unexpected behavior when using checkpointing features or "
"certain schedulers. To enable, set the train function "
"arguments to be `func(config, checkpoint)`.")
"arguments to be `func(config, checkpoint_dir=None)`.")
output = train_func(config)
elif use_checkpoint:
output = train_func(config, checkpoint=checkpoint)
output = train_func(config, checkpoint_dir=checkpoint_dir)
else:
output = train_func(config, reporter)