mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
[tune] More robust resolution/detection of signature (#10365)
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
@@ -5,11 +5,11 @@ import os
|
||||
from typing import Sequence
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.function_runner import detect_checkpoint_function
|
||||
from ray.tune.registry import register_trainable, get_trainable_cls
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.sample import Domain
|
||||
from ray.tune.stopper import FunctionStopper, Stopper
|
||||
from ray.tune.utils import detect_checkpoint_function
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ from ray.tune import TuneError, session
|
||||
from ray.tune.trainable import Trainable, TrainableUtil
|
||||
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||
SHOULD_CHECKPOINT)
|
||||
from ray.tune.utils import (detect_checkpoint_function, detect_config_single,
|
||||
detect_reporter)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -459,24 +461,6 @@ class FunctionRunner(Trainable):
|
||||
pass
|
||||
|
||||
|
||||
def detect_checkpoint_function(train_func, abort=False):
|
||||
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
|
||||
argspec = inspect.getfullargspec(train_func)
|
||||
func_args = argspec.args
|
||||
func_kwargs = argspec.kwonlyargs
|
||||
validated = len(func_args) == 2 and any("checkpoint_dir" in arg
|
||||
for arg in func_args)
|
||||
validated = validated or (len(func_args) == 1) and any(
|
||||
"checkpoint_dir" in arg for arg in func_kwargs)
|
||||
if abort and not validated:
|
||||
raise ValueError(
|
||||
"Provided training function must have 2 args "
|
||||
"in the signature, and the latter arg must "
|
||||
"contain `checkpoint_dir`. For example: "
|
||||
"`func(config, checkpoint_dir=None)`. Got {}".format(func_args))
|
||||
return validated
|
||||
|
||||
|
||||
def wrap_function(train_func, warn=True):
|
||||
if hasattr(train_func, "__mixins__"):
|
||||
inherit_from = train_func.__mixins__ + (FunctionRunner, )
|
||||
@@ -485,16 +469,18 @@ def wrap_function(train_func, warn=True):
|
||||
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_checkpoint = detect_checkpoint_function(train_func)
|
||||
if len(func_args) > 1: # more arguments than just the config
|
||||
if "reporter" not in func_args and not use_checkpoint:
|
||||
raise ValueError(
|
||||
"Unknown argument found in the Trainable function. "
|
||||
"Arguments other than the 'config' arg must be one "
|
||||
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
|
||||
func_args))
|
||||
use_config_single = detect_config_single(train_func)
|
||||
use_reporter = detect_reporter(train_func)
|
||||
|
||||
use_reporter = "reporter" in func_args
|
||||
if not use_checkpoint and not use_reporter:
|
||||
if not any([use_checkpoint, use_config_single, use_reporter]):
|
||||
# use_reporter is hidden
|
||||
raise ValueError(
|
||||
"Unknown argument found in the Trainable function. "
|
||||
"The function args must include a 'config' positional "
|
||||
"parameter. Any other args must be 'checkpoint_dir'. "
|
||||
"Found: {}".format(func_args))
|
||||
|
||||
if use_config_single and not use_checkpoint:
|
||||
if log_once("tune_function_checkpoint") and warn:
|
||||
logger.warning(
|
||||
"Function checkpointing is disabled. This may result in "
|
||||
|
||||
@@ -12,10 +12,10 @@ import ray
|
||||
from ray import tune
|
||||
from ray.tune.result import RESULT_DUPLICATE
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.function_runner import (wrap_function,
|
||||
detect_checkpoint_function)
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
from ray.tune.utils import detect_checkpoint_function
|
||||
from ray.util.sgd.torch.utils import setup_process_group, setup_address
|
||||
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ class ValidateUtilTest(unittest.TestCase):
|
||||
# this is not serializable
|
||||
e = threading.Event()
|
||||
|
||||
def test():
|
||||
def test(config):
|
||||
print(e)
|
||||
|
||||
assert diagnose_serialization(test) is not True
|
||||
@@ -88,7 +88,7 @@ class ValidateUtilTest(unittest.TestCase):
|
||||
# the `test` scope.
|
||||
|
||||
# correct implementation
|
||||
def test():
|
||||
def test(config):
|
||||
e = threading.Event()
|
||||
print(e)
|
||||
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from ray.tune.utils.util import deep_update, flatten_dict, get_pinned_object, \
|
||||
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor, \
|
||||
validate_save_restore, warn_if_slow, diagnose_serialization, env_integer
|
||||
from ray.tune.utils.util import (
|
||||
deep_update, flatten_dict, get_pinned_object, merge_dicts,
|
||||
pin_in_object_store, unflattened_lookup, UtilMonitor,
|
||||
validate_save_restore, warn_if_slow, diagnose_serialization,
|
||||
detect_checkpoint_function, detect_reporter, detect_config_single,
|
||||
env_integer)
|
||||
|
||||
__all__ = [
|
||||
"deep_update", "flatten_dict", "get_pinned_object", "merge_dicts",
|
||||
"pin_in_object_store", "unflattened_lookup", "UtilMonitor",
|
||||
"validate_save_restore", "warn_if_slow", "diagnose_serialization",
|
||||
"detect_checkpoint_function", "detect_reporter", "detect_config_single",
|
||||
"env_integer"
|
||||
]
|
||||
|
||||
@@ -435,6 +435,50 @@ def validate_save_restore(trainable_cls,
|
||||
return True
|
||||
|
||||
|
||||
def detect_checkpoint_function(train_func, abort=False):
|
||||
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
|
||||
func_sig = inspect.signature(train_func)
|
||||
validated = True
|
||||
try:
|
||||
# check if signature is func(config, checkpoint_dir=None)
|
||||
func_sig.bind({}, checkpoint_dir="tmp/path")
|
||||
except Exception as e:
|
||||
logger.debug(str(e))
|
||||
validated = False
|
||||
if abort and not validated:
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
raise ValueError(
|
||||
"Provided training function must have 2 args "
|
||||
"in the signature, and the latter arg must "
|
||||
"contain `checkpoint_dir`. For example: "
|
||||
"`func(config, checkpoint_dir=None)`. Got {}".format(func_args))
|
||||
return validated
|
||||
|
||||
|
||||
def detect_reporter(func):
|
||||
"""Use reporter if any arg has "reporter" and args = 2"""
|
||||
func_sig = inspect.signature(func)
|
||||
use_reporter = True
|
||||
try:
|
||||
func_sig.bind({}, reporter=None)
|
||||
except Exception as e:
|
||||
logger.debug(str(e))
|
||||
use_reporter = False
|
||||
return use_reporter
|
||||
|
||||
|
||||
def detect_config_single(func):
|
||||
"""Check if func({}) works."""
|
||||
func_sig = inspect.signature(func)
|
||||
use_config_single = True
|
||||
try:
|
||||
func_sig.bind({})
|
||||
except Exception as e:
|
||||
logger.debug(str(e))
|
||||
use_config_single = False
|
||||
return use_config_single
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
X = pin_in_object_store("hello")
|
||||
|
||||
+1
-1
@@ -111,7 +111,7 @@ if os.getenv("RAY_USE_NEW_GCS") == "on":
|
||||
extras = {
|
||||
"debug": [],
|
||||
"serve": ["uvicorn", "flask", "requests", "pydantic", "dataclasses"],
|
||||
"tune": ["tabulate", "tensorboardX", "pandas"]
|
||||
"tune": ["tabulate", "tensorboardX", "pandas", "dataclasses"]
|
||||
}
|
||||
|
||||
extras["rllib"] = extras["tune"] + [
|
||||
|
||||
Reference in New Issue
Block a user