[tune] logger refactor part 3: Add ExperimentLogger class (#11749)

This commit is contained in:
Kai Fricke
2020-11-05 17:55:38 +01:00
committed by GitHub
parent f6717b8b03
commit 603accf1c2
17 changed files with 389 additions and 95 deletions
-1
View File
@@ -190,7 +190,6 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
restore_path=spec.get("restore"),
trial_name_creator=spec.get("trial_name_creator"),
trial_dirname_creator=spec.get("trial_dirname_creator"),
loggers=spec.get("loggers"),
log_to_file=spec.get("log_to_file"),
# str(None) doesn't create None
max_failures=args.max_failures,
+11
View File
@@ -123,6 +123,17 @@ class Experiment:
max_failures=0,
restore=None):
if loggers is not None:
# Most users won't run into this as `tune.run()` does not pass
# the argument anymore. However, we will want to inform users
# if they instantiate their `Experiment` objects themselves.
raise ValueError(
"Passing `loggers` to an `Experiment` is deprecated. Use "
"an `ExperimentLogger` callback instead, e.g. by passing the "
"`Logger` classes to `tune.logger.LegacyExperimentLogger` and "
"passing this as part of the `callback` parameter to "
"`tune.run()`.")
config = config or {}
if callable(run) and detect_checkpoint_function(run):
if checkpoint_at_end:
+97 -3
View File
@@ -5,10 +5,11 @@ import numpy as np
import os
import yaml
from typing import TYPE_CHECKING, Dict, List, Optional, Type
from typing import Iterable, TYPE_CHECKING, Dict, List, Optional, Type
import ray.cloudpickle as cloudpickle
from ray.tune.callback import Callback
from ray.tune.utils.util import SafeFallbackEncoder
from ray.util.debug import log_once
from ray.tune.result import (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
@@ -129,7 +130,8 @@ class JsonLogger(Logger):
self.local_out.write(b)
def flush(self):
self.local_out.flush()
if not self.local_out.closed:
self.local_out.flush()
def close(self):
self.local_out.close()
@@ -181,7 +183,8 @@ class CSVLogger(Logger):
self._file.flush()
def flush(self):
self._file.flush()
if not self._file.closed:
self._file.flush()
def close(self):
self._file.close()
@@ -361,6 +364,97 @@ class UnifiedLogger(Logger):
_logger.flush()
class ExperimentLogger(Callback):
def log_trial_start(self, trial: "Trial"):
raise NotImplementedError
def log_trial_restore(self, trial: "Trial"):
raise NotImplementedError
def log_trial_save(self, trial: "Trial"):
raise NotImplementedError
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
raise NotImplementedError
def log_trial_end(self, trial: "Trial", failed: bool = False):
raise NotImplementedError
def on_trial_result(self, iteration: int, trials: List["Trial"],
trial: "Trial", result: Dict, **info):
self.log_trial_result(iteration, trial, result)
def on_trial_start(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
self.log_trial_start(trial)
def on_trial_restore(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
self.log_trial_restore(trial)
def on_trial_save(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
self.log_trial_save(trial)
def on_trial_complete(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
self.log_trial_end(trial, failed=False)
def on_trial_error(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
self.log_trial_end(trial, failed=True)
class LegacyExperimentLogger(ExperimentLogger):
"""Supports logging to trial-specific `Logger` classes.
Previously, Ray Tune logging was handled via `Logger` classes that have
been instantiated per-trial. This callback is a fallback to these
`Logger`-classes, instantiating each `Logger` class for each trial
and logging to them.
Args:
logger_classes (Iterable[Type[Logger]]): Logger classes that should
be instantiated for each trial.
"""
def __init__(self, logger_classes: Iterable[Type[Logger]]):
self.logger_classes = list(logger_classes)
self._class_trial_loggers: Dict[Type[Logger], Dict["Trial",
Logger]] = {}
def log_trial_start(self, trial: "Trial"):
trial.init_logdir()
for logger_class in self.logger_classes:
trial_loggers = self._class_trial_loggers.get(logger_class, {})
if trial not in trial_loggers:
logger = logger_class(trial.config, trial.logdir, trial)
trial_loggers[trial] = logger
self._class_trial_loggers[logger_class] = trial_loggers
def log_trial_restore(self, trial: "Trial"):
for logger_class, trial_loggers in self._class_trial_loggers.items():
if trial in trial_loggers:
trial_loggers[trial].flush()
def log_trial_save(self, trial: "Trial"):
for logger_class, trial_loggers in self._class_trial_loggers.items():
if trial in trial_loggers:
trial_loggers[trial].flush()
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
for logger_class, trial_loggers in self._class_trial_loggers.items():
if trial in trial_loggers:
trial_loggers[trial].on_result(result)
def log_trial_end(self, trial: "Trial", failed: bool = False):
for logger_class, trial_loggers in self._class_trial_loggers.items():
if trial in trial_loggers:
trial_loggers[trial].close()
def pretty_print(result):
result = result.copy()
result.update(config=None) # drop config from pretty print
+4 -9
View File
@@ -177,7 +177,7 @@ class RayTrialExecutor(TrialExecutor):
self._update_avail_resources()
def _setup_remote_runner(self, trial, reuse_allowed):
trial.init_logger()
trial.init_logdir()
# We checkpoint metadata here to try mitigating logdir duplication
self.try_checkpoint_metadata(trial)
logger_creator = partial(noop_logger_creator, logdir=trial.logdir)
@@ -297,8 +297,7 @@ class RayTrialExecutor(TrialExecutor):
elif train and not trial.is_restoring:
self._train(trial)
def _stop_trial(self, trial, error=False, error_msg=None,
stop_logger=True):
def _stop_trial(self, trial, error=False, error_msg=None):
"""Stops this trial.
Stops this trial, releasing all allocating resources. If stopping the
@@ -308,7 +307,6 @@ class RayTrialExecutor(TrialExecutor):
Args:
error (bool): Whether to mark this trial as terminated in error.
error_msg (str): Optional error message.
stop_logger (bool): Whether to shut down the trial logger.
"""
self.set_status(trial, Trial.ERROR if error else Trial.TERMINATED)
trial.set_location(Location())
@@ -329,8 +327,6 @@ class RayTrialExecutor(TrialExecutor):
self.set_status(trial, Trial.ERROR)
finally:
trial.set_runner(None)
if stop_logger:
trial.close_logger()
def start_trial(self, trial, checkpoint=None, train=True):
"""Starts the trial.
@@ -365,11 +361,10 @@ class RayTrialExecutor(TrialExecutor):
out = [rid for rid, t in dictionary.items() if t is item]
return out
def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True):
def stop_trial(self, trial, error=False, error_msg=None):
"""Only returns resources if resources allocated."""
prior_status = trial.status
self._stop_trial(
trial, error=error, error_msg=error_msg, stop_logger=stop_logger)
self._stop_trial(trial, error=error, error_msg=error_msg)
if prior_status == Trial.RUNNING:
logger.debug("Trial %s: Returning resources.", trial)
self._return_resources(trial.resources)
+1 -1
View File
@@ -574,7 +574,7 @@ class PopulationBasedTraining(FIFOScheduler):
trial_executor.restore(
trial, new_state.last_checkpoint, block=True)
else:
trial_executor.stop_trial(trial, stop_logger=False)
trial_executor.stop_trial(trial)
trial.config = new_config
trial.experiment_tag = new_tag
trial_executor.start_trial(
+13 -13
View File
@@ -12,7 +12,7 @@ from ray.rllib import _register_all
from ray import tune
from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper,
EarlyStopping)
EarlyStopping, run)
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
AsyncHyperBandScheduler)
@@ -90,19 +90,19 @@ class TrainableFunctionApiTest(unittest.TestCase):
class_trainable_name = "class_trainable"
register_trainable(class_trainable_name, _WrappedTrainable)
trials = run_experiments(
{
"function_api": {
"run": _function_trainable,
"loggers": [FunctionAPILogger],
},
"class_api": {
"run": class_trainable_name,
"loggers": [ClassAPILogger],
},
},
[trial1] = run(
_function_trainable,
loggers=[FunctionAPILogger],
raise_on_failed_trial=False,
scheduler=MockScheduler())
scheduler=MockScheduler()).trials
[trial2] = run(
class_trainable_name,
loggers=[ClassAPILogger],
raise_on_failed_trial=False,
scheduler=MockScheduler()).trials
trials = [trial1, trial2]
# Ignore these fields
NO_COMPARE_FIELDS = {
+67 -14
View File
@@ -7,7 +7,7 @@ from ray.rllib import _register_all
from ray.tune.result import TIMESTEPS_TOTAL
from ray.tune import Trainable, TuneError
from ray.tune import register_trainable, run_experiments
from ray.tune.logger import Logger
from ray.tune.logger import LegacyExperimentLogger, Logger
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial, ExportFormat
@@ -173,21 +173,25 @@ class RunExperimentTest(unittest.TestCase):
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
def testCustomLogger(self):
def testCustomLoggerNoAutoLogging(self):
"""Does not create CSV/JSON logger callbacks automatically"""
os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"] = "1"
class CustomLogger(Logger):
def on_result(self, result):
with open(os.path.join(self.logdir, "test.log"), "w") as f:
f.write("hi")
[trial] = run_experiments({
"foo": {
"run": "__fake",
"stop": {
"training_iteration": 1
},
"loggers": [CustomLogger]
}
})
[trial] = run_experiments(
{
"foo": {
"run": "__fake",
"stop": {
"training_iteration": 1
}
}
},
callbacks=[LegacyExperimentLogger(logger_classes=[CustomLogger])])
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log")))
self.assertFalse(
os.path.exists(os.path.join(trial.logdir, "params.json")))
@@ -203,16 +207,65 @@ class RunExperimentTest(unittest.TestCase):
self.assertTrue(
os.path.exists(os.path.join(trial.logdir, "params.json")))
[trial] = run_experiments(
{
"foo": {
"run": "__fake",
"stop": {
"training_iteration": 1
}
}
},
callbacks=[LegacyExperimentLogger(logger_classes=[])])
self.assertFalse(
os.path.exists(os.path.join(trial.logdir, "params.json")))
def testCustomLoggerWithAutoLogging(self):
"""Creates CSV/JSON logger callbacks automatically"""
if "TUNE_DISABLE_AUTO_CALLBACK_LOGGERS" in os.environ:
del os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"]
class CustomLogger(Logger):
def on_result(self, result):
with open(os.path.join(self.logdir, "test.log"), "w") as f:
f.write("hi")
[trial] = run_experiments(
{
"foo": {
"run": "__fake",
"stop": {
"training_iteration": 1
}
}
},
callbacks=[LegacyExperimentLogger(logger_classes=[CustomLogger])])
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log")))
self.assertTrue(
os.path.exists(os.path.join(trial.logdir, "params.json")))
[trial] = run_experiments({
"foo": {
"run": "__fake",
"stop": {
"training_iteration": 1
},
"loggers": []
}
}
})
self.assertFalse(
self.assertTrue(
os.path.exists(os.path.join(trial.logdir, "params.json")))
[trial] = run_experiments(
{
"foo": {
"run": "__fake",
"stop": {
"training_iteration": 1
}
}
},
callbacks=[LegacyExperimentLogger(logger_classes=[])])
self.assertTrue(
os.path.exists(os.path.join(trial.logdir, "params.json")))
def testCustomTrialString(self):
+1 -1
View File
@@ -263,7 +263,7 @@ class TrialRunnerTest(unittest.TestCase):
def on_trial_result(self, trial_runner, trial, result):
if result["training_iteration"] == 1:
executor = trial_runner.trial_executor
executor.stop_trial(trial, stop_logger=False)
executor.stop_trial(trial)
trial.update_resources(2, 0)
executor.start_trial(trial)
return TrialScheduler.CONTINUE
@@ -9,12 +9,16 @@ import ray
from ray import tune
from ray.rllib import _register_all
from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.logger import DEFAULT_LOGGERS, ExperimentLogger, \
LegacyExperimentLogger
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import TRAINING_ITERATION
from ray.tune.syncer import SyncConfig, SyncerCallback
from ray.tune.trial import Trial
from ray.tune.callback import Callback
from ray.tune.trial_runner import TrialRunner
from ray.tune import Callback
from ray.tune.utils.callback import create_default_callbacks
class TestCallback(Callback):
@@ -200,6 +204,63 @@ class TrialRunnerCallbacks(unittest.TestCase):
self.callback.state["trial_complete"]["trial"].config["do"],
"delay")
def testCallbackReordering(self):
"""SyncerCallback should come after ExperimentLogger callbacks"""
def get_positions(callbacks):
first_logger_pos = None
last_logger_pos = None
syncer_pos = None
for i, callback in enumerate(callbacks):
if isinstance(callback, ExperimentLogger):
if first_logger_pos is None:
first_logger_pos = i
last_logger_pos = i
elif isinstance(callback, SyncerCallback):
syncer_pos = i
return first_logger_pos, last_logger_pos, syncer_pos
# Auto creation of loggers, no callbacks, no syncer
callbacks = create_default_callbacks(None, SyncConfig(), None)
first_logger_pos, last_logger_pos, syncer_pos = get_positions(
callbacks)
self.assertLess(last_logger_pos, syncer_pos)
# Auto creation of loggers with callbacks
callbacks = create_default_callbacks([Callback()], SyncConfig(), None)
first_logger_pos, last_logger_pos, syncer_pos = get_positions(
callbacks)
self.assertLess(last_logger_pos, syncer_pos)
# Auto creation of loggers with existing logger (but no CSV/JSON)
callbacks = create_default_callbacks([ExperimentLogger()],
SyncConfig(), None)
first_logger_pos, last_logger_pos, syncer_pos = get_positions(
callbacks)
self.assertLess(last_logger_pos, syncer_pos)
# This should throw an error as the syncer comes before the logger
with self.assertRaises(ValueError):
callbacks = create_default_callbacks(
[SyncerCallback(None),
ExperimentLogger()], SyncConfig(), None)
# This should be reordered but preserve the regular callback order
[mc1, mc2, mc3] = [Callback(), Callback(), Callback()]
# Has to be legacy logger to avoid logger callback creation
lc = LegacyExperimentLogger(logger_classes=DEFAULT_LOGGERS)
callbacks = create_default_callbacks([mc1, mc2, lc, mc3], SyncConfig(),
None)
print(callbacks)
first_logger_pos, last_logger_pos, syncer_pos = get_positions(
callbacks)
self.assertLess(last_logger_pos, syncer_pos)
self.assertLess(callbacks.index(mc1), callbacks.index(mc2))
self.assertLess(callbacks.index(mc2), callbacks.index(mc3))
self.assertLess(callbacks.index(lc), callbacks.index(mc3))
# Syncer callback is appended
self.assertLess(callbacks.index(mc3), syncer_pos)
if __name__ == "__main__":
import pytest
@@ -217,10 +217,8 @@ class _MockTrialExecutor(TrialExecutor):
trial.restored_checkpoint = checkpoint_obj.value
trial.status = Trial.RUNNING
def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True):
def stop_trial(self, trial, error=False, error_msg=None):
trial.status = Trial.ERROR if error else Trial.TERMINATED
if stop_logger:
trial.logger_running = False
def restore(self, trial, checkpoint=None, block=False):
pass
+11 -34
View File
@@ -12,10 +12,10 @@ import os
from numbers import Number
from ray.tune import TuneError
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager
from ray.tune.logger import pretty_print, UnifiedLogger
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
# need because there are cyclic imports that may cause specific names to not
# have been defined yet. See https://github.com/ray-project/ray/issues/1716.
from ray.tune.logger import pretty_print
from ray.tune.registry import get_trainable_cls, validate_trainable
from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION
from ray.tune.resources import Resources, json_to_resources, resources_to_json
@@ -189,7 +189,6 @@ class Trial:
restore_path=None,
trial_name_creator=None,
trial_dirname_creator=None,
loggers=None,
log_to_file=None,
max_failures=0):
"""Initialize a new trial.
@@ -222,7 +221,6 @@ class Trial:
self.location = Location()
self.resources = resources or Resources(cpu=1, gpu=0)
self.stopping_criterion = stopping_criterion or {}
self.loggers = loggers
self.log_to_file = log_to_file
# Make sure `stdout_file, stderr_file = Trial.log_to_file` works
@@ -250,7 +248,6 @@ class Trial:
self.start_time = None
self.logdir = None
self.runner = None
self.result_logger = None
self.last_debug = 0
self.error_file = None
self.error_msg = None
@@ -285,7 +282,6 @@ class Trial:
self.extra_arg = None
self._nonjson_fields = [
"loggers",
"results",
"best_result",
"param_config",
@@ -350,22 +346,17 @@ class Trial:
export_formats=self.export_formats,
restore_path=self.restore_path,
trial_name_creator=self.trial_name_creator,
loggers=self.loggers,
log_to_file=self.log_to_file,
max_failures=self.max_failures,
)
def init_logger(self):
"""Init logger."""
if not self.result_logger:
if not self.logdir:
self.logdir = create_logdir(self._generate_dirname(),
self.local_dir)
else:
os.makedirs(self.logdir, exist_ok=True)
self.result_logger = UnifiedLogger(
self.config, self.logdir, trial=self, loggers=self.loggers)
def init_logdir(self):
"""Init logdir."""
if not self.logdir:
self.logdir = create_logdir(self._generate_dirname(),
self.local_dir)
else:
os.makedirs(self.logdir, exist_ok=True)
def update_resources(self, cpu, gpu, **kwargs):
"""EXPERIMENTAL: Updates the resource requirements.
@@ -395,12 +386,6 @@ class Trial:
if self.start_time is None:
self.start_time = time.time()
def close_logger(self):
"""Closes logger."""
if self.result_logger:
self.result_logger.close()
self.result_logger = None
def write_error_log(self, error_msg):
if error_msg and self.logdir:
self.num_failures += 1
@@ -479,7 +464,6 @@ class Trial:
self.set_location(Location(result.get("node_ip"), result.get("pid")))
self.last_result = result
self.last_update_time = time.time()
self.result_logger.on_result(self.last_result)
for metric, value in flatten_dict(result).items():
if isinstance(value, Number):
@@ -569,7 +553,7 @@ class Trial:
def __getstate__(self):
"""Memento generator for Trial.
Sets RUNNING trials to PENDING, and flushes the result logger.
Sets RUNNING trials to PENDING.
Note this can only occur if the trial holds a PERSISTENT checkpoint.
"""
assert self.checkpoint.storage == Checkpoint.PERSISTENT, (
@@ -582,19 +566,13 @@ class Trial:
state["runner"] = None
state["location"] = Location()
state["result_logger"] = None
# Avoid waiting for events that will never occur on resume.
state["resuming_from"] = None
state["saving_to"] = None
if self.result_logger:
self.result_logger.flush()
state["__logger_started__"] = True
else:
state["__logger_started__"] = False
return copy.deepcopy(state)
def __setstate__(self, state):
logger_started = state.pop("__logger_started__")
state["resources"] = json_to_resources(state["resources"])
if state["status"] == Trial.RUNNING:
@@ -604,5 +582,4 @@ class Trial:
self.__dict__.update(state)
validate_trainable(self.trainable_name)
if logger_started:
self.init_logger()
self.init_logdir() # Create logdir if it does not exist
+2 -3
View File
@@ -85,7 +85,7 @@ class TrialExecutor:
raise NotImplementedError("Subclasses of TrialExecutor must provide "
"start_trial() method")
def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True):
def stop_trial(self, trial, error=False, error_msg=None):
"""Stops the trial.
Stops this trial, releasing all allocating resources.
@@ -95,7 +95,6 @@ class TrialExecutor:
Args:
error (bool): Whether to mark this trial as terminated in error.
error_msg (str): Optional error message.
stop_logger (bool): Whether to shut down the trial logger.
"""
raise NotImplementedError("Subclasses of TrialExecutor must provide "
"stop_trial() method")
@@ -113,7 +112,7 @@ class TrialExecutor:
assert trial.status == Trial.RUNNING, trial.status
try:
self.save(trial, Checkpoint.MEMORY)
self.stop_trial(trial, stop_logger=False)
self.stop_trial(trial)
self.set_status(trial, Trial.PAUSED)
except Exception:
logger.exception("Error pausing runner.")
+1 -5
View File
@@ -793,11 +793,7 @@ class TrialRunner:
# Restore was unsuccessful, try again without checkpoint.
trial.clear_checkpoint()
self.trial_executor.stop_trial(
trial,
error=error_msg is not None,
error_msg=error_msg,
stop_logger=False)
trial.result_logger.flush()
trial, error=error_msg is not None, error_msg=error_msg)
if self.trial_executor.has_resources(trial.resources):
logger.info(
"Trial %s: Attempting to restore "
+3 -3
View File
@@ -311,7 +311,6 @@ def run(
sync_to_driver=sync_config.sync_to_driver,
trial_name_creator=trial_name_creator,
trial_dirname_creator=trial_dirname_creator,
loggers=loggers,
log_to_file=log_to_file,
checkpoint_freq=checkpoint_freq,
checkpoint_at_end=checkpoint_at_end,
@@ -355,8 +354,9 @@ def run(
"own `metric` and `mode` parameters. Either remove the arguments "
"from your scheduler or from your call to `tune.run()`")
# Create syncer callbacks
callbacks = create_default_callbacks(callbacks, sync_config)
# Create logger and syncer callbacks
callbacks = create_default_callbacks(
callbacks, sync_config, loggers=loggers)
runner = TrialRunner(
search_alg=search_alg,
+80 -3
View File
@@ -3,16 +3,73 @@ from typing import List, Optional
from ray.tune.callback import Callback
from ray.tune.syncer import SyncConfig
from ray.tune.logger import CSVLogger, DEFAULT_LOGGERS, ExperimentLogger, \
JsonLogger, LegacyExperimentLogger, Logger
from ray.tune.syncer import SyncerCallback
def create_default_callbacks(callbacks: Optional[List[Callback]],
sync_config: SyncConfig):
sync_config: SyncConfig,
loggers: Optional[List[Logger]]):
callbacks = callbacks or []
has_syncer_callback = False
has_csv_logger = False
has_json_logger = False
# Check if there is a SyncerCallback
has_syncer_callback = any(isinstance(c, SyncerCallback) for c in callbacks)
# Track syncer obj/index to move callback after loggers
last_logger_index = None
syncer_index = None
if not loggers:
# If no logger callback and no `loggers` have been provided,
# add DEFAULT_LOGGERS.
if not any(
isinstance(callback, ExperimentLogger)
for callback in callbacks):
loggers = DEFAULT_LOGGERS
# Create LegacyExperimentLogger for passed Logger classes
if loggers:
# Todo(krfricke): Deprecate `loggers` argument, print warning here.
add_loggers = []
for trial_logger in loggers:
if isinstance(trial_logger, ExperimentLogger):
callbacks.append(trial_logger)
elif isinstance(trial_logger, type) and issubclass(
trial_logger, Logger):
add_loggers.append(trial_logger)
else:
raise ValueError(
f"Invalid value passed to `loggers` argument of "
f"`tune.run()`: {trial_logger}")
if add_loggers:
callbacks.append(LegacyExperimentLogger(add_loggers))
# Check if we have a CSV and JSON logger
for i, callback in enumerate(callbacks):
if isinstance(callback, LegacyExperimentLogger):
last_logger_index = i
if CSVLogger in callback.logger_classes:
has_csv_logger = True
if JsonLogger in callback.logger_classes:
has_json_logger = True
# Todo(krfricke): add checks for new ExperimentLogger classes
elif isinstance(callback, SyncerCallback):
syncer_index = i
has_syncer_callback = True
# If CSV or JSON logger is missing, add
if os.environ.get("TUNE_DISABLE_AUTO_CALLBACK_LOGGERS", "0") != "1":
# Todo(krfricke): Switch to new ExperimentLogger classes
add_loggers = []
if not has_csv_logger:
add_loggers.append(CSVLogger)
if not has_json_logger:
add_loggers.append(JsonLogger)
if add_loggers:
callbacks.append(LegacyExperimentLogger(add_loggers))
last_logger_index = len(callbacks) - 1
# If no SyncerCallback was found, add
if not has_syncer_callback and os.environ.get(
@@ -20,5 +77,25 @@ def create_default_callbacks(callbacks: Optional[List[Callback]],
syncer_callback = SyncerCallback(
sync_function=sync_config.sync_to_driver)
callbacks.append(syncer_callback)
syncer_index = len(callbacks) - 1
# Todo(krfricke): Maybe check if syncer comes after all loggers
if syncer_index is not None and last_logger_index is not None and \
syncer_index < last_logger_index:
if (not has_csv_logger or not has_json_logger) and not loggers:
# Only raise the warning if the loggers were passed by the user.
# (I.e. don't warn if this was automatic behavior and they only
# passed a customer SyncerCallback).
raise ValueError(
"The `SyncerCallback` you passed to `tune.run()` came before "
"at least one `ExperimentLogger`. Syncing should be done "
"after writing logs. Please re-order the callbacks so that "
"the `SyncerCallback` comes after any `ExperimentLogger`.")
else:
# If these loggers were automatically created. just re-order
# the callbacks
syncer_obj = callbacks[syncer_index]
callbacks.pop(syncer_index)
callbacks.insert(last_logger_index, syncer_obj)
return callbacks
+25
View File
@@ -6,6 +6,7 @@ import os
import inspect
import threading
import time
import uuid
from collections import defaultdict, deque, Mapping, Sequence
from datetime import datetime
from threading import Thread
@@ -540,6 +541,30 @@ def detect_config_single(func):
return use_config_single
def create_logdir(dirname: str, local_dir: str):
"""Create an empty logdir with name `dirname` in `local_dir`.
If `local_dir`/`dirname` already exists, a unique string is appended
to the dirname.
Args:
dirname (str): Dirname to create in `local_dir`
local_dir (str): Root directory for the log dir
Returns: full path to the newly created logdir.
"""
local_dir = os.path.expanduser(local_dir)
logdir = os.path.join(local_dir, dirname)
if os.path.exists(logdir):
old_dirname = dirname
dirname += "_" + uuid.uuid4().hex[:4]
logger.info(f"Creating a new dirname {dirname} because "
f"trial dirname '{old_dirname}' already exists.")
logdir = os.path.join(local_dir, dirname)
os.makedirs(logdir, exist_ok=True)
return logdir
class SafeFallbackEncoder(json.JSONEncoder):
def __init__(self, nan_str="null", **kwargs):
super(SafeFallbackEncoder, self).__init__(**kwargs)