[SGD][Docs] docs for training/ validation results (#10181)

This commit is contained in:
Amog Kamsetty
2020-08-19 17:22:28 -07:00
committed by GitHub
parent a785106b47
commit 9ff687c093
3 changed files with 88 additions and 3 deletions
+31 -2
View File
@@ -156,7 +156,17 @@ def find_free_port():
class AverageMeter:
"""Computes and stores the average and current value."""
"""Utility for computing and storing the average and most recent value.
Example:
>>> meter = AverageMeter()
>>> meter.update(5)
>>> meter.val, meter.avg, meter.sum
(5, 5.0, 5)
>>> meter.update(10, n=4)
>>> meter.val, meter.avg, meter.sum
(10, 9.0, 45)
"""
def __init__(self):
self.reset()
@@ -168,6 +178,7 @@ class AverageMeter:
self.count = 0
def update(self, val, n=1):
"""Update current value, total sum, and average."""
self.val = val
self.sum += val * n
self.count += n
@@ -175,7 +186,24 @@ class AverageMeter:
class AverageMeterCollection:
"""A grouping of AverageMeters."""
"""A grouping of AverageMeters.
This utility is used in TrainingOperator.train_epoch and
TrainingOperator.validate to
collect averages and most recent value across all batches. One
AverageMeter object is used for each metric.
Example:
>>> meter_collection = AverageMeterCollection()
>>> meter_collection.update({"loss": 0.5, "acc": 0.5}, n=32)
>>> meter_collection.summary()
{'batch_count': 1, 'num_samples': 32, 'loss': 0.5,
'last_loss': 0.5, 'acc': 0.5, 'last_acc': 0.5}
>>> meter_collection.update({"loss": 0.1, "acc": 0.9}, n=32)
>>> meter_collection.summary()
{'batch_count': 2, 'num_samples': 64, 'loss': 0.3,
'last_loss': 0.1, 'acc': 0.7, 'last_acc': 0.9}
"""
def __init__(self):
self._batch_count = 0
@@ -183,6 +211,7 @@ class AverageMeterCollection:
self._meters = collections.defaultdict(AverageMeter)
def update(self, metrics, n=1):
"""Does one batch of updates for the provided metrics."""
self._batch_count += 1
self.n += n
for metric, value in metrics.items():