mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:53:20 +08:00
[tune] add accessible trial_info (#7378)
* add accessible trial_info * trial name and info * doc * fix gp * Update doc/source/tune-package-ref.rst * Apply suggestions from code review * fix * trial * fixtest * testfix
This commit is contained in:
@@ -29,10 +29,17 @@ class StatusReporter:
|
||||
>>> reporter(timesteps_this_iter=1)
|
||||
"""
|
||||
|
||||
def __init__(self, result_queue, continue_semaphore, logdir=None):
|
||||
def __init__(self,
|
||||
result_queue,
|
||||
continue_semaphore,
|
||||
trial_name=None,
|
||||
trial_id=None,
|
||||
logdir=None):
|
||||
self._queue = result_queue
|
||||
self._last_report_time = None
|
||||
self._continue_semaphore = continue_semaphore
|
||||
self._trial_name = trial_name
|
||||
self._trial_id = trial_id
|
||||
self._logdir = logdir
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
@@ -78,6 +85,16 @@ class StatusReporter:
|
||||
def logdir(self):
|
||||
return self._logdir
|
||||
|
||||
@property
|
||||
def trial_name(self):
|
||||
"""Trial name for the corresponding trial of this Trainable."""
|
||||
return self._trial_name
|
||||
|
||||
@property
|
||||
def trial_id(self):
|
||||
"""Trial id for the corresponding trial of this Trainable."""
|
||||
return self._trial_id
|
||||
|
||||
|
||||
class _RunnerThread(threading.Thread):
|
||||
"""Supervisor thread that runs your script."""
|
||||
@@ -133,7 +150,11 @@ class FunctionRunner(Trainable):
|
||||
self._error_queue = queue.Queue(1)
|
||||
|
||||
self._status_reporter = StatusReporter(
|
||||
self._results_queue, self._continue_semaphore, self.logdir)
|
||||
self._results_queue,
|
||||
self._continue_semaphore,
|
||||
trial_name=self.trial_name,
|
||||
trial_id=self.trial_id,
|
||||
logdir=self.logdir)
|
||||
self._last_result = {}
|
||||
config = config.copy()
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# coding: utf-8
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@@ -13,9 +14,10 @@ from ray.resource_spec import ResourceSpec
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.error import AbortTrialExecution, TuneError
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.result import TRIAL_INFO
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
from ray.tune.trial import Trial, Checkpoint, Location
|
||||
from ray.tune.trial import Trial, Checkpoint, Location, TrialInfo
|
||||
from ray.tune.trial_executor import TrialExecutor
|
||||
from ray.tune.utils import warn_if_slow
|
||||
|
||||
@@ -119,8 +121,10 @@ class RayTrialExecutor(TrialExecutor):
|
||||
logger.debug("Trial %s: Setting up new remote runner.", trial)
|
||||
# Logging for trials is handled centrally by TrialRunner, so
|
||||
# configure the remote runner to use a noop-logger.
|
||||
trial_config = copy.deepcopy(trial.config)
|
||||
trial_config[TRIAL_INFO] = TrialInfo(trial)
|
||||
kwargs = {
|
||||
"config": trial.config,
|
||||
"config": trial_config,
|
||||
"logger_creator": logger_creator,
|
||||
}
|
||||
if issubclass(trial.get_trainable_cls(), DurableTrainable):
|
||||
|
||||
@@ -65,6 +65,10 @@ DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
|
||||
# avoid double-logging results when using the Function API.
|
||||
RESULT_DUPLICATE = "__duplicate__"
|
||||
|
||||
# __trial_info__ is a magic keyword used internally to pass trial_info
|
||||
# to the Trainable via the constructor.
|
||||
TRIAL_INFO = "__trial_info__"
|
||||
|
||||
# Where Tune writes result files by default
|
||||
DEFAULT_RESULTS_DIR = (os.environ.get("TEST_TMPDIR")
|
||||
or os.environ.get("TUNE_RESULT_DIR")
|
||||
|
||||
@@ -584,6 +584,36 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result["mean_accuracy"], float("inf"))
|
||||
|
||||
def testTrialInfoAccess(self):
|
||||
class TestTrainable(Trainable):
|
||||
def _train(self):
|
||||
result = {"name": self.trial_name, "trial_id": self.trial_id}
|
||||
print(result)
|
||||
return result
|
||||
|
||||
analysis = tune.run(TestTrainable, stop={TRAINING_ITERATION: 1})
|
||||
trial = analysis.trials[0]
|
||||
self.assertEqual(trial.last_result.get("name"), str(trial))
|
||||
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
|
||||
|
||||
def testTrialInfoAccessFunction(self):
|
||||
def train(config, reporter):
|
||||
reporter(name=reporter.trial_name, trial_id=reporter.trial_id)
|
||||
|
||||
analysis = tune.run(train, stop={TRAINING_ITERATION: 1})
|
||||
trial = analysis.trials[0]
|
||||
self.assertEqual(trial.last_result.get("name"), str(trial))
|
||||
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
|
||||
|
||||
def track_train(config):
|
||||
tune.track.log(
|
||||
name=tune.track.trial_name(), trial_id=tune.track.trial_id())
|
||||
|
||||
analysis = tune.run(track_train, stop={TRAINING_ITERATION: 1})
|
||||
trial = analysis.trials[0]
|
||||
self.assertEqual(trial.last_result.get("name"), str(trial))
|
||||
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
|
||||
|
||||
def testNestedResults(self):
|
||||
def create_result(i):
|
||||
return {"test": {"1": {"2": {"3": i, "4": False}}}}
|
||||
|
||||
@@ -64,4 +64,25 @@ def trial_dir():
|
||||
return _session.logdir
|
||||
|
||||
|
||||
__all__ = ["TrackSession", "session", "log", "trial_dir", "init", "shutdown"]
|
||||
def trial_name():
|
||||
"""Trial name for the corresponding trial of this Trainable.
|
||||
|
||||
This is not set if not using Tune.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.trial_name
|
||||
|
||||
|
||||
def trial_id():
|
||||
"""Trial id for the corresponding trial of this Trainable.
|
||||
|
||||
This is not set if not using Tune.
|
||||
"""
|
||||
_session = get_session()
|
||||
return _session.trial_id
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TrackSession", "session", "log", "trial_dir", "init", "shutdown",
|
||||
"trial_name", "trial_id"
|
||||
]
|
||||
|
||||
@@ -17,7 +17,8 @@ class _ReporterHook(Logger):
|
||||
class TrackSession:
|
||||
"""Manages results for a single session.
|
||||
|
||||
Represents a single Trial in an experiment.
|
||||
Represents a single Trial in an experiment. This is automatically
|
||||
created when using ``tune.run``.
|
||||
|
||||
Attributes:
|
||||
trial_name (str): Custom trial name.
|
||||
@@ -31,7 +32,7 @@ class TrackSession:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
trial_name="",
|
||||
trial_name=None,
|
||||
experiment_dir=None,
|
||||
upload_dir=None,
|
||||
trial_config=None,
|
||||
@@ -42,18 +43,17 @@ class TrackSession:
|
||||
self.trial_config = None
|
||||
self._iteration = -1
|
||||
self.is_tune_session = bool(_tune_reporter)
|
||||
self.trial_id = Trial.generate_id()
|
||||
if trial_name:
|
||||
self.trial_id = trial_name + "_" + self.trial_id
|
||||
if self.is_tune_session:
|
||||
self._logger = _ReporterHook(_tune_reporter)
|
||||
self._logdir = _tune_reporter.logdir
|
||||
self._trial_name = _tune_reporter.trial_name
|
||||
self._trial_id = _tune_reporter.trial_id
|
||||
else:
|
||||
self._initialize_logging(trial_name, experiment_dir, upload_dir,
|
||||
trial_config)
|
||||
self._trial_id = Trial.generate_id()
|
||||
self._trial_name = trial_name or self._trial_id
|
||||
self._initialize_logging(experiment_dir, upload_dir, trial_config)
|
||||
|
||||
def _initialize_logging(self,
|
||||
trial_name="",
|
||||
experiment_dir=None,
|
||||
upload_dir=None,
|
||||
trial_config=None):
|
||||
@@ -67,7 +67,8 @@ class TrackSession:
|
||||
self._experiment_dir = os.path.expanduser(experiment_dir)
|
||||
|
||||
# TODO(rliaw): Refactor `logdir` to `trial_dir`.
|
||||
self._logdir = Trial.create_logdir(trial_name, self._experiment_dir)
|
||||
self._logdir = Trial.create_logdir(self.trial_name,
|
||||
self._experiment_dir)
|
||||
self._upload_dir = upload_dir
|
||||
self.trial_config = trial_config or {}
|
||||
|
||||
@@ -95,6 +96,10 @@ class TrackSession:
|
||||
self._logger.on_result(metrics_dict)
|
||||
|
||||
def close(self):
|
||||
"""Closes loggers.
|
||||
|
||||
No need to call this when using ``tune.run``.
|
||||
"""
|
||||
self.trial_config["trial_completed"] = True
|
||||
self.trial_config["end_time"] = datetime.now().isoformat()
|
||||
# TODO(rliaw): Have Tune support updated configs
|
||||
@@ -106,3 +111,13 @@ class TrackSession:
|
||||
def logdir(self):
|
||||
"""Trial logdir (subdir of given experiment directory)"""
|
||||
return self._logdir
|
||||
|
||||
@property
|
||||
def trial_name(self):
|
||||
"""Trial name for the corresponding trial of this Trainable"""
|
||||
return self._trial_name
|
||||
|
||||
@property
|
||||
def trial_id(self):
|
||||
"""Trial id for the corresponding trial of this Trainable"""
|
||||
return self._trial_id
|
||||
|
||||
@@ -18,7 +18,7 @@ from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
|
||||
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL,
|
||||
EPISODES_THIS_ITER, EPISODES_TOTAL,
|
||||
TRAINING_ITERATION, RESULT_DUPLICATE)
|
||||
TRAINING_ITERATION, RESULT_DUPLICATE, TRIAL_INFO)
|
||||
from ray.tune.utils import UtilMonitor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -147,6 +147,7 @@ class Trainable:
|
||||
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
self.config = config or {}
|
||||
trial_info = self.config.pop(TRIAL_INFO, None)
|
||||
|
||||
if logger_creator:
|
||||
self._result_logger = logger_creator(self.config)
|
||||
@@ -167,6 +168,7 @@ class Trainable:
|
||||
self._timesteps_since_restore = 0
|
||||
self._iterations_since_restore = 0
|
||||
self._restored = False
|
||||
self._trial_info = trial_info
|
||||
|
||||
start_time = time.time()
|
||||
self._setup(copy.deepcopy(self.config))
|
||||
@@ -207,7 +209,7 @@ class Trainable:
|
||||
return ""
|
||||
|
||||
def current_ip(self):
|
||||
logger.warning("Getting current IP.")
|
||||
logger.info("Getting current IP.")
|
||||
self._local_ip = ray.services.get_node_ip_address()
|
||||
return self._local_ip
|
||||
|
||||
@@ -511,6 +513,30 @@ class Trainable:
|
||||
"""
|
||||
return os.path.join(self._logdir, "")
|
||||
|
||||
@property
|
||||
def trial_name(self):
|
||||
"""Trial name for the corresponding trial of this Trainable.
|
||||
|
||||
This is not set if not using Tune.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
name = self.trial_name
|
||||
"""
|
||||
return self._trial_info.trial_name
|
||||
|
||||
@property
|
||||
def trial_id(self):
|
||||
"""Trial ID for the corresponding trial of this Trainable.
|
||||
|
||||
This is not set if not using Tune.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trial_id = self.trial_id
|
||||
"""
|
||||
return self._trial_info.trial_id
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
"""Current training iteration.
|
||||
|
||||
@@ -102,6 +102,27 @@ def checkpoint_deleter(trial_id, runner):
|
||||
return delete
|
||||
|
||||
|
||||
class TrialInfo:
|
||||
"""Serializable struct for holding information for a Trial.
|
||||
|
||||
Attributes:
|
||||
trial_name (str): String name of the currernt trial.
|
||||
trial_id (str): trial_id of the trial
|
||||
"""
|
||||
|
||||
def __init__(self, trial):
|
||||
self._trial_name = str(trial)
|
||||
self._trial_id = trial.trial_id
|
||||
|
||||
@property
|
||||
def trial_name(self):
|
||||
return self._trial_name
|
||||
|
||||
@property
|
||||
def trial_id(self):
|
||||
return self._trial_id
|
||||
|
||||
|
||||
class Trial:
|
||||
"""A trial object holds the state for one model training run.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user