[tune] add accessible trial_info (#7378)

* add accessible trial_info

* trial name and info

* doc

* fix
gp

* Update doc/source/tune-package-ref.rst

* Apply suggestions from code review

* fix

* trial

* fixtest

* testfix
This commit is contained in:
Richard Liaw
2020-03-17 23:44:18 -07:00
committed by GitHub
parent 745b9d643d
commit ea10cd212c
8 changed files with 158 additions and 16 deletions
+23 -2
View File
@@ -29,10 +29,17 @@ class StatusReporter:
>>> reporter(timesteps_this_iter=1)
"""
def __init__(self, result_queue, continue_semaphore, logdir=None):
def __init__(self,
result_queue,
continue_semaphore,
trial_name=None,
trial_id=None,
logdir=None):
self._queue = result_queue
self._last_report_time = None
self._continue_semaphore = continue_semaphore
self._trial_name = trial_name
self._trial_id = trial_id
self._logdir = logdir
def __call__(self, **kwargs):
@@ -78,6 +85,16 @@ class StatusReporter:
def logdir(self):
return self._logdir
@property
def trial_name(self):
"""Trial name for the corresponding trial of this Trainable."""
return self._trial_name
@property
def trial_id(self):
"""Trial id for the corresponding trial of this Trainable."""
return self._trial_id
class _RunnerThread(threading.Thread):
"""Supervisor thread that runs your script."""
@@ -133,7 +150,11 @@ class FunctionRunner(Trainable):
self._error_queue = queue.Queue(1)
self._status_reporter = StatusReporter(
self._results_queue, self._continue_semaphore, self.logdir)
self._results_queue,
self._continue_semaphore,
trial_name=self.trial_name,
trial_id=self.trial_id,
logdir=self.logdir)
self._last_result = {}
config = config.copy()
+6 -2
View File
@@ -1,4 +1,5 @@
# coding: utf-8
import copy
import logging
import os
import random
@@ -13,9 +14,10 @@ from ray.resource_spec import ResourceSpec
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.error import AbortTrialExecution, TuneError
from ray.tune.logger import NoopLogger
from ray.tune.result import TRIAL_INFO
from ray.tune.resources import Resources
from ray.tune.trainable import TrainableUtil
from ray.tune.trial import Trial, Checkpoint, Location
from ray.tune.trial import Trial, Checkpoint, Location, TrialInfo
from ray.tune.trial_executor import TrialExecutor
from ray.tune.utils import warn_if_slow
@@ -119,8 +121,10 @@ class RayTrialExecutor(TrialExecutor):
logger.debug("Trial %s: Setting up new remote runner.", trial)
# Logging for trials is handled centrally by TrialRunner, so
# configure the remote runner to use a noop-logger.
trial_config = copy.deepcopy(trial.config)
trial_config[TRIAL_INFO] = TrialInfo(trial)
kwargs = {
"config": trial.config,
"config": trial_config,
"logger_creator": logger_creator,
}
if issubclass(trial.get_trainable_cls(), DurableTrainable):
+4
View File
@@ -65,6 +65,10 @@ DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
# avoid double-logging results when using the Function API.
RESULT_DUPLICATE = "__duplicate__"
# __trial_info__ is a magic keyword used internally to pass trial_info
# to the Trainable via the constructor.
TRIAL_INFO = "__trial_info__"
# Where Tune writes result files by default
DEFAULT_RESULTS_DIR = (os.environ.get("TEST_TMPDIR")
or os.environ.get("TUNE_RESULT_DIR")
+30
View File
@@ -584,6 +584,36 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result["mean_accuracy"], float("inf"))
def testTrialInfoAccess(self):
class TestTrainable(Trainable):
def _train(self):
result = {"name": self.trial_name, "trial_id": self.trial_id}
print(result)
return result
analysis = tune.run(TestTrainable, stop={TRAINING_ITERATION: 1})
trial = analysis.trials[0]
self.assertEqual(trial.last_result.get("name"), str(trial))
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
def testTrialInfoAccessFunction(self):
def train(config, reporter):
reporter(name=reporter.trial_name, trial_id=reporter.trial_id)
analysis = tune.run(train, stop={TRAINING_ITERATION: 1})
trial = analysis.trials[0]
self.assertEqual(trial.last_result.get("name"), str(trial))
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
def track_train(config):
tune.track.log(
name=tune.track.trial_name(), trial_id=tune.track.trial_id())
analysis = tune.run(track_train, stop={TRAINING_ITERATION: 1})
trial = analysis.trials[0]
self.assertEqual(trial.last_result.get("name"), str(trial))
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
def testNestedResults(self):
def create_result(i):
return {"test": {"1": {"2": {"3": i, "4": False}}}}
+22 -1
View File
@@ -64,4 +64,25 @@ def trial_dir():
return _session.logdir
__all__ = ["TrackSession", "session", "log", "trial_dir", "init", "shutdown"]
def trial_name():
"""Trial name for the corresponding trial of this Trainable.
This is not set if not using Tune.
"""
_session = get_session()
return _session.trial_name
def trial_id():
"""Trial id for the corresponding trial of this Trainable.
This is not set if not using Tune.
"""
_session = get_session()
return _session.trial_id
__all__ = [
"TrackSession", "session", "log", "trial_dir", "init", "shutdown",
"trial_name", "trial_id"
]
+24 -9
View File
@@ -17,7 +17,8 @@ class _ReporterHook(Logger):
class TrackSession:
"""Manages results for a single session.
Represents a single Trial in an experiment.
Represents a single Trial in an experiment. This is automatically
created when using ``tune.run``.
Attributes:
trial_name (str): Custom trial name.
@@ -31,7 +32,7 @@ class TrackSession:
"""
def __init__(self,
trial_name="",
trial_name=None,
experiment_dir=None,
upload_dir=None,
trial_config=None,
@@ -42,18 +43,17 @@ class TrackSession:
self.trial_config = None
self._iteration = -1
self.is_tune_session = bool(_tune_reporter)
self.trial_id = Trial.generate_id()
if trial_name:
self.trial_id = trial_name + "_" + self.trial_id
if self.is_tune_session:
self._logger = _ReporterHook(_tune_reporter)
self._logdir = _tune_reporter.logdir
self._trial_name = _tune_reporter.trial_name
self._trial_id = _tune_reporter.trial_id
else:
self._initialize_logging(trial_name, experiment_dir, upload_dir,
trial_config)
self._trial_id = Trial.generate_id()
self._trial_name = trial_name or self._trial_id
self._initialize_logging(experiment_dir, upload_dir, trial_config)
def _initialize_logging(self,
trial_name="",
experiment_dir=None,
upload_dir=None,
trial_config=None):
@@ -67,7 +67,8 @@ class TrackSession:
self._experiment_dir = os.path.expanduser(experiment_dir)
# TODO(rliaw): Refactor `logdir` to `trial_dir`.
self._logdir = Trial.create_logdir(trial_name, self._experiment_dir)
self._logdir = Trial.create_logdir(self.trial_name,
self._experiment_dir)
self._upload_dir = upload_dir
self.trial_config = trial_config or {}
@@ -95,6 +96,10 @@ class TrackSession:
self._logger.on_result(metrics_dict)
def close(self):
"""Closes loggers.
No need to call this when using ``tune.run``.
"""
self.trial_config["trial_completed"] = True
self.trial_config["end_time"] = datetime.now().isoformat()
# TODO(rliaw): Have Tune support updated configs
@@ -106,3 +111,13 @@ class TrackSession:
def logdir(self):
"""Trial logdir (subdir of given experiment directory)"""
return self._logdir
@property
def trial_name(self):
"""Trial name for the corresponding trial of this Trainable"""
return self._trial_name
@property
def trial_id(self):
"""Trial id for the corresponding trial of this Trainable"""
return self._trial_id
+28 -2
View File
@@ -18,7 +18,7 @@ from ray.tune.logger import UnifiedLogger
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL,
EPISODES_THIS_ITER, EPISODES_TOTAL,
TRAINING_ITERATION, RESULT_DUPLICATE)
TRAINING_ITERATION, RESULT_DUPLICATE, TRIAL_INFO)
from ray.tune.utils import UtilMonitor
logger = logging.getLogger(__name__)
@@ -147,6 +147,7 @@ class Trainable:
self._experiment_id = uuid.uuid4().hex
self.config = config or {}
trial_info = self.config.pop(TRIAL_INFO, None)
if logger_creator:
self._result_logger = logger_creator(self.config)
@@ -167,6 +168,7 @@ class Trainable:
self._timesteps_since_restore = 0
self._iterations_since_restore = 0
self._restored = False
self._trial_info = trial_info
start_time = time.time()
self._setup(copy.deepcopy(self.config))
@@ -207,7 +209,7 @@ class Trainable:
return ""
def current_ip(self):
logger.warning("Getting current IP.")
logger.info("Getting current IP.")
self._local_ip = ray.services.get_node_ip_address()
return self._local_ip
@@ -511,6 +513,30 @@ class Trainable:
"""
return os.path.join(self._logdir, "")
@property
def trial_name(self):
"""Trial name for the corresponding trial of this Trainable.
This is not set if not using Tune.
.. code-block:: python
name = self.trial_name
"""
return self._trial_info.trial_name
@property
def trial_id(self):
"""Trial ID for the corresponding trial of this Trainable.
This is not set if not using Tune.
.. code-block:: python
trial_id = self.trial_id
"""
return self._trial_info.trial_id
@property
def iteration(self):
"""Current training iteration.
+21
View File
@@ -102,6 +102,27 @@ def checkpoint_deleter(trial_id, runner):
return delete
class TrialInfo:
"""Serializable struct for holding information for a Trial.
Attributes:
trial_name (str): String name of the currernt trial.
trial_id (str): trial_id of the trial
"""
def __init__(self, trial):
self._trial_name = str(trial)
self._trial_id = trial.trial_id
@property
def trial_name(self):
return self._trial_name
@property
def trial_id(self):
return self._trial_id
class Trial:
"""A trial object holds the state for one model training run.