[tune] check for running session (#10840)

This commit is contained in:
Richard Liaw
2020-09-16 18:55:11 -07:00
committed by GitHub
parent 829a2307df
commit d3feb83053
3 changed files with 24 additions and 5 deletions
+6 -5
View File
@@ -9,9 +9,9 @@ from ray.tune.registry import register_env, register_trainable
from ray.tune.trainable import Trainable
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.suggest import grid_search
from ray.tune.session import (report, get_trial_dir, get_trial_name,
get_trial_id, make_checkpoint_dir,
save_checkpoint, checkpoint_dir)
from ray.tune.session import (
report, get_trial_dir, get_trial_name, get_trial_id, make_checkpoint_dir,
save_checkpoint, checkpoint_dir, is_session_enabled)
from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
JupyterNotebookReporter)
from ray.tune.sample import (function, sample_from, uniform, quniform, choice,
@@ -28,6 +28,7 @@ __all__ = [
"qrandint", "randn", "qrandn", "loguniform", "qloguniform",
"ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter",
"ProgressReporter", "report", "get_trial_dir", "get_trial_name",
"get_trial_id", "make_checkpoint_dir", "save_checkpoint", "checkpoint_dir",
"SyncConfig", "create_searcher", "create_scheduler"
"get_trial_id", "make_checkpoint_dir", "save_checkpoint",
"is_session_enabled", "checkpoint_dir", "SyncConfig", "create_searcher",
"create_scheduler"
]
+6
View File
@@ -7,6 +7,12 @@ logger = logging.getLogger(__name__)
_session = None
def is_session_enabled() -> bool:
"""Returns True if running within an Tune process."""
global _session
return _session is not None
def get_session():
global _session
if not _session:
@@ -400,6 +400,18 @@ class FunctionApiTest(unittest.TestCase):
trial_dfs = list(analysis.trial_dataframes.values())
assert len(trial_dfs[0]["training_iteration"]) == 10
def testEnabled(self):
def train(config, checkpoint_dir=None):
is_active = tune.is_session_enabled()
if is_active:
tune.report(active=is_active)
return is_active
assert train({}) is False
analysis = tune.run(train)
t = analysis.trials[0]
assert t.last_result["active"]
def testBlankCheckpoint(self):
def train(config, checkpoint_dir=None):
restored = bool(checkpoint_dir)