mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:19:38 +08:00
[tune] check for running session (#10840)
This commit is contained in:
@@ -9,9 +9,9 @@ from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.session import (report, get_trial_dir, get_trial_name,
|
||||
get_trial_id, make_checkpoint_dir,
|
||||
save_checkpoint, checkpoint_dir)
|
||||
from ray.tune.session import (
|
||||
report, get_trial_dir, get_trial_name, get_trial_id, make_checkpoint_dir,
|
||||
save_checkpoint, checkpoint_dir, is_session_enabled)
|
||||
from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
|
||||
JupyterNotebookReporter)
|
||||
from ray.tune.sample import (function, sample_from, uniform, quniform, choice,
|
||||
@@ -28,6 +28,7 @@ __all__ = [
|
||||
"qrandint", "randn", "qrandn", "loguniform", "qloguniform",
|
||||
"ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter",
|
||||
"ProgressReporter", "report", "get_trial_dir", "get_trial_name",
|
||||
"get_trial_id", "make_checkpoint_dir", "save_checkpoint", "checkpoint_dir",
|
||||
"SyncConfig", "create_searcher", "create_scheduler"
|
||||
"get_trial_id", "make_checkpoint_dir", "save_checkpoint",
|
||||
"is_session_enabled", "checkpoint_dir", "SyncConfig", "create_searcher",
|
||||
"create_scheduler"
|
||||
]
|
||||
|
||||
@@ -7,6 +7,12 @@ logger = logging.getLogger(__name__)
|
||||
_session = None
|
||||
|
||||
|
||||
def is_session_enabled() -> bool:
|
||||
"""Returns True if running within an Tune process."""
|
||||
global _session
|
||||
return _session is not None
|
||||
|
||||
|
||||
def get_session():
|
||||
global _session
|
||||
if not _session:
|
||||
|
||||
@@ -400,6 +400,18 @@ class FunctionApiTest(unittest.TestCase):
|
||||
trial_dfs = list(analysis.trial_dataframes.values())
|
||||
assert len(trial_dfs[0]["training_iteration"]) == 10
|
||||
|
||||
def testEnabled(self):
|
||||
def train(config, checkpoint_dir=None):
|
||||
is_active = tune.is_session_enabled()
|
||||
if is_active:
|
||||
tune.report(active=is_active)
|
||||
return is_active
|
||||
|
||||
assert train({}) is False
|
||||
analysis = tune.run(train)
|
||||
t = analysis.trials[0]
|
||||
assert t.last_result["active"]
|
||||
|
||||
def testBlankCheckpoint(self):
|
||||
def train(config, checkpoint_dir=None):
|
||||
restored = bool(checkpoint_dir)
|
||||
|
||||
Reference in New Issue
Block a user