mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 19:49:04 +08:00
[SGD][Docs] docs for training/ validation results (#10181)
This commit is contained in:
@@ -156,7 +156,17 @@ def find_free_port():
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value."""
|
||||
"""Utility for computing and storing the average and most recent value.
|
||||
|
||||
Example:
|
||||
>>> meter = AverageMeter()
|
||||
>>> meter.update(5)
|
||||
>>> meter.val, meter.avg, meter.sum
|
||||
(5, 5.0, 5)
|
||||
>>> meter.update(10, n=4)
|
||||
>>> meter.val, meter.avg, meter.sum
|
||||
(10, 9.0, 45)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@@ -168,6 +178,7 @@ class AverageMeter:
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
"""Update current value, total sum, and average."""
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
@@ -175,7 +186,24 @@ class AverageMeter:
|
||||
|
||||
|
||||
class AverageMeterCollection:
|
||||
"""A grouping of AverageMeters."""
|
||||
"""A grouping of AverageMeters.
|
||||
|
||||
This utility is used in TrainingOperator.train_epoch and
|
||||
TrainingOperator.validate to
|
||||
collect averages and most recent value across all batches. One
|
||||
AverageMeter object is used for each metric.
|
||||
|
||||
Example:
|
||||
>>> meter_collection = AverageMeterCollection()
|
||||
>>> meter_collection.update({"loss": 0.5, "acc": 0.5}, n=32)
|
||||
>>> meter_collection.summary()
|
||||
{'batch_count': 1, 'num_samples': 32, 'loss': 0.5,
|
||||
'last_loss': 0.5, 'acc': 0.5, 'last_acc': 0.5}
|
||||
>>> meter_collection.update({"loss": 0.1, "acc": 0.9}, n=32)
|
||||
>>> meter_collection.summary()
|
||||
{'batch_count': 2, 'num_samples': 64, 'loss': 0.3,
|
||||
'last_loss': 0.1, 'acc': 0.7, 'last_acc': 0.9}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._batch_count = 0
|
||||
@@ -183,6 +211,7 @@ class AverageMeterCollection:
|
||||
self._meters = collections.defaultdict(AverageMeter)
|
||||
|
||||
def update(self, metrics, n=1):
|
||||
"""Does one batch of updates for the provided metrics."""
|
||||
self._batch_count += 1
|
||||
self.n += n
|
||||
for metric, value in metrics.items():
|
||||
|
||||
Reference in New Issue
Block a user