mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 12:45:44 +08:00
[tune] logger migration to ExperimentLogger classes (#11984)
This commit is contained in:
@@ -4,11 +4,12 @@ import argparse
|
||||
import time
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.logger import LoggerCallback
|
||||
|
||||
|
||||
class TestLogger(tune.logger.Logger):
|
||||
def on_result(self, result):
|
||||
print("TestLogger", result)
|
||||
class TestLoggerCallback(LoggerCallback):
|
||||
def on_trial_result(self, iteration, trials, trial, result, **info):
|
||||
print(f"TestLogger for trial {trial}: {result}")
|
||||
|
||||
|
||||
def trial_str_creator(trial):
|
||||
@@ -44,8 +45,8 @@ if __name__ == "__main__":
|
||||
mode="min",
|
||||
num_samples=5,
|
||||
trial_name_creator=trial_str_creator,
|
||||
loggers=[TestLogger],
|
||||
stop={"training_iteration": 1 if args.smoke_test else 99999},
|
||||
callbacks=[TestLoggerCallback()],
|
||||
stop={"training_iteration": 1 if args.smoke_test else 100},
|
||||
config={
|
||||
"steps": 100,
|
||||
"width": tune.randint(10, 100),
|
||||
|
||||
@@ -7,9 +7,9 @@ import wandb
|
||||
|
||||
from ray import tune
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.integration.wandb import WandbLogger, WandbTrainableMixin, \
|
||||
from ray.tune.integration.wandb import WandbLoggerCallback, \
|
||||
WandbTrainableMixin, \
|
||||
wandb_mixin
|
||||
from ray.tune.logger import DEFAULT_LOGGERS
|
||||
|
||||
|
||||
def train_function(config, checkpoint_dir=None):
|
||||
@@ -19,20 +19,19 @@ def train_function(config, checkpoint_dir=None):
|
||||
|
||||
|
||||
def tune_function(api_key_file):
|
||||
"""Example for using a WandbLogger with the function API"""
|
||||
"""Example for using a WandbLoggerCallback with the function API"""
|
||||
analysis = tune.run(
|
||||
train_function,
|
||||
metric="loss",
|
||||
mode="min",
|
||||
config={
|
||||
"mean": tune.grid_search([1, 2, 3, 4, 5]),
|
||||
"sd": tune.uniform(0.2, 0.8),
|
||||
"wandb": {
|
||||
"api_key_file": api_key_file,
|
||||
"project": "Wandb_example"
|
||||
}
|
||||
"sd": tune.uniform(0.2, 0.8)
|
||||
},
|
||||
loggers=DEFAULT_LOGGERS + (WandbLogger, ))
|
||||
callbacks=[
|
||||
WandbLoggerCallback(
|
||||
api_key_file=api_key_file, project="Wandb_example")
|
||||
])
|
||||
return analysis.best_config
|
||||
|
||||
|
||||
@@ -95,7 +94,7 @@ if __name__ == "__main__":
|
||||
api_key_file = "~/.wandb_api_key"
|
||||
|
||||
if args.mock_api:
|
||||
WandbLogger._logger_process_cls = MagicMock
|
||||
WandbLoggerCallback._logger_process_cls = MagicMock
|
||||
decorated_train_function.__mixins__ = tuple()
|
||||
WandbTrainable._wandb = MagicMock()
|
||||
wandb = MagicMock() # noqa: F811
|
||||
|
||||
@@ -129,8 +129,8 @@ class Experiment:
|
||||
# if they instantiate their `Experiment` objects themselves.
|
||||
raise ValueError(
|
||||
"Passing `loggers` to an `Experiment` is deprecated. Use "
|
||||
"an `ExperimentLogger` callback instead, e.g. by passing the "
|
||||
"`Logger` classes to `tune.logger.LegacyExperimentLogger` and "
|
||||
"an `LoggerCallback` callback instead, e.g. by passing the "
|
||||
"`Logger` classes to `tune.logger.LegacyLoggerCallback` and "
|
||||
"passing this as part of the `callback` parameter to "
|
||||
"`tune.run()`.")
|
||||
|
||||
|
||||
@@ -2,14 +2,15 @@ import os
|
||||
import pickle
|
||||
from multiprocessing import Process, Queue
|
||||
from numbers import Number
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
from ray import logger
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.function_runner import FunctionRunner
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.logger import LoggerCallback, Logger
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -30,7 +31,7 @@ def _is_allowed_type(obj):
|
||||
return isinstance(obj, Number)
|
||||
|
||||
|
||||
def _clean_log(obj):
|
||||
def _clean_log(obj: Any):
|
||||
# Fixes https://github.com/ray-project/ray/issues/10631
|
||||
if isinstance(obj, dict):
|
||||
return {k: _clean_log(v) for k, v in obj.items()}
|
||||
@@ -142,12 +143,10 @@ def wandb_mixin(func: Callable):
|
||||
return func
|
||||
|
||||
|
||||
def _set_api_key(wandb_config: Dict):
|
||||
def _set_api_key(api_key_file: Optional[str] = None,
|
||||
api_key: Optional[str] = None):
|
||||
"""Set WandB API key from `wandb_config`. Will pop the
|
||||
`api_key_file` and `api_key` keys from `wandb_config` parameter"""
|
||||
api_key_file = os.path.expanduser(wandb_config.pop("api_key_file", ""))
|
||||
api_key = wandb_config.pop("api_key", None)
|
||||
|
||||
if api_key_file:
|
||||
if api_key:
|
||||
raise ValueError("Both WandB `api_key_file` and `api_key` set.")
|
||||
@@ -166,7 +165,8 @@ def _set_api_key(wandb_config: Dict):
|
||||
pass
|
||||
raise ValueError(
|
||||
"No WandB API key found. Either set the {} environment "
|
||||
"variable, pass `api_key` or `api_key_file` in the config, "
|
||||
"variable, pass `api_key` or `api_key_file` to the"
|
||||
"`WandbLoggerCallback` class as arguments, "
|
||||
"or run `wandb login` from the command line".format(WANDB_ENV_VAR))
|
||||
|
||||
|
||||
@@ -219,9 +219,166 @@ class _WandbLoggingProcess(Process):
|
||||
return log, config_update
|
||||
|
||||
|
||||
class WandbLoggerCallback(LoggerCallback):
|
||||
"""WandbLoggerCallback
|
||||
|
||||
Weights and biases (https://www.wandb.com/) is a tool for experiment
|
||||
tracking, model optimization, and dataset versioning. This Ray Tune
|
||||
``LoggerCallback`` sends metrics to Wandb for automatic tracking and
|
||||
visualization.
|
||||
|
||||
Args:
|
||||
project (str): Name of the Wandb project. Mandatory.
|
||||
group (str): Name of the Wandb group. Defaults to the trainable
|
||||
name.
|
||||
api_key_file (str): Path to file containing the Wandb API KEY. This
|
||||
file only needs to be present on the node running the Tune script
|
||||
if using the WandbLogger.
|
||||
api_key (str): Wandb API Key. Alternative to setting ``api_key_file``.
|
||||
excludes (list): List of metrics that should be excluded from
|
||||
the log.
|
||||
log_config (bool): Boolean indicating if the ``config`` parameter of
|
||||
the ``results`` dict should be logged. This makes sense if
|
||||
parameters will change during training, e.g. with
|
||||
PopulationBasedTraining. Defaults to False.
|
||||
**kwargs: The keyword arguments will be pased to ``wandb.init()``.
|
||||
|
||||
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
|
||||
by Tune, but can be overwritten by filling out the respective configuration
|
||||
values.
|
||||
|
||||
Please see here for all other valid configuration settings:
|
||||
https://docs.wandb.com/library/init
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.logger import DEFAULT_LOGGERS
|
||||
from ray.tune.integration.wandb import WandbLoggerCallback
|
||||
tune.run(
|
||||
train_fn,
|
||||
config={
|
||||
# define search space here
|
||||
"parameter_1": tune.choice([1, 2, 3]),
|
||||
"parameter_2": tune.choice([4, 5, 6]),
|
||||
},
|
||||
callbacks=[WandbLoggerCallback(
|
||||
project="Optimization_Project",
|
||||
api_key_file="/path/to/file",
|
||||
log_config=True)])
|
||||
|
||||
"""
|
||||
|
||||
# Do not log these result keys
|
||||
_exclude_results = ["done", "should_checkpoint"]
|
||||
|
||||
# Use these result keys to update `wandb.config`
|
||||
_config_results = [
|
||||
"trial_id", "experiment_tag", "node_ip", "experiment_id", "hostname",
|
||||
"pid", "date"
|
||||
]
|
||||
|
||||
_logger_process_cls = _WandbLoggingProcess
|
||||
|
||||
def __init__(self,
|
||||
project: str,
|
||||
group: Optional[str] = None,
|
||||
api_key_file: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
excludes: Optional[List[str]] = None,
|
||||
log_config: bool = False,
|
||||
**kwargs):
|
||||
self.project = project
|
||||
self.group = group
|
||||
self.api_key_file = os.path.expanduser(
|
||||
api_key_file) if api_key_file else None
|
||||
self.api_key = api_key
|
||||
self.excludes = excludes or []
|
||||
self.log_config = log_config
|
||||
self.kwargs = kwargs
|
||||
|
||||
self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {}
|
||||
self._trial_queues: Dict["Trial", Queue] = {}
|
||||
|
||||
_set_api_key(self.api_key_file, self.api_key)
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
config = trial.config.copy()
|
||||
|
||||
config.pop("callbacks", None) # Remove callbacks
|
||||
|
||||
exclude_results = self._exclude_results.copy()
|
||||
|
||||
# Additional excludes
|
||||
exclude_results += self.excludes
|
||||
|
||||
# Log config keys on each result?
|
||||
if not self.log_config:
|
||||
exclude_results += ["config"]
|
||||
|
||||
# Fill trial ID and name
|
||||
trial_id = trial.trial_id if trial else None
|
||||
trial_name = str(trial) if trial else None
|
||||
|
||||
# Project name for Wandb
|
||||
wandb_project = self.project
|
||||
|
||||
# Grouping
|
||||
wandb_group = self.group or trial.trainable_name if trial else None
|
||||
|
||||
# remove unpickleable items!
|
||||
config = _clean_log(config)
|
||||
|
||||
wandb_init_kwargs = dict(
|
||||
id=trial_id,
|
||||
name=trial_name,
|
||||
resume=True,
|
||||
reinit=True,
|
||||
allow_val_change=True,
|
||||
group=wandb_group,
|
||||
project=wandb_project,
|
||||
config=config)
|
||||
wandb_init_kwargs.update(self.kwargs)
|
||||
|
||||
self._trial_queues[trial] = Queue()
|
||||
self._trial_processes[trial] = self._logger_process_cls(
|
||||
queue=self._trial_queues[trial],
|
||||
exclude=exclude_results,
|
||||
to_config=self._config_results,
|
||||
**wandb_init_kwargs)
|
||||
self._trial_processes[trial].start()
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
if trial not in self._trial_processes:
|
||||
self.log_trial_start(trial)
|
||||
|
||||
result = _clean_log(result)
|
||||
self._trial_queues[trial].put(result)
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
self._trial_queues[trial].put(_WANDB_QUEUE_END)
|
||||
self._trial_processes[trial].join(timeout=10)
|
||||
|
||||
del self._trial_queues[trial]
|
||||
del self._trial_processes[trial]
|
||||
|
||||
def __del__(self):
|
||||
for trial in self._trial_processes:
|
||||
if trial in self._trial_queues:
|
||||
self._trial_queues[trial].put(_WANDB_QUEUE_END)
|
||||
del self._trial_queues[trial]
|
||||
self._trial_processes[trial].join(timeout=2)
|
||||
del self._trial_processes[trial]
|
||||
|
||||
|
||||
class WandbLogger(Logger):
|
||||
"""WandbLogger
|
||||
|
||||
.. warning::
|
||||
This `Logger` class is deprecated. Use the `WandbLoggerCallback`
|
||||
callback instead.
|
||||
|
||||
Weights and biases (https://www.wandb.com/) is a tool for experiment
|
||||
tracking, model optimization, and dataset versioning. This Ray Tune
|
||||
``Logger`` sends metrics to Wandb for automatic tracking and
|
||||
@@ -300,21 +457,10 @@ class WandbLogger(Logger):
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Do not log these result keys
|
||||
_exclude_results = ["done", "should_checkpoint"]
|
||||
|
||||
# Use these result keys to update `wandb.config`
|
||||
_config_results = [
|
||||
"trial_id", "experiment_tag", "node_ip", "experiment_id", "hostname",
|
||||
"pid", "date"
|
||||
]
|
||||
|
||||
_logger_process_cls = _WandbLoggingProcess
|
||||
_experiment_logger_cls = WandbLoggerCallback
|
||||
|
||||
def _init(self):
|
||||
config = self.config.copy()
|
||||
|
||||
config.pop("callbacks", None) # Remove callbacks
|
||||
|
||||
try:
|
||||
@@ -329,63 +475,17 @@ class WandbLogger(Logger):
|
||||
"Make sure to include a `wandb` key in your `config` dict "
|
||||
"containing at least a `project` specification.")
|
||||
|
||||
_set_api_key(wandb_config)
|
||||
self._trial_experiment_logger = self._experiment_logger_cls(
|
||||
**wandb_config)
|
||||
|
||||
exclude_results = self._exclude_results.copy()
|
||||
|
||||
# Additional excludes
|
||||
additional_excludes = wandb_config.pop("excludes", [])
|
||||
exclude_results += additional_excludes
|
||||
|
||||
# Log config keys on each result?
|
||||
log_config = wandb_config.pop("log_config", False)
|
||||
if not log_config:
|
||||
exclude_results += ["config"]
|
||||
|
||||
# Fill trial ID and name
|
||||
trial_id = self.trial.trial_id if self.trial else None
|
||||
trial_name = str(self.trial) if self.trial else None
|
||||
|
||||
# Project name for Wandb
|
||||
try:
|
||||
wandb_project = wandb_config.pop("project")
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"You need to specify a `project` in your wandb `config` dict.")
|
||||
|
||||
# Grouping
|
||||
wandb_group = wandb_config.pop(
|
||||
"group", self.trial.trainable_name if self.trial else None)
|
||||
|
||||
# remove unpickleable items!
|
||||
config = _clean_log(config)
|
||||
|
||||
wandb_init_kwargs = dict(
|
||||
id=trial_id,
|
||||
name=trial_name,
|
||||
resume=True,
|
||||
reinit=True,
|
||||
allow_val_change=True,
|
||||
group=wandb_group,
|
||||
project=wandb_project,
|
||||
config=config)
|
||||
wandb_init_kwargs.update(wandb_config)
|
||||
|
||||
self._queue = Queue()
|
||||
self._wandb = self._logger_process_cls(
|
||||
queue=self._queue,
|
||||
exclude=exclude_results,
|
||||
to_config=self._config_results,
|
||||
**wandb_init_kwargs)
|
||||
self._wandb.start()
|
||||
self._trial_experiment_logger.log_trial_start(self.trial)
|
||||
|
||||
def on_result(self, result: Dict):
|
||||
result = _clean_log(result)
|
||||
self._queue.put(result)
|
||||
self._trial_experiment_logger.log_trial_result(0, self.trial, result)
|
||||
|
||||
def close(self):
|
||||
self._queue.put(_WANDB_QUEUE_END)
|
||||
self._wandb.join(timeout=10)
|
||||
self._trial_experiment_logger.log_trial_end(self.trial, failed=False)
|
||||
del self._trial_experiment_logger
|
||||
|
||||
|
||||
class WandbTrainableMixin:
|
||||
@@ -411,7 +511,11 @@ class WandbTrainableMixin:
|
||||
"Make sure to include a `wandb` key in your `config` dict "
|
||||
"containing at least a `project` specification.")
|
||||
|
||||
_set_api_key(wandb_config)
|
||||
api_key_file = wandb_config.pop("api_key_file", None)
|
||||
if api_key_file:
|
||||
api_key_file = os.path.expanduser(api_key_file)
|
||||
|
||||
_set_api_key(api_key_file, wandb_config.pop("api_key", None))
|
||||
|
||||
# Fill trial ID and name
|
||||
trial_id = self.trial_id
|
||||
|
||||
+285
-8
@@ -5,7 +5,7 @@ import numpy as np
|
||||
import os
|
||||
import yaml
|
||||
|
||||
from typing import Iterable, TYPE_CHECKING, Dict, List, Optional, Type
|
||||
from typing import Iterable, TYPE_CHECKING, Dict, List, Optional, TextIO, Type
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
|
||||
@@ -364,21 +364,60 @@ class UnifiedLogger(Logger):
|
||||
_logger.flush()
|
||||
|
||||
|
||||
class ExperimentLogger(Callback):
|
||||
class LoggerCallback(Callback):
|
||||
"""Base class for experiment-level logger callbacks
|
||||
|
||||
This base class defines a general interface for logging events,
|
||||
like trial starts, restores, ends, checkpoint saves, and receiving
|
||||
trial results.
|
||||
|
||||
Callbacks implementing this interface should make sure that logging
|
||||
utilities are cleaned up properly on trial termination, i.e. when
|
||||
``log_trial_end`` is received. This includes e.g. closing files.
|
||||
"""
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
raise NotImplementedError
|
||||
"""Handle logging when a trial starts.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def log_trial_restore(self, trial: "Trial"):
|
||||
raise NotImplementedError
|
||||
"""Handle logging when a trial restores.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def log_trial_save(self, trial: "Trial"):
|
||||
raise NotImplementedError
|
||||
"""Handle logging when a trial saves a checkpoint.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
raise NotImplementedError
|
||||
"""Handle logging when a trial reports a result.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial object.
|
||||
result (dict): Result dictionary.
|
||||
"""
|
||||
pass
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
raise NotImplementedError
|
||||
"""Handle logging when a trial ends.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial object.
|
||||
failed (bool): True if the Trial finished gracefully, False if
|
||||
it failed (e.g. when it raised an exception).
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_trial_result(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", result: Dict, **info):
|
||||
@@ -405,7 +444,7 @@ class ExperimentLogger(Callback):
|
||||
self.log_trial_end(trial, failed=True)
|
||||
|
||||
|
||||
class LegacyExperimentLogger(ExperimentLogger):
|
||||
class LegacyLoggerCallback(LoggerCallback):
|
||||
"""Supports logging to trial-specific `Logger` classes.
|
||||
|
||||
Previously, Ray Tune logging was handled via `Logger` classes that have
|
||||
@@ -455,6 +494,244 @@ class LegacyExperimentLogger(ExperimentLogger):
|
||||
trial_loggers[trial].close()
|
||||
|
||||
|
||||
class JsonLoggerCallback(LoggerCallback):
|
||||
"""Logs trial results in json format.
|
||||
|
||||
Also writes to a results file and param.json file when results or
|
||||
configurations are updated. Experiments must be executed with the
|
||||
JsonLoggerCallback to be compatible with the ExperimentAnalysis tool.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._trial_configs: Dict["Trial", Dict] = {}
|
||||
self._trial_files: Dict["Trial", TextIO] = {}
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
if trial in self._trial_files:
|
||||
self._trial_files[trial].close()
|
||||
|
||||
# Update config
|
||||
self.update_config(trial, trial.config)
|
||||
|
||||
# Make sure logdir exists
|
||||
trial.init_logdir()
|
||||
local_file = os.path.join(trial.logdir, EXPR_RESULT_FILE)
|
||||
self._trial_files[trial] = open(local_file, "at")
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
if trial not in self._trial_files:
|
||||
self.log_trial_start(trial)
|
||||
json.dump(result, self._trial_files[trial], cls=SafeFallbackEncoder)
|
||||
self._trial_files[trial].write("\n")
|
||||
self._trial_files[trial].flush()
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
if trial not in self._trial_files:
|
||||
return
|
||||
|
||||
self._trial_files[trial].close()
|
||||
del self._trial_files[trial]
|
||||
|
||||
def update_config(self, trial: "Trial", config: Dict):
|
||||
self._trial_configs[trial] = config
|
||||
|
||||
config_out = os.path.join(trial.logdir, EXPR_PARAM_FILE)
|
||||
with open(config_out, "w") as f:
|
||||
json.dump(
|
||||
self._trial_configs[trial],
|
||||
f,
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
cls=SafeFallbackEncoder)
|
||||
|
||||
config_pkl = os.path.join(trial.logdir, EXPR_PARAM_PICKLE_FILE)
|
||||
with open(config_pkl, "wb") as f:
|
||||
cloudpickle.dump(self._trial_configs[trial], f)
|
||||
|
||||
|
||||
class CSVLoggerCallback(LoggerCallback):
|
||||
"""Logs results to progress.csv under the trial directory.
|
||||
|
||||
Automatically flattens nested dicts in the result dict before writing
|
||||
to csv:
|
||||
|
||||
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._trial_continue: Dict["Trial", bool] = {}
|
||||
self._trial_files: Dict["Trial", TextIO] = {}
|
||||
self._trial_csv: Dict["Trial", csv.DictWriter] = {}
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
if trial in self._trial_files:
|
||||
self._trial_files[trial].close()
|
||||
|
||||
# Make sure logdir exists
|
||||
trial.init_logdir()
|
||||
local_file = os.path.join(trial.logdir, EXPR_PROGRESS_FILE)
|
||||
self._trial_continue[trial] = os.path.exists(local_file)
|
||||
self._trial_files[trial] = open(local_file, "at")
|
||||
self._trial_csv[trial] = None
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
if trial not in self._trial_files:
|
||||
self.log_trial_start(trial)
|
||||
|
||||
tmp = result.copy()
|
||||
tmp.pop("config", None)
|
||||
result = flatten_dict(tmp, delimiter="/")
|
||||
|
||||
if not self._trial_csv[trial]:
|
||||
self._trial_csv[trial] = csv.DictWriter(self._trial_files[trial],
|
||||
result.keys())
|
||||
if not self._trial_continue[trial]:
|
||||
self._trial_csv[trial].writeheader()
|
||||
|
||||
self._trial_csv[trial].writerow({
|
||||
k: v
|
||||
for k, v in result.items()
|
||||
if k in self._trial_csv[trial].fieldnames
|
||||
})
|
||||
self._trial_files[trial].flush()
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
if trial not in self._trial_files:
|
||||
return
|
||||
|
||||
del self._trial_csv[trial]
|
||||
self._trial_files[trial].close()
|
||||
del self._trial_files[trial]
|
||||
|
||||
|
||||
class TBXLoggerCallback(LoggerCallback):
|
||||
"""TensorBoardX Logger.
|
||||
|
||||
Note that hparams will be written only after a trial has terminated.
|
||||
This logger automatically flattens nested dicts to show on TensorBoard:
|
||||
|
||||
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
|
||||
"""
|
||||
|
||||
# NoneType is not supported on the last TBX release yet.
|
||||
VALID_HPARAMS = (str, bool, np.bool8, int, np.integer, float, list)
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
self._summary_writer_cls = SummaryWriter
|
||||
except ImportError:
|
||||
if log_once("tbx-install"):
|
||||
logger.info(
|
||||
"pip install 'ray[tune]' to see TensorBoard files.")
|
||||
raise
|
||||
self._trial_writer: Dict["Trial", SummaryWriter] = {}
|
||||
self._trial_result: Dict["Trial", Dict] = {}
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
trial.init_logdir()
|
||||
self._trial_writer[trial] = self._summary_writer_cls(
|
||||
trial.logdir, flush_secs=30)
|
||||
self._trial_result[trial] = {}
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
if trial not in self._trial_writer:
|
||||
self.log_trial_start(trial)
|
||||
|
||||
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
|
||||
|
||||
tmp = result.copy()
|
||||
for k in [
|
||||
"config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION
|
||||
]:
|
||||
if k in tmp:
|
||||
del tmp[k] # not useful to log these
|
||||
|
||||
flat_result = flatten_dict(tmp, delimiter="/")
|
||||
path = ["ray", "tune"]
|
||||
valid_result = {}
|
||||
|
||||
for attr, value in flat_result.items():
|
||||
full_attr = "/".join(path + [attr])
|
||||
if (isinstance(value, tuple(VALID_SUMMARY_TYPES))
|
||||
and not np.isnan(value)):
|
||||
valid_result[full_attr] = value
|
||||
self._trial_writer[trial].add_scalar(
|
||||
full_attr, value, global_step=step)
|
||||
elif ((isinstance(value, list) and len(value) > 0)
|
||||
or (isinstance(value, np.ndarray) and value.size > 0)):
|
||||
valid_result[full_attr] = value
|
||||
|
||||
# Must be video
|
||||
if isinstance(value, np.ndarray) and value.ndim == 5:
|
||||
self._trial_writer[trial].add_video(
|
||||
full_attr, value, global_step=step, fps=20)
|
||||
continue
|
||||
|
||||
try:
|
||||
self._trial_writer[trial].add_histogram(
|
||||
full_attr, value, global_step=step)
|
||||
# In case TensorboardX still doesn't think it's a valid value
|
||||
# (e.g. `[[]]`), warn and move on.
|
||||
except (ValueError, TypeError):
|
||||
if log_once("invalid_tbx_value"):
|
||||
logger.warning(
|
||||
"You are trying to log an invalid value ({}={}) "
|
||||
"via {}!".format(full_attr, value,
|
||||
type(self).__name__))
|
||||
|
||||
self._trial_result[trial] = valid_result
|
||||
self._trial_writer[trial].flush()
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
if trial in self._trial_writer:
|
||||
if trial and trial.evaluated_params and self._trial_result[trial]:
|
||||
flat_result = flatten_dict(
|
||||
self._trial_result[trial], delimiter="/")
|
||||
scrubbed_result = {
|
||||
k: value
|
||||
for k, value in flat_result.items()
|
||||
if isinstance(value, tuple(VALID_SUMMARY_TYPES))
|
||||
}
|
||||
self._try_log_hparams(trial, scrubbed_result)
|
||||
self._trial_writer[trial].close()
|
||||
del self._trial_writer[trial]
|
||||
del self._trial_result[trial]
|
||||
|
||||
def _try_log_hparams(self, trial: "Trial", result: Dict):
|
||||
# TBX currently errors if the hparams value is None.
|
||||
flat_params = flatten_dict(trial.evaluated_params)
|
||||
scrubbed_params = {
|
||||
k: v
|
||||
for k, v in flat_params.items()
|
||||
if isinstance(v, self.VALID_HPARAMS)
|
||||
}
|
||||
|
||||
removed = {
|
||||
k: v
|
||||
for k, v in flat_params.items()
|
||||
if not isinstance(v, self.VALID_HPARAMS)
|
||||
}
|
||||
if removed:
|
||||
logger.info(
|
||||
"Removed the following hyperparameter values when "
|
||||
"logging to tensorboard: %s", str(removed))
|
||||
|
||||
from tensorboardX.summary import hparams
|
||||
try:
|
||||
experiment_tag, session_start_tag, session_end_tag = hparams(
|
||||
hparam_dict=scrubbed_params, metric_dict=result)
|
||||
self._trial_writer[trial].file_writer.add_summary(experiment_tag)
|
||||
self._trial_writer[trial].file_writer.add_summary(
|
||||
session_start_tag)
|
||||
self._trial_writer[trial].file_writer.add_summary(session_end_tag)
|
||||
except Exception:
|
||||
logger.exception("TensorboardX failed to log hparams. "
|
||||
"This may be due to an unsupported type "
|
||||
"in the hyperparameter values.")
|
||||
|
||||
|
||||
def pretty_print(result):
|
||||
result = result.copy()
|
||||
result.update(config=None) # drop config from pretty print
|
||||
|
||||
@@ -8,15 +8,22 @@ import numpy as np
|
||||
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.integration.wandb import _WandbLoggingProcess, \
|
||||
from ray.tune.integration.wandb import WandbLoggerCallback, \
|
||||
_WandbLoggingProcess, \
|
||||
_WANDB_QUEUE_END, WandbLogger, WANDB_ENV_VAR, WandbTrainableMixin, \
|
||||
wandb_mixin
|
||||
from ray.tune.result import TRIAL_INFO
|
||||
from ray.tune.trial import TrialInfo
|
||||
|
||||
Trial = namedtuple("MockTrial",
|
||||
["config", "trial_id", "trial_name", "trainable_name"])
|
||||
Trial.__str__ = lambda t: t.trial_name
|
||||
|
||||
class Trial(
|
||||
namedtuple("MockTrial",
|
||||
["config", "trial_id", "trial_name", "trainable_name"])):
|
||||
def __hash__(self):
|
||||
return hash(self.trial_id)
|
||||
|
||||
def __str__(self):
|
||||
return self.trial_name
|
||||
|
||||
|
||||
class _MockWandbLoggingProcess(_WandbLoggingProcess):
|
||||
@@ -37,9 +44,21 @@ class _MockWandbLoggingProcess(_WandbLoggingProcess):
|
||||
self.logs.put(log)
|
||||
|
||||
|
||||
class WandbTestLogger(WandbLogger):
|
||||
class WandbTestExperimentLogger(WandbLoggerCallback):
|
||||
_logger_process_cls = _MockWandbLoggingProcess
|
||||
|
||||
@property
|
||||
def trial_processes(self):
|
||||
return self._trial_processes
|
||||
|
||||
|
||||
class WandbTestLogger(WandbLogger):
|
||||
_experiment_logger_cls = WandbTestExperimentLogger
|
||||
|
||||
@property
|
||||
def trial_process(self):
|
||||
return self._trial_experiment_logger.trial_processes[self.trial]
|
||||
|
||||
|
||||
class _MockWandbAPI(object):
|
||||
def init(self, *args, **kwargs):
|
||||
@@ -63,7 +82,7 @@ class WandbIntegrationTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def testWandbLoggerConfig(self):
|
||||
def testWandbLegacyLoggerConfig(self):
|
||||
trial_config = {"par1": 4, "par2": 9.12345678}
|
||||
trial = Trial(trial_config, 0, "trial_0", "trainable")
|
||||
|
||||
@@ -115,11 +134,13 @@ class WandbIntegrationTest(unittest.TestCase):
|
||||
trial_config["wandb"] = {"project": "test_project"}
|
||||
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
self.assertEqual(logger._wandb.kwargs["project"], "test_project")
|
||||
self.assertEqual(logger._wandb.kwargs["id"], trial.trial_id)
|
||||
self.assertEqual(logger._wandb.kwargs["name"], trial.trial_name)
|
||||
self.assertEqual(logger._wandb.kwargs["group"], trial.trainable_name)
|
||||
self.assertIn("config", logger._wandb._exclude)
|
||||
self.assertEqual(logger.trial_process.kwargs["project"],
|
||||
"test_project")
|
||||
self.assertEqual(logger.trial_process.kwargs["id"], trial.trial_id)
|
||||
self.assertEqual(logger.trial_process.kwargs["name"], trial.trial_name)
|
||||
self.assertEqual(logger.trial_process.kwargs["group"],
|
||||
trial.trainable_name)
|
||||
self.assertIn("config", logger.trial_process._exclude)
|
||||
|
||||
logger.close()
|
||||
|
||||
@@ -127,8 +148,8 @@ class WandbIntegrationTest(unittest.TestCase):
|
||||
trial_config["wandb"] = {"project": "test_project", "log_config": True}
|
||||
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
self.assertNotIn("config", logger._wandb._exclude)
|
||||
self.assertNotIn("metric", logger._wandb._exclude)
|
||||
self.assertNotIn("config", logger.trial_process._exclude)
|
||||
self.assertNotIn("metric", logger.trial_process._exclude)
|
||||
|
||||
logger.close()
|
||||
|
||||
@@ -139,12 +160,12 @@ class WandbIntegrationTest(unittest.TestCase):
|
||||
}
|
||||
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
self.assertIn("config", logger._wandb._exclude)
|
||||
self.assertIn("metric", logger._wandb._exclude)
|
||||
self.assertIn("config", logger.trial_process._exclude)
|
||||
self.assertIn("metric", logger.trial_process._exclude)
|
||||
|
||||
logger.close()
|
||||
|
||||
def testWandbLoggerReporting(self):
|
||||
def testWandbLegacyLoggerReporting(self):
|
||||
trial_config = {"par1": 4, "par2": 9.12345678}
|
||||
trial = Trial(trial_config, 0, "trial_0", "trainable")
|
||||
|
||||
@@ -166,7 +187,7 @@ class WandbIntegrationTest(unittest.TestCase):
|
||||
|
||||
logger.on_result(r1)
|
||||
|
||||
logged = logger._wandb.logs.get(timeout=10)
|
||||
logged = logger.trial_process.logs.get(timeout=10)
|
||||
self.assertIn("metric1", logged)
|
||||
self.assertNotIn("metric2", logged)
|
||||
self.assertIn("metric3", logged)
|
||||
@@ -176,6 +197,106 @@ class WandbIntegrationTest(unittest.TestCase):
|
||||
|
||||
logger.close()
|
||||
|
||||
def testWandbLoggerConfig(self):
|
||||
trial_config = {"par1": 4, "par2": 9.12345678}
|
||||
trial = Trial(trial_config, 0, "trial_0", "trainable")
|
||||
|
||||
if WANDB_ENV_VAR in os.environ:
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# No API key
|
||||
with self.assertRaises(ValueError):
|
||||
logger = WandbTestExperimentLogger(project="test_project")
|
||||
|
||||
# API Key in config
|
||||
logger = WandbTestExperimentLogger(
|
||||
project="test_project", api_key="1234")
|
||||
self.assertEqual(os.environ[WANDB_ENV_VAR], "1234")
|
||||
|
||||
del logger
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# API Key file
|
||||
with tempfile.NamedTemporaryFile("wt") as fp:
|
||||
fp.write("5678")
|
||||
fp.flush()
|
||||
|
||||
logger = WandbTestExperimentLogger(
|
||||
project="test_project", api_key_file=fp.name)
|
||||
self.assertEqual(os.environ[WANDB_ENV_VAR], "5678")
|
||||
|
||||
del logger
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# API Key in env
|
||||
os.environ[WANDB_ENV_VAR] = "9012"
|
||||
logger = WandbTestExperimentLogger(project="test_project")
|
||||
del logger
|
||||
|
||||
# From now on, the API key is in the env variable.
|
||||
|
||||
logger = WandbTestExperimentLogger(project="test_project")
|
||||
logger.log_trial_start(trial)
|
||||
|
||||
self.assertEqual(logger.trial_processes[trial].kwargs["project"],
|
||||
"test_project")
|
||||
self.assertEqual(logger.trial_processes[trial].kwargs["id"],
|
||||
trial.trial_id)
|
||||
self.assertEqual(logger.trial_processes[trial].kwargs["name"],
|
||||
trial.trial_name)
|
||||
self.assertEqual(logger.trial_processes[trial].kwargs["group"],
|
||||
trial.trainable_name)
|
||||
self.assertIn("config", logger.trial_processes[trial]._exclude)
|
||||
|
||||
del logger
|
||||
|
||||
# log config.
|
||||
logger = WandbTestExperimentLogger(
|
||||
project="test_project", log_config=True)
|
||||
logger.log_trial_start(trial)
|
||||
self.assertNotIn("config", logger.trial_processes[trial]._exclude)
|
||||
self.assertNotIn("metric", logger.trial_processes[trial]._exclude)
|
||||
|
||||
del logger
|
||||
|
||||
# Exclude metric.
|
||||
logger = WandbTestExperimentLogger(
|
||||
project="test_project", excludes=["metric"])
|
||||
logger.log_trial_start(trial)
|
||||
self.assertIn("config", logger.trial_processes[trial]._exclude)
|
||||
self.assertIn("metric", logger.trial_processes[trial]._exclude)
|
||||
|
||||
del logger
|
||||
|
||||
def testWandbLoggerReporting(self):
|
||||
trial_config = {"par1": 4, "par2": 9.12345678}
|
||||
trial = Trial(trial_config, 0, "trial_0", "trainable")
|
||||
|
||||
logger = WandbTestExperimentLogger(
|
||||
project="test_project", api_key="1234", excludes=["metric2"])
|
||||
logger.on_trial_start(0, [], trial)
|
||||
|
||||
r1 = {
|
||||
"metric1": 0.8,
|
||||
"metric2": 1.4,
|
||||
"metric3": np.asarray(32.0),
|
||||
"metric4": np.float32(32.0),
|
||||
"const": "text",
|
||||
"config": trial_config
|
||||
}
|
||||
|
||||
logger.on_trial_result(0, [], trial, r1)
|
||||
|
||||
logged = logger.trial_processes[trial].logs.get(timeout=10)
|
||||
self.assertIn("metric1", logged)
|
||||
self.assertNotIn("metric2", logged)
|
||||
self.assertIn("metric3", logged)
|
||||
self.assertIn("metric4", logged)
|
||||
self.assertNotIn("const", logged)
|
||||
self.assertNotIn("config", logged)
|
||||
|
||||
del logger
|
||||
|
||||
def testWandbMixinConfig(self):
|
||||
config = {"par1": 4, "par2": 9.12345678}
|
||||
trial = Trial(config, 0, "trial_0", "trainable")
|
||||
|
||||
@@ -1,12 +1,33 @@
|
||||
import csv
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from collections import namedtuple
|
||||
import unittest
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
from ray.cloudpickle import cloudpickle
|
||||
|
||||
from ray.tune.logger import JsonLogger, CSVLogger, TBXLogger
|
||||
from ray.tune.logger import CSVLoggerCallback, JsonLoggerCallback, \
|
||||
JsonLogger, CSVLogger, \
|
||||
TBXLoggerCallback, TBXLogger
|
||||
from ray.tune.result import EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE, \
|
||||
EXPR_PROGRESS_FILE, \
|
||||
EXPR_RESULT_FILE
|
||||
|
||||
Trial = namedtuple("MockTrial", ["evaluated_params", "trial_id"])
|
||||
|
||||
class Trial(
|
||||
namedtuple("MockTrial", ["evaluated_params", "trial_id", "logdir"])):
|
||||
@property
|
||||
def config(self):
|
||||
return self.evaluated_params
|
||||
|
||||
def init_logdir(self):
|
||||
return
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.trial_id)
|
||||
|
||||
|
||||
def result(t, rew, **kwargs):
|
||||
@@ -28,24 +49,116 @@ class LoggerSuite(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir, ignore_errors=True)
|
||||
|
||||
def testCSV(self):
|
||||
def testLegacyCSV(self):
|
||||
config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}}
|
||||
t = Trial(evaluated_params=config, trial_id="csv")
|
||||
t = Trial(
|
||||
evaluated_params=config, trial_id="csv", logdir=self.test_dir)
|
||||
logger = CSVLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(2, 4))
|
||||
logger.on_result(result(2, 4))
|
||||
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
|
||||
logger.on_result(result(2, 5))
|
||||
logger.on_result(result(2, 6, score=[1, 2, 3], hello={"world": 1}))
|
||||
logger.close()
|
||||
|
||||
self._validate_csv_result()
|
||||
|
||||
def testCSV(self):
|
||||
config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}}
|
||||
t = Trial(
|
||||
evaluated_params=config, trial_id="csv", logdir=self.test_dir)
|
||||
logger = CSVLoggerCallback()
|
||||
logger.on_trial_result(0, [], t, result(0, 4))
|
||||
logger.on_trial_result(1, [], t, result(1, 5))
|
||||
logger.on_trial_result(
|
||||
2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1}))
|
||||
|
||||
logger.on_trial_complete(3, [], t)
|
||||
self._validate_csv_result()
|
||||
|
||||
def _validate_csv_result(self):
|
||||
results = []
|
||||
result_file = os.path.join(self.test_dir, EXPR_PROGRESS_FILE)
|
||||
with open(result_file, "rt") as fp:
|
||||
reader = csv.DictReader(fp)
|
||||
for row in reader:
|
||||
results.append(row)
|
||||
|
||||
self.assertEqual(len(results), 3)
|
||||
self.assertSequenceEqual(
|
||||
[int(row["episode_reward_mean"]) for row in results], [4, 5, 6])
|
||||
|
||||
def testJSONLegacyLogger(self):
|
||||
config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}}
|
||||
t = Trial(
|
||||
evaluated_params=config, trial_id="json", logdir=self.test_dir)
|
||||
logger = JsonLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(0, 4))
|
||||
logger.on_result(result(1, 5))
|
||||
logger.on_result(result(2, 6, score=[1, 2, 3], hello={"world": 1}))
|
||||
logger.close()
|
||||
|
||||
self._validate_json_result(config)
|
||||
|
||||
def testJSON(self):
|
||||
config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}}
|
||||
t = Trial(evaluated_params=config, trial_id="json")
|
||||
logger = JsonLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
t = Trial(
|
||||
evaluated_params=config, trial_id="json", logdir=self.test_dir)
|
||||
logger = JsonLoggerCallback()
|
||||
logger.on_trial_result(0, [], t, result(0, 4))
|
||||
logger.on_trial_result(1, [], t, result(1, 5))
|
||||
logger.on_trial_result(
|
||||
2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1}))
|
||||
|
||||
logger.on_trial_complete(3, [], t)
|
||||
self._validate_json_result(config)
|
||||
|
||||
def _validate_json_result(self, config):
|
||||
# Check result logs
|
||||
results = []
|
||||
result_file = os.path.join(self.test_dir, EXPR_RESULT_FILE)
|
||||
with open(result_file, "rt") as fp:
|
||||
for row in fp.readlines():
|
||||
results.append(json.loads(row))
|
||||
|
||||
self.assertEqual(len(results), 3)
|
||||
self.assertSequenceEqual(
|
||||
[int(row["episode_reward_mean"]) for row in results], [4, 5, 6])
|
||||
|
||||
# Check json saved config file
|
||||
config_file = os.path.join(self.test_dir, EXPR_PARAM_FILE)
|
||||
with open(config_file, "rt") as fp:
|
||||
loaded_config = json.load(fp)
|
||||
|
||||
self.assertEqual(loaded_config, config)
|
||||
|
||||
# Check pickled config file
|
||||
config_file = os.path.join(self.test_dir, EXPR_PARAM_PICKLE_FILE)
|
||||
with open(config_file, "rb") as fp:
|
||||
loaded_config = cloudpickle.load(fp)
|
||||
|
||||
self.assertEqual(loaded_config, config)
|
||||
|
||||
def testLegacyTBX(self):
|
||||
config = {
|
||||
"a": 2,
|
||||
"b": [1, 2],
|
||||
"c": {
|
||||
"c": {
|
||||
"D": 123
|
||||
}
|
||||
},
|
||||
"d": np.int64(1),
|
||||
"e": np.bool8(True)
|
||||
}
|
||||
t = Trial(
|
||||
evaluated_params=config, trial_id="tbx", logdir=self.test_dir)
|
||||
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(0, 4))
|
||||
logger.on_result(result(1, 4))
|
||||
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
|
||||
logger.on_result(result(1, 5))
|
||||
logger.on_result(result(2, 6, score=[1, 2, 3], hello={"world": 1}))
|
||||
logger.close()
|
||||
|
||||
self._validate_tbx_result()
|
||||
|
||||
def testTBX(self):
|
||||
config = {
|
||||
"a": 2,
|
||||
@@ -58,16 +171,40 @@ class LoggerSuite(unittest.TestCase):
|
||||
"d": np.int64(1),
|
||||
"e": np.bool8(True)
|
||||
}
|
||||
t = Trial(evaluated_params=config, trial_id="tbx")
|
||||
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(0, 4))
|
||||
logger.on_result(result(1, 4))
|
||||
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
|
||||
logger.close()
|
||||
t = Trial(
|
||||
evaluated_params=config, trial_id="tbx", logdir=self.test_dir)
|
||||
logger = TBXLoggerCallback()
|
||||
logger.on_trial_result(0, [], t, result(0, 4))
|
||||
logger.on_trial_result(1, [], t, result(1, 5))
|
||||
logger.on_trial_result(
|
||||
2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1}))
|
||||
|
||||
def testBadTBX(self):
|
||||
logger.on_trial_complete(3, [], t)
|
||||
|
||||
self._validate_tbx_result()
|
||||
|
||||
def _validate_tbx_result(self):
|
||||
try:
|
||||
from tensorflow.python.summary.summary_iterator \
|
||||
import summary_iterator
|
||||
except ImportError:
|
||||
print("Skipping rest of test as tensorflow is not installed.")
|
||||
return
|
||||
|
||||
events_file = list(glob.glob(f"{self.test_dir}/events*"))[0]
|
||||
results = []
|
||||
for event in summary_iterator(events_file):
|
||||
for v in event.summary.value:
|
||||
if v.tag == "ray/tune/episode_reward_mean":
|
||||
results.append(v.simple_value)
|
||||
|
||||
self.assertEqual(len(results), 3)
|
||||
self.assertSequenceEqual([int(res) for res in results], [4, 5, 6])
|
||||
|
||||
def testLegacyBadTBX(self):
|
||||
config = {"b": (1, 2, 3)}
|
||||
t = Trial(evaluated_params=config, trial_id="tbx")
|
||||
t = Trial(
|
||||
evaluated_params=config, trial_id="tbx", logdir=self.test_dir)
|
||||
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(0, 4))
|
||||
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
|
||||
@@ -76,7 +213,8 @@ class LoggerSuite(unittest.TestCase):
|
||||
assert "INFO" in cm.output[0]
|
||||
|
||||
config = {"None": None}
|
||||
t = Trial(evaluated_params=config, trial_id="tbx")
|
||||
t = Trial(
|
||||
evaluated_params=config, trial_id="tbx", logdir=self.test_dir)
|
||||
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(0, 4))
|
||||
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
|
||||
@@ -84,6 +222,31 @@ class LoggerSuite(unittest.TestCase):
|
||||
logger.close()
|
||||
assert "INFO" in cm.output[0]
|
||||
|
||||
def testBadTBX(self):
|
||||
config = {"b": (1, 2, 3)}
|
||||
t = Trial(
|
||||
evaluated_params=config, trial_id="tbx", logdir=self.test_dir)
|
||||
logger = TBXLoggerCallback()
|
||||
logger.on_trial_result(0, [], t, result(0, 4))
|
||||
logger.on_trial_result(1, [], t, result(1, 5))
|
||||
logger.on_trial_result(
|
||||
2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1}))
|
||||
with self.assertLogs("ray.tune.logger", level="INFO") as cm:
|
||||
logger.on_trial_complete(3, [], t)
|
||||
assert "INFO" in cm.output[0]
|
||||
|
||||
config = {"None": None}
|
||||
t = Trial(
|
||||
evaluated_params=config, trial_id="tbx", logdir=self.test_dir)
|
||||
logger = TBXLoggerCallback()
|
||||
logger.on_trial_result(0, [], t, result(0, 4))
|
||||
logger.on_trial_result(1, [], t, result(1, 5))
|
||||
logger.on_trial_result(
|
||||
2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1}))
|
||||
with self.assertLogs("ray.tune.logger", level="INFO") as cm:
|
||||
logger.on_trial_complete(3, [], t)
|
||||
assert "INFO" in cm.output[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
@@ -7,7 +7,7 @@ from ray.rllib import _register_all
|
||||
from ray.tune.result import TIMESTEPS_TOTAL
|
||||
from ray.tune import Trainable, TuneError
|
||||
from ray.tune import register_trainable, run_experiments
|
||||
from ray.tune.logger import LegacyExperimentLogger, Logger
|
||||
from ray.tune.logger import LegacyLoggerCallback, Logger
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial, ExportFormat
|
||||
|
||||
@@ -191,7 +191,7 @@ class RunExperimentTest(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
},
|
||||
callbacks=[LegacyExperimentLogger(logger_classes=[CustomLogger])])
|
||||
callbacks=[LegacyLoggerCallback(logger_classes=[CustomLogger])])
|
||||
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log")))
|
||||
self.assertFalse(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
@@ -204,7 +204,7 @@ class RunExperimentTest(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
})
|
||||
self.assertTrue(
|
||||
self.assertFalse(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
|
||||
[trial] = run_experiments(
|
||||
@@ -216,7 +216,7 @@ class RunExperimentTest(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
},
|
||||
callbacks=[LegacyExperimentLogger(logger_classes=[])])
|
||||
callbacks=[LegacyLoggerCallback(logger_classes=[])])
|
||||
self.assertFalse(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
|
||||
@@ -239,7 +239,7 @@ class RunExperimentTest(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
},
|
||||
callbacks=[LegacyExperimentLogger(logger_classes=[CustomLogger])])
|
||||
callbacks=[LegacyLoggerCallback(logger_classes=[CustomLogger])])
|
||||
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log")))
|
||||
self.assertTrue(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
@@ -264,7 +264,7 @@ class RunExperimentTest(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
},
|
||||
callbacks=[LegacyExperimentLogger(logger_classes=[])])
|
||||
callbacks=[LegacyLoggerCallback(logger_classes=[])])
|
||||
self.assertTrue(
|
||||
os.path.exists(os.path.join(trial.logdir, "params.json")))
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ import ray
|
||||
from ray import tune
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.checkpoint_manager import Checkpoint
|
||||
from ray.tune.logger import DEFAULT_LOGGERS, ExperimentLogger, \
|
||||
LegacyExperimentLogger
|
||||
from ray.tune.logger import DEFAULT_LOGGERS, LoggerCallback, \
|
||||
LegacyLoggerCallback
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.syncer import SyncConfig, SyncerCallback
|
||||
@@ -205,14 +205,14 @@ class TrialRunnerCallbacks(unittest.TestCase):
|
||||
"delay")
|
||||
|
||||
def testCallbackReordering(self):
|
||||
"""SyncerCallback should come after ExperimentLogger callbacks"""
|
||||
"""SyncerCallback should come after LoggerCallback callbacks"""
|
||||
|
||||
def get_positions(callbacks):
|
||||
first_logger_pos = None
|
||||
last_logger_pos = None
|
||||
syncer_pos = None
|
||||
for i, callback in enumerate(callbacks):
|
||||
if isinstance(callback, ExperimentLogger):
|
||||
if isinstance(callback, LoggerCallback):
|
||||
if first_logger_pos is None:
|
||||
first_logger_pos = i
|
||||
last_logger_pos = i
|
||||
@@ -233,8 +233,8 @@ class TrialRunnerCallbacks(unittest.TestCase):
|
||||
self.assertLess(last_logger_pos, syncer_pos)
|
||||
|
||||
# Auto creation of loggers with existing logger (but no CSV/JSON)
|
||||
callbacks = create_default_callbacks([ExperimentLogger()],
|
||||
SyncConfig(), None)
|
||||
callbacks = create_default_callbacks([LoggerCallback()], SyncConfig(),
|
||||
None)
|
||||
first_logger_pos, last_logger_pos, syncer_pos = get_positions(
|
||||
callbacks)
|
||||
self.assertLess(last_logger_pos, syncer_pos)
|
||||
@@ -242,13 +242,12 @@ class TrialRunnerCallbacks(unittest.TestCase):
|
||||
# This should throw an error as the syncer comes before the logger
|
||||
with self.assertRaises(ValueError):
|
||||
callbacks = create_default_callbacks(
|
||||
[SyncerCallback(None),
|
||||
ExperimentLogger()], SyncConfig(), None)
|
||||
[SyncerCallback(None), LoggerCallback()], SyncConfig(), None)
|
||||
|
||||
# This should be reordered but preserve the regular callback order
|
||||
[mc1, mc2, mc3] = [Callback(), Callback(), Callback()]
|
||||
# Has to be legacy logger to avoid logger callback creation
|
||||
lc = LegacyExperimentLogger(logger_classes=DEFAULT_LOGGERS)
|
||||
lc = LegacyLoggerCallback(logger_classes=DEFAULT_LOGGERS)
|
||||
callbacks = create_default_callbacks([mc1, mc2, lc, mc3], SyncConfig(),
|
||||
None)
|
||||
print(callbacks)
|
||||
|
||||
@@ -73,7 +73,6 @@ def run(
|
||||
checkpoint_at_end=False,
|
||||
verbose=Verbosity.V3_TRIAL_DETAILS,
|
||||
progress_reporter=None,
|
||||
loggers=None,
|
||||
log_to_file=False,
|
||||
trial_name_creator=None,
|
||||
trial_dirname_creator=None,
|
||||
@@ -90,6 +89,7 @@ def run(
|
||||
raise_on_failed_trial=True,
|
||||
callbacks=None,
|
||||
# Deprecated args
|
||||
loggers=None,
|
||||
ray_auto_init=None,
|
||||
run_errored_only=None,
|
||||
global_checkpoint_period=None,
|
||||
@@ -196,9 +196,6 @@ def run(
|
||||
intermediate experiment progress. Defaults to CLIReporter if
|
||||
running in command-line, or JupyterNotebookReporter if running in
|
||||
a Jupyter notebook.
|
||||
loggers (list): List of logger creators to be used with
|
||||
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
|
||||
See `ray/tune/logger.py`.
|
||||
log_to_file (bool|str|Sequence): Log stdout and stderr to files in
|
||||
Tune's trial directories. If this is `False` (default), no files
|
||||
are written. If `true`, outputs are written to `trialdir/stdout`
|
||||
@@ -251,7 +248,9 @@ def run(
|
||||
trial (of ERROR state) when the experiments complete.
|
||||
callbacks (list): List of callbacks that will be called at different
|
||||
times in the training loop. Must be instances of the
|
||||
``ray.tune.trial_runner.Callback`` class.
|
||||
``ray.tune.trial_runner.Callback`` class. If not passed,
|
||||
`LoggerCallback` and `SyncerCallback` callbacks are automatically
|
||||
added.
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.progress_reporter import TrialProgressCallback
|
||||
from ray.tune.syncer import SyncConfig
|
||||
from ray.tune.logger import CSVLogger, DEFAULT_LOGGERS, ExperimentLogger, \
|
||||
JsonLogger, LegacyExperimentLogger, Logger
|
||||
from ray.tune.logger import CSVLoggerCallback, CSVLogger, \
|
||||
LoggerCallback, \
|
||||
JsonLoggerCallback, JsonLogger, LegacyLoggerCallback, Logger, \
|
||||
TBXLoggerCallback, TBXLogger
|
||||
from ray.tune.syncer import SyncerCallback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_default_callbacks(callbacks: Optional[List[Callback]],
|
||||
sync_config: SyncConfig,
|
||||
@@ -40,6 +45,7 @@ def create_default_callbacks(callbacks: Optional[List[Callback]],
|
||||
has_syncer_callback = False
|
||||
has_csv_logger = False
|
||||
has_json_logger = False
|
||||
has_tbx_logger = False
|
||||
|
||||
has_trial_progress_callback = any(
|
||||
isinstance(c, TrialProgressCallback) for c in callbacks)
|
||||
@@ -52,20 +58,14 @@ def create_default_callbacks(callbacks: Optional[List[Callback]],
|
||||
last_logger_index = None
|
||||
syncer_index = None
|
||||
|
||||
if not loggers:
|
||||
# If no logger callback and no `loggers` have been provided,
|
||||
# add DEFAULT_LOGGERS.
|
||||
if not any(
|
||||
isinstance(callback, ExperimentLogger)
|
||||
for callback in callbacks):
|
||||
loggers = DEFAULT_LOGGERS
|
||||
|
||||
# Create LegacyExperimentLogger for passed Logger classes
|
||||
# Create LegacyLoggerCallback for passed Logger classes
|
||||
if loggers:
|
||||
# Todo(krfricke): Deprecate `loggers` argument, print warning here.
|
||||
# Add warning as soon as we ported all loggers to LoggerCallback
|
||||
# classes.
|
||||
add_loggers = []
|
||||
for trial_logger in loggers:
|
||||
if isinstance(trial_logger, ExperimentLogger):
|
||||
if isinstance(trial_logger, LoggerCallback):
|
||||
callbacks.append(trial_logger)
|
||||
elif isinstance(trial_logger, type) and issubclass(
|
||||
trial_logger, Logger):
|
||||
@@ -75,31 +75,41 @@ def create_default_callbacks(callbacks: Optional[List[Callback]],
|
||||
f"Invalid value passed to `loggers` argument of "
|
||||
f"`tune.run()`: {trial_logger}")
|
||||
if add_loggers:
|
||||
callbacks.append(LegacyExperimentLogger(add_loggers))
|
||||
callbacks.append(LegacyLoggerCallback(add_loggers))
|
||||
|
||||
# Check if we have a CSV and JSON logger
|
||||
# Check if we have a CSV, JSON and TensorboardX logger
|
||||
for i, callback in enumerate(callbacks):
|
||||
if isinstance(callback, LegacyExperimentLogger):
|
||||
if isinstance(callback, LegacyLoggerCallback):
|
||||
last_logger_index = i
|
||||
if CSVLogger in callback.logger_classes:
|
||||
has_csv_logger = True
|
||||
if JsonLogger in callback.logger_classes:
|
||||
has_json_logger = True
|
||||
# Todo(krfricke): add checks for new ExperimentLogger classes
|
||||
if TBXLogger in callback.logger_classes:
|
||||
has_tbx_logger = True
|
||||
elif isinstance(callback, CSVLoggerCallback):
|
||||
has_csv_logger = True
|
||||
last_logger_index = i
|
||||
elif isinstance(callback, JsonLoggerCallback):
|
||||
has_json_logger = True
|
||||
last_logger_index = i
|
||||
elif isinstance(callback, TBXLoggerCallback):
|
||||
has_tbx_logger = True
|
||||
last_logger_index = i
|
||||
elif isinstance(callback, SyncerCallback):
|
||||
syncer_index = i
|
||||
has_syncer_callback = True
|
||||
|
||||
# If CSV or JSON logger is missing, add
|
||||
# If CSV, JSON or TensorboardX loggers are missing, add
|
||||
if os.environ.get("TUNE_DISABLE_AUTO_CALLBACK_LOGGERS", "0") != "1":
|
||||
# Todo(krfricke): Switch to new ExperimentLogger classes
|
||||
add_loggers = []
|
||||
if not has_csv_logger:
|
||||
add_loggers.append(CSVLogger)
|
||||
callbacks.append(CSVLoggerCallback())
|
||||
last_logger_index = len(callbacks) - 1
|
||||
if not has_json_logger:
|
||||
add_loggers.append(JsonLogger)
|
||||
if add_loggers:
|
||||
callbacks.append(LegacyExperimentLogger(add_loggers))
|
||||
callbacks.append(JsonLoggerCallback())
|
||||
last_logger_index = len(callbacks) - 1
|
||||
if not has_tbx_logger:
|
||||
callbacks.append(TBXLoggerCallback())
|
||||
last_logger_index = len(callbacks) - 1
|
||||
|
||||
# If no SyncerCallback was found, add
|
||||
@@ -110,18 +120,18 @@ def create_default_callbacks(callbacks: Optional[List[Callback]],
|
||||
callbacks.append(syncer_callback)
|
||||
syncer_index = len(callbacks) - 1
|
||||
|
||||
# Todo(krfricke): Maybe check if syncer comes after all loggers
|
||||
if syncer_index is not None and last_logger_index is not None and \
|
||||
syncer_index < last_logger_index:
|
||||
if (not has_csv_logger or not has_json_logger) and not loggers:
|
||||
if (not has_csv_logger or not has_json_logger or not has_tbx_logger) \
|
||||
and not loggers:
|
||||
# Only raise the warning if the loggers were passed by the user.
|
||||
# (I.e. don't warn if this was automatic behavior and they only
|
||||
# passed a customer SyncerCallback).
|
||||
raise ValueError(
|
||||
"The `SyncerCallback` you passed to `tune.run()` came before "
|
||||
"at least one `ExperimentLogger`. Syncing should be done "
|
||||
"at least one `LoggerCallback`. Syncing should be done "
|
||||
"after writing logs. Please re-order the callbacks so that "
|
||||
"the `SyncerCallback` comes after any `ExperimentLogger`.")
|
||||
"the `SyncerCallback` comes after any `LoggerCallback`.")
|
||||
else:
|
||||
# If these loggers were automatically created. just re-order
|
||||
# the callbacks
|
||||
|
||||
Reference in New Issue
Block a user