[tune] logger migration to ExperimentLogger classes (#11984)

This commit is contained in:
Kai Fricke
2020-11-17 00:08:37 +01:00
committed by GitHub
parent 3dc68533a9
commit 9f5986ee58
14 changed files with 897 additions and 223 deletions
+6 -5
View File
@@ -4,11 +4,12 @@ import argparse
import time
from ray import tune
from ray.tune.logger import LoggerCallback
class TestLogger(tune.logger.Logger):
def on_result(self, result):
print("TestLogger", result)
class TestLoggerCallback(LoggerCallback):
def on_trial_result(self, iteration, trials, trial, result, **info):
print(f"TestLogger for trial {trial}: {result}")
def trial_str_creator(trial):
@@ -44,8 +45,8 @@ if __name__ == "__main__":
mode="min",
num_samples=5,
trial_name_creator=trial_str_creator,
loggers=[TestLogger],
stop={"training_iteration": 1 if args.smoke_test else 99999},
callbacks=[TestLoggerCallback()],
stop={"training_iteration": 1 if args.smoke_test else 100},
config={
"steps": 100,
"width": tune.randint(10, 100),
+9 -10
View File
@@ -7,9 +7,9 @@ import wandb
from ray import tune
from ray.tune import Trainable
from ray.tune.integration.wandb import WandbLogger, WandbTrainableMixin, \
from ray.tune.integration.wandb import WandbLoggerCallback, \
WandbTrainableMixin, \
wandb_mixin
from ray.tune.logger import DEFAULT_LOGGERS
def train_function(config, checkpoint_dir=None):
@@ -19,20 +19,19 @@ def train_function(config, checkpoint_dir=None):
def tune_function(api_key_file):
"""Example for using a WandbLogger with the function API"""
"""Example for using a WandbLoggerCallback with the function API"""
analysis = tune.run(
train_function,
metric="loss",
mode="min",
config={
"mean": tune.grid_search([1, 2, 3, 4, 5]),
"sd": tune.uniform(0.2, 0.8),
"wandb": {
"api_key_file": api_key_file,
"project": "Wandb_example"
}
"sd": tune.uniform(0.2, 0.8)
},
loggers=DEFAULT_LOGGERS + (WandbLogger, ))
callbacks=[
WandbLoggerCallback(
api_key_file=api_key_file, project="Wandb_example")
])
return analysis.best_config
@@ -95,7 +94,7 @@ if __name__ == "__main__":
api_key_file = "~/.wandb_api_key"
if args.mock_api:
WandbLogger._logger_process_cls = MagicMock
WandbLoggerCallback._logger_process_cls = MagicMock
decorated_train_function.__mixins__ = tuple()
WandbTrainable._wandb = MagicMock()
wandb = MagicMock() # noqa: F811
+2 -2
View File
@@ -129,8 +129,8 @@ class Experiment:
# if they instantiate their `Experiment` objects themselves.
raise ValueError(
"Passing `loggers` to an `Experiment` is deprecated. Use "
"an `ExperimentLogger` callback instead, e.g. by passing the "
"`Logger` classes to `tune.logger.LegacyExperimentLogger` and "
"an `LoggerCallback` callback instead, e.g. by passing the "
"`Logger` classes to `tune.logger.LegacyLoggerCallback` and "
"passing this as part of the `callback` parameter to "
"`tune.run()`.")
+177 -73
View File
@@ -2,14 +2,15 @@ import os
import pickle
from multiprocessing import Process, Queue
from numbers import Number
from typing import Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
from ray import logger
from ray.tune import Trainable
from ray.tune.function_runner import FunctionRunner
from ray.tune.logger import Logger
from ray.tune.logger import LoggerCallback, Logger
from ray.tune.utils import flatten_dict
from ray.tune.trial import Trial
import yaml
@@ -30,7 +31,7 @@ def _is_allowed_type(obj):
return isinstance(obj, Number)
def _clean_log(obj):
def _clean_log(obj: Any):
# Fixes https://github.com/ray-project/ray/issues/10631
if isinstance(obj, dict):
return {k: _clean_log(v) for k, v in obj.items()}
@@ -142,12 +143,10 @@ def wandb_mixin(func: Callable):
return func
def _set_api_key(wandb_config: Dict):
def _set_api_key(api_key_file: Optional[str] = None,
api_key: Optional[str] = None):
"""Set WandB API key from `wandb_config`. Will pop the
`api_key_file` and `api_key` keys from `wandb_config` parameter"""
api_key_file = os.path.expanduser(wandb_config.pop("api_key_file", ""))
api_key = wandb_config.pop("api_key", None)
if api_key_file:
if api_key:
raise ValueError("Both WandB `api_key_file` and `api_key` set.")
@@ -166,7 +165,8 @@ def _set_api_key(wandb_config: Dict):
pass
raise ValueError(
"No WandB API key found. Either set the {} environment "
"variable, pass `api_key` or `api_key_file` in the config, "
"variable, pass `api_key` or `api_key_file` to the"
"`WandbLoggerCallback` class as arguments, "
"or run `wandb login` from the command line".format(WANDB_ENV_VAR))
@@ -219,9 +219,166 @@ class _WandbLoggingProcess(Process):
return log, config_update
class WandbLoggerCallback(LoggerCallback):
"""WandbLoggerCallback
Weights and biases (https://www.wandb.com/) is a tool for experiment
tracking, model optimization, and dataset versioning. This Ray Tune
``LoggerCallback`` sends metrics to Wandb for automatic tracking and
visualization.
Args:
project (str): Name of the Wandb project. Mandatory.
group (str): Name of the Wandb group. Defaults to the trainable
name.
api_key_file (str): Path to file containing the Wandb API KEY. This
file only needs to be present on the node running the Tune script
if using the WandbLogger.
api_key (str): Wandb API Key. Alternative to setting ``api_key_file``.
excludes (list): List of metrics that should be excluded from
the log.
log_config (bool): Boolean indicating if the ``config`` parameter of
the ``results`` dict should be logged. This makes sense if
parameters will change during training, e.g. with
PopulationBasedTraining. Defaults to False.
**kwargs: The keyword arguments will be pased to ``wandb.init()``.
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
by Tune, but can be overwritten by filling out the respective configuration
values.
Please see here for all other valid configuration settings:
https://docs.wandb.com/library/init
Example:
.. code-block:: python
from ray.tune.logger import DEFAULT_LOGGERS
from ray.tune.integration.wandb import WandbLoggerCallback
tune.run(
train_fn,
config={
# define search space here
"parameter_1": tune.choice([1, 2, 3]),
"parameter_2": tune.choice([4, 5, 6]),
},
callbacks=[WandbLoggerCallback(
project="Optimization_Project",
api_key_file="/path/to/file",
log_config=True)])
"""
# Do not log these result keys
_exclude_results = ["done", "should_checkpoint"]
# Use these result keys to update `wandb.config`
_config_results = [
"trial_id", "experiment_tag", "node_ip", "experiment_id", "hostname",
"pid", "date"
]
_logger_process_cls = _WandbLoggingProcess
def __init__(self,
project: str,
group: Optional[str] = None,
api_key_file: Optional[str] = None,
api_key: Optional[str] = None,
excludes: Optional[List[str]] = None,
log_config: bool = False,
**kwargs):
self.project = project
self.group = group
self.api_key_file = os.path.expanduser(
api_key_file) if api_key_file else None
self.api_key = api_key
self.excludes = excludes or []
self.log_config = log_config
self.kwargs = kwargs
self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {}
self._trial_queues: Dict["Trial", Queue] = {}
_set_api_key(self.api_key_file, self.api_key)
def log_trial_start(self, trial: "Trial"):
config = trial.config.copy()
config.pop("callbacks", None) # Remove callbacks
exclude_results = self._exclude_results.copy()
# Additional excludes
exclude_results += self.excludes
# Log config keys on each result?
if not self.log_config:
exclude_results += ["config"]
# Fill trial ID and name
trial_id = trial.trial_id if trial else None
trial_name = str(trial) if trial else None
# Project name for Wandb
wandb_project = self.project
# Grouping
wandb_group = self.group or trial.trainable_name if trial else None
# remove unpickleable items!
config = _clean_log(config)
wandb_init_kwargs = dict(
id=trial_id,
name=trial_name,
resume=True,
reinit=True,
allow_val_change=True,
group=wandb_group,
project=wandb_project,
config=config)
wandb_init_kwargs.update(self.kwargs)
self._trial_queues[trial] = Queue()
self._trial_processes[trial] = self._logger_process_cls(
queue=self._trial_queues[trial],
exclude=exclude_results,
to_config=self._config_results,
**wandb_init_kwargs)
self._trial_processes[trial].start()
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
if trial not in self._trial_processes:
self.log_trial_start(trial)
result = _clean_log(result)
self._trial_queues[trial].put(result)
def log_trial_end(self, trial: "Trial", failed: bool = False):
self._trial_queues[trial].put(_WANDB_QUEUE_END)
self._trial_processes[trial].join(timeout=10)
del self._trial_queues[trial]
del self._trial_processes[trial]
def __del__(self):
for trial in self._trial_processes:
if trial in self._trial_queues:
self._trial_queues[trial].put(_WANDB_QUEUE_END)
del self._trial_queues[trial]
self._trial_processes[trial].join(timeout=2)
del self._trial_processes[trial]
class WandbLogger(Logger):
"""WandbLogger
.. warning::
This `Logger` class is deprecated. Use the `WandbLoggerCallback`
callback instead.
Weights and biases (https://www.wandb.com/) is a tool for experiment
tracking, model optimization, and dataset versioning. This Ray Tune
``Logger`` sends metrics to Wandb for automatic tracking and
@@ -300,21 +457,10 @@ class WandbLogger(Logger):
"""
# Do not log these result keys
_exclude_results = ["done", "should_checkpoint"]
# Use these result keys to update `wandb.config`
_config_results = [
"trial_id", "experiment_tag", "node_ip", "experiment_id", "hostname",
"pid", "date"
]
_logger_process_cls = _WandbLoggingProcess
_experiment_logger_cls = WandbLoggerCallback
def _init(self):
config = self.config.copy()
config.pop("callbacks", None) # Remove callbacks
try:
@@ -329,63 +475,17 @@ class WandbLogger(Logger):
"Make sure to include a `wandb` key in your `config` dict "
"containing at least a `project` specification.")
_set_api_key(wandb_config)
self._trial_experiment_logger = self._experiment_logger_cls(
**wandb_config)
exclude_results = self._exclude_results.copy()
# Additional excludes
additional_excludes = wandb_config.pop("excludes", [])
exclude_results += additional_excludes
# Log config keys on each result?
log_config = wandb_config.pop("log_config", False)
if not log_config:
exclude_results += ["config"]
# Fill trial ID and name
trial_id = self.trial.trial_id if self.trial else None
trial_name = str(self.trial) if self.trial else None
# Project name for Wandb
try:
wandb_project = wandb_config.pop("project")
except KeyError:
raise ValueError(
"You need to specify a `project` in your wandb `config` dict.")
# Grouping
wandb_group = wandb_config.pop(
"group", self.trial.trainable_name if self.trial else None)
# remove unpickleable items!
config = _clean_log(config)
wandb_init_kwargs = dict(
id=trial_id,
name=trial_name,
resume=True,
reinit=True,
allow_val_change=True,
group=wandb_group,
project=wandb_project,
config=config)
wandb_init_kwargs.update(wandb_config)
self._queue = Queue()
self._wandb = self._logger_process_cls(
queue=self._queue,
exclude=exclude_results,
to_config=self._config_results,
**wandb_init_kwargs)
self._wandb.start()
self._trial_experiment_logger.log_trial_start(self.trial)
def on_result(self, result: Dict):
result = _clean_log(result)
self._queue.put(result)
self._trial_experiment_logger.log_trial_result(0, self.trial, result)
def close(self):
self._queue.put(_WANDB_QUEUE_END)
self._wandb.join(timeout=10)
self._trial_experiment_logger.log_trial_end(self.trial, failed=False)
del self._trial_experiment_logger
class WandbTrainableMixin:
@@ -411,7 +511,11 @@ class WandbTrainableMixin:
"Make sure to include a `wandb` key in your `config` dict "
"containing at least a `project` specification.")
_set_api_key(wandb_config)
api_key_file = wandb_config.pop("api_key_file", None)
if api_key_file:
api_key_file = os.path.expanduser(api_key_file)
_set_api_key(api_key_file, wandb_config.pop("api_key", None))
# Fill trial ID and name
trial_id = self.trial_id
+285 -8
View File
@@ -5,7 +5,7 @@ import numpy as np
import os
import yaml
from typing import Iterable, TYPE_CHECKING, Dict, List, Optional, Type
from typing import Iterable, TYPE_CHECKING, Dict, List, Optional, TextIO, Type
import ray.cloudpickle as cloudpickle
@@ -364,21 +364,60 @@ class UnifiedLogger(Logger):
_logger.flush()
class ExperimentLogger(Callback):
class LoggerCallback(Callback):
"""Base class for experiment-level logger callbacks
This base class defines a general interface for logging events,
like trial starts, restores, ends, checkpoint saves, and receiving
trial results.
Callbacks implementing this interface should make sure that logging
utilities are cleaned up properly on trial termination, i.e. when
``log_trial_end`` is received. This includes e.g. closing files.
"""
def log_trial_start(self, trial: "Trial"):
raise NotImplementedError
"""Handle logging when a trial starts.
Args:
trial (Trial): Trial object.
"""
pass
def log_trial_restore(self, trial: "Trial"):
raise NotImplementedError
"""Handle logging when a trial restores.
Args:
trial (Trial): Trial object.
"""
pass
def log_trial_save(self, trial: "Trial"):
raise NotImplementedError
"""Handle logging when a trial saves a checkpoint.
Args:
trial (Trial): Trial object.
"""
pass
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
raise NotImplementedError
"""Handle logging when a trial reports a result.
Args:
trial (Trial): Trial object.
result (dict): Result dictionary.
"""
pass
def log_trial_end(self, trial: "Trial", failed: bool = False):
raise NotImplementedError
"""Handle logging when a trial ends.
Args:
trial (Trial): Trial object.
failed (bool): True if the Trial finished gracefully, False if
it failed (e.g. when it raised an exception).
"""
pass
def on_trial_result(self, iteration: int, trials: List["Trial"],
trial: "Trial", result: Dict, **info):
@@ -405,7 +444,7 @@ class ExperimentLogger(Callback):
self.log_trial_end(trial, failed=True)
class LegacyExperimentLogger(ExperimentLogger):
class LegacyLoggerCallback(LoggerCallback):
"""Supports logging to trial-specific `Logger` classes.
Previously, Ray Tune logging was handled via `Logger` classes that have
@@ -455,6 +494,244 @@ class LegacyExperimentLogger(ExperimentLogger):
trial_loggers[trial].close()
class JsonLoggerCallback(LoggerCallback):
"""Logs trial results in json format.
Also writes to a results file and param.json file when results or
configurations are updated. Experiments must be executed with the
JsonLoggerCallback to be compatible with the ExperimentAnalysis tool.
"""
def __init__(self):
self._trial_configs: Dict["Trial", Dict] = {}
self._trial_files: Dict["Trial", TextIO] = {}
def log_trial_start(self, trial: "Trial"):
if trial in self._trial_files:
self._trial_files[trial].close()
# Update config
self.update_config(trial, trial.config)
# Make sure logdir exists
trial.init_logdir()
local_file = os.path.join(trial.logdir, EXPR_RESULT_FILE)
self._trial_files[trial] = open(local_file, "at")
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
if trial not in self._trial_files:
self.log_trial_start(trial)
json.dump(result, self._trial_files[trial], cls=SafeFallbackEncoder)
self._trial_files[trial].write("\n")
self._trial_files[trial].flush()
def log_trial_end(self, trial: "Trial", failed: bool = False):
if trial not in self._trial_files:
return
self._trial_files[trial].close()
del self._trial_files[trial]
def update_config(self, trial: "Trial", config: Dict):
self._trial_configs[trial] = config
config_out = os.path.join(trial.logdir, EXPR_PARAM_FILE)
with open(config_out, "w") as f:
json.dump(
self._trial_configs[trial],
f,
indent=2,
sort_keys=True,
cls=SafeFallbackEncoder)
config_pkl = os.path.join(trial.logdir, EXPR_PARAM_PICKLE_FILE)
with open(config_pkl, "wb") as f:
cloudpickle.dump(self._trial_configs[trial], f)
class CSVLoggerCallback(LoggerCallback):
"""Logs results to progress.csv under the trial directory.
Automatically flattens nested dicts in the result dict before writing
to csv:
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
"""
def __init__(self):
self._trial_continue: Dict["Trial", bool] = {}
self._trial_files: Dict["Trial", TextIO] = {}
self._trial_csv: Dict["Trial", csv.DictWriter] = {}
def log_trial_start(self, trial: "Trial"):
if trial in self._trial_files:
self._trial_files[trial].close()
# Make sure logdir exists
trial.init_logdir()
local_file = os.path.join(trial.logdir, EXPR_PROGRESS_FILE)
self._trial_continue[trial] = os.path.exists(local_file)
self._trial_files[trial] = open(local_file, "at")
self._trial_csv[trial] = None
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
if trial not in self._trial_files:
self.log_trial_start(trial)
tmp = result.copy()
tmp.pop("config", None)
result = flatten_dict(tmp, delimiter="/")
if not self._trial_csv[trial]:
self._trial_csv[trial] = csv.DictWriter(self._trial_files[trial],
result.keys())
if not self._trial_continue[trial]:
self._trial_csv[trial].writeheader()
self._trial_csv[trial].writerow({
k: v
for k, v in result.items()
if k in self._trial_csv[trial].fieldnames
})
self._trial_files[trial].flush()
def log_trial_end(self, trial: "Trial", failed: bool = False):
if trial not in self._trial_files:
return
del self._trial_csv[trial]
self._trial_files[trial].close()
del self._trial_files[trial]
class TBXLoggerCallback(LoggerCallback):
"""TensorBoardX Logger.
Note that hparams will be written only after a trial has terminated.
This logger automatically flattens nested dicts to show on TensorBoard:
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
"""
# NoneType is not supported on the last TBX release yet.
VALID_HPARAMS = (str, bool, np.bool8, int, np.integer, float, list)
def __init__(self):
try:
from tensorboardX import SummaryWriter
self._summary_writer_cls = SummaryWriter
except ImportError:
if log_once("tbx-install"):
logger.info(
"pip install 'ray[tune]' to see TensorBoard files.")
raise
self._trial_writer: Dict["Trial", SummaryWriter] = {}
self._trial_result: Dict["Trial", Dict] = {}
def log_trial_start(self, trial: "Trial"):
trial.init_logdir()
self._trial_writer[trial] = self._summary_writer_cls(
trial.logdir, flush_secs=30)
self._trial_result[trial] = {}
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
if trial not in self._trial_writer:
self.log_trial_start(trial)
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
tmp = result.copy()
for k in [
"config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION
]:
if k in tmp:
del tmp[k] # not useful to log these
flat_result = flatten_dict(tmp, delimiter="/")
path = ["ray", "tune"]
valid_result = {}
for attr, value in flat_result.items():
full_attr = "/".join(path + [attr])
if (isinstance(value, tuple(VALID_SUMMARY_TYPES))
and not np.isnan(value)):
valid_result[full_attr] = value
self._trial_writer[trial].add_scalar(
full_attr, value, global_step=step)
elif ((isinstance(value, list) and len(value) > 0)
or (isinstance(value, np.ndarray) and value.size > 0)):
valid_result[full_attr] = value
# Must be video
if isinstance(value, np.ndarray) and value.ndim == 5:
self._trial_writer[trial].add_video(
full_attr, value, global_step=step, fps=20)
continue
try:
self._trial_writer[trial].add_histogram(
full_attr, value, global_step=step)
# In case TensorboardX still doesn't think it's a valid value
# (e.g. `[[]]`), warn and move on.
except (ValueError, TypeError):
if log_once("invalid_tbx_value"):
logger.warning(
"You are trying to log an invalid value ({}={}) "
"via {}!".format(full_attr, value,
type(self).__name__))
self._trial_result[trial] = valid_result
self._trial_writer[trial].flush()
def log_trial_end(self, trial: "Trial", failed: bool = False):
if trial in self._trial_writer:
if trial and trial.evaluated_params and self._trial_result[trial]:
flat_result = flatten_dict(
self._trial_result[trial], delimiter="/")
scrubbed_result = {
k: value
for k, value in flat_result.items()
if isinstance(value, tuple(VALID_SUMMARY_TYPES))
}
self._try_log_hparams(trial, scrubbed_result)
self._trial_writer[trial].close()
del self._trial_writer[trial]
del self._trial_result[trial]
def _try_log_hparams(self, trial: "Trial", result: Dict):
# TBX currently errors if the hparams value is None.
flat_params = flatten_dict(trial.evaluated_params)
scrubbed_params = {
k: v
for k, v in flat_params.items()
if isinstance(v, self.VALID_HPARAMS)
}
removed = {
k: v
for k, v in flat_params.items()
if not isinstance(v, self.VALID_HPARAMS)
}
if removed:
logger.info(
"Removed the following hyperparameter values when "
"logging to tensorboard: %s", str(removed))
from tensorboardX.summary import hparams
try:
experiment_tag, session_start_tag, session_end_tag = hparams(
hparam_dict=scrubbed_params, metric_dict=result)
self._trial_writer[trial].file_writer.add_summary(experiment_tag)
self._trial_writer[trial].file_writer.add_summary(
session_start_tag)
self._trial_writer[trial].file_writer.add_summary(session_end_tag)
except Exception:
logger.exception("TensorboardX failed to log hparams. "
"This may be due to an unsupported type "
"in the hyperparameter values.")
def pretty_print(result):
result = result.copy()
result.update(config=None) # drop config from pretty print
+138 -17
View File
@@ -8,15 +8,22 @@ import numpy as np
from ray.tune import Trainable
from ray.tune.function_runner import wrap_function
from ray.tune.integration.wandb import _WandbLoggingProcess, \
from ray.tune.integration.wandb import WandbLoggerCallback, \
_WandbLoggingProcess, \
_WANDB_QUEUE_END, WandbLogger, WANDB_ENV_VAR, WandbTrainableMixin, \
wandb_mixin
from ray.tune.result import TRIAL_INFO
from ray.tune.trial import TrialInfo
Trial = namedtuple("MockTrial",
["config", "trial_id", "trial_name", "trainable_name"])
Trial.__str__ = lambda t: t.trial_name
class Trial(
namedtuple("MockTrial",
["config", "trial_id", "trial_name", "trainable_name"])):
def __hash__(self):
return hash(self.trial_id)
def __str__(self):
return self.trial_name
class _MockWandbLoggingProcess(_WandbLoggingProcess):
@@ -37,9 +44,21 @@ class _MockWandbLoggingProcess(_WandbLoggingProcess):
self.logs.put(log)
class WandbTestLogger(WandbLogger):
class WandbTestExperimentLogger(WandbLoggerCallback):
_logger_process_cls = _MockWandbLoggingProcess
@property
def trial_processes(self):
return self._trial_processes
class WandbTestLogger(WandbLogger):
_experiment_logger_cls = WandbTestExperimentLogger
@property
def trial_process(self):
return self._trial_experiment_logger.trial_processes[self.trial]
class _MockWandbAPI(object):
def init(self, *args, **kwargs):
@@ -63,7 +82,7 @@ class WandbIntegrationTest(unittest.TestCase):
def tearDown(self):
pass
def testWandbLoggerConfig(self):
def testWandbLegacyLoggerConfig(self):
trial_config = {"par1": 4, "par2": 9.12345678}
trial = Trial(trial_config, 0, "trial_0", "trainable")
@@ -115,11 +134,13 @@ class WandbIntegrationTest(unittest.TestCase):
trial_config["wandb"] = {"project": "test_project"}
logger = WandbTestLogger(trial_config, "/tmp", trial)
self.assertEqual(logger._wandb.kwargs["project"], "test_project")
self.assertEqual(logger._wandb.kwargs["id"], trial.trial_id)
self.assertEqual(logger._wandb.kwargs["name"], trial.trial_name)
self.assertEqual(logger._wandb.kwargs["group"], trial.trainable_name)
self.assertIn("config", logger._wandb._exclude)
self.assertEqual(logger.trial_process.kwargs["project"],
"test_project")
self.assertEqual(logger.trial_process.kwargs["id"], trial.trial_id)
self.assertEqual(logger.trial_process.kwargs["name"], trial.trial_name)
self.assertEqual(logger.trial_process.kwargs["group"],
trial.trainable_name)
self.assertIn("config", logger.trial_process._exclude)
logger.close()
@@ -127,8 +148,8 @@ class WandbIntegrationTest(unittest.TestCase):
trial_config["wandb"] = {"project": "test_project", "log_config": True}
logger = WandbTestLogger(trial_config, "/tmp", trial)
self.assertNotIn("config", logger._wandb._exclude)
self.assertNotIn("metric", logger._wandb._exclude)
self.assertNotIn("config", logger.trial_process._exclude)
self.assertNotIn("metric", logger.trial_process._exclude)
logger.close()
@@ -139,12 +160,12 @@ class WandbIntegrationTest(unittest.TestCase):
}
logger = WandbTestLogger(trial_config, "/tmp", trial)
self.assertIn("config", logger._wandb._exclude)
self.assertIn("metric", logger._wandb._exclude)
self.assertIn("config", logger.trial_process._exclude)
self.assertIn("metric", logger.trial_process._exclude)
logger.close()
def testWandbLoggerReporting(self):
def testWandbLegacyLoggerReporting(self):
trial_config = {"par1": 4, "par2": 9.12345678}
trial = Trial(trial_config, 0, "trial_0", "trainable")
@@ -166,7 +187,7 @@ class WandbIntegrationTest(unittest.TestCase):
logger.on_result(r1)
logged = logger._wandb.logs.get(timeout=10)
logged = logger.trial_process.logs.get(timeout=10)
self.assertIn("metric1", logged)
self.assertNotIn("metric2", logged)
self.assertIn("metric3", logged)
@@ -176,6 +197,106 @@ class WandbIntegrationTest(unittest.TestCase):
logger.close()
def testWandbLoggerConfig(self):
trial_config = {"par1": 4, "par2": 9.12345678}
trial = Trial(trial_config, 0, "trial_0", "trainable")
if WANDB_ENV_VAR in os.environ:
del os.environ[WANDB_ENV_VAR]
# No API key
with self.assertRaises(ValueError):
logger = WandbTestExperimentLogger(project="test_project")
# API Key in config
logger = WandbTestExperimentLogger(
project="test_project", api_key="1234")
self.assertEqual(os.environ[WANDB_ENV_VAR], "1234")
del logger
del os.environ[WANDB_ENV_VAR]
# API Key file
with tempfile.NamedTemporaryFile("wt") as fp:
fp.write("5678")
fp.flush()
logger = WandbTestExperimentLogger(
project="test_project", api_key_file=fp.name)
self.assertEqual(os.environ[WANDB_ENV_VAR], "5678")
del logger
del os.environ[WANDB_ENV_VAR]
# API Key in env
os.environ[WANDB_ENV_VAR] = "9012"
logger = WandbTestExperimentLogger(project="test_project")
del logger
# From now on, the API key is in the env variable.
logger = WandbTestExperimentLogger(project="test_project")
logger.log_trial_start(trial)
self.assertEqual(logger.trial_processes[trial].kwargs["project"],
"test_project")
self.assertEqual(logger.trial_processes[trial].kwargs["id"],
trial.trial_id)
self.assertEqual(logger.trial_processes[trial].kwargs["name"],
trial.trial_name)
self.assertEqual(logger.trial_processes[trial].kwargs["group"],
trial.trainable_name)
self.assertIn("config", logger.trial_processes[trial]._exclude)
del logger
# log config.
logger = WandbTestExperimentLogger(
project="test_project", log_config=True)
logger.log_trial_start(trial)
self.assertNotIn("config", logger.trial_processes[trial]._exclude)
self.assertNotIn("metric", logger.trial_processes[trial]._exclude)
del logger
# Exclude metric.
logger = WandbTestExperimentLogger(
project="test_project", excludes=["metric"])
logger.log_trial_start(trial)
self.assertIn("config", logger.trial_processes[trial]._exclude)
self.assertIn("metric", logger.trial_processes[trial]._exclude)
del logger
def testWandbLoggerReporting(self):
trial_config = {"par1": 4, "par2": 9.12345678}
trial = Trial(trial_config, 0, "trial_0", "trainable")
logger = WandbTestExperimentLogger(
project="test_project", api_key="1234", excludes=["metric2"])
logger.on_trial_start(0, [], trial)
r1 = {
"metric1": 0.8,
"metric2": 1.4,
"metric3": np.asarray(32.0),
"metric4": np.float32(32.0),
"const": "text",
"config": trial_config
}
logger.on_trial_result(0, [], trial, r1)
logged = logger.trial_processes[trial].logs.get(timeout=10)
self.assertIn("metric1", logged)
self.assertNotIn("metric2", logged)
self.assertIn("metric3", logged)
self.assertIn("metric4", logged)
self.assertNotIn("const", logged)
self.assertNotIn("config", logged)
del logger
def testWandbMixinConfig(self):
config = {"par1": 4, "par2": 9.12345678}
trial = Trial(config, 0, "trial_0", "trainable")
+182 -19
View File
@@ -1,12 +1,33 @@
import csv
import glob
import json
import os
from collections import namedtuple
import unittest
import tempfile
import shutil
import numpy as np
from ray.cloudpickle import cloudpickle
from ray.tune.logger import JsonLogger, CSVLogger, TBXLogger
from ray.tune.logger import CSVLoggerCallback, JsonLoggerCallback, \
JsonLogger, CSVLogger, \
TBXLoggerCallback, TBXLogger
from ray.tune.result import EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE, \
EXPR_PROGRESS_FILE, \
EXPR_RESULT_FILE
Trial = namedtuple("MockTrial", ["evaluated_params", "trial_id"])
class Trial(
namedtuple("MockTrial", ["evaluated_params", "trial_id", "logdir"])):
@property
def config(self):
return self.evaluated_params
def init_logdir(self):
return
def __hash__(self):
return hash(self.trial_id)
def result(t, rew, **kwargs):
@@ -28,24 +49,116 @@ class LoggerSuite(unittest.TestCase):
def tearDown(self):
shutil.rmtree(self.test_dir, ignore_errors=True)
def testCSV(self):
def testLegacyCSV(self):
config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}}
t = Trial(evaluated_params=config, trial_id="csv")
t = Trial(
evaluated_params=config, trial_id="csv", logdir=self.test_dir)
logger = CSVLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(2, 4))
logger.on_result(result(2, 4))
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
logger.on_result(result(2, 5))
logger.on_result(result(2, 6, score=[1, 2, 3], hello={"world": 1}))
logger.close()
self._validate_csv_result()
def testCSV(self):
config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}}
t = Trial(
evaluated_params=config, trial_id="csv", logdir=self.test_dir)
logger = CSVLoggerCallback()
logger.on_trial_result(0, [], t, result(0, 4))
logger.on_trial_result(1, [], t, result(1, 5))
logger.on_trial_result(
2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1}))
logger.on_trial_complete(3, [], t)
self._validate_csv_result()
def _validate_csv_result(self):
results = []
result_file = os.path.join(self.test_dir, EXPR_PROGRESS_FILE)
with open(result_file, "rt") as fp:
reader = csv.DictReader(fp)
for row in reader:
results.append(row)
self.assertEqual(len(results), 3)
self.assertSequenceEqual(
[int(row["episode_reward_mean"]) for row in results], [4, 5, 6])
def testJSONLegacyLogger(self):
config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}}
t = Trial(
evaluated_params=config, trial_id="json", logdir=self.test_dir)
logger = JsonLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(0, 4))
logger.on_result(result(1, 5))
logger.on_result(result(2, 6, score=[1, 2, 3], hello={"world": 1}))
logger.close()
self._validate_json_result(config)
def testJSON(self):
config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}}
t = Trial(evaluated_params=config, trial_id="json")
logger = JsonLogger(config=config, logdir=self.test_dir, trial=t)
t = Trial(
evaluated_params=config, trial_id="json", logdir=self.test_dir)
logger = JsonLoggerCallback()
logger.on_trial_result(0, [], t, result(0, 4))
logger.on_trial_result(1, [], t, result(1, 5))
logger.on_trial_result(
2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1}))
logger.on_trial_complete(3, [], t)
self._validate_json_result(config)
def _validate_json_result(self, config):
# Check result logs
results = []
result_file = os.path.join(self.test_dir, EXPR_RESULT_FILE)
with open(result_file, "rt") as fp:
for row in fp.readlines():
results.append(json.loads(row))
self.assertEqual(len(results), 3)
self.assertSequenceEqual(
[int(row["episode_reward_mean"]) for row in results], [4, 5, 6])
# Check json saved config file
config_file = os.path.join(self.test_dir, EXPR_PARAM_FILE)
with open(config_file, "rt") as fp:
loaded_config = json.load(fp)
self.assertEqual(loaded_config, config)
# Check pickled config file
config_file = os.path.join(self.test_dir, EXPR_PARAM_PICKLE_FILE)
with open(config_file, "rb") as fp:
loaded_config = cloudpickle.load(fp)
self.assertEqual(loaded_config, config)
def testLegacyTBX(self):
config = {
"a": 2,
"b": [1, 2],
"c": {
"c": {
"D": 123
}
},
"d": np.int64(1),
"e": np.bool8(True)
}
t = Trial(
evaluated_params=config, trial_id="tbx", logdir=self.test_dir)
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(0, 4))
logger.on_result(result(1, 4))
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
logger.on_result(result(1, 5))
logger.on_result(result(2, 6, score=[1, 2, 3], hello={"world": 1}))
logger.close()
self._validate_tbx_result()
def testTBX(self):
config = {
"a": 2,
@@ -58,16 +171,40 @@ class LoggerSuite(unittest.TestCase):
"d": np.int64(1),
"e": np.bool8(True)
}
t = Trial(evaluated_params=config, trial_id="tbx")
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(0, 4))
logger.on_result(result(1, 4))
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
logger.close()
t = Trial(
evaluated_params=config, trial_id="tbx", logdir=self.test_dir)
logger = TBXLoggerCallback()
logger.on_trial_result(0, [], t, result(0, 4))
logger.on_trial_result(1, [], t, result(1, 5))
logger.on_trial_result(
2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1}))
def testBadTBX(self):
logger.on_trial_complete(3, [], t)
self._validate_tbx_result()
def _validate_tbx_result(self):
try:
from tensorflow.python.summary.summary_iterator \
import summary_iterator
except ImportError:
print("Skipping rest of test as tensorflow is not installed.")
return
events_file = list(glob.glob(f"{self.test_dir}/events*"))[0]
results = []
for event in summary_iterator(events_file):
for v in event.summary.value:
if v.tag == "ray/tune/episode_reward_mean":
results.append(v.simple_value)
self.assertEqual(len(results), 3)
self.assertSequenceEqual([int(res) for res in results], [4, 5, 6])
def testLegacyBadTBX(self):
config = {"b": (1, 2, 3)}
t = Trial(evaluated_params=config, trial_id="tbx")
t = Trial(
evaluated_params=config, trial_id="tbx", logdir=self.test_dir)
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(0, 4))
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
@@ -76,7 +213,8 @@ class LoggerSuite(unittest.TestCase):
assert "INFO" in cm.output[0]
config = {"None": None}
t = Trial(evaluated_params=config, trial_id="tbx")
t = Trial(
evaluated_params=config, trial_id="tbx", logdir=self.test_dir)
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(0, 4))
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
@@ -84,6 +222,31 @@ class LoggerSuite(unittest.TestCase):
logger.close()
assert "INFO" in cm.output[0]
def testBadTBX(self):
config = {"b": (1, 2, 3)}
t = Trial(
evaluated_params=config, trial_id="tbx", logdir=self.test_dir)
logger = TBXLoggerCallback()
logger.on_trial_result(0, [], t, result(0, 4))
logger.on_trial_result(1, [], t, result(1, 5))
logger.on_trial_result(
2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1}))
with self.assertLogs("ray.tune.logger", level="INFO") as cm:
logger.on_trial_complete(3, [], t)
assert "INFO" in cm.output[0]
config = {"None": None}
t = Trial(
evaluated_params=config, trial_id="tbx", logdir=self.test_dir)
logger = TBXLoggerCallback()
logger.on_trial_result(0, [], t, result(0, 4))
logger.on_trial_result(1, [], t, result(1, 5))
logger.on_trial_result(
2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1}))
with self.assertLogs("ray.tune.logger", level="INFO") as cm:
logger.on_trial_complete(3, [], t)
assert "INFO" in cm.output[0]
if __name__ == "__main__":
import pytest
+6 -6
View File
@@ -7,7 +7,7 @@ from ray.rllib import _register_all
from ray.tune.result import TIMESTEPS_TOTAL
from ray.tune import Trainable, TuneError
from ray.tune import register_trainable, run_experiments
from ray.tune.logger import LegacyExperimentLogger, Logger
from ray.tune.logger import LegacyLoggerCallback, Logger
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial, ExportFormat
@@ -191,7 +191,7 @@ class RunExperimentTest(unittest.TestCase):
}
}
},
callbacks=[LegacyExperimentLogger(logger_classes=[CustomLogger])])
callbacks=[LegacyLoggerCallback(logger_classes=[CustomLogger])])
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log")))
self.assertFalse(
os.path.exists(os.path.join(trial.logdir, "params.json")))
@@ -204,7 +204,7 @@ class RunExperimentTest(unittest.TestCase):
}
}
})
self.assertTrue(
self.assertFalse(
os.path.exists(os.path.join(trial.logdir, "params.json")))
[trial] = run_experiments(
@@ -216,7 +216,7 @@ class RunExperimentTest(unittest.TestCase):
}
}
},
callbacks=[LegacyExperimentLogger(logger_classes=[])])
callbacks=[LegacyLoggerCallback(logger_classes=[])])
self.assertFalse(
os.path.exists(os.path.join(trial.logdir, "params.json")))
@@ -239,7 +239,7 @@ class RunExperimentTest(unittest.TestCase):
}
}
},
callbacks=[LegacyExperimentLogger(logger_classes=[CustomLogger])])
callbacks=[LegacyLoggerCallback(logger_classes=[CustomLogger])])
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log")))
self.assertTrue(
os.path.exists(os.path.join(trial.logdir, "params.json")))
@@ -264,7 +264,7 @@ class RunExperimentTest(unittest.TestCase):
}
}
},
callbacks=[LegacyExperimentLogger(logger_classes=[])])
callbacks=[LegacyLoggerCallback(logger_classes=[])])
self.assertTrue(
os.path.exists(os.path.join(trial.logdir, "params.json")))
@@ -9,8 +9,8 @@ import ray
from ray import tune
from ray.rllib import _register_all
from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.logger import DEFAULT_LOGGERS, ExperimentLogger, \
LegacyExperimentLogger
from ray.tune.logger import DEFAULT_LOGGERS, LoggerCallback, \
LegacyLoggerCallback
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import TRAINING_ITERATION
from ray.tune.syncer import SyncConfig, SyncerCallback
@@ -205,14 +205,14 @@ class TrialRunnerCallbacks(unittest.TestCase):
"delay")
def testCallbackReordering(self):
"""SyncerCallback should come after ExperimentLogger callbacks"""
"""SyncerCallback should come after LoggerCallback callbacks"""
def get_positions(callbacks):
first_logger_pos = None
last_logger_pos = None
syncer_pos = None
for i, callback in enumerate(callbacks):
if isinstance(callback, ExperimentLogger):
if isinstance(callback, LoggerCallback):
if first_logger_pos is None:
first_logger_pos = i
last_logger_pos = i
@@ -233,8 +233,8 @@ class TrialRunnerCallbacks(unittest.TestCase):
self.assertLess(last_logger_pos, syncer_pos)
# Auto creation of loggers with existing logger (but no CSV/JSON)
callbacks = create_default_callbacks([ExperimentLogger()],
SyncConfig(), None)
callbacks = create_default_callbacks([LoggerCallback()], SyncConfig(),
None)
first_logger_pos, last_logger_pos, syncer_pos = get_positions(
callbacks)
self.assertLess(last_logger_pos, syncer_pos)
@@ -242,13 +242,12 @@ class TrialRunnerCallbacks(unittest.TestCase):
# This should throw an error as the syncer comes before the logger
with self.assertRaises(ValueError):
callbacks = create_default_callbacks(
[SyncerCallback(None),
ExperimentLogger()], SyncConfig(), None)
[SyncerCallback(None), LoggerCallback()], SyncConfig(), None)
# This should be reordered but preserve the regular callback order
[mc1, mc2, mc3] = [Callback(), Callback(), Callback()]
# Has to be legacy logger to avoid logger callback creation
lc = LegacyExperimentLogger(logger_classes=DEFAULT_LOGGERS)
lc = LegacyLoggerCallback(logger_classes=DEFAULT_LOGGERS)
callbacks = create_default_callbacks([mc1, mc2, lc, mc3], SyncConfig(),
None)
print(callbacks)
+4 -5
View File
@@ -73,7 +73,6 @@ def run(
checkpoint_at_end=False,
verbose=Verbosity.V3_TRIAL_DETAILS,
progress_reporter=None,
loggers=None,
log_to_file=False,
trial_name_creator=None,
trial_dirname_creator=None,
@@ -90,6 +89,7 @@ def run(
raise_on_failed_trial=True,
callbacks=None,
# Deprecated args
loggers=None,
ray_auto_init=None,
run_errored_only=None,
global_checkpoint_period=None,
@@ -196,9 +196,6 @@ def run(
intermediate experiment progress. Defaults to CLIReporter if
running in command-line, or JupyterNotebookReporter if running in
a Jupyter notebook.
loggers (list): List of logger creators to be used with
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
See `ray/tune/logger.py`.
log_to_file (bool|str|Sequence): Log stdout and stderr to files in
Tune's trial directories. If this is `False` (default), no files
are written. If `true`, outputs are written to `trialdir/stdout`
@@ -251,7 +248,9 @@ def run(
trial (of ERROR state) when the experiments complete.
callbacks (list): List of callbacks that will be called at different
times in the training loop. Must be instances of the
``ray.tune.trial_runner.Callback`` class.
``ray.tune.trial_runner.Callback`` class. If not passed,
`LoggerCallback` and `SyncerCallback` callbacks are automatically
added.
Returns:
+37 -27
View File
@@ -1,13 +1,18 @@
import logging
import os
from typing import List, Optional
from ray.tune.callback import Callback
from ray.tune.progress_reporter import TrialProgressCallback
from ray.tune.syncer import SyncConfig
from ray.tune.logger import CSVLogger, DEFAULT_LOGGERS, ExperimentLogger, \
JsonLogger, LegacyExperimentLogger, Logger
from ray.tune.logger import CSVLoggerCallback, CSVLogger, \
LoggerCallback, \
JsonLoggerCallback, JsonLogger, LegacyLoggerCallback, Logger, \
TBXLoggerCallback, TBXLogger
from ray.tune.syncer import SyncerCallback
logger = logging.getLogger(__name__)
def create_default_callbacks(callbacks: Optional[List[Callback]],
sync_config: SyncConfig,
@@ -40,6 +45,7 @@ def create_default_callbacks(callbacks: Optional[List[Callback]],
has_syncer_callback = False
has_csv_logger = False
has_json_logger = False
has_tbx_logger = False
has_trial_progress_callback = any(
isinstance(c, TrialProgressCallback) for c in callbacks)
@@ -52,20 +58,14 @@ def create_default_callbacks(callbacks: Optional[List[Callback]],
last_logger_index = None
syncer_index = None
if not loggers:
# If no logger callback and no `loggers` have been provided,
# add DEFAULT_LOGGERS.
if not any(
isinstance(callback, ExperimentLogger)
for callback in callbacks):
loggers = DEFAULT_LOGGERS
# Create LegacyExperimentLogger for passed Logger classes
# Create LegacyLoggerCallback for passed Logger classes
if loggers:
# Todo(krfricke): Deprecate `loggers` argument, print warning here.
# Add warning as soon as we ported all loggers to LoggerCallback
# classes.
add_loggers = []
for trial_logger in loggers:
if isinstance(trial_logger, ExperimentLogger):
if isinstance(trial_logger, LoggerCallback):
callbacks.append(trial_logger)
elif isinstance(trial_logger, type) and issubclass(
trial_logger, Logger):
@@ -75,31 +75,41 @@ def create_default_callbacks(callbacks: Optional[List[Callback]],
f"Invalid value passed to `loggers` argument of "
f"`tune.run()`: {trial_logger}")
if add_loggers:
callbacks.append(LegacyExperimentLogger(add_loggers))
callbacks.append(LegacyLoggerCallback(add_loggers))
# Check if we have a CSV and JSON logger
# Check if we have a CSV, JSON and TensorboardX logger
for i, callback in enumerate(callbacks):
if isinstance(callback, LegacyExperimentLogger):
if isinstance(callback, LegacyLoggerCallback):
last_logger_index = i
if CSVLogger in callback.logger_classes:
has_csv_logger = True
if JsonLogger in callback.logger_classes:
has_json_logger = True
# Todo(krfricke): add checks for new ExperimentLogger classes
if TBXLogger in callback.logger_classes:
has_tbx_logger = True
elif isinstance(callback, CSVLoggerCallback):
has_csv_logger = True
last_logger_index = i
elif isinstance(callback, JsonLoggerCallback):
has_json_logger = True
last_logger_index = i
elif isinstance(callback, TBXLoggerCallback):
has_tbx_logger = True
last_logger_index = i
elif isinstance(callback, SyncerCallback):
syncer_index = i
has_syncer_callback = True
# If CSV or JSON logger is missing, add
# If CSV, JSON or TensorboardX loggers are missing, add
if os.environ.get("TUNE_DISABLE_AUTO_CALLBACK_LOGGERS", "0") != "1":
# Todo(krfricke): Switch to new ExperimentLogger classes
add_loggers = []
if not has_csv_logger:
add_loggers.append(CSVLogger)
callbacks.append(CSVLoggerCallback())
last_logger_index = len(callbacks) - 1
if not has_json_logger:
add_loggers.append(JsonLogger)
if add_loggers:
callbacks.append(LegacyExperimentLogger(add_loggers))
callbacks.append(JsonLoggerCallback())
last_logger_index = len(callbacks) - 1
if not has_tbx_logger:
callbacks.append(TBXLoggerCallback())
last_logger_index = len(callbacks) - 1
# If no SyncerCallback was found, add
@@ -110,18 +120,18 @@ def create_default_callbacks(callbacks: Optional[List[Callback]],
callbacks.append(syncer_callback)
syncer_index = len(callbacks) - 1
# Todo(krfricke): Maybe check if syncer comes after all loggers
if syncer_index is not None and last_logger_index is not None and \
syncer_index < last_logger_index:
if (not has_csv_logger or not has_json_logger) and not loggers:
if (not has_csv_logger or not has_json_logger or not has_tbx_logger) \
and not loggers:
# Only raise the warning if the loggers were passed by the user.
# (I.e. don't warn if this was automatic behavior and they only
# passed a customer SyncerCallback).
raise ValueError(
"The `SyncerCallback` you passed to `tune.run()` came before "
"at least one `ExperimentLogger`. Syncing should be done "
"at least one `LoggerCallback`. Syncing should be done "
"after writing logs. Please re-order the callbacks so that "
"the `SyncerCallback` comes after any `ExperimentLogger`.")
"the `SyncerCallback` comes after any `LoggerCallback`.")
else:
# If these loggers were automatically created. just re-order
# the callbacks