mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 08:31:42 +08:00
[tune/sgd] Document func_trainable and add checkpoint context (#9739)
Co-authored-by: krfricke <krfricke@users.noreply.github.com> Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
@@ -350,35 +350,47 @@ class FunctionRunner(Trainable):
|
||||
pass
|
||||
|
||||
|
||||
def detect_checkpoint_function(train_func):
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_checkpoint = "checkpoint" in func_args
|
||||
return use_checkpoint
|
||||
def detect_checkpoint_function(train_func, abort=False):
|
||||
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
|
||||
argspec = inspect.getfullargspec(train_func)
|
||||
func_args = argspec.args
|
||||
func_kwargs = argspec.kwonlyargs
|
||||
validated = len(func_args) == 2 and any("checkpoint_dir" in arg
|
||||
for arg in func_args)
|
||||
validated = validated or (len(func_args) == 1) and any(
|
||||
"checkpoint_dir" in arg for arg in func_kwargs)
|
||||
if abort and not validated:
|
||||
raise ValueError(
|
||||
"Provided training function must have 2 args "
|
||||
"in the signature, and the latter arg must "
|
||||
"contain `checkpoint_dir`. For example: "
|
||||
"`func(config, checkpoint_dir=None)`. Got {}".format(func_args))
|
||||
return validated
|
||||
|
||||
|
||||
def wrap_function(train_func):
|
||||
class ImplicitFunc(FunctionRunner):
|
||||
def _trainable_func(self, config, reporter, checkpoint):
|
||||
def _trainable_func(self, config, reporter, checkpoint_dir):
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
if len(func_args) > 1: # more arguments than just the config
|
||||
if "reporter" not in func_args and (
|
||||
"checkpoint" not in func_args):
|
||||
not detect_checkpoint_function(train_func)):
|
||||
raise ValueError(
|
||||
"Unknown argument found in the Trainable function. "
|
||||
"Arguments other than the 'config' arg must be one "
|
||||
"of ['reporter', 'checkpoint']. Found: {}".format(
|
||||
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
|
||||
func_args))
|
||||
use_reporter = "reporter" in func_args
|
||||
use_checkpoint = "checkpoint" in func_args
|
||||
use_checkpoint = detect_checkpoint_function(train_func)
|
||||
if not use_checkpoint and not use_reporter:
|
||||
logger.warning(
|
||||
"Function checkpointing is disabled. This may result in "
|
||||
"unexpected behavior when using checkpointing features or "
|
||||
"certain schedulers. To enable, set the train function "
|
||||
"arguments to be `func(config, checkpoint)`.")
|
||||
"arguments to be `func(config, checkpoint_dir=None)`.")
|
||||
output = train_func(config)
|
||||
elif use_checkpoint:
|
||||
output = train_func(config, checkpoint=checkpoint)
|
||||
output = train_func(config, checkpoint_dir=checkpoint_dir)
|
||||
else:
|
||||
output = train_func(config, reporter)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user