[tune] More robust resolution/detection of signature (#10365)

Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
Richard Liaw
2020-09-08 11:38:16 -07:00
committed by GitHub
parent 39c598bab0
commit 5851e893ee
7 changed files with 70 additions and 36 deletions
+1 -1
View File
@@ -5,11 +5,11 @@ import os
from typing import Sequence
from ray.tune.error import TuneError
from ray.tune.function_runner import detect_checkpoint_function
from ray.tune.registry import register_trainable, get_trainable_cls
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.sample import Domain
from ray.tune.stopper import FunctionStopper, Stopper
from ray.tune.utils import detect_checkpoint_function
logger = logging.getLogger(__name__)
+13 -27
View File
@@ -14,6 +14,8 @@ from ray.tune import TuneError, session
from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
SHOULD_CHECKPOINT)
from ray.tune.utils import (detect_checkpoint_function, detect_config_single,
detect_reporter)
logger = logging.getLogger(__name__)
@@ -459,24 +461,6 @@ class FunctionRunner(Trainable):
pass
def detect_checkpoint_function(train_func, abort=False):
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
argspec = inspect.getfullargspec(train_func)
func_args = argspec.args
func_kwargs = argspec.kwonlyargs
validated = len(func_args) == 2 and any("checkpoint_dir" in arg
for arg in func_args)
validated = validated or (len(func_args) == 1) and any(
"checkpoint_dir" in arg for arg in func_kwargs)
if abort and not validated:
raise ValueError(
"Provided training function must have 2 args "
"in the signature, and the latter arg must "
"contain `checkpoint_dir`. For example: "
"`func(config, checkpoint_dir=None)`. Got {}".format(func_args))
return validated
def wrap_function(train_func, warn=True):
if hasattr(train_func, "__mixins__"):
inherit_from = train_func.__mixins__ + (FunctionRunner, )
@@ -485,16 +469,18 @@ def wrap_function(train_func, warn=True):
func_args = inspect.getfullargspec(train_func).args
use_checkpoint = detect_checkpoint_function(train_func)
if len(func_args) > 1: # more arguments than just the config
if "reporter" not in func_args and not use_checkpoint:
raise ValueError(
"Unknown argument found in the Trainable function. "
"Arguments other than the 'config' arg must be one "
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
func_args))
use_config_single = detect_config_single(train_func)
use_reporter = detect_reporter(train_func)
use_reporter = "reporter" in func_args
if not use_checkpoint and not use_reporter:
if not any([use_checkpoint, use_config_single, use_reporter]):
# use_reporter is hidden
raise ValueError(
"Unknown argument found in the Trainable function. "
"The function args must include a 'config' positional "
"parameter. Any other args must be 'checkpoint_dir'. "
"Found: {}".format(func_args))
if use_config_single and not use_checkpoint:
if log_once("tune_function_checkpoint") and warn:
logger.warning(
"Function checkpointing is disabled. This may result in "
+2 -2
View File
@@ -12,10 +12,10 @@ import ray
from ray import tune
from ray.tune.result import RESULT_DUPLICATE
from ray.tune.logger import NoopLogger
from ray.tune.function_runner import (wrap_function,
detect_checkpoint_function)
from ray.tune.function_runner import wrap_function
from ray.tune.resources import Resources
from ray.tune.trainable import TrainableUtil
from ray.tune.utils import detect_checkpoint_function
from ray.util.sgd.torch.utils import setup_process_group, setup_address
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
+2 -2
View File
@@ -79,7 +79,7 @@ class ValidateUtilTest(unittest.TestCase):
# this is not serializable
e = threading.Event()
def test():
def test(config):
print(e)
assert diagnose_serialization(test) is not True
@@ -88,7 +88,7 @@ class ValidateUtilTest(unittest.TestCase):
# the `test` scope.
# correct implementation
def test():
def test(config):
e = threading.Event()
print(e)
+7 -3
View File
@@ -1,10 +1,14 @@
from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \
validate_save_restore, warn_if_slow, diagnose_serialization, env_integer
from ray.tune.utils.util import (
deep_update, flatten_dict, get_pinned_object, merge_dicts,
pin_in_object_store, unflattened_lookup, UtilMonitor,
validate_save_restore, warn_if_slow, diagnose_serialization,
detect_checkpoint_function, detect_reporter, detect_config_single,
env_integer)
__all__ = [
"deep_update", "flatten_dict", "get_pinned_object", "merge_dicts",
"pin_in_object_store", "unflattened_lookup", "UtilMonitor",
"validate_save_restore", "warn_if_slow", "diagnose_serialization",
"detect_checkpoint_function", "detect_reporter", "detect_config_single",
"env_integer"
]
+44
View File
@@ -435,6 +435,50 @@ def validate_save_restore(trainable_cls,
return True
def detect_checkpoint_function(train_func, abort=False):
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
func_sig = inspect.signature(train_func)
validated = True
try:
# check if signature is func(config, checkpoint_dir=None)
func_sig.bind({}, checkpoint_dir="tmp/path")
except Exception as e:
logger.debug(str(e))
validated = False
if abort and not validated:
func_args = inspect.getfullargspec(train_func).args
raise ValueError(
"Provided training function must have 2 args "
"in the signature, and the latter arg must "
"contain `checkpoint_dir`. For example: "
"`func(config, checkpoint_dir=None)`. Got {}".format(func_args))
return validated
def detect_reporter(func):
"""Use reporter if any arg has "reporter" and args = 2"""
func_sig = inspect.signature(func)
use_reporter = True
try:
func_sig.bind({}, reporter=None)
except Exception as e:
logger.debug(str(e))
use_reporter = False
return use_reporter
def detect_config_single(func):
"""Check if func({}) works."""
func_sig = inspect.signature(func)
use_config_single = True
try:
func_sig.bind({})
except Exception as e:
logger.debug(str(e))
use_config_single = False
return use_config_single
if __name__ == "__main__":
ray.init()
X = pin_in_object_store("hello")
+1 -1
View File
@@ -111,7 +111,7 @@ if os.getenv("RAY_USE_NEW_GCS") == "on":
extras = {
"debug": [],
"serve": ["uvicorn", "flask", "requests", "pydantic", "dataclasses"],
"tune": ["tabulate", "tensorboardX", "pandas"]
"tune": ["tabulate", "tensorboardX", "pandas", "dataclasses"]
}
extras["rllib"] = extras["tune"] + [