diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index a312c1afe..c67646e40 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -5,11 +5,11 @@ import os from typing import Sequence from ray.tune.error import TuneError -from ray.tune.function_runner import detect_checkpoint_function from ray.tune.registry import register_trainable, get_trainable_cls from ray.tune.result import DEFAULT_RESULTS_DIR from ray.tune.sample import Domain from ray.tune.stopper import FunctionStopper, Stopper +from ray.tune.utils import detect_checkpoint_function logger = logging.getLogger(__name__) diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index 1c8b5f49c..e722bacf6 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -14,6 +14,8 @@ from ray.tune import TuneError, session from ray.tune.trainable import Trainable, TrainableUtil from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE, SHOULD_CHECKPOINT) +from ray.tune.utils import (detect_checkpoint_function, detect_config_single, + detect_reporter) logger = logging.getLogger(__name__) @@ -459,24 +461,6 @@ class FunctionRunner(Trainable): pass -def detect_checkpoint_function(train_func, abort=False): - """Use checkpointing if any arg has "checkpoint_dir" and args = 2""" - argspec = inspect.getfullargspec(train_func) - func_args = argspec.args - func_kwargs = argspec.kwonlyargs - validated = len(func_args) == 2 and any("checkpoint_dir" in arg - for arg in func_args) - validated = validated or (len(func_args) == 1) and any( - "checkpoint_dir" in arg for arg in func_kwargs) - if abort and not validated: - raise ValueError( - "Provided training function must have 2 args " - "in the signature, and the latter arg must " - "contain `checkpoint_dir`. For example: " - "`func(config, checkpoint_dir=None)`. Got {}".format(func_args)) - return validated - - def wrap_function(train_func, warn=True): if hasattr(train_func, "__mixins__"): inherit_from = train_func.__mixins__ + (FunctionRunner, ) @@ -485,16 +469,18 @@ def wrap_function(train_func, warn=True): func_args = inspect.getfullargspec(train_func).args use_checkpoint = detect_checkpoint_function(train_func) - if len(func_args) > 1: # more arguments than just the config - if "reporter" not in func_args and not use_checkpoint: - raise ValueError( - "Unknown argument found in the Trainable function. " - "Arguments other than the 'config' arg must be one " - "of ['reporter', 'checkpoint_dir']. Found: {}".format( - func_args)) + use_config_single = detect_config_single(train_func) + use_reporter = detect_reporter(train_func) - use_reporter = "reporter" in func_args - if not use_checkpoint and not use_reporter: + if not any([use_checkpoint, use_config_single, use_reporter]): + # use_reporter is hidden + raise ValueError( + "Unknown argument found in the Trainable function. " + "The function args must include a 'config' positional " + "parameter. Any other args must be 'checkpoint_dir'. " + "Found: {}".format(func_args)) + + if use_config_single and not use_checkpoint: if log_once("tune_function_checkpoint") and warn: logger.warning( "Function checkpointing is disabled. This may result in " diff --git a/python/ray/tune/integration/torch.py b/python/ray/tune/integration/torch.py index 64a0bb88f..f7612a82a 100644 --- a/python/ray/tune/integration/torch.py +++ b/python/ray/tune/integration/torch.py @@ -12,10 +12,10 @@ import ray from ray import tune from ray.tune.result import RESULT_DUPLICATE from ray.tune.logger import NoopLogger -from ray.tune.function_runner import (wrap_function, - detect_checkpoint_function) +from ray.tune.function_runner import wrap_function from ray.tune.resources import Resources from ray.tune.trainable import TrainableUtil +from ray.tune.utils import detect_checkpoint_function from ray.util.sgd.torch.utils import setup_process_group, setup_address from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S diff --git a/python/ray/tune/tests/test_experiment.py b/python/ray/tune/tests/test_experiment.py index a154231a3..13289a020 100644 --- a/python/ray/tune/tests/test_experiment.py +++ b/python/ray/tune/tests/test_experiment.py @@ -79,7 +79,7 @@ class ValidateUtilTest(unittest.TestCase): # this is not serializable e = threading.Event() - def test(): + def test(config): print(e) assert diagnose_serialization(test) is not True @@ -88,7 +88,7 @@ class ValidateUtilTest(unittest.TestCase): # the `test` scope. # correct implementation - def test(): + def test(config): e = threading.Event() print(e) diff --git a/python/ray/tune/utils/__init__.py b/python/ray/tune/utils/__init__.py index ab61a4e88..d37a074cc 100644 --- a/python/ray/tune/utils/__init__.py +++ b/python/ray/tune/utils/__init__.py @@ -1,10 +1,14 @@ -from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \ - merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \ - validate_save_restore, warn_if_slow, diagnose_serialization, env_integer +from ray.tune.utils.util import ( + deep_update, flatten_dict, get_pinned_object, merge_dicts, + pin_in_object_store, unflattened_lookup, UtilMonitor, + validate_save_restore, warn_if_slow, diagnose_serialization, + detect_checkpoint_function, detect_reporter, detect_config_single, + env_integer) __all__ = [ "deep_update", "flatten_dict", "get_pinned_object", "merge_dicts", "pin_in_object_store", "unflattened_lookup", "UtilMonitor", "validate_save_restore", "warn_if_slow", "diagnose_serialization", + "detect_checkpoint_function", "detect_reporter", "detect_config_single", "env_integer" ] diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index 8f49d4962..d31d07197 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -435,6 +435,50 @@ def validate_save_restore(trainable_cls, return True +def detect_checkpoint_function(train_func, abort=False): + """Use checkpointing if any arg has "checkpoint_dir" and args = 2""" + func_sig = inspect.signature(train_func) + validated = True + try: + # check if signature is func(config, checkpoint_dir=None) + func_sig.bind({}, checkpoint_dir="tmp/path") + except Exception as e: + logger.debug(str(e)) + validated = False + if abort and not validated: + func_args = inspect.getfullargspec(train_func).args + raise ValueError( + "Provided training function must have 2 args " + "in the signature, and the latter arg must " + "contain `checkpoint_dir`. For example: " + "`func(config, checkpoint_dir=None)`. Got {}".format(func_args)) + return validated + + +def detect_reporter(func): + """Use reporter if any arg has "reporter" and args = 2""" + func_sig = inspect.signature(func) + use_reporter = True + try: + func_sig.bind({}, reporter=None) + except Exception as e: + logger.debug(str(e)) + use_reporter = False + return use_reporter + + +def detect_config_single(func): + """Check if func({}) works.""" + func_sig = inspect.signature(func) + use_config_single = True + try: + func_sig.bind({}) + except Exception as e: + logger.debug(str(e)) + use_config_single = False + return use_config_single + + if __name__ == "__main__": ray.init() X = pin_in_object_store("hello") diff --git a/python/setup.py b/python/setup.py index 02d29ce0b..89cede854 100644 --- a/python/setup.py +++ b/python/setup.py @@ -111,7 +111,7 @@ if os.getenv("RAY_USE_NEW_GCS") == "on": extras = { "debug": [], "serve": ["uvicorn", "flask", "requests", "pydantic", "dataclasses"], - "tune": ["tabulate", "tensorboardX", "pandas"] + "tune": ["tabulate", "tensorboardX", "pandas", "dataclasses"] } extras["rllib"] = extras["tune"] + [