From 9ff687c093c955604ba844166eea6842e4040dd0 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Wed, 19 Aug 2020 17:22:28 -0700 Subject: [PATCH] [SGD][Docs] docs for training/ validation results (#10181) --- doc/source/raysgd/raysgd_pytorch.rst | 45 +++++++++++++++++++++++++++- doc/source/raysgd/raysgd_ref.rst | 13 ++++++++ python/ray/util/sgd/utils.py | 33 ++++++++++++++++++-- 3 files changed, 88 insertions(+), 3 deletions(-) diff --git a/doc/source/raysgd/raysgd_pytorch.rst b/doc/source/raysgd/raysgd_pytorch.rst index 532bf1612..19160df26 100644 --- a/doc/source/raysgd/raysgd_pytorch.rst +++ b/doc/source/raysgd/raysgd_pytorch.rst @@ -157,7 +157,7 @@ Now that the trainer is constructed, here's how to train the model. val_metrics = trainer.validate() -Each ``train`` call makes one pass over the training data, and each ``validate`` call runs the model on the validation data passed in by the ``data_creator``. +Each ``train`` call makes one pass over the training data (trains on 1 epoch), and each ``validate`` call runs the model on the validation data passed in by the ``data_creator``. You can also obtain profiling information: @@ -396,6 +396,49 @@ The trained torch model can be extracted for use within the same Python program trainer.train() model = trainer.get_model() # Returns multiple models if the model_creator does. +Training & Validation Results +----------------------------- +The output for ``trainer.train()`` and ``trainer.validate()`` are first collected on a per-batch basis. These results are then averaged: first across each batch in the epoch, and then across all workers. + +By default, the output of ``train`` contains the following: + +.. code-block:: python + + # Total number of samples trained on in this epoch. + num_samples + # Current training epoch. + epoch + # Number of batches trained on in this epoch averaged across all workers. + batch_count + # Training loss averaged across all batches on all workers. + train_loss + # Training loss for the last batch in epoch averaged across all workers. + last_train_loss + +And for ``validate``: + +.. code-block:: python + + # Total number of samples validated on. + num_samples + # Number of batches validated on averaged across all workers. + batch_count + # Validation loss averaged across all batches on all workers. + val_loss + # Validation loss for last batch averaged across all workers. + last_val_loss + # Validation accuracy for last batch averaged across all workers. + val_accuracy + # Validation accuracy for last batch averaged across all workers. + last_val_accuracy + +If ``train`` or ``validate`` are run with ``reduce_results=False``, results are not averaged across workers and a list of results for each worker is returned. +If run with ``profile=True``, timing stats for a single worker is returned alongside the results above. + +To add additional metrics to return you should implement your own custom training operator (:ref:`raysgd-custom-training`). +If overriding ``train_batch`` or ``validate_batch``, the result outputs are automatically averaged across all batches, and the results for the last batch are automatically returned. +If overriding ``train_epoch`` or ``validate`` you may find ``ray.util.sgd.utils.AverageMeterCollection`` (:ref:`ref-utils`) useful to handle this averaging. + Mixed Precision (FP16) Training ------------------------------- diff --git a/doc/source/raysgd/raysgd_ref.rst b/doc/source/raysgd/raysgd_ref.rst index aeb4bf842..692571a36 100644 --- a/doc/source/raysgd/raysgd_ref.rst +++ b/doc/source/raysgd/raysgd_ref.rst @@ -42,3 +42,16 @@ Dataset .. automethod:: __init__ +.. _ref-utils: + +Utils +----- + +.. autoclass:: ray.util.sgd.utils.AverageMeter + :members: + +.. autoclass:: ray.util.sgd.utils.AverageMeterCollection + :members: + + + diff --git a/python/ray/util/sgd/utils.py b/python/ray/util/sgd/utils.py index 5900fdcf8..090a46098 100644 --- a/python/ray/util/sgd/utils.py +++ b/python/ray/util/sgd/utils.py @@ -156,7 +156,17 @@ def find_free_port(): class AverageMeter: - """Computes and stores the average and current value.""" + """Utility for computing and storing the average and most recent value. + + Example: + >>> meter = AverageMeter() + >>> meter.update(5) + >>> meter.val, meter.avg, meter.sum + (5, 5.0, 5) + >>> meter.update(10, n=4) + >>> meter.val, meter.avg, meter.sum + (10, 9.0, 45) + """ def __init__(self): self.reset() @@ -168,6 +178,7 @@ class AverageMeter: self.count = 0 def update(self, val, n=1): + """Update current value, total sum, and average.""" self.val = val self.sum += val * n self.count += n @@ -175,7 +186,24 @@ class AverageMeter: class AverageMeterCollection: - """A grouping of AverageMeters.""" + """A grouping of AverageMeters. + + This utility is used in TrainingOperator.train_epoch and + TrainingOperator.validate to + collect averages and most recent value across all batches. One + AverageMeter object is used for each metric. + + Example: + >>> meter_collection = AverageMeterCollection() + >>> meter_collection.update({"loss": 0.5, "acc": 0.5}, n=32) + >>> meter_collection.summary() + {'batch_count': 1, 'num_samples': 32, 'loss': 0.5, + 'last_loss': 0.5, 'acc': 0.5, 'last_acc': 0.5} + >>> meter_collection.update({"loss": 0.1, "acc": 0.9}, n=32) + >>> meter_collection.summary() + {'batch_count': 2, 'num_samples': 64, 'loss': 0.3, + 'last_loss': 0.1, 'acc': 0.7, 'last_acc': 0.9} + """ def __init__(self): self._batch_count = 0 @@ -183,6 +211,7 @@ class AverageMeterCollection: self._meters = collections.defaultdict(AverageMeter) def update(self, metrics, n=1): + """Does one batch of updates for the provided metrics.""" self._batch_count += 1 self.n += n for metric, value in metrics.items():