[tune] cleanup error messaging/diagnose_serialization helper (#10210)

This commit is contained in:
Richard Liaw
2020-08-22 11:50:49 -07:00
committed by GitHub
parent 24ee496b89
commit 6bd5458bef
8 changed files with 178 additions and 38 deletions
+21 -17
View File
@@ -9,6 +9,7 @@ import uuid
from six.moves import queue
from ray.util.debug import log_once
from ray.tune import TuneError, session
from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
@@ -476,34 +477,37 @@ def detect_checkpoint_function(train_func, abort=False):
return validated
def wrap_function(train_func):
def wrap_function(train_func, warn=True):
if hasattr(train_func, "__mixins__"):
inherit_from = train_func.__mixins__ + (FunctionRunner, )
else:
inherit_from = (FunctionRunner, )
func_args = inspect.getfullargspec(train_func).args
use_checkpoint = detect_checkpoint_function(train_func)
if len(func_args) > 1: # more arguments than just the config
if "reporter" not in func_args and not use_checkpoint:
raise ValueError(
"Unknown argument found in the Trainable function. "
"Arguments other than the 'config' arg must be one "
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
func_args))
use_reporter = "reporter" in func_args
if not use_checkpoint and not use_reporter:
if log_once("tune_function_checkpoint") and warn:
logger.warning(
"Function checkpointing is disabled. This may result in "
"unexpected behavior when using checkpointing features or "
"certain schedulers. To enable, set the train function "
"arguments to be `func(config, checkpoint_dir=None)`.")
class ImplicitFunc(*inherit_from):
_name = train_func.__name__ if hasattr(train_func, "__name__") \
else "func"
def _trainable_func(self, config, reporter, checkpoint_dir):
func_args = inspect.getfullargspec(train_func).args
if len(func_args) > 1: # more arguments than just the config
if "reporter" not in func_args and (
not detect_checkpoint_function(train_func)):
raise ValueError(
"Unknown argument found in the Trainable function. "
"Arguments other than the 'config' arg must be one "
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
func_args))
use_reporter = "reporter" in func_args
use_checkpoint = detect_checkpoint_function(train_func)
if not use_checkpoint and not use_reporter:
logger.warning(
"Function checkpointing is disabled. This may result in "
"unexpected behavior when using checkpointing features or "
"certain schedulers. To enable, set the train function "
"arguments to be `func(config, checkpoint_dir=None)`.")
output = train_func(config)
elif use_checkpoint:
output = train_func(config, checkpoint_dir=checkpoint_dir)