diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index ef898d04a..813b05760 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -9,9 +9,9 @@ from ray.tune.registry import register_env, register_trainable from ray.tune.trainable import Trainable from ray.tune.durable_trainable import DurableTrainable from ray.tune.suggest import grid_search -from ray.tune.session import (report, get_trial_dir, get_trial_name, - get_trial_id, make_checkpoint_dir, - save_checkpoint, checkpoint_dir) +from ray.tune.session import ( + report, get_trial_dir, get_trial_name, get_trial_id, make_checkpoint_dir, + save_checkpoint, checkpoint_dir, is_session_enabled) from ray.tune.progress_reporter import (ProgressReporter, CLIReporter, JupyterNotebookReporter) from ray.tune.sample import (function, sample_from, uniform, quniform, choice, @@ -28,6 +28,7 @@ __all__ = [ "qrandint", "randn", "qrandn", "loguniform", "qloguniform", "ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report", "get_trial_dir", "get_trial_name", - "get_trial_id", "make_checkpoint_dir", "save_checkpoint", "checkpoint_dir", - "SyncConfig", "create_searcher", "create_scheduler" + "get_trial_id", "make_checkpoint_dir", "save_checkpoint", + "is_session_enabled", "checkpoint_dir", "SyncConfig", "create_searcher", + "create_scheduler" ] diff --git a/python/ray/tune/session.py b/python/ray/tune/session.py index c682e9457..4a4169bd9 100644 --- a/python/ray/tune/session.py +++ b/python/ray/tune/session.py @@ -7,6 +7,12 @@ logger = logging.getLogger(__name__) _session = None +def is_session_enabled() -> bool: + """Returns True if running within an Tune process.""" + global _session + return _session is not None + + def get_session(): global _session if not _session: diff --git a/python/ray/tune/tests/test_function_api.py b/python/ray/tune/tests/test_function_api.py index b3714326b..10f781457 100644 --- a/python/ray/tune/tests/test_function_api.py +++ b/python/ray/tune/tests/test_function_api.py @@ -400,6 +400,18 @@ class FunctionApiTest(unittest.TestCase): trial_dfs = list(analysis.trial_dataframes.values()) assert len(trial_dfs[0]["training_iteration"]) == 10 + def testEnabled(self): + def train(config, checkpoint_dir=None): + is_active = tune.is_session_enabled() + if is_active: + tune.report(active=is_active) + return is_active + + assert train({}) is False + analysis = tune.run(train) + t = analysis.trials[0] + assert t.last_result["active"] + def testBlankCheckpoint(self): def train(config, checkpoint_dir=None): restored = bool(checkpoint_dir)