diff --git a/doc/source/tune/_tutorials/tune-wandb.rst b/doc/source/tune/_tutorials/tune-wandb.rst index 8cb851f58..4722dd086 100644 --- a/doc/source/tune/_tutorials/tune-wandb.rst +++ b/doc/source/tune/_tutorials/tune-wandb.rst @@ -15,7 +15,7 @@ tools. :target: https://www.wandb.com/ Ray Tune currently offers two lightweight integrations for Weights & Biases. -One is the :ref:`WandbLogger `, which automatically logs +One is the :ref:`WandbLoggerCallback `, which automatically logs metrics reported to Tune to the Wandb API. The other one is the :ref:`@wandb_mixin ` decorator, which can be @@ -28,7 +28,7 @@ Please :doc:`see here for a full example `. .. _tune-wandb-logger: -.. autoclass:: ray.tune.integration.wandb.WandbLogger +.. autoclass:: ray.tune.integration.wandb.WandbLoggerCallback :noindex: .. _tune-wandb-mixin: diff --git a/doc/source/tune/api_docs/integration.rst b/doc/source/tune/api_docs/integration.rst index 4731b8d51..7eebe4fcb 100644 --- a/doc/source/tune/api_docs/integration.rst +++ b/doc/source/tune/api_docs/integration.rst @@ -77,7 +77,7 @@ Weights and Biases (tune.integration.wandb) :ref:`See also here `. -.. autoclass:: ray.tune.integration.wandb.WandbLogger +.. autoclass:: ray.tune.integration.wandb.WandbLoggerCallback .. autofunction:: ray.tune.integration.wandb.wandb_mixin diff --git a/doc/source/tune/api_docs/logging.rst b/doc/source/tune/api_docs/logging.rst index 14561fb1a..240ec97af 100644 --- a/doc/source/tune/api_docs/logging.rst +++ b/doc/source/tune/api_docs/logging.rst @@ -7,56 +7,61 @@ Tune has default loggers for Tensorboard, CSV, and JSON formats. By default, Tun If you need to log something lower level like model weights or gradients, see :ref:`Trainable Logging `. +.. note:: + Tune's per-trial ``Logger`` classes have been deprecated. They can still be used, but we encourage you + to use our new interface with the ``LoggerCallback`` class instead. + Custom Loggers -------------- -You can create a custom logger by inheriting the Logger interface (:ref:`logger-interface`): +You can create a custom logger by inheriting the LoggerCallback interface (:ref:`logger-interface`): .. code-block:: python - from ray.tune.logger import Logger + from typing import Dict, List - class MLFLowLogger(Logger): - """MLFlow logger. + import json + import os - Requires the experiment configuration to have a MLFlow Experiment ID - or manually set the proper environment variables. - """ + from ray.tune.logger import LoggerCallback - def _init(self): - from mlflow.tracking import MlflowClient - client = MlflowClient() - # self.config is the same config that your Trainable will see. - run = client.create_run(self.config.get("mlflow_experiment_id")) - self._run_id = run.info.run_id - for key, value in self.config.items(): - client.log_param(self._run_id, key, value) - self.client = client + class CustomLoggerCallback(LoggerCallback): + """Custom logger interface""" - def on_result(self, result): - for key, value in result.items(): - if not isinstance(value, float): - continue - self.client.log_metric( - self._run_id, key, value, step=result.get(TRAINING_ITERATION)) + def __init__(self, filename: str = "log.txt): + self._trial_files = {} + self._filename = filename + + def log_trial_start(self, trial: "Trial"): + trial_logfile = os.path.join(trial.logdir, self._filename) + self._trial_files[trial] = open(trial_logfile, "at") + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + if trial in self._trial_files: + self._trial_files[trial].write(json.dumps(result)) + + def on_trial_complete(self, iteration: int, trials: List["Trial"], + trial: "Trial", **info): + if trial in self._trial_files: + self._trial_files[trial].close() + del self._trial_files[trial] - def close(self): - self.client.set_terminated(self._run_id) You can then pass in your own logger as follows: .. code-block:: python - from ray.tune.logger import DEFAULT_LOGGERS + from ray import tune tune.run( MyTrainableClass, name="experiment_name", - loggers=DEFAULT_LOGGERS + (CustomLogger1, CustomLogger2) + callbacks=[CustomLoggerCallback("log_test.txt")] ) -These loggers will be called along with the default Tune loggers. You can also check out `logger.py `__ for implementation details. +Per default, Ray Tune creates JSON, CSV and TensorboardX logger callbacks if you don't pass them yourself. +You can disable this behavior by setting the ``TUNE_DISABLE_AUTO_CALLBACK_LOGGERS`` environment variable to ``"1"``. An example of creating a custom logger can be found in :doc:`/tune/examples/logging_example`. @@ -131,7 +136,7 @@ In the distributed case, these logs will be sync'ed back to the driver under you Viskit ------ -Tune automatically integrates with `Viskit `_ via the ``CSVLogger`` outputs. To use VisKit (you may have to install some dependencies), run: +Tune automatically integrates with `Viskit `_ via the ``CSVLoggerCallback`` outputs. To use VisKit (you may have to install some dependencies), run: .. code-block:: bash @@ -143,25 +148,20 @@ The nonrelevant metrics (like timing stats) can be disabled on the left to show .. image:: /ray-tune-viskit.png -UnifiedLogger -------------- - -.. autoclass:: ray.tune.logger.UnifiedLogger - TBXLogger --------- -.. autoclass:: ray.tune.logger.TBXLogger +.. autoclass:: ray.tune.logger.TBXLoggerCallback JsonLogger ---------- -.. autoclass:: ray.tune.logger.JsonLogger +.. autoclass:: ray.tune.logger.JsonLoggerCallback CSVLogger --------- -.. autoclass:: ray.tune.logger.CSVLogger +.. autoclass:: ray.tune.logger.CSVLoggerCallback MLFLowLogger ------------ @@ -173,7 +173,8 @@ Tune also provides a default logger for `MLFlow `_. You can .. _logger-interface: -Logger ------- +LoggerCallback +-------------- -.. autoclass:: ray.tune.logger.Logger +.. autoclass:: ray.tune.logger.LoggerCallback + :members: log_trial_start, log_trial_restore, log_trial_save, log_trial_result, log_trial_end diff --git a/python/ray/tune/examples/logging_example.py b/python/ray/tune/examples/logging_example.py index 53fede738..318331c96 100755 --- a/python/ray/tune/examples/logging_example.py +++ b/python/ray/tune/examples/logging_example.py @@ -4,11 +4,12 @@ import argparse import time from ray import tune +from ray.tune.logger import LoggerCallback -class TestLogger(tune.logger.Logger): - def on_result(self, result): - print("TestLogger", result) +class TestLoggerCallback(LoggerCallback): + def on_trial_result(self, iteration, trials, trial, result, **info): + print(f"TestLogger for trial {trial}: {result}") def trial_str_creator(trial): @@ -44,8 +45,8 @@ if __name__ == "__main__": mode="min", num_samples=5, trial_name_creator=trial_str_creator, - loggers=[TestLogger], - stop={"training_iteration": 1 if args.smoke_test else 99999}, + callbacks=[TestLoggerCallback()], + stop={"training_iteration": 1 if args.smoke_test else 100}, config={ "steps": 100, "width": tune.randint(10, 100), diff --git a/python/ray/tune/examples/wandb_example.py b/python/ray/tune/examples/wandb_example.py index acd89f9d8..2a7b0a437 100644 --- a/python/ray/tune/examples/wandb_example.py +++ b/python/ray/tune/examples/wandb_example.py @@ -7,9 +7,9 @@ import wandb from ray import tune from ray.tune import Trainable -from ray.tune.integration.wandb import WandbLogger, WandbTrainableMixin, \ +from ray.tune.integration.wandb import WandbLoggerCallback, \ + WandbTrainableMixin, \ wandb_mixin -from ray.tune.logger import DEFAULT_LOGGERS def train_function(config, checkpoint_dir=None): @@ -19,20 +19,19 @@ def train_function(config, checkpoint_dir=None): def tune_function(api_key_file): - """Example for using a WandbLogger with the function API""" + """Example for using a WandbLoggerCallback with the function API""" analysis = tune.run( train_function, metric="loss", mode="min", config={ "mean": tune.grid_search([1, 2, 3, 4, 5]), - "sd": tune.uniform(0.2, 0.8), - "wandb": { - "api_key_file": api_key_file, - "project": "Wandb_example" - } + "sd": tune.uniform(0.2, 0.8) }, - loggers=DEFAULT_LOGGERS + (WandbLogger, )) + callbacks=[ + WandbLoggerCallback( + api_key_file=api_key_file, project="Wandb_example") + ]) return analysis.best_config @@ -95,7 +94,7 @@ if __name__ == "__main__": api_key_file = "~/.wandb_api_key" if args.mock_api: - WandbLogger._logger_process_cls = MagicMock + WandbLoggerCallback._logger_process_cls = MagicMock decorated_train_function.__mixins__ = tuple() WandbTrainable._wandb = MagicMock() wandb = MagicMock() # noqa: F811 diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 76c4b9c03..98ae6e640 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -129,8 +129,8 @@ class Experiment: # if they instantiate their `Experiment` objects themselves. raise ValueError( "Passing `loggers` to an `Experiment` is deprecated. Use " - "an `ExperimentLogger` callback instead, e.g. by passing the " - "`Logger` classes to `tune.logger.LegacyExperimentLogger` and " + "an `LoggerCallback` callback instead, e.g. by passing the " + "`Logger` classes to `tune.logger.LegacyLoggerCallback` and " "passing this as part of the `callback` parameter to " "`tune.run()`.") diff --git a/python/ray/tune/integration/wandb.py b/python/ray/tune/integration/wandb.py index 3c2efacf4..82327fffc 100644 --- a/python/ray/tune/integration/wandb.py +++ b/python/ray/tune/integration/wandb.py @@ -2,14 +2,15 @@ import os import pickle from multiprocessing import Process, Queue from numbers import Number -from typing import Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np from ray import logger from ray.tune import Trainable from ray.tune.function_runner import FunctionRunner -from ray.tune.logger import Logger +from ray.tune.logger import LoggerCallback, Logger from ray.tune.utils import flatten_dict +from ray.tune.trial import Trial import yaml @@ -30,7 +31,7 @@ def _is_allowed_type(obj): return isinstance(obj, Number) -def _clean_log(obj): +def _clean_log(obj: Any): # Fixes https://github.com/ray-project/ray/issues/10631 if isinstance(obj, dict): return {k: _clean_log(v) for k, v in obj.items()} @@ -142,12 +143,10 @@ def wandb_mixin(func: Callable): return func -def _set_api_key(wandb_config: Dict): +def _set_api_key(api_key_file: Optional[str] = None, + api_key: Optional[str] = None): """Set WandB API key from `wandb_config`. Will pop the `api_key_file` and `api_key` keys from `wandb_config` parameter""" - api_key_file = os.path.expanduser(wandb_config.pop("api_key_file", "")) - api_key = wandb_config.pop("api_key", None) - if api_key_file: if api_key: raise ValueError("Both WandB `api_key_file` and `api_key` set.") @@ -166,7 +165,8 @@ def _set_api_key(wandb_config: Dict): pass raise ValueError( "No WandB API key found. Either set the {} environment " - "variable, pass `api_key` or `api_key_file` in the config, " + "variable, pass `api_key` or `api_key_file` to the" + "`WandbLoggerCallback` class as arguments, " "or run `wandb login` from the command line".format(WANDB_ENV_VAR)) @@ -219,9 +219,166 @@ class _WandbLoggingProcess(Process): return log, config_update +class WandbLoggerCallback(LoggerCallback): + """WandbLoggerCallback + + Weights and biases (https://www.wandb.com/) is a tool for experiment + tracking, model optimization, and dataset versioning. This Ray Tune + ``LoggerCallback`` sends metrics to Wandb for automatic tracking and + visualization. + + Args: + project (str): Name of the Wandb project. Mandatory. + group (str): Name of the Wandb group. Defaults to the trainable + name. + api_key_file (str): Path to file containing the Wandb API KEY. This + file only needs to be present on the node running the Tune script + if using the WandbLogger. + api_key (str): Wandb API Key. Alternative to setting ``api_key_file``. + excludes (list): List of metrics that should be excluded from + the log. + log_config (bool): Boolean indicating if the ``config`` parameter of + the ``results`` dict should be logged. This makes sense if + parameters will change during training, e.g. with + PopulationBasedTraining. Defaults to False. + **kwargs: The keyword arguments will be pased to ``wandb.init()``. + + Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected + by Tune, but can be overwritten by filling out the respective configuration + values. + + Please see here for all other valid configuration settings: + https://docs.wandb.com/library/init + + Example: + + .. code-block:: python + + from ray.tune.logger import DEFAULT_LOGGERS + from ray.tune.integration.wandb import WandbLoggerCallback + tune.run( + train_fn, + config={ + # define search space here + "parameter_1": tune.choice([1, 2, 3]), + "parameter_2": tune.choice([4, 5, 6]), + }, + callbacks=[WandbLoggerCallback( + project="Optimization_Project", + api_key_file="/path/to/file", + log_config=True)]) + + """ + + # Do not log these result keys + _exclude_results = ["done", "should_checkpoint"] + + # Use these result keys to update `wandb.config` + _config_results = [ + "trial_id", "experiment_tag", "node_ip", "experiment_id", "hostname", + "pid", "date" + ] + + _logger_process_cls = _WandbLoggingProcess + + def __init__(self, + project: str, + group: Optional[str] = None, + api_key_file: Optional[str] = None, + api_key: Optional[str] = None, + excludes: Optional[List[str]] = None, + log_config: bool = False, + **kwargs): + self.project = project + self.group = group + self.api_key_file = os.path.expanduser( + api_key_file) if api_key_file else None + self.api_key = api_key + self.excludes = excludes or [] + self.log_config = log_config + self.kwargs = kwargs + + self._trial_processes: Dict["Trial", _WandbLoggingProcess] = {} + self._trial_queues: Dict["Trial", Queue] = {} + + _set_api_key(self.api_key_file, self.api_key) + + def log_trial_start(self, trial: "Trial"): + config = trial.config.copy() + + config.pop("callbacks", None) # Remove callbacks + + exclude_results = self._exclude_results.copy() + + # Additional excludes + exclude_results += self.excludes + + # Log config keys on each result? + if not self.log_config: + exclude_results += ["config"] + + # Fill trial ID and name + trial_id = trial.trial_id if trial else None + trial_name = str(trial) if trial else None + + # Project name for Wandb + wandb_project = self.project + + # Grouping + wandb_group = self.group or trial.trainable_name if trial else None + + # remove unpickleable items! + config = _clean_log(config) + + wandb_init_kwargs = dict( + id=trial_id, + name=trial_name, + resume=True, + reinit=True, + allow_val_change=True, + group=wandb_group, + project=wandb_project, + config=config) + wandb_init_kwargs.update(self.kwargs) + + self._trial_queues[trial] = Queue() + self._trial_processes[trial] = self._logger_process_cls( + queue=self._trial_queues[trial], + exclude=exclude_results, + to_config=self._config_results, + **wandb_init_kwargs) + self._trial_processes[trial].start() + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + if trial not in self._trial_processes: + self.log_trial_start(trial) + + result = _clean_log(result) + self._trial_queues[trial].put(result) + + def log_trial_end(self, trial: "Trial", failed: bool = False): + self._trial_queues[trial].put(_WANDB_QUEUE_END) + self._trial_processes[trial].join(timeout=10) + + del self._trial_queues[trial] + del self._trial_processes[trial] + + def __del__(self): + for trial in self._trial_processes: + if trial in self._trial_queues: + self._trial_queues[trial].put(_WANDB_QUEUE_END) + del self._trial_queues[trial] + self._trial_processes[trial].join(timeout=2) + del self._trial_processes[trial] + + class WandbLogger(Logger): """WandbLogger + .. warning:: + This `Logger` class is deprecated. Use the `WandbLoggerCallback` + callback instead. + Weights and biases (https://www.wandb.com/) is a tool for experiment tracking, model optimization, and dataset versioning. This Ray Tune ``Logger`` sends metrics to Wandb for automatic tracking and @@ -300,21 +457,10 @@ class WandbLogger(Logger): """ - - # Do not log these result keys - _exclude_results = ["done", "should_checkpoint"] - - # Use these result keys to update `wandb.config` - _config_results = [ - "trial_id", "experiment_tag", "node_ip", "experiment_id", "hostname", - "pid", "date" - ] - - _logger_process_cls = _WandbLoggingProcess + _experiment_logger_cls = WandbLoggerCallback def _init(self): config = self.config.copy() - config.pop("callbacks", None) # Remove callbacks try: @@ -329,63 +475,17 @@ class WandbLogger(Logger): "Make sure to include a `wandb` key in your `config` dict " "containing at least a `project` specification.") - _set_api_key(wandb_config) + self._trial_experiment_logger = self._experiment_logger_cls( + **wandb_config) - exclude_results = self._exclude_results.copy() - - # Additional excludes - additional_excludes = wandb_config.pop("excludes", []) - exclude_results += additional_excludes - - # Log config keys on each result? - log_config = wandb_config.pop("log_config", False) - if not log_config: - exclude_results += ["config"] - - # Fill trial ID and name - trial_id = self.trial.trial_id if self.trial else None - trial_name = str(self.trial) if self.trial else None - - # Project name for Wandb - try: - wandb_project = wandb_config.pop("project") - except KeyError: - raise ValueError( - "You need to specify a `project` in your wandb `config` dict.") - - # Grouping - wandb_group = wandb_config.pop( - "group", self.trial.trainable_name if self.trial else None) - - # remove unpickleable items! - config = _clean_log(config) - - wandb_init_kwargs = dict( - id=trial_id, - name=trial_name, - resume=True, - reinit=True, - allow_val_change=True, - group=wandb_group, - project=wandb_project, - config=config) - wandb_init_kwargs.update(wandb_config) - - self._queue = Queue() - self._wandb = self._logger_process_cls( - queue=self._queue, - exclude=exclude_results, - to_config=self._config_results, - **wandb_init_kwargs) - self._wandb.start() + self._trial_experiment_logger.log_trial_start(self.trial) def on_result(self, result: Dict): - result = _clean_log(result) - self._queue.put(result) + self._trial_experiment_logger.log_trial_result(0, self.trial, result) def close(self): - self._queue.put(_WANDB_QUEUE_END) - self._wandb.join(timeout=10) + self._trial_experiment_logger.log_trial_end(self.trial, failed=False) + del self._trial_experiment_logger class WandbTrainableMixin: @@ -411,7 +511,11 @@ class WandbTrainableMixin: "Make sure to include a `wandb` key in your `config` dict " "containing at least a `project` specification.") - _set_api_key(wandb_config) + api_key_file = wandb_config.pop("api_key_file", None) + if api_key_file: + api_key_file = os.path.expanduser(api_key_file) + + _set_api_key(api_key_file, wandb_config.pop("api_key", None)) # Fill trial ID and name trial_id = self.trial_id diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 6f363baca..ff55676cc 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -5,7 +5,7 @@ import numpy as np import os import yaml -from typing import Iterable, TYPE_CHECKING, Dict, List, Optional, Type +from typing import Iterable, TYPE_CHECKING, Dict, List, Optional, TextIO, Type import ray.cloudpickle as cloudpickle @@ -364,21 +364,60 @@ class UnifiedLogger(Logger): _logger.flush() -class ExperimentLogger(Callback): +class LoggerCallback(Callback): + """Base class for experiment-level logger callbacks + + This base class defines a general interface for logging events, + like trial starts, restores, ends, checkpoint saves, and receiving + trial results. + + Callbacks implementing this interface should make sure that logging + utilities are cleaned up properly on trial termination, i.e. when + ``log_trial_end`` is received. This includes e.g. closing files. + """ + def log_trial_start(self, trial: "Trial"): - raise NotImplementedError + """Handle logging when a trial starts. + + Args: + trial (Trial): Trial object. + """ + pass def log_trial_restore(self, trial: "Trial"): - raise NotImplementedError + """Handle logging when a trial restores. + + Args: + trial (Trial): Trial object. + """ + pass def log_trial_save(self, trial: "Trial"): - raise NotImplementedError + """Handle logging when a trial saves a checkpoint. + + Args: + trial (Trial): Trial object. + """ + pass def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): - raise NotImplementedError + """Handle logging when a trial reports a result. + + Args: + trial (Trial): Trial object. + result (dict): Result dictionary. + """ + pass def log_trial_end(self, trial: "Trial", failed: bool = False): - raise NotImplementedError + """Handle logging when a trial ends. + + Args: + trial (Trial): Trial object. + failed (bool): True if the Trial finished gracefully, False if + it failed (e.g. when it raised an exception). + """ + pass def on_trial_result(self, iteration: int, trials: List["Trial"], trial: "Trial", result: Dict, **info): @@ -405,7 +444,7 @@ class ExperimentLogger(Callback): self.log_trial_end(trial, failed=True) -class LegacyExperimentLogger(ExperimentLogger): +class LegacyLoggerCallback(LoggerCallback): """Supports logging to trial-specific `Logger` classes. Previously, Ray Tune logging was handled via `Logger` classes that have @@ -455,6 +494,244 @@ class LegacyExperimentLogger(ExperimentLogger): trial_loggers[trial].close() +class JsonLoggerCallback(LoggerCallback): + """Logs trial results in json format. + + Also writes to a results file and param.json file when results or + configurations are updated. Experiments must be executed with the + JsonLoggerCallback to be compatible with the ExperimentAnalysis tool. + """ + + def __init__(self): + self._trial_configs: Dict["Trial", Dict] = {} + self._trial_files: Dict["Trial", TextIO] = {} + + def log_trial_start(self, trial: "Trial"): + if trial in self._trial_files: + self._trial_files[trial].close() + + # Update config + self.update_config(trial, trial.config) + + # Make sure logdir exists + trial.init_logdir() + local_file = os.path.join(trial.logdir, EXPR_RESULT_FILE) + self._trial_files[trial] = open(local_file, "at") + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + if trial not in self._trial_files: + self.log_trial_start(trial) + json.dump(result, self._trial_files[trial], cls=SafeFallbackEncoder) + self._trial_files[trial].write("\n") + self._trial_files[trial].flush() + + def log_trial_end(self, trial: "Trial", failed: bool = False): + if trial not in self._trial_files: + return + + self._trial_files[trial].close() + del self._trial_files[trial] + + def update_config(self, trial: "Trial", config: Dict): + self._trial_configs[trial] = config + + config_out = os.path.join(trial.logdir, EXPR_PARAM_FILE) + with open(config_out, "w") as f: + json.dump( + self._trial_configs[trial], + f, + indent=2, + sort_keys=True, + cls=SafeFallbackEncoder) + + config_pkl = os.path.join(trial.logdir, EXPR_PARAM_PICKLE_FILE) + with open(config_pkl, "wb") as f: + cloudpickle.dump(self._trial_configs[trial], f) + + +class CSVLoggerCallback(LoggerCallback): + """Logs results to progress.csv under the trial directory. + + Automatically flattens nested dicts in the result dict before writing + to csv: + + {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2} + + """ + + def __init__(self): + self._trial_continue: Dict["Trial", bool] = {} + self._trial_files: Dict["Trial", TextIO] = {} + self._trial_csv: Dict["Trial", csv.DictWriter] = {} + + def log_trial_start(self, trial: "Trial"): + if trial in self._trial_files: + self._trial_files[trial].close() + + # Make sure logdir exists + trial.init_logdir() + local_file = os.path.join(trial.logdir, EXPR_PROGRESS_FILE) + self._trial_continue[trial] = os.path.exists(local_file) + self._trial_files[trial] = open(local_file, "at") + self._trial_csv[trial] = None + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + if trial not in self._trial_files: + self.log_trial_start(trial) + + tmp = result.copy() + tmp.pop("config", None) + result = flatten_dict(tmp, delimiter="/") + + if not self._trial_csv[trial]: + self._trial_csv[trial] = csv.DictWriter(self._trial_files[trial], + result.keys()) + if not self._trial_continue[trial]: + self._trial_csv[trial].writeheader() + + self._trial_csv[trial].writerow({ + k: v + for k, v in result.items() + if k in self._trial_csv[trial].fieldnames + }) + self._trial_files[trial].flush() + + def log_trial_end(self, trial: "Trial", failed: bool = False): + if trial not in self._trial_files: + return + + del self._trial_csv[trial] + self._trial_files[trial].close() + del self._trial_files[trial] + + +class TBXLoggerCallback(LoggerCallback): + """TensorBoardX Logger. + + Note that hparams will be written only after a trial has terminated. + This logger automatically flattens nested dicts to show on TensorBoard: + + {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2} + """ + + # NoneType is not supported on the last TBX release yet. + VALID_HPARAMS = (str, bool, np.bool8, int, np.integer, float, list) + + def __init__(self): + try: + from tensorboardX import SummaryWriter + self._summary_writer_cls = SummaryWriter + except ImportError: + if log_once("tbx-install"): + logger.info( + "pip install 'ray[tune]' to see TensorBoard files.") + raise + self._trial_writer: Dict["Trial", SummaryWriter] = {} + self._trial_result: Dict["Trial", Dict] = {} + + def log_trial_start(self, trial: "Trial"): + trial.init_logdir() + self._trial_writer[trial] = self._summary_writer_cls( + trial.logdir, flush_secs=30) + self._trial_result[trial] = {} + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + if trial not in self._trial_writer: + self.log_trial_start(trial) + + step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION] + + tmp = result.copy() + for k in [ + "config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION + ]: + if k in tmp: + del tmp[k] # not useful to log these + + flat_result = flatten_dict(tmp, delimiter="/") + path = ["ray", "tune"] + valid_result = {} + + for attr, value in flat_result.items(): + full_attr = "/".join(path + [attr]) + if (isinstance(value, tuple(VALID_SUMMARY_TYPES)) + and not np.isnan(value)): + valid_result[full_attr] = value + self._trial_writer[trial].add_scalar( + full_attr, value, global_step=step) + elif ((isinstance(value, list) and len(value) > 0) + or (isinstance(value, np.ndarray) and value.size > 0)): + valid_result[full_attr] = value + + # Must be video + if isinstance(value, np.ndarray) and value.ndim == 5: + self._trial_writer[trial].add_video( + full_attr, value, global_step=step, fps=20) + continue + + try: + self._trial_writer[trial].add_histogram( + full_attr, value, global_step=step) + # In case TensorboardX still doesn't think it's a valid value + # (e.g. `[[]]`), warn and move on. + except (ValueError, TypeError): + if log_once("invalid_tbx_value"): + logger.warning( + "You are trying to log an invalid value ({}={}) " + "via {}!".format(full_attr, value, + type(self).__name__)) + + self._trial_result[trial] = valid_result + self._trial_writer[trial].flush() + + def log_trial_end(self, trial: "Trial", failed: bool = False): + if trial in self._trial_writer: + if trial and trial.evaluated_params and self._trial_result[trial]: + flat_result = flatten_dict( + self._trial_result[trial], delimiter="/") + scrubbed_result = { + k: value + for k, value in flat_result.items() + if isinstance(value, tuple(VALID_SUMMARY_TYPES)) + } + self._try_log_hparams(trial, scrubbed_result) + self._trial_writer[trial].close() + del self._trial_writer[trial] + del self._trial_result[trial] + + def _try_log_hparams(self, trial: "Trial", result: Dict): + # TBX currently errors if the hparams value is None. + flat_params = flatten_dict(trial.evaluated_params) + scrubbed_params = { + k: v + for k, v in flat_params.items() + if isinstance(v, self.VALID_HPARAMS) + } + + removed = { + k: v + for k, v in flat_params.items() + if not isinstance(v, self.VALID_HPARAMS) + } + if removed: + logger.info( + "Removed the following hyperparameter values when " + "logging to tensorboard: %s", str(removed)) + + from tensorboardX.summary import hparams + try: + experiment_tag, session_start_tag, session_end_tag = hparams( + hparam_dict=scrubbed_params, metric_dict=result) + self._trial_writer[trial].file_writer.add_summary(experiment_tag) + self._trial_writer[trial].file_writer.add_summary( + session_start_tag) + self._trial_writer[trial].file_writer.add_summary(session_end_tag) + except Exception: + logger.exception("TensorboardX failed to log hparams. " + "This may be due to an unsupported type " + "in the hyperparameter values.") + + def pretty_print(result): result = result.copy() result.update(config=None) # drop config from pretty print diff --git a/python/ray/tune/tests/test_integration_wandb.py b/python/ray/tune/tests/test_integration_wandb.py index 6baa6b341..9869585eb 100644 --- a/python/ray/tune/tests/test_integration_wandb.py +++ b/python/ray/tune/tests/test_integration_wandb.py @@ -8,15 +8,22 @@ import numpy as np from ray.tune import Trainable from ray.tune.function_runner import wrap_function -from ray.tune.integration.wandb import _WandbLoggingProcess, \ +from ray.tune.integration.wandb import WandbLoggerCallback, \ + _WandbLoggingProcess, \ _WANDB_QUEUE_END, WandbLogger, WANDB_ENV_VAR, WandbTrainableMixin, \ wandb_mixin from ray.tune.result import TRIAL_INFO from ray.tune.trial import TrialInfo -Trial = namedtuple("MockTrial", - ["config", "trial_id", "trial_name", "trainable_name"]) -Trial.__str__ = lambda t: t.trial_name + +class Trial( + namedtuple("MockTrial", + ["config", "trial_id", "trial_name", "trainable_name"])): + def __hash__(self): + return hash(self.trial_id) + + def __str__(self): + return self.trial_name class _MockWandbLoggingProcess(_WandbLoggingProcess): @@ -37,9 +44,21 @@ class _MockWandbLoggingProcess(_WandbLoggingProcess): self.logs.put(log) -class WandbTestLogger(WandbLogger): +class WandbTestExperimentLogger(WandbLoggerCallback): _logger_process_cls = _MockWandbLoggingProcess + @property + def trial_processes(self): + return self._trial_processes + + +class WandbTestLogger(WandbLogger): + _experiment_logger_cls = WandbTestExperimentLogger + + @property + def trial_process(self): + return self._trial_experiment_logger.trial_processes[self.trial] + class _MockWandbAPI(object): def init(self, *args, **kwargs): @@ -63,7 +82,7 @@ class WandbIntegrationTest(unittest.TestCase): def tearDown(self): pass - def testWandbLoggerConfig(self): + def testWandbLegacyLoggerConfig(self): trial_config = {"par1": 4, "par2": 9.12345678} trial = Trial(trial_config, 0, "trial_0", "trainable") @@ -115,11 +134,13 @@ class WandbIntegrationTest(unittest.TestCase): trial_config["wandb"] = {"project": "test_project"} logger = WandbTestLogger(trial_config, "/tmp", trial) - self.assertEqual(logger._wandb.kwargs["project"], "test_project") - self.assertEqual(logger._wandb.kwargs["id"], trial.trial_id) - self.assertEqual(logger._wandb.kwargs["name"], trial.trial_name) - self.assertEqual(logger._wandb.kwargs["group"], trial.trainable_name) - self.assertIn("config", logger._wandb._exclude) + self.assertEqual(logger.trial_process.kwargs["project"], + "test_project") + self.assertEqual(logger.trial_process.kwargs["id"], trial.trial_id) + self.assertEqual(logger.trial_process.kwargs["name"], trial.trial_name) + self.assertEqual(logger.trial_process.kwargs["group"], + trial.trainable_name) + self.assertIn("config", logger.trial_process._exclude) logger.close() @@ -127,8 +148,8 @@ class WandbIntegrationTest(unittest.TestCase): trial_config["wandb"] = {"project": "test_project", "log_config": True} logger = WandbTestLogger(trial_config, "/tmp", trial) - self.assertNotIn("config", logger._wandb._exclude) - self.assertNotIn("metric", logger._wandb._exclude) + self.assertNotIn("config", logger.trial_process._exclude) + self.assertNotIn("metric", logger.trial_process._exclude) logger.close() @@ -139,12 +160,12 @@ class WandbIntegrationTest(unittest.TestCase): } logger = WandbTestLogger(trial_config, "/tmp", trial) - self.assertIn("config", logger._wandb._exclude) - self.assertIn("metric", logger._wandb._exclude) + self.assertIn("config", logger.trial_process._exclude) + self.assertIn("metric", logger.trial_process._exclude) logger.close() - def testWandbLoggerReporting(self): + def testWandbLegacyLoggerReporting(self): trial_config = {"par1": 4, "par2": 9.12345678} trial = Trial(trial_config, 0, "trial_0", "trainable") @@ -166,7 +187,7 @@ class WandbIntegrationTest(unittest.TestCase): logger.on_result(r1) - logged = logger._wandb.logs.get(timeout=10) + logged = logger.trial_process.logs.get(timeout=10) self.assertIn("metric1", logged) self.assertNotIn("metric2", logged) self.assertIn("metric3", logged) @@ -176,6 +197,106 @@ class WandbIntegrationTest(unittest.TestCase): logger.close() + def testWandbLoggerConfig(self): + trial_config = {"par1": 4, "par2": 9.12345678} + trial = Trial(trial_config, 0, "trial_0", "trainable") + + if WANDB_ENV_VAR in os.environ: + del os.environ[WANDB_ENV_VAR] + + # No API key + with self.assertRaises(ValueError): + logger = WandbTestExperimentLogger(project="test_project") + + # API Key in config + logger = WandbTestExperimentLogger( + project="test_project", api_key="1234") + self.assertEqual(os.environ[WANDB_ENV_VAR], "1234") + + del logger + del os.environ[WANDB_ENV_VAR] + + # API Key file + with tempfile.NamedTemporaryFile("wt") as fp: + fp.write("5678") + fp.flush() + + logger = WandbTestExperimentLogger( + project="test_project", api_key_file=fp.name) + self.assertEqual(os.environ[WANDB_ENV_VAR], "5678") + + del logger + del os.environ[WANDB_ENV_VAR] + + # API Key in env + os.environ[WANDB_ENV_VAR] = "9012" + logger = WandbTestExperimentLogger(project="test_project") + del logger + + # From now on, the API key is in the env variable. + + logger = WandbTestExperimentLogger(project="test_project") + logger.log_trial_start(trial) + + self.assertEqual(logger.trial_processes[trial].kwargs["project"], + "test_project") + self.assertEqual(logger.trial_processes[trial].kwargs["id"], + trial.trial_id) + self.assertEqual(logger.trial_processes[trial].kwargs["name"], + trial.trial_name) + self.assertEqual(logger.trial_processes[trial].kwargs["group"], + trial.trainable_name) + self.assertIn("config", logger.trial_processes[trial]._exclude) + + del logger + + # log config. + logger = WandbTestExperimentLogger( + project="test_project", log_config=True) + logger.log_trial_start(trial) + self.assertNotIn("config", logger.trial_processes[trial]._exclude) + self.assertNotIn("metric", logger.trial_processes[trial]._exclude) + + del logger + + # Exclude metric. + logger = WandbTestExperimentLogger( + project="test_project", excludes=["metric"]) + logger.log_trial_start(trial) + self.assertIn("config", logger.trial_processes[trial]._exclude) + self.assertIn("metric", logger.trial_processes[trial]._exclude) + + del logger + + def testWandbLoggerReporting(self): + trial_config = {"par1": 4, "par2": 9.12345678} + trial = Trial(trial_config, 0, "trial_0", "trainable") + + logger = WandbTestExperimentLogger( + project="test_project", api_key="1234", excludes=["metric2"]) + logger.on_trial_start(0, [], trial) + + r1 = { + "metric1": 0.8, + "metric2": 1.4, + "metric3": np.asarray(32.0), + "metric4": np.float32(32.0), + "const": "text", + "config": trial_config + } + + logger.on_trial_result(0, [], trial, r1) + + logged = logger.trial_processes[trial].logs.get(timeout=10) + self.assertIn("metric1", logged) + self.assertNotIn("metric2", logged) + self.assertIn("metric3", logged) + self.assertIn("metric4", logged) + self.assertNotIn("const", logged) + self.assertNotIn("config", logged) + + del logger + def testWandbMixinConfig(self): config = {"par1": 4, "par2": 9.12345678} trial = Trial(config, 0, "trial_0", "trainable") diff --git a/python/ray/tune/tests/test_logger.py b/python/ray/tune/tests/test_logger.py index 17215b36c..234b65c43 100644 --- a/python/ray/tune/tests/test_logger.py +++ b/python/ray/tune/tests/test_logger.py @@ -1,12 +1,33 @@ +import csv +import glob +import json +import os from collections import namedtuple import unittest import tempfile import shutil import numpy as np +from ray.cloudpickle import cloudpickle -from ray.tune.logger import JsonLogger, CSVLogger, TBXLogger +from ray.tune.logger import CSVLoggerCallback, JsonLoggerCallback, \ + JsonLogger, CSVLogger, \ + TBXLoggerCallback, TBXLogger +from ray.tune.result import EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE, \ + EXPR_PROGRESS_FILE, \ + EXPR_RESULT_FILE -Trial = namedtuple("MockTrial", ["evaluated_params", "trial_id"]) + +class Trial( + namedtuple("MockTrial", ["evaluated_params", "trial_id", "logdir"])): + @property + def config(self): + return self.evaluated_params + + def init_logdir(self): + return + + def __hash__(self): + return hash(self.trial_id) def result(t, rew, **kwargs): @@ -28,24 +49,116 @@ class LoggerSuite(unittest.TestCase): def tearDown(self): shutil.rmtree(self.test_dir, ignore_errors=True) - def testCSV(self): + def testLegacyCSV(self): config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}} - t = Trial(evaluated_params=config, trial_id="csv") + t = Trial( + evaluated_params=config, trial_id="csv", logdir=self.test_dir) logger = CSVLogger(config=config, logdir=self.test_dir, trial=t) logger.on_result(result(2, 4)) - logger.on_result(result(2, 4)) - logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1})) + logger.on_result(result(2, 5)) + logger.on_result(result(2, 6, score=[1, 2, 3], hello={"world": 1})) logger.close() + self._validate_csv_result() + + def testCSV(self): + config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}} + t = Trial( + evaluated_params=config, trial_id="csv", logdir=self.test_dir) + logger = CSVLoggerCallback() + logger.on_trial_result(0, [], t, result(0, 4)) + logger.on_trial_result(1, [], t, result(1, 5)) + logger.on_trial_result( + 2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1})) + + logger.on_trial_complete(3, [], t) + self._validate_csv_result() + + def _validate_csv_result(self): + results = [] + result_file = os.path.join(self.test_dir, EXPR_PROGRESS_FILE) + with open(result_file, "rt") as fp: + reader = csv.DictReader(fp) + for row in reader: + results.append(row) + + self.assertEqual(len(results), 3) + self.assertSequenceEqual( + [int(row["episode_reward_mean"]) for row in results], [4, 5, 6]) + + def testJSONLegacyLogger(self): + config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}} + t = Trial( + evaluated_params=config, trial_id="json", logdir=self.test_dir) + logger = JsonLogger(config=config, logdir=self.test_dir, trial=t) + logger.on_result(result(0, 4)) + logger.on_result(result(1, 5)) + logger.on_result(result(2, 6, score=[1, 2, 3], hello={"world": 1})) + logger.close() + + self._validate_json_result(config) + def testJSON(self): config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}} - t = Trial(evaluated_params=config, trial_id="json") - logger = JsonLogger(config=config, logdir=self.test_dir, trial=t) + t = Trial( + evaluated_params=config, trial_id="json", logdir=self.test_dir) + logger = JsonLoggerCallback() + logger.on_trial_result(0, [], t, result(0, 4)) + logger.on_trial_result(1, [], t, result(1, 5)) + logger.on_trial_result( + 2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1})) + + logger.on_trial_complete(3, [], t) + self._validate_json_result(config) + + def _validate_json_result(self, config): + # Check result logs + results = [] + result_file = os.path.join(self.test_dir, EXPR_RESULT_FILE) + with open(result_file, "rt") as fp: + for row in fp.readlines(): + results.append(json.loads(row)) + + self.assertEqual(len(results), 3) + self.assertSequenceEqual( + [int(row["episode_reward_mean"]) for row in results], [4, 5, 6]) + + # Check json saved config file + config_file = os.path.join(self.test_dir, EXPR_PARAM_FILE) + with open(config_file, "rt") as fp: + loaded_config = json.load(fp) + + self.assertEqual(loaded_config, config) + + # Check pickled config file + config_file = os.path.join(self.test_dir, EXPR_PARAM_PICKLE_FILE) + with open(config_file, "rb") as fp: + loaded_config = cloudpickle.load(fp) + + self.assertEqual(loaded_config, config) + + def testLegacyTBX(self): + config = { + "a": 2, + "b": [1, 2], + "c": { + "c": { + "D": 123 + } + }, + "d": np.int64(1), + "e": np.bool8(True) + } + t = Trial( + evaluated_params=config, trial_id="tbx", logdir=self.test_dir) + logger = TBXLogger(config=config, logdir=self.test_dir, trial=t) logger.on_result(result(0, 4)) - logger.on_result(result(1, 4)) - logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1})) + logger.on_result(result(1, 5)) + logger.on_result(result(2, 6, score=[1, 2, 3], hello={"world": 1})) logger.close() + self._validate_tbx_result() + def testTBX(self): config = { "a": 2, @@ -58,16 +171,40 @@ class LoggerSuite(unittest.TestCase): "d": np.int64(1), "e": np.bool8(True) } - t = Trial(evaluated_params=config, trial_id="tbx") - logger = TBXLogger(config=config, logdir=self.test_dir, trial=t) - logger.on_result(result(0, 4)) - logger.on_result(result(1, 4)) - logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1})) - logger.close() + t = Trial( + evaluated_params=config, trial_id="tbx", logdir=self.test_dir) + logger = TBXLoggerCallback() + logger.on_trial_result(0, [], t, result(0, 4)) + logger.on_trial_result(1, [], t, result(1, 5)) + logger.on_trial_result( + 2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1})) - def testBadTBX(self): + logger.on_trial_complete(3, [], t) + + self._validate_tbx_result() + + def _validate_tbx_result(self): + try: + from tensorflow.python.summary.summary_iterator \ + import summary_iterator + except ImportError: + print("Skipping rest of test as tensorflow is not installed.") + return + + events_file = list(glob.glob(f"{self.test_dir}/events*"))[0] + results = [] + for event in summary_iterator(events_file): + for v in event.summary.value: + if v.tag == "ray/tune/episode_reward_mean": + results.append(v.simple_value) + + self.assertEqual(len(results), 3) + self.assertSequenceEqual([int(res) for res in results], [4, 5, 6]) + + def testLegacyBadTBX(self): config = {"b": (1, 2, 3)} - t = Trial(evaluated_params=config, trial_id="tbx") + t = Trial( + evaluated_params=config, trial_id="tbx", logdir=self.test_dir) logger = TBXLogger(config=config, logdir=self.test_dir, trial=t) logger.on_result(result(0, 4)) logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1})) @@ -76,7 +213,8 @@ class LoggerSuite(unittest.TestCase): assert "INFO" in cm.output[0] config = {"None": None} - t = Trial(evaluated_params=config, trial_id="tbx") + t = Trial( + evaluated_params=config, trial_id="tbx", logdir=self.test_dir) logger = TBXLogger(config=config, logdir=self.test_dir, trial=t) logger.on_result(result(0, 4)) logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1})) @@ -84,6 +222,31 @@ class LoggerSuite(unittest.TestCase): logger.close() assert "INFO" in cm.output[0] + def testBadTBX(self): + config = {"b": (1, 2, 3)} + t = Trial( + evaluated_params=config, trial_id="tbx", logdir=self.test_dir) + logger = TBXLoggerCallback() + logger.on_trial_result(0, [], t, result(0, 4)) + logger.on_trial_result(1, [], t, result(1, 5)) + logger.on_trial_result( + 2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1})) + with self.assertLogs("ray.tune.logger", level="INFO") as cm: + logger.on_trial_complete(3, [], t) + assert "INFO" in cm.output[0] + + config = {"None": None} + t = Trial( + evaluated_params=config, trial_id="tbx", logdir=self.test_dir) + logger = TBXLoggerCallback() + logger.on_trial_result(0, [], t, result(0, 4)) + logger.on_trial_result(1, [], t, result(1, 5)) + logger.on_trial_result( + 2, [], t, result(2, 6, score=[1, 2, 3], hello={"world": 1})) + with self.assertLogs("ray.tune.logger", level="INFO") as cm: + logger.on_trial_complete(3, [], t) + assert "INFO" in cm.output[0] + if __name__ == "__main__": import pytest diff --git a/python/ray/tune/tests/test_run_experiment.py b/python/ray/tune/tests/test_run_experiment.py index e95a0d0c9..6845507b2 100644 --- a/python/ray/tune/tests/test_run_experiment.py +++ b/python/ray/tune/tests/test_run_experiment.py @@ -7,7 +7,7 @@ from ray.rllib import _register_all from ray.tune.result import TIMESTEPS_TOTAL from ray.tune import Trainable, TuneError from ray.tune import register_trainable, run_experiments -from ray.tune.logger import LegacyExperimentLogger, Logger +from ray.tune.logger import LegacyLoggerCallback, Logger from ray.tune.experiment import Experiment from ray.tune.trial import Trial, ExportFormat @@ -191,7 +191,7 @@ class RunExperimentTest(unittest.TestCase): } } }, - callbacks=[LegacyExperimentLogger(logger_classes=[CustomLogger])]) + callbacks=[LegacyLoggerCallback(logger_classes=[CustomLogger])]) self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log"))) self.assertFalse( os.path.exists(os.path.join(trial.logdir, "params.json"))) @@ -204,7 +204,7 @@ class RunExperimentTest(unittest.TestCase): } } }) - self.assertTrue( + self.assertFalse( os.path.exists(os.path.join(trial.logdir, "params.json"))) [trial] = run_experiments( @@ -216,7 +216,7 @@ class RunExperimentTest(unittest.TestCase): } } }, - callbacks=[LegacyExperimentLogger(logger_classes=[])]) + callbacks=[LegacyLoggerCallback(logger_classes=[])]) self.assertFalse( os.path.exists(os.path.join(trial.logdir, "params.json"))) @@ -239,7 +239,7 @@ class RunExperimentTest(unittest.TestCase): } } }, - callbacks=[LegacyExperimentLogger(logger_classes=[CustomLogger])]) + callbacks=[LegacyLoggerCallback(logger_classes=[CustomLogger])]) self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log"))) self.assertTrue( os.path.exists(os.path.join(trial.logdir, "params.json"))) @@ -264,7 +264,7 @@ class RunExperimentTest(unittest.TestCase): } } }, - callbacks=[LegacyExperimentLogger(logger_classes=[])]) + callbacks=[LegacyLoggerCallback(logger_classes=[])]) self.assertTrue( os.path.exists(os.path.join(trial.logdir, "params.json"))) diff --git a/python/ray/tune/tests/test_trial_runner_callbacks.py b/python/ray/tune/tests/test_trial_runner_callbacks.py index f774f9e50..1649c7e93 100644 --- a/python/ray/tune/tests/test_trial_runner_callbacks.py +++ b/python/ray/tune/tests/test_trial_runner_callbacks.py @@ -9,8 +9,8 @@ import ray from ray import tune from ray.rllib import _register_all from ray.tune.checkpoint_manager import Checkpoint -from ray.tune.logger import DEFAULT_LOGGERS, ExperimentLogger, \ - LegacyExperimentLogger +from ray.tune.logger import DEFAULT_LOGGERS, LoggerCallback, \ + LegacyLoggerCallback from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.result import TRAINING_ITERATION from ray.tune.syncer import SyncConfig, SyncerCallback @@ -205,14 +205,14 @@ class TrialRunnerCallbacks(unittest.TestCase): "delay") def testCallbackReordering(self): - """SyncerCallback should come after ExperimentLogger callbacks""" + """SyncerCallback should come after LoggerCallback callbacks""" def get_positions(callbacks): first_logger_pos = None last_logger_pos = None syncer_pos = None for i, callback in enumerate(callbacks): - if isinstance(callback, ExperimentLogger): + if isinstance(callback, LoggerCallback): if first_logger_pos is None: first_logger_pos = i last_logger_pos = i @@ -233,8 +233,8 @@ class TrialRunnerCallbacks(unittest.TestCase): self.assertLess(last_logger_pos, syncer_pos) # Auto creation of loggers with existing logger (but no CSV/JSON) - callbacks = create_default_callbacks([ExperimentLogger()], - SyncConfig(), None) + callbacks = create_default_callbacks([LoggerCallback()], SyncConfig(), + None) first_logger_pos, last_logger_pos, syncer_pos = get_positions( callbacks) self.assertLess(last_logger_pos, syncer_pos) @@ -242,13 +242,12 @@ class TrialRunnerCallbacks(unittest.TestCase): # This should throw an error as the syncer comes before the logger with self.assertRaises(ValueError): callbacks = create_default_callbacks( - [SyncerCallback(None), - ExperimentLogger()], SyncConfig(), None) + [SyncerCallback(None), LoggerCallback()], SyncConfig(), None) # This should be reordered but preserve the regular callback order [mc1, mc2, mc3] = [Callback(), Callback(), Callback()] # Has to be legacy logger to avoid logger callback creation - lc = LegacyExperimentLogger(logger_classes=DEFAULT_LOGGERS) + lc = LegacyLoggerCallback(logger_classes=DEFAULT_LOGGERS) callbacks = create_default_callbacks([mc1, mc2, lc, mc3], SyncConfig(), None) print(callbacks) diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 8f35af055..e809fa5ce 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -73,7 +73,6 @@ def run( checkpoint_at_end=False, verbose=Verbosity.V3_TRIAL_DETAILS, progress_reporter=None, - loggers=None, log_to_file=False, trial_name_creator=None, trial_dirname_creator=None, @@ -90,6 +89,7 @@ def run( raise_on_failed_trial=True, callbacks=None, # Deprecated args + loggers=None, ray_auto_init=None, run_errored_only=None, global_checkpoint_period=None, @@ -196,9 +196,6 @@ def run( intermediate experiment progress. Defaults to CLIReporter if running in command-line, or JupyterNotebookReporter if running in a Jupyter notebook. - loggers (list): List of logger creators to be used with - each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS. - See `ray/tune/logger.py`. log_to_file (bool|str|Sequence): Log stdout and stderr to files in Tune's trial directories. If this is `False` (default), no files are written. If `true`, outputs are written to `trialdir/stdout` @@ -251,7 +248,9 @@ def run( trial (of ERROR state) when the experiments complete. callbacks (list): List of callbacks that will be called at different times in the training loop. Must be instances of the - ``ray.tune.trial_runner.Callback`` class. + ``ray.tune.trial_runner.Callback`` class. If not passed, + `LoggerCallback` and `SyncerCallback` callbacks are automatically + added. Returns: diff --git a/python/ray/tune/utils/callback.py b/python/ray/tune/utils/callback.py index 91ec27f91..88e3e4cbe 100644 --- a/python/ray/tune/utils/callback.py +++ b/python/ray/tune/utils/callback.py @@ -1,13 +1,18 @@ +import logging import os from typing import List, Optional from ray.tune.callback import Callback from ray.tune.progress_reporter import TrialProgressCallback from ray.tune.syncer import SyncConfig -from ray.tune.logger import CSVLogger, DEFAULT_LOGGERS, ExperimentLogger, \ - JsonLogger, LegacyExperimentLogger, Logger +from ray.tune.logger import CSVLoggerCallback, CSVLogger, \ + LoggerCallback, \ + JsonLoggerCallback, JsonLogger, LegacyLoggerCallback, Logger, \ + TBXLoggerCallback, TBXLogger from ray.tune.syncer import SyncerCallback +logger = logging.getLogger(__name__) + def create_default_callbacks(callbacks: Optional[List[Callback]], sync_config: SyncConfig, @@ -40,6 +45,7 @@ def create_default_callbacks(callbacks: Optional[List[Callback]], has_syncer_callback = False has_csv_logger = False has_json_logger = False + has_tbx_logger = False has_trial_progress_callback = any( isinstance(c, TrialProgressCallback) for c in callbacks) @@ -52,20 +58,14 @@ def create_default_callbacks(callbacks: Optional[List[Callback]], last_logger_index = None syncer_index = None - if not loggers: - # If no logger callback and no `loggers` have been provided, - # add DEFAULT_LOGGERS. - if not any( - isinstance(callback, ExperimentLogger) - for callback in callbacks): - loggers = DEFAULT_LOGGERS - - # Create LegacyExperimentLogger for passed Logger classes + # Create LegacyLoggerCallback for passed Logger classes if loggers: # Todo(krfricke): Deprecate `loggers` argument, print warning here. + # Add warning as soon as we ported all loggers to LoggerCallback + # classes. add_loggers = [] for trial_logger in loggers: - if isinstance(trial_logger, ExperimentLogger): + if isinstance(trial_logger, LoggerCallback): callbacks.append(trial_logger) elif isinstance(trial_logger, type) and issubclass( trial_logger, Logger): @@ -75,31 +75,41 @@ def create_default_callbacks(callbacks: Optional[List[Callback]], f"Invalid value passed to `loggers` argument of " f"`tune.run()`: {trial_logger}") if add_loggers: - callbacks.append(LegacyExperimentLogger(add_loggers)) + callbacks.append(LegacyLoggerCallback(add_loggers)) - # Check if we have a CSV and JSON logger + # Check if we have a CSV, JSON and TensorboardX logger for i, callback in enumerate(callbacks): - if isinstance(callback, LegacyExperimentLogger): + if isinstance(callback, LegacyLoggerCallback): last_logger_index = i if CSVLogger in callback.logger_classes: has_csv_logger = True if JsonLogger in callback.logger_classes: has_json_logger = True - # Todo(krfricke): add checks for new ExperimentLogger classes + if TBXLogger in callback.logger_classes: + has_tbx_logger = True + elif isinstance(callback, CSVLoggerCallback): + has_csv_logger = True + last_logger_index = i + elif isinstance(callback, JsonLoggerCallback): + has_json_logger = True + last_logger_index = i + elif isinstance(callback, TBXLoggerCallback): + has_tbx_logger = True + last_logger_index = i elif isinstance(callback, SyncerCallback): syncer_index = i has_syncer_callback = True - # If CSV or JSON logger is missing, add + # If CSV, JSON or TensorboardX loggers are missing, add if os.environ.get("TUNE_DISABLE_AUTO_CALLBACK_LOGGERS", "0") != "1": - # Todo(krfricke): Switch to new ExperimentLogger classes - add_loggers = [] if not has_csv_logger: - add_loggers.append(CSVLogger) + callbacks.append(CSVLoggerCallback()) + last_logger_index = len(callbacks) - 1 if not has_json_logger: - add_loggers.append(JsonLogger) - if add_loggers: - callbacks.append(LegacyExperimentLogger(add_loggers)) + callbacks.append(JsonLoggerCallback()) + last_logger_index = len(callbacks) - 1 + if not has_tbx_logger: + callbacks.append(TBXLoggerCallback()) last_logger_index = len(callbacks) - 1 # If no SyncerCallback was found, add @@ -110,18 +120,18 @@ def create_default_callbacks(callbacks: Optional[List[Callback]], callbacks.append(syncer_callback) syncer_index = len(callbacks) - 1 - # Todo(krfricke): Maybe check if syncer comes after all loggers if syncer_index is not None and last_logger_index is not None and \ syncer_index < last_logger_index: - if (not has_csv_logger or not has_json_logger) and not loggers: + if (not has_csv_logger or not has_json_logger or not has_tbx_logger) \ + and not loggers: # Only raise the warning if the loggers were passed by the user. # (I.e. don't warn if this was automatic behavior and they only # passed a customer SyncerCallback). raise ValueError( "The `SyncerCallback` you passed to `tune.run()` came before " - "at least one `ExperimentLogger`. Syncing should be done " + "at least one `LoggerCallback`. Syncing should be done " "after writing logs. Please re-order the callbacks so that " - "the `SyncerCallback` comes after any `ExperimentLogger`.") + "the `SyncerCallback` comes after any `LoggerCallback`.") else: # If these loggers were automatically created. just re-order # the callbacks