mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 10:52:02 +08:00
[tune] logger refactor part 3: Add ExperimentLogger class (#11749)
This commit is contained in:
@@ -190,7 +190,6 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||
restore_path=spec.get("restore"),
|
||||
trial_name_creator=spec.get("trial_name_creator"),
|
||||
trial_dirname_creator=spec.get("trial_dirname_creator"),
|
||||
loggers=spec.get("loggers"),
|
||||
log_to_file=spec.get("log_to_file"),
|
||||
# str(None) doesn't create None
|
||||
max_failures=args.max_failures,
|
||||
|
||||
@@ -123,6 +123,17 @@ class Experiment:
|
||||
max_failures=0,
|
||||
restore=None):
|
||||
|
||||
if loggers is not None:
|
||||
# Most users won't run into this as `tune.run()` does not pass
|
||||
# the argument anymore. However, we will want to inform users
|
||||
# if they instantiate their `Experiment` objects themselves.
|
||||
raise ValueError(
|
||||
"Passing `loggers` to an `Experiment` is deprecated. Use "
|
||||
"an `ExperimentLogger` callback instead, e.g. by passing the "
|
||||
"`Logger` classes to `tune.logger.LegacyExperimentLogger` and "
|
||||
"passing this as part of the `callback` parameter to "
|
||||
"`tune.run()`.")
|
||||
|
||||
config = config or {}
|
||||
if callable(run) and detect_checkpoint_function(run):
|
||||
if checkpoint_at_end:
|
||||
|
||||
@@ -5,10 +5,11 @@ import numpy as np
|
||||
import os
|
||||
import yaml
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Type
|
||||
from typing import Iterable, TYPE_CHECKING, Dict, List, Optional, Type
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.utils.util import SafeFallbackEncoder
|
||||
from ray.util.debug import log_once
|
||||
from ray.tune.result import (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
|
||||
@@ -129,7 +130,8 @@ class JsonLogger(Logger):
|
||||
self.local_out.write(b)
|
||||
|
||||
def flush(self):
|
||||
self.local_out.flush()
|
||||
if not self.local_out.closed:
|
||||
self.local_out.flush()
|
||||
|
||||
def close(self):
|
||||
self.local_out.close()
|
||||
@@ -181,7 +183,8 @@ class CSVLogger(Logger):
|
||||
self._file.flush()
|
||||
|
||||
def flush(self):
|
||||
self._file.flush()
|
||||
if not self._file.closed:
|
||||
self._file.flush()
|
||||
|
||||
def close(self):
|
||||
self._file.close()
|
||||
@@ -361,6 +364,97 @@ class UnifiedLogger(Logger):
|
||||
_logger.flush()
|
||||
|
||||
|
||||
class ExperimentLogger(Callback):
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
raise NotImplementedError
|
||||
|
||||
def log_trial_restore(self, trial: "Trial"):
|
||||
raise NotImplementedError
|
||||
|
||||
def log_trial_save(self, trial: "Trial"):
|
||||
raise NotImplementedError
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
raise NotImplementedError
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_result(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", result: Dict, **info):
|
||||
self.log_trial_result(iteration, trial, result)
|
||||
|
||||
def on_trial_start(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", **info):
|
||||
self.log_trial_start(trial)
|
||||
|
||||
def on_trial_restore(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", **info):
|
||||
self.log_trial_restore(trial)
|
||||
|
||||
def on_trial_save(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", **info):
|
||||
self.log_trial_save(trial)
|
||||
|
||||
def on_trial_complete(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", **info):
|
||||
self.log_trial_end(trial, failed=False)
|
||||
|
||||
def on_trial_error(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", **info):
|
||||
self.log_trial_end(trial, failed=True)
|
||||
|
||||
|
||||
class LegacyExperimentLogger(ExperimentLogger):
|
||||
"""Supports logging to trial-specific `Logger` classes.
|
||||
|
||||
Previously, Ray Tune logging was handled via `Logger` classes that have
|
||||
been instantiated per-trial. This callback is a fallback to these
|
||||
`Logger`-classes, instantiating each `Logger` class for each trial
|
||||
and logging to them.
|
||||
|
||||
Args:
|
||||
logger_classes (Iterable[Type[Logger]]): Logger classes that should
|
||||
be instantiated for each trial.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, logger_classes: Iterable[Type[Logger]]):
|
||||
self.logger_classes = list(logger_classes)
|
||||
self._class_trial_loggers: Dict[Type[Logger], Dict["Trial",
|
||||
Logger]] = {}
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
trial.init_logdir()
|
||||
|
||||
for logger_class in self.logger_classes:
|
||||
trial_loggers = self._class_trial_loggers.get(logger_class, {})
|
||||
if trial not in trial_loggers:
|
||||
logger = logger_class(trial.config, trial.logdir, trial)
|
||||
trial_loggers[trial] = logger
|
||||
self._class_trial_loggers[logger_class] = trial_loggers
|
||||
|
||||
def log_trial_restore(self, trial: "Trial"):
|
||||
for logger_class, trial_loggers in self._class_trial_loggers.items():
|
||||
if trial in trial_loggers:
|
||||
trial_loggers[trial].flush()
|
||||
|
||||
def log_trial_save(self, trial: "Trial"):
|
||||
for logger_class, trial_loggers in self._class_trial_loggers.items():
|
||||
if trial in trial_loggers:
|
||||
trial_loggers[trial].flush()
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
for logger_class, trial_loggers in self._class_trial_loggers.items():
|
||||
if trial in trial_loggers:
|
||||
trial_loggers[trial].on_result(result)
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
for logger_class, trial_loggers in self._class_trial_loggers.items():
|
||||
if trial in trial_loggers:
|
||||
trial_loggers[trial].close()
|
||||
|
||||
|
||||
def pretty_print(result):
|
||||
result = result.copy()
|
||||
result.update(config=None) # drop config from pretty print
|
||||
|
||||
@@ -177,7 +177,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
self._update_avail_resources()
|
||||
|
||||
def _setup_remote_runner(self, trial, reuse_allowed):
|
||||
trial.init_logger()
|
||||
trial.init_logdir()
|
||||
# We checkpoint metadata here to try mitigating logdir duplication
|
||||
self.try_checkpoint_metadata(trial)
|
||||
logger_creator = partial(noop_logger_creator, logdir=trial.logdir)
|
||||
@@ -297,8 +297,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
elif train and not trial.is_restoring:
|
||||
self._train(trial)
|
||||
|
||||
def _stop_trial(self, trial, error=False, error_msg=None,
|
||||
stop_logger=True):
|
||||
def _stop_trial(self, trial, error=False, error_msg=None):
|
||||
"""Stops this trial.
|
||||
|
||||
Stops this trial, releasing all allocating resources. If stopping the
|
||||
@@ -308,7 +307,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||
Args:
|
||||
error (bool): Whether to mark this trial as terminated in error.
|
||||
error_msg (str): Optional error message.
|
||||
stop_logger (bool): Whether to shut down the trial logger.
|
||||
"""
|
||||
self.set_status(trial, Trial.ERROR if error else Trial.TERMINATED)
|
||||
trial.set_location(Location())
|
||||
@@ -329,8 +327,6 @@ class RayTrialExecutor(TrialExecutor):
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
finally:
|
||||
trial.set_runner(None)
|
||||
if stop_logger:
|
||||
trial.close_logger()
|
||||
|
||||
def start_trial(self, trial, checkpoint=None, train=True):
|
||||
"""Starts the trial.
|
||||
@@ -365,11 +361,10 @@ class RayTrialExecutor(TrialExecutor):
|
||||
out = [rid for rid, t in dictionary.items() if t is item]
|
||||
return out
|
||||
|
||||
def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True):
|
||||
def stop_trial(self, trial, error=False, error_msg=None):
|
||||
"""Only returns resources if resources allocated."""
|
||||
prior_status = trial.status
|
||||
self._stop_trial(
|
||||
trial, error=error, error_msg=error_msg, stop_logger=stop_logger)
|
||||
self._stop_trial(trial, error=error, error_msg=error_msg)
|
||||
if prior_status == Trial.RUNNING:
|
||||
logger.debug("Trial %s: Returning resources.", trial)
|
||||
self._return_resources(trial.resources)
|
||||
|
||||
@@ -574,7 +574,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
trial_executor.restore(
|
||||
trial, new_state.last_checkpoint, block=True)
|
||||
else:
|
||||
trial_executor.stop_trial(trial, stop_logger=False)
|
||||
trial_executor.stop_trial(trial)
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = new_tag
|
||||
trial_executor.start_trial(
|
||||
|
||||
@@ -12,7 +12,7 @@ from ray.rllib import _register_all
|
||||
|
||||
from ray import tune
|
||||
from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper,
|
||||
EarlyStopping)
|
||||
EarlyStopping, run)
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
|
||||
AsyncHyperBandScheduler)
|
||||
@@ -90,19 +90,19 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
class_trainable_name = "class_trainable"
|
||||
register_trainable(class_trainable_name, _WrappedTrainable)
|
||||
|
||||
trials = run_experiments(
|
||||
{
|
||||
"function_api": {
|
||||
"run": _function_trainable,
|
||||
"loggers": [FunctionAPILogger],
|
||||
},
|
||||
"class_api": {
|
||||
"run": class_trainable_name,
|
||||
"loggers": [ClassAPILogger],
|
||||
},
|
||||
},
|
||||
[trial1] = run(
|
||||
_function_trainable,
|
||||
loggers=[FunctionAPILogger],
|
||||
raise_on_failed_trial=False,
|
||||
scheduler=MockScheduler())
|
||||
scheduler=MockScheduler()).trials
|
||||
|
||||
[trial2] = run(
|
||||
class_trainable_name,
|
||||
loggers=[ClassAPILogger],
|
||||
raise_on_failed_trial=False,
|
||||
scheduler=MockScheduler()).trials
|
||||
|
||||
trials = [trial1, trial2]
|
||||
|
||||
# Ignore these fields
|
||||
NO_COMPARE_FIELDS = {
|
||||
|
||||
@@ -7,7 +7,7 @@ from ray.rllib import _register_all
|
||||
from ray.tune.result import TIMESTEPS_TOTAL
|
||||
from ray.tune import Trainable, TuneError
|
||||
from ray.tune import register_trainable, run_experiments
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.logger import LegacyExperimentLogger, Logger
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial, ExportFormat
|
||||
|
||||
@@ -173,21 +173,25 @@ class RunExperimentTest(unittest.TestCase):
|
||||
for trial in trials:
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
|
||||
def testCustomLogger(self):
|
||||
def testCustomLoggerNoAutoLogging(self):
|
||||
"""Does not create CSV/JSON logger callbacks automatically"""
|
||||
os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"] = "1"
|
||||
|
||||
class CustomLogger(Logger):
|
||||
def on_result(self, result):
|
||||
with open(os.path.join(self.logdir, "test.log"), "w") as f:
|
||||
f.write("hi")
|
||||
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"loggers": [CustomLogger]
|
||||
}
|
||||
})
|
||||
[trial] = run_experiments(
|
||||
{
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
callbacks=[LegacyExperimentLogger(logger_classes=[CustomLogger])])
|
||||
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log")))
|
||||
self.assertFalse(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
@@ -203,16 +207,65 @@ class RunExperimentTest(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
|
||||
[trial] = run_experiments(
|
||||
{
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
callbacks=[LegacyExperimentLogger(logger_classes=[])])
|
||||
self.assertFalse(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
|
||||
def testCustomLoggerWithAutoLogging(self):
|
||||
"""Creates CSV/JSON logger callbacks automatically"""
|
||||
if "TUNE_DISABLE_AUTO_CALLBACK_LOGGERS" in os.environ:
|
||||
del os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"]
|
||||
|
||||
class CustomLogger(Logger):
|
||||
def on_result(self, result):
|
||||
with open(os.path.join(self.logdir, "test.log"), "w") as f:
|
||||
f.write("hi")
|
||||
|
||||
[trial] = run_experiments(
|
||||
{
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
callbacks=[LegacyExperimentLogger(logger_classes=[CustomLogger])])
|
||||
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log")))
|
||||
self.assertTrue(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"loggers": []
|
||||
}
|
||||
}
|
||||
})
|
||||
self.assertFalse(
|
||||
self.assertTrue(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
|
||||
[trial] = run_experiments(
|
||||
{
|
||||
"foo": {
|
||||
"run": "__fake",
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
callbacks=[LegacyExperimentLogger(logger_classes=[])])
|
||||
self.assertTrue(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
|
||||
def testCustomTrialString(self):
|
||||
|
||||
@@ -263,7 +263,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
if result["training_iteration"] == 1:
|
||||
executor = trial_runner.trial_executor
|
||||
executor.stop_trial(trial, stop_logger=False)
|
||||
executor.stop_trial(trial)
|
||||
trial.update_resources(2, 0)
|
||||
executor.start_trial(trial)
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
@@ -9,12 +9,16 @@ import ray
|
||||
from ray import tune
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.checkpoint_manager import Checkpoint
|
||||
from ray.tune.logger import DEFAULT_LOGGERS, ExperimentLogger, \
|
||||
LegacyExperimentLogger
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.syncer import SyncConfig, SyncerCallback
|
||||
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune import Callback
|
||||
from ray.tune.utils.callback import create_default_callbacks
|
||||
|
||||
|
||||
class TestCallback(Callback):
|
||||
@@ -200,6 +204,63 @@ class TrialRunnerCallbacks(unittest.TestCase):
|
||||
self.callback.state["trial_complete"]["trial"].config["do"],
|
||||
"delay")
|
||||
|
||||
def testCallbackReordering(self):
|
||||
"""SyncerCallback should come after ExperimentLogger callbacks"""
|
||||
|
||||
def get_positions(callbacks):
|
||||
first_logger_pos = None
|
||||
last_logger_pos = None
|
||||
syncer_pos = None
|
||||
for i, callback in enumerate(callbacks):
|
||||
if isinstance(callback, ExperimentLogger):
|
||||
if first_logger_pos is None:
|
||||
first_logger_pos = i
|
||||
last_logger_pos = i
|
||||
elif isinstance(callback, SyncerCallback):
|
||||
syncer_pos = i
|
||||
return first_logger_pos, last_logger_pos, syncer_pos
|
||||
|
||||
# Auto creation of loggers, no callbacks, no syncer
|
||||
callbacks = create_default_callbacks(None, SyncConfig(), None)
|
||||
first_logger_pos, last_logger_pos, syncer_pos = get_positions(
|
||||
callbacks)
|
||||
self.assertLess(last_logger_pos, syncer_pos)
|
||||
|
||||
# Auto creation of loggers with callbacks
|
||||
callbacks = create_default_callbacks([Callback()], SyncConfig(), None)
|
||||
first_logger_pos, last_logger_pos, syncer_pos = get_positions(
|
||||
callbacks)
|
||||
self.assertLess(last_logger_pos, syncer_pos)
|
||||
|
||||
# Auto creation of loggers with existing logger (but no CSV/JSON)
|
||||
callbacks = create_default_callbacks([ExperimentLogger()],
|
||||
SyncConfig(), None)
|
||||
first_logger_pos, last_logger_pos, syncer_pos = get_positions(
|
||||
callbacks)
|
||||
self.assertLess(last_logger_pos, syncer_pos)
|
||||
|
||||
# This should throw an error as the syncer comes before the logger
|
||||
with self.assertRaises(ValueError):
|
||||
callbacks = create_default_callbacks(
|
||||
[SyncerCallback(None),
|
||||
ExperimentLogger()], SyncConfig(), None)
|
||||
|
||||
# This should be reordered but preserve the regular callback order
|
||||
[mc1, mc2, mc3] = [Callback(), Callback(), Callback()]
|
||||
# Has to be legacy logger to avoid logger callback creation
|
||||
lc = LegacyExperimentLogger(logger_classes=DEFAULT_LOGGERS)
|
||||
callbacks = create_default_callbacks([mc1, mc2, lc, mc3], SyncConfig(),
|
||||
None)
|
||||
print(callbacks)
|
||||
first_logger_pos, last_logger_pos, syncer_pos = get_positions(
|
||||
callbacks)
|
||||
self.assertLess(last_logger_pos, syncer_pos)
|
||||
self.assertLess(callbacks.index(mc1), callbacks.index(mc2))
|
||||
self.assertLess(callbacks.index(mc2), callbacks.index(mc3))
|
||||
self.assertLess(callbacks.index(lc), callbacks.index(mc3))
|
||||
# Syncer callback is appended
|
||||
self.assertLess(callbacks.index(mc3), syncer_pos)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
@@ -217,10 +217,8 @@ class _MockTrialExecutor(TrialExecutor):
|
||||
trial.restored_checkpoint = checkpoint_obj.value
|
||||
trial.status = Trial.RUNNING
|
||||
|
||||
def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True):
|
||||
def stop_trial(self, trial, error=False, error_msg=None):
|
||||
trial.status = Trial.ERROR if error else Trial.TERMINATED
|
||||
if stop_logger:
|
||||
trial.logger_running = False
|
||||
|
||||
def restore(self, trial, checkpoint=None, block=False):
|
||||
pass
|
||||
|
||||
+11
-34
@@ -12,10 +12,10 @@ import os
|
||||
from numbers import Number
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager
|
||||
from ray.tune.logger import pretty_print, UnifiedLogger
|
||||
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
|
||||
# need because there are cyclic imports that may cause specific names to not
|
||||
# have been defined yet. See https://github.com/ray-project/ray/issues/1716.
|
||||
from ray.tune.logger import pretty_print
|
||||
from ray.tune.registry import get_trainable_cls, validate_trainable
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION
|
||||
from ray.tune.resources import Resources, json_to_resources, resources_to_json
|
||||
@@ -189,7 +189,6 @@ class Trial:
|
||||
restore_path=None,
|
||||
trial_name_creator=None,
|
||||
trial_dirname_creator=None,
|
||||
loggers=None,
|
||||
log_to_file=None,
|
||||
max_failures=0):
|
||||
"""Initialize a new trial.
|
||||
@@ -222,7 +221,6 @@ class Trial:
|
||||
self.location = Location()
|
||||
self.resources = resources or Resources(cpu=1, gpu=0)
|
||||
self.stopping_criterion = stopping_criterion or {}
|
||||
self.loggers = loggers
|
||||
|
||||
self.log_to_file = log_to_file
|
||||
# Make sure `stdout_file, stderr_file = Trial.log_to_file` works
|
||||
@@ -250,7 +248,6 @@ class Trial:
|
||||
self.start_time = None
|
||||
self.logdir = None
|
||||
self.runner = None
|
||||
self.result_logger = None
|
||||
self.last_debug = 0
|
||||
self.error_file = None
|
||||
self.error_msg = None
|
||||
@@ -285,7 +282,6 @@ class Trial:
|
||||
self.extra_arg = None
|
||||
|
||||
self._nonjson_fields = [
|
||||
"loggers",
|
||||
"results",
|
||||
"best_result",
|
||||
"param_config",
|
||||
@@ -350,22 +346,17 @@ class Trial:
|
||||
export_formats=self.export_formats,
|
||||
restore_path=self.restore_path,
|
||||
trial_name_creator=self.trial_name_creator,
|
||||
loggers=self.loggers,
|
||||
log_to_file=self.log_to_file,
|
||||
max_failures=self.max_failures,
|
||||
)
|
||||
|
||||
def init_logger(self):
|
||||
"""Init logger."""
|
||||
if not self.result_logger:
|
||||
if not self.logdir:
|
||||
self.logdir = create_logdir(self._generate_dirname(),
|
||||
self.local_dir)
|
||||
else:
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
|
||||
self.result_logger = UnifiedLogger(
|
||||
self.config, self.logdir, trial=self, loggers=self.loggers)
|
||||
def init_logdir(self):
|
||||
"""Init logdir."""
|
||||
if not self.logdir:
|
||||
self.logdir = create_logdir(self._generate_dirname(),
|
||||
self.local_dir)
|
||||
else:
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
|
||||
def update_resources(self, cpu, gpu, **kwargs):
|
||||
"""EXPERIMENTAL: Updates the resource requirements.
|
||||
@@ -395,12 +386,6 @@ class Trial:
|
||||
if self.start_time is None:
|
||||
self.start_time = time.time()
|
||||
|
||||
def close_logger(self):
|
||||
"""Closes logger."""
|
||||
if self.result_logger:
|
||||
self.result_logger.close()
|
||||
self.result_logger = None
|
||||
|
||||
def write_error_log(self, error_msg):
|
||||
if error_msg and self.logdir:
|
||||
self.num_failures += 1
|
||||
@@ -479,7 +464,6 @@ class Trial:
|
||||
self.set_location(Location(result.get("node_ip"), result.get("pid")))
|
||||
self.last_result = result
|
||||
self.last_update_time = time.time()
|
||||
self.result_logger.on_result(self.last_result)
|
||||
|
||||
for metric, value in flatten_dict(result).items():
|
||||
if isinstance(value, Number):
|
||||
@@ -569,7 +553,7 @@ class Trial:
|
||||
def __getstate__(self):
|
||||
"""Memento generator for Trial.
|
||||
|
||||
Sets RUNNING trials to PENDING, and flushes the result logger.
|
||||
Sets RUNNING trials to PENDING.
|
||||
Note this can only occur if the trial holds a PERSISTENT checkpoint.
|
||||
"""
|
||||
assert self.checkpoint.storage == Checkpoint.PERSISTENT, (
|
||||
@@ -582,19 +566,13 @@ class Trial:
|
||||
|
||||
state["runner"] = None
|
||||
state["location"] = Location()
|
||||
state["result_logger"] = None
|
||||
# Avoid waiting for events that will never occur on resume.
|
||||
state["resuming_from"] = None
|
||||
state["saving_to"] = None
|
||||
if self.result_logger:
|
||||
self.result_logger.flush()
|
||||
state["__logger_started__"] = True
|
||||
else:
|
||||
state["__logger_started__"] = False
|
||||
|
||||
return copy.deepcopy(state)
|
||||
|
||||
def __setstate__(self, state):
|
||||
logger_started = state.pop("__logger_started__")
|
||||
state["resources"] = json_to_resources(state["resources"])
|
||||
|
||||
if state["status"] == Trial.RUNNING:
|
||||
@@ -604,5 +582,4 @@ class Trial:
|
||||
|
||||
self.__dict__.update(state)
|
||||
validate_trainable(self.trainable_name)
|
||||
if logger_started:
|
||||
self.init_logger()
|
||||
self.init_logdir() # Create logdir if it does not exist
|
||||
|
||||
@@ -85,7 +85,7 @@ class TrialExecutor:
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"start_trial() method")
|
||||
|
||||
def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True):
|
||||
def stop_trial(self, trial, error=False, error_msg=None):
|
||||
"""Stops the trial.
|
||||
|
||||
Stops this trial, releasing all allocating resources.
|
||||
@@ -95,7 +95,6 @@ class TrialExecutor:
|
||||
Args:
|
||||
error (bool): Whether to mark this trial as terminated in error.
|
||||
error_msg (str): Optional error message.
|
||||
stop_logger (bool): Whether to shut down the trial logger.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"stop_trial() method")
|
||||
@@ -113,7 +112,7 @@ class TrialExecutor:
|
||||
assert trial.status == Trial.RUNNING, trial.status
|
||||
try:
|
||||
self.save(trial, Checkpoint.MEMORY)
|
||||
self.stop_trial(trial, stop_logger=False)
|
||||
self.stop_trial(trial)
|
||||
self.set_status(trial, Trial.PAUSED)
|
||||
except Exception:
|
||||
logger.exception("Error pausing runner.")
|
||||
|
||||
@@ -793,11 +793,7 @@ class TrialRunner:
|
||||
# Restore was unsuccessful, try again without checkpoint.
|
||||
trial.clear_checkpoint()
|
||||
self.trial_executor.stop_trial(
|
||||
trial,
|
||||
error=error_msg is not None,
|
||||
error_msg=error_msg,
|
||||
stop_logger=False)
|
||||
trial.result_logger.flush()
|
||||
trial, error=error_msg is not None, error_msg=error_msg)
|
||||
if self.trial_executor.has_resources(trial.resources):
|
||||
logger.info(
|
||||
"Trial %s: Attempting to restore "
|
||||
|
||||
@@ -311,7 +311,6 @@ def run(
|
||||
sync_to_driver=sync_config.sync_to_driver,
|
||||
trial_name_creator=trial_name_creator,
|
||||
trial_dirname_creator=trial_dirname_creator,
|
||||
loggers=loggers,
|
||||
log_to_file=log_to_file,
|
||||
checkpoint_freq=checkpoint_freq,
|
||||
checkpoint_at_end=checkpoint_at_end,
|
||||
@@ -355,8 +354,9 @@ def run(
|
||||
"own `metric` and `mode` parameters. Either remove the arguments "
|
||||
"from your scheduler or from your call to `tune.run()`")
|
||||
|
||||
# Create syncer callbacks
|
||||
callbacks = create_default_callbacks(callbacks, sync_config)
|
||||
# Create logger and syncer callbacks
|
||||
callbacks = create_default_callbacks(
|
||||
callbacks, sync_config, loggers=loggers)
|
||||
|
||||
runner = TrialRunner(
|
||||
search_alg=search_alg,
|
||||
|
||||
@@ -3,16 +3,73 @@ from typing import List, Optional
|
||||
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.syncer import SyncConfig
|
||||
from ray.tune.logger import CSVLogger, DEFAULT_LOGGERS, ExperimentLogger, \
|
||||
JsonLogger, LegacyExperimentLogger, Logger
|
||||
from ray.tune.syncer import SyncerCallback
|
||||
|
||||
|
||||
def create_default_callbacks(callbacks: Optional[List[Callback]],
|
||||
sync_config: SyncConfig):
|
||||
sync_config: SyncConfig,
|
||||
loggers: Optional[List[Logger]]):
|
||||
|
||||
callbacks = callbacks or []
|
||||
has_syncer_callback = False
|
||||
has_csv_logger = False
|
||||
has_json_logger = False
|
||||
|
||||
# Check if there is a SyncerCallback
|
||||
has_syncer_callback = any(isinstance(c, SyncerCallback) for c in callbacks)
|
||||
# Track syncer obj/index to move callback after loggers
|
||||
last_logger_index = None
|
||||
syncer_index = None
|
||||
|
||||
if not loggers:
|
||||
# If no logger callback and no `loggers` have been provided,
|
||||
# add DEFAULT_LOGGERS.
|
||||
if not any(
|
||||
isinstance(callback, ExperimentLogger)
|
||||
for callback in callbacks):
|
||||
loggers = DEFAULT_LOGGERS
|
||||
|
||||
# Create LegacyExperimentLogger for passed Logger classes
|
||||
if loggers:
|
||||
# Todo(krfricke): Deprecate `loggers` argument, print warning here.
|
||||
add_loggers = []
|
||||
for trial_logger in loggers:
|
||||
if isinstance(trial_logger, ExperimentLogger):
|
||||
callbacks.append(trial_logger)
|
||||
elif isinstance(trial_logger, type) and issubclass(
|
||||
trial_logger, Logger):
|
||||
add_loggers.append(trial_logger)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid value passed to `loggers` argument of "
|
||||
f"`tune.run()`: {trial_logger}")
|
||||
if add_loggers:
|
||||
callbacks.append(LegacyExperimentLogger(add_loggers))
|
||||
|
||||
# Check if we have a CSV and JSON logger
|
||||
for i, callback in enumerate(callbacks):
|
||||
if isinstance(callback, LegacyExperimentLogger):
|
||||
last_logger_index = i
|
||||
if CSVLogger in callback.logger_classes:
|
||||
has_csv_logger = True
|
||||
if JsonLogger in callback.logger_classes:
|
||||
has_json_logger = True
|
||||
# Todo(krfricke): add checks for new ExperimentLogger classes
|
||||
elif isinstance(callback, SyncerCallback):
|
||||
syncer_index = i
|
||||
has_syncer_callback = True
|
||||
|
||||
# If CSV or JSON logger is missing, add
|
||||
if os.environ.get("TUNE_DISABLE_AUTO_CALLBACK_LOGGERS", "0") != "1":
|
||||
# Todo(krfricke): Switch to new ExperimentLogger classes
|
||||
add_loggers = []
|
||||
if not has_csv_logger:
|
||||
add_loggers.append(CSVLogger)
|
||||
if not has_json_logger:
|
||||
add_loggers.append(JsonLogger)
|
||||
if add_loggers:
|
||||
callbacks.append(LegacyExperimentLogger(add_loggers))
|
||||
last_logger_index = len(callbacks) - 1
|
||||
|
||||
# If no SyncerCallback was found, add
|
||||
if not has_syncer_callback and os.environ.get(
|
||||
@@ -20,5 +77,25 @@ def create_default_callbacks(callbacks: Optional[List[Callback]],
|
||||
syncer_callback = SyncerCallback(
|
||||
sync_function=sync_config.sync_to_driver)
|
||||
callbacks.append(syncer_callback)
|
||||
syncer_index = len(callbacks) - 1
|
||||
|
||||
# Todo(krfricke): Maybe check if syncer comes after all loggers
|
||||
if syncer_index is not None and last_logger_index is not None and \
|
||||
syncer_index < last_logger_index:
|
||||
if (not has_csv_logger or not has_json_logger) and not loggers:
|
||||
# Only raise the warning if the loggers were passed by the user.
|
||||
# (I.e. don't warn if this was automatic behavior and they only
|
||||
# passed a customer SyncerCallback).
|
||||
raise ValueError(
|
||||
"The `SyncerCallback` you passed to `tune.run()` came before "
|
||||
"at least one `ExperimentLogger`. Syncing should be done "
|
||||
"after writing logs. Please re-order the callbacks so that "
|
||||
"the `SyncerCallback` comes after any `ExperimentLogger`.")
|
||||
else:
|
||||
# If these loggers were automatically created. just re-order
|
||||
# the callbacks
|
||||
syncer_obj = callbacks[syncer_index]
|
||||
callbacks.pop(syncer_index)
|
||||
callbacks.insert(last_logger_index, syncer_obj)
|
||||
|
||||
return callbacks
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
import inspect
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict, deque, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
@@ -540,6 +541,30 @@ def detect_config_single(func):
|
||||
return use_config_single
|
||||
|
||||
|
||||
def create_logdir(dirname: str, local_dir: str):
|
||||
"""Create an empty logdir with name `dirname` in `local_dir`.
|
||||
|
||||
If `local_dir`/`dirname` already exists, a unique string is appended
|
||||
to the dirname.
|
||||
|
||||
Args:
|
||||
dirname (str): Dirname to create in `local_dir`
|
||||
local_dir (str): Root directory for the log dir
|
||||
|
||||
Returns: full path to the newly created logdir.
|
||||
"""
|
||||
local_dir = os.path.expanduser(local_dir)
|
||||
logdir = os.path.join(local_dir, dirname)
|
||||
if os.path.exists(logdir):
|
||||
old_dirname = dirname
|
||||
dirname += "_" + uuid.uuid4().hex[:4]
|
||||
logger.info(f"Creating a new dirname {dirname} because "
|
||||
f"trial dirname '{old_dirname}' already exists.")
|
||||
logdir = os.path.join(local_dir, dirname)
|
||||
os.makedirs(logdir, exist_ok=True)
|
||||
return logdir
|
||||
|
||||
|
||||
class SafeFallbackEncoder(json.JSONEncoder):
|
||||
def __init__(self, nan_str="null", **kwargs):
|
||||
super(SafeFallbackEncoder, self).__init__(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user