mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 20:07:41 +08:00
c2ade075a3
Implements distributed SGD using distributed PyTorch.
241 lines
6.3 KiB
Python
241 lines
6.3 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from collections import namedtuple
|
|
from contextlib import closing
|
|
import numpy as np
|
|
import socket
|
|
import time
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
def train(train_iterator, model, criterion, optimizer):
|
|
"""Runs 1 training epoch"""
|
|
batch_time = AverageMeter()
|
|
data_time = AverageMeter()
|
|
losses = AverageMeter()
|
|
|
|
timers = {k: TimerStat() for k in ["d2h", "fwd", "grad", "apply"]}
|
|
|
|
# switch to train mode
|
|
model.train()
|
|
|
|
end = time.time()
|
|
|
|
for i, (features, target) in enumerate(train_iterator):
|
|
# measure data loading time
|
|
data_time.update(time.time() - end)
|
|
|
|
# Create non_blocking tensors for distributed training
|
|
with timers["d2h"]:
|
|
if torch.cuda.is_available():
|
|
features = features.cuda(non_blocking=True)
|
|
target = target.cuda(non_blocking=True)
|
|
|
|
# compute output
|
|
with timers["fwd"]:
|
|
output = model(features)
|
|
loss = criterion(output, target)
|
|
|
|
# measure accuracy and record loss
|
|
losses.update(loss.item(), features.size(0))
|
|
|
|
with timers["grad"]:
|
|
# compute gradients in a backward pass
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
|
|
with timers["apply"]:
|
|
# Call step of optimizer to update model params
|
|
optimizer.step()
|
|
|
|
# measure elapsed time
|
|
batch_time.update(time.time() - end)
|
|
end = time.time()
|
|
|
|
stats = {
|
|
"batch_time": batch_time.avg,
|
|
"batch_processed": losses.count,
|
|
"train_loss": losses.avg,
|
|
"data_time": data_time.avg,
|
|
}
|
|
stats.update({k: t.mean for k, t in timers.items()})
|
|
return stats
|
|
|
|
|
|
def validate(val_loader, model, criterion):
|
|
batch_time = AverageMeter()
|
|
losses = AverageMeter()
|
|
|
|
# switch to evaluate mode
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
end = time.time()
|
|
for i, (features, target) in enumerate(val_loader):
|
|
|
|
if torch.cuda.is_available():
|
|
features = features.cuda(non_blocking=True)
|
|
target = target.cuda(non_blocking=True)
|
|
|
|
# compute output
|
|
output = model(features)
|
|
loss = criterion(output, target)
|
|
|
|
# measure accuracy and record loss
|
|
losses.update(loss.item(), features.size(0))
|
|
|
|
# measure elapsed time
|
|
batch_time.update(time.time() - end)
|
|
end = time.time()
|
|
|
|
stats = {"batch_time": batch_time.avg, "validation_loss": losses.avg}
|
|
return stats
|
|
|
|
|
|
class TimerStat(object):
|
|
"""A running stat for conveniently logging the duration of a code block.
|
|
|
|
Note that this class is *not* thread-safe.
|
|
|
|
Examples:
|
|
Time a call to 'time.sleep'.
|
|
|
|
>>> import time
|
|
>>> sleep_timer = TimerStat()
|
|
>>> with sleep_timer:
|
|
... time.sleep(1)
|
|
>>> round(sleep_timer.mean)
|
|
1
|
|
"""
|
|
|
|
def __init__(self, window_size=10):
|
|
self._window_size = window_size
|
|
self._samples = []
|
|
self._units_processed = []
|
|
self._start_time = None
|
|
self._total_time = 0.0
|
|
self.count = 0
|
|
|
|
def __enter__(self):
|
|
assert self._start_time is None, "concurrent updates not supported"
|
|
self._start_time = time.time()
|
|
|
|
def __exit__(self, type, value, tb):
|
|
assert self._start_time is not None
|
|
time_delta = time.time() - self._start_time
|
|
self.push(time_delta)
|
|
self._start_time = None
|
|
|
|
def push(self, time_delta):
|
|
self._samples.append(time_delta)
|
|
if len(self._samples) > self._window_size:
|
|
self._samples.pop(0)
|
|
self.count += 1
|
|
self._total_time += time_delta
|
|
|
|
def push_units_processed(self, n):
|
|
self._units_processed.append(n)
|
|
if len(self._units_processed) > self._window_size:
|
|
self._units_processed.pop(0)
|
|
|
|
@property
|
|
def mean(self):
|
|
return np.mean(self._samples)
|
|
|
|
@property
|
|
def median(self):
|
|
return np.median(self._samples)
|
|
|
|
@property
|
|
def sum(self):
|
|
return np.sum(self._samples)
|
|
|
|
@property
|
|
def max(self):
|
|
return np.max(self._samples)
|
|
|
|
@property
|
|
def first(self):
|
|
return self._samples[0] if self._samples else None
|
|
|
|
@property
|
|
def last(self):
|
|
return self._samples[-1] if self._samples else None
|
|
|
|
@property
|
|
def size(self):
|
|
return len(self._samples)
|
|
|
|
@property
|
|
def mean_units_processed(self):
|
|
return float(np.mean(self._units_processed))
|
|
|
|
@property
|
|
def mean_throughput(self):
|
|
time_total = sum(self._samples)
|
|
if not time_total:
|
|
return 0.0
|
|
return sum(self._units_processed) / time_total
|
|
|
|
def reset(self):
|
|
self._samples = []
|
|
self._units_processed = []
|
|
self._start_time = None
|
|
self._total_time = 0.0
|
|
self.count = 0
|
|
|
|
|
|
def find_free_port():
|
|
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
|
s.bind(("", 0))
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
return s.getsockname()[1]
|
|
|
|
|
|
class AverageMeter(object):
|
|
"""Computes and stores the average and current value"""
|
|
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
class Resources(
|
|
namedtuple("Resources", ["num_cpus", "num_gpus", "resources"])):
|
|
__slots__ = ()
|
|
|
|
def __new__(cls, num_cpus=1, num_gpus=0, resources=None):
|
|
if resources is None:
|
|
resources = {}
|
|
|
|
return super(Resources, cls).__new__(cls, num_cpus, num_gpus,
|
|
resources)
|
|
|
|
|
|
def sgd_mse_optimizer(model, config):
|
|
"""Returns the mean squared error criterion and SGD optimizer.
|
|
|
|
Args:
|
|
model (torch.nn.Module): the model to optimize.
|
|
config (dict): configuration for the optimizer.
|
|
lr (float): the learning rate. defaults to 0.01.
|
|
"""
|
|
learning_rate = config.get("lr", 0.01)
|
|
criterion = nn.MSELoss()
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
|
|
return criterion, optimizer
|