mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 10:33:24 +08:00
[tune] cleanup error messaging/diagnose_serialization helper (#10210)
This commit is contained in:
@@ -9,6 +9,7 @@ import uuid
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.tune import TuneError, session
|
||||
from ray.tune.trainable import Trainable, TrainableUtil
|
||||
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||
@@ -476,34 +477,37 @@ def detect_checkpoint_function(train_func, abort=False):
|
||||
return validated
|
||||
|
||||
|
||||
def wrap_function(train_func):
|
||||
def wrap_function(train_func, warn=True):
|
||||
if hasattr(train_func, "__mixins__"):
|
||||
inherit_from = train_func.__mixins__ + (FunctionRunner, )
|
||||
else:
|
||||
inherit_from = (FunctionRunner, )
|
||||
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_checkpoint = detect_checkpoint_function(train_func)
|
||||
if len(func_args) > 1: # more arguments than just the config
|
||||
if "reporter" not in func_args and not use_checkpoint:
|
||||
raise ValueError(
|
||||
"Unknown argument found in the Trainable function. "
|
||||
"Arguments other than the 'config' arg must be one "
|
||||
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
|
||||
func_args))
|
||||
|
||||
use_reporter = "reporter" in func_args
|
||||
if not use_checkpoint and not use_reporter:
|
||||
if log_once("tune_function_checkpoint") and warn:
|
||||
logger.warning(
|
||||
"Function checkpointing is disabled. This may result in "
|
||||
"unexpected behavior when using checkpointing features or "
|
||||
"certain schedulers. To enable, set the train function "
|
||||
"arguments to be `func(config, checkpoint_dir=None)`.")
|
||||
|
||||
class ImplicitFunc(*inherit_from):
|
||||
_name = train_func.__name__ if hasattr(train_func, "__name__") \
|
||||
else "func"
|
||||
|
||||
def _trainable_func(self, config, reporter, checkpoint_dir):
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
if len(func_args) > 1: # more arguments than just the config
|
||||
if "reporter" not in func_args and (
|
||||
not detect_checkpoint_function(train_func)):
|
||||
raise ValueError(
|
||||
"Unknown argument found in the Trainable function. "
|
||||
"Arguments other than the 'config' arg must be one "
|
||||
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
|
||||
func_args))
|
||||
use_reporter = "reporter" in func_args
|
||||
use_checkpoint = detect_checkpoint_function(train_func)
|
||||
if not use_checkpoint and not use_reporter:
|
||||
logger.warning(
|
||||
"Function checkpointing is disabled. This may result in "
|
||||
"unexpected behavior when using checkpointing features or "
|
||||
"certain schedulers. To enable, set the train function "
|
||||
"arguments to be `func(config, checkpoint_dir=None)`.")
|
||||
output = train_func(config)
|
||||
elif use_checkpoint:
|
||||
output = train_func(config, checkpoint_dir=checkpoint_dir)
|
||||
|
||||
Reference in New Issue
Block a user