diff --git a/doc/source/conf.py b/doc/source/conf.py index 300b06ea3..9c8567f0c 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -57,6 +57,7 @@ MOCK_MODULES = [ "tensorflow.contrib.slim", "tensorflow.core", "tensorflow.core.util", + "tensorflow.keras", "tensorflow.python", "tensorflow.python.client", "tensorflow.python.util", @@ -78,7 +79,9 @@ for mod_name in MOCK_MODULES: # ray.rllib.models.action_dist.py and # ray.rllib.models.lstm.py will use tf.VERSION sys.modules["tensorflow"].VERSION = "9.9.9" +sys.modules["tensorflow.keras.callbacks"] = ChildClassMock() sys.modules["pytorch_lightning"] = ChildClassMock() + # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. diff --git a/doc/source/tune/api_docs/integration.rst b/doc/source/tune/api_docs/integration.rst index 61537af49..d03cba0cc 100644 --- a/doc/source/tune/api_docs/integration.rst +++ b/doc/source/tune/api_docs/integration.rst @@ -7,6 +7,15 @@ External library integrations (tune.integration) :local: :depth: 1 +.. _tune-integration-keras: + +Keras (tune.integration.keras) +------------------------------------------------------ + +.. autoclass:: ray.tune.integration.keras.TuneReportCallback + +.. autoclass:: ray.tune.integration.keras.TuneReportCheckpointCallback + .. _tune-integration-kubernetes: Kubernetes (tune.integration.kubernetes) diff --git a/python/ray/tune/examples/tune_mnist_keras.py b/python/ray/tune/examples/tune_mnist_keras.py index b0fd0f21c..3eac52b4a 100644 --- a/python/ray/tune/examples/tune_mnist_keras.py +++ b/python/ray/tune/examples/tune_mnist_keras.py @@ -2,7 +2,7 @@ import argparse import numpy as np from tensorflow.keras.datasets import mnist -from ray.tune.integration.keras import TuneReporterCallback +from ray.tune.integration.keras import TuneReportCallback parser = argparse.ArgumentParser() parser.add_argument( @@ -39,7 +39,9 @@ def train_mnist(config): epochs=epochs, verbose=0, validation_data=(x_test, y_test), - callbacks=[TuneReporterCallback()]) + callbacks=[TuneReportCallback({ + "mean_accuracy": "accuracy" + })]) if __name__ == "__main__": diff --git a/python/ray/tune/integration/keras.py b/python/ray/tune/integration/keras.py index 02f29125e..1a0a85e04 100644 --- a/python/ray/tune/integration/keras.py +++ b/python/ray/tune/integration/keras.py @@ -1,47 +1,296 @@ -from tensorflow import keras +from collections import Counter +from typing import Dict, List, Union + +from tensorflow.keras.callbacks import Callback +from ray import tune + +import os -class TuneReporterCallback(keras.callbacks.Callback): - """Tune Callback for Keras.""" +class TuneCallback(Callback): + """Base class for Tune's Keras callbacks.""" + _allowed = [ + "batch_begin", + "batch_end", + "epoch_begin", + "epoch_end", + "train_batch_begin", + "train_batch_end", + "test_batch_begin", + "test_batch_end", + "predict_batch_begin", + "predict_batch_end", + "train_begin", + "train_end", + "test_begin", + "test_end", + "predict_begin", + "predict_end", + ] - def __init__(self, reporter=None, freq="batch", logs=None): - """Initializer. + def __init__(self, on: Union[str, List[str]] = "validation_end"): + super(TuneCallback, self).__init__() - Args: - freq (str): Sets the frequency of reporting intermediate results. - One of ["batch", "epoch"]. - """ - self.iteration = 0 - logs = logs or {} - if freq not in ["batch", "epoch"]: - raise ValueError("{} not supported as a frequency.".format(freq)) - self.freq = freq - super(TuneReporterCallback, self).__init__() + if not isinstance(on, list): + on = [on] + if any(w not in self._allowed for w in on): + raise ValueError( + "Invalid trigger time selected: {}. Must be one of {}".format( + on, self._allowed)) + self._on = on + + def _handle(self, logs: Dict): + raise NotImplementedError + + def on_batch_begin(self, batch, logs=None): + if "batch_begin" in self._on: + self._handle(logs, "batch_begin") def on_batch_end(self, batch, logs=None): - from ray import tune - logs = logs or {} - if not self.freq == "batch": - return - self.iteration += 1 - for metric in list(logs): - if "loss" in metric and "neg_" not in metric: - logs["neg_" + metric] = -logs[metric] - if "acc" in logs: - tune.report(keras_info=logs, mean_accuracy=logs["acc"]) - else: - tune.report(keras_info=logs, mean_accuracy=logs.get("accuracy")) + if "batch_end" in self._on: + self._handle(logs, "batch_end") - def on_epoch_end(self, batch, logs=None): - from ray import tune - logs = logs or {} - if not self.freq == "epoch": - return - self.iteration += 1 - for metric in list(logs): - if "loss" in metric and "neg_" not in metric: - logs["neg_" + metric] = -logs[metric] - if "acc" in logs: - tune.report(keras_info=logs, mean_accuracy=logs["acc"]) + def on_epoch_begin(self, epoch, logs=None): + if "epoch_begin" in self._on: + self._handle(logs, "epoch_begin") + + def on_epoch_end(self, epoch, logs=None): + if "epoch_end" in self._on: + self._handle(logs, "epoch_end") + + def on_train_batch_begin(self, batch, logs=None): + if "train_batch_begin" in self._on: + self._handle(logs, "train_batch_begin") + + def on_train_batch_end(self, batch, logs=None): + if "train_batch_end" in self._on: + self._handle(logs, "train_batch_end") + + def on_test_batch_begin(self, batch, logs=None): + if "test_batch_begin" in self._on: + self._handle(logs, "test_batch_begin") + + def on_test_batch_end(self, batch, logs=None): + if "test_batch_end" in self._on: + self._handle(logs, "test_batch_end") + + def on_predict_batch_begin(self, batch, logs=None): + if "predict_batch_begin" in self._on: + self._handle(logs, "predict_batch_begin") + + def on_predict_batch_end(self, batch, logs=None): + if "predict_batch_end" in self._on: + self._handle(logs, "predict_batch_end") + + def on_train_begin(self, logs=None): + if "train_begin" in self._on: + self._handle(logs, "train_begin") + + def on_train_end(self, logs=None): + if "train_end" in self._on: + self._handle(logs, "train_end") + + def on_test_begin(self, logs=None): + if "test_begin" in self._on: + self._handle(logs, "test_begin") + + def on_test_end(self, logs=None): + if "test_end" in self._on: + self._handle(logs, "test_end") + + def on_predict_begin(self, logs=None): + if "predict_begin" in self._on: + self._handle(logs, "predict_begin") + + def on_predict_end(self, logs=None): + if "predict_end" in self._on: + self._handle(logs, "predict_end") + + +class TuneReportCallback(TuneCallback): + """Keras to Ray Tune reporting callback + + Reports metrics to Ray Tune. + + Args: + metrics (str|list|dict): Metrics to report to Tune. If this is a list, + each item describes the metric key reported to Keras, + and it will reported under the same name to Tune. If this is a + dict, each key will be the name reported to Tune and the respective + value will be the metric key reported to Keras. If this is None, + all Keras logs will be reported. + on (str|list): When to trigger checkpoint creations. Must be one of + the Keras event hooks (less the ``on_``), e.g. + "train_start", or "predict_end". Defaults to "epoch_end". + + Example: + + .. code-block:: python + + from ray.tune.integration.keras import TuneReportCallback + + # Report accuracy to Tune after each epoch: + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs, + verbose=0, + validation_data=(x_test, y_test), + callbacks=[TuneReportCallback( + {"mean_accuracy": "accuracy"}, on="epoch_end")]) + + """ + + def __init__(self, + metrics: Union[None, str, List[str], Dict[str, str]] = None, + on: Union[str, List[str]] = "epoch_end"): + super(TuneReportCallback, self).__init__(on) + if isinstance(metrics, str): + metrics = [metrics] + self._metrics = metrics + + def _handle(self, logs: Dict, when: str = None): + if not self._metrics: + report_dict = logs else: - tune.report(keras_info=logs, mean_accuracy=logs.get("accuracy")) + report_dict = {} + for key in self._metrics: + if isinstance(self._metrics, dict): + metric = self._metrics[key] + else: + metric = key + report_dict[key] = logs[metric] + tune.report(**report_dict) + + +class _TuneCheckpointCallback(TuneCallback): + """Keras checkpoint callback + + Saves checkpoints after each validation step. + + Checkpoint are currently not registered if no ``tune.report()`` call + is made afterwards. Consider using ``TuneReportCheckpointCallback`` + instead. + + Args: + filename (str): Filename of the checkpoint within the checkpoint + directory. Defaults to "checkpoint". + frequency (int|list): Checkpoint frequency. If this is an integer `n`, + checkpoints are saved every `n` times each hook was called. If + this is a list, it specifies the checkpoint frequencies for each + hook individually. + on (str|list): When to trigger checkpoint creations. Must be one of + the Keras event hooks (less the ``on_``), e.g. + "train_start", or "predict_end". Defaults to "epoch_end". + + + """ + + def __init__(self, + filename: str = "checkpoint", + frequency: Union[int, List[int]] = 1, + on: Union[str, List[str]] = "epoch_end"): + + if isinstance(frequency, list): + if not isinstance(on, list) or len(frequency) != len(on): + raise ValueError( + "If you pass a list for checkpoint frequencies, the `on` " + "parameter has to be a list with the same length.") + + self._frequency = frequency + + super(_TuneCheckpointCallback, self).__init__(on) + + self._filename = filename + self._counter = Counter() + self._cp_count = 0 # Has to be monotonically increasing + + def _handle(self, logs: Dict, when: str = None): + self._counter[when] += 1 + + if isinstance(self._frequency, list): + index = self._on.index(when) + freq = self._frequency[index] + else: + freq = self._frequency + + if self._counter[when] % freq == 0: + with tune.checkpoint_dir(step=self._cp_count) as checkpoint_dir: + self.model.save( + os.path.join(checkpoint_dir, self._filename), + overwrite=True) + self._cp_count += 1 + + +class TuneReportCheckpointCallback(TuneCallback): + """Keras report and checkpoint callback + + Saves checkpoints after each validation step. Also reports metrics to Tune, + which is needed for checkpoint registration. + + Use this callback to register saved checkpoints with Ray Tune. This means + that checkpoints will be manages by the `CheckpointManager` and can be + used for advanced scheduling and search algorithms, like + Population Based Training. + + The ``tf.keras.callbacks.ModelCheckpoint`` callback also saves checkpoints, + but doesn't register them with Ray Tune. + + Args: + metrics (str|list|dict): Metrics to report to Tune. If this is a list, + each item describes the metric key reported to Keras, + and it will reported under the same name to Tune. If this is a + dict, each key will be the name reported to Tune and the respective + value will be the metric key reported to Keras. If this is None, + all Keras logs will be reported. + filename (str): Filename of the checkpoint within the checkpoint + directory. Defaults to "checkpoint". + frequency (int|list): Checkpoint frequency. If this is an integer `n`, + checkpoints are saved every `n` times each hook was called. If + this is a list, it specifies the checkpoint frequencies for each + hook individually. + on (str|list): When to trigger checkpoint creations. Must be one of + the Keras event hooks (less the ``on_``), e.g. + "train_start", or "predict_end". Defaults to "epoch_end". + + + Example: + + .. code-block:: python + + from ray.tune.integration.keras import TuneReportCheckpointCallback + + # Save checkpoint and report accuracy to Tune after each epoch: + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs, + verbose=0, + validation_data=(x_test, y_test), + callbacks=[TuneReportCheckpointCallback( + metrics={"mean_accuracy": "accuracy"}, + filename="model", + on="epoch_end")]) + + + """ + + def __init__(self, + metrics: Union[None, str, List[str], Dict[str, str]] = None, + filename: str = "checkpoint", + frequency: Union[int, List[int]] = 1, + on: Union[str, List[str]] = "epoch_end"): + super(TuneReportCheckpointCallback, self).__init__(on) + self._checkpoint = _TuneCheckpointCallback(filename, frequency, on) + self._report = TuneReportCallback(metrics, on) + + def _handle(self, logs: Dict, when: str = None): + self._checkpoint._handle(logs, when) + self._report._handle(logs, when) + + def set_model(self, model): + # Pass through for the checkpoint callback to set model + self._checkpoint.set_model(model) + self._report.set_model(model)