mirror of
https://github.com/wassname/pytorch-a2c-ppo-acktr.git
synced 2026-06-27 16:20:05 +08:00
223 lines
7.3 KiB
Python
223 lines
7.3 KiB
Python
import math
|
|
|
|
import torch
|
|
import torch.optim as optim
|
|
|
|
|
|
# TODO: In order to make this code faster:
|
|
# 1) Implement _extract_patches as a single cuda kernel
|
|
# 2) Compute QR decomposition in a separate process
|
|
# 3) Actually make a general KFAC optimizer so it fits PyTorch
|
|
|
|
|
|
def _extract_patches(x, kernel_size, stride, padding):
|
|
if padding[0] + padding[1] > 0:
|
|
x = F.pad(x, (padding[1], padding[1], padding[0],
|
|
padding[0])).data # Actually check dims
|
|
x = x.unfold(2, kernel_size[0], stride[0])
|
|
x = x.unfold(3, kernel_size[1], stride[1])
|
|
x = x.transpose_(1, 2).transpose_(2, 3).contiguous()
|
|
x = x.view(
|
|
x.size(0), x.size(1), x.size(2), x.size(3) * x.size(4) * x.size(5))
|
|
return x
|
|
|
|
|
|
def compute_cov_a(a, classname, layer_info, fast_cnn):
|
|
batch_size = a.size(0)
|
|
|
|
if classname == 'Conv2d':
|
|
if fast_cnn:
|
|
a = _extract_patches(a, *layer_info)
|
|
a = a.view(a.size(0), -1, a.size(-1))
|
|
a = a.mean(1)
|
|
else:
|
|
a = _extract_patches(a, *layer_info)
|
|
a = a.view(-1, a.size(-1)).div_(a.size(1)).div_(a.size(2))
|
|
elif classname == 'AddBias':
|
|
is_cuda = a.is_cuda
|
|
a = torch.ones(a.size(0), 1)
|
|
if is_cuda:
|
|
a = a.cuda()
|
|
|
|
return a.t() @ (a / batch_size)
|
|
|
|
|
|
def compute_cov_g(g, classname, layer_info, fast_cnn):
|
|
batch_size = g.size(0)
|
|
|
|
if classname == 'Conv2d':
|
|
if fast_cnn:
|
|
g = g.view(g.size(0), g.size(1), -1)
|
|
g = g.sum(-1)
|
|
else:
|
|
g = g.transpose(1, 2).transpose(2, 3).contiguous()
|
|
g = g.view(-1, g.size(-1)).mul_(g.size(1)).mul_(g.size(2))
|
|
elif classname == 'AddBias':
|
|
g = g.view(g.size(0), g.size(1), -1)
|
|
g = g.sum(-1)
|
|
|
|
g_ = g * batch_size
|
|
return g_.t() @ (g_ / g.size(0))
|
|
|
|
|
|
def update_running_stat(aa, m_aa, momentum):
|
|
# Do the trick to keep aa unchanged and not create any additional tensors
|
|
m_aa *= momentum / (1 - momentum)
|
|
m_aa += aa
|
|
m_aa *= (1 - momentum)
|
|
|
|
|
|
class KFACOptimizer(optim.Optimizer):
|
|
def __init__(self,
|
|
model,
|
|
lr=0.25,
|
|
momentum=0.9,
|
|
stat_decay=0.99,
|
|
kl_clip=0.001,
|
|
damping=1e-2,
|
|
weight_decay=0,
|
|
fast_cnn=False,
|
|
Ts=1,
|
|
Tf=10):
|
|
defaults = dict()
|
|
super(KFACOptimizer, self).__init__(model.parameters(), defaults)
|
|
|
|
self.known_modules = {'Linear', 'Conv2d', 'AddBias'}
|
|
|
|
self.modules = []
|
|
self.grad_outputs = {}
|
|
|
|
self.model = model
|
|
self._prepare_model()
|
|
|
|
self.steps = 0
|
|
|
|
self.m_aa, self.m_gg = {}, {}
|
|
self.Q_a, self.Q_g = {}, {}
|
|
self.d_a, self.d_g = {}, {}
|
|
|
|
self.momentum = momentum
|
|
self.stat_decay = stat_decay
|
|
|
|
self.lr = lr
|
|
self.kl_clip = kl_clip
|
|
self.damping = damping
|
|
self.weight_decay = weight_decay
|
|
|
|
self.fast_cnn = fast_cnn
|
|
|
|
self.Ts = Ts
|
|
self.Tf = Tf
|
|
|
|
self.optim = optim.SGD(
|
|
model.parameters(),
|
|
lr=self.lr * (1 - self.momentum),
|
|
momentum=self.momentum)
|
|
|
|
def _save_input(self, module, input):
|
|
if input[0].volatile == False and self.steps % self.Ts == 0:
|
|
classname = module.__class__.__name__
|
|
layer_info = None
|
|
if classname == 'Conv2d':
|
|
layer_info = (module.kernel_size, module.stride,
|
|
module.padding)
|
|
|
|
aa = compute_cov_a(input[0].data, classname, layer_info,
|
|
self.fast_cnn)
|
|
|
|
# Initialize buffers
|
|
if self.steps == 0:
|
|
self.m_aa[module] = aa.clone()
|
|
|
|
update_running_stat(aa, self.m_aa[module], self.stat_decay)
|
|
|
|
def _save_grad_output(self, module, grad_input, grad_output):
|
|
if self.acc_stats:
|
|
classname = module.__class__.__name__
|
|
layer_info = None
|
|
if classname == 'Conv2d':
|
|
layer_info = (module.kernel_size, module.stride,
|
|
module.padding)
|
|
|
|
gg = compute_cov_g(grad_output[0].data, classname,
|
|
layer_info, self.fast_cnn)
|
|
|
|
# Initialize buffers
|
|
if self.steps == 0:
|
|
self.m_gg[module] = gg.clone()
|
|
|
|
update_running_stat(gg, self.m_gg[module], self.stat_decay)
|
|
|
|
def _prepare_model(self):
|
|
for module in self.model.children():
|
|
classname = module.__class__.__name__
|
|
if classname in self.known_modules:
|
|
assert not ((classname in ['Linear', 'Conv2d']) and module.bias is not None), \
|
|
"You must have a bias as a separate layer"
|
|
|
|
self.modules.append(module)
|
|
module.register_forward_pre_hook(self._save_input)
|
|
module.register_backward_hook(self._save_grad_output)
|
|
elif len(list(module.parameters())) > 0:
|
|
raise NotImplementedError(
|
|
'Layer {} is not supported'.format(classname))
|
|
|
|
def step(self):
|
|
# Add weight decay
|
|
if self.weight_decay > 0:
|
|
for p in self.model.parameters():
|
|
p.grad.data.add_(self.weight_decay, p.data)
|
|
|
|
updates = {}
|
|
for i, m in enumerate(self.modules):
|
|
assert len(list(m.parameters())
|
|
) == 1, "Can handle only one parameter at the moment"
|
|
classname = m.__class__.__name__
|
|
p = next(m.parameters())
|
|
|
|
la = self.damping + self.weight_decay
|
|
|
|
if self.steps % self.Tf == 0:
|
|
# My asynchronous implementation exists, I will add it later.
|
|
# Experimenting with different ways to this in PyTorch.
|
|
self.d_a[m], self.Q_a[m] = torch.symeig(
|
|
self.m_aa[m].cpu().double(), eigenvectors=True)
|
|
self.d_g[m], self.Q_g[m] = torch.symeig(
|
|
self.m_gg[m].cpu().double(), eigenvectors=True)
|
|
self.d_a[m], self.Q_a[m] = self.d_a[m].float(), self.Q_a[m].float()
|
|
self.d_g[m], self.Q_g[m] = self.d_g[m].float(), self.Q_g[m].float()
|
|
if self.m_aa[m].is_cuda:
|
|
self.d_a[m], self.Q_a[m] = self.d_a[m].cuda(), self.Q_a[m].cuda()
|
|
self.d_g[m], self.Q_g[m] = self.d_g[m].cuda(), self.Q_g[m].cuda()
|
|
|
|
self.d_a[m].mul_((self.d_a[m] > 1e-6).float())
|
|
self.d_g[m].mul_((self.d_g[m] > 1e-6).float())
|
|
|
|
if classname == 'Conv2d':
|
|
p_grad_mat = p.grad.data.view(p.grad.data.size(0), -1)
|
|
else:
|
|
p_grad_mat = p.grad.data
|
|
|
|
v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
|
|
v2 = v1 / (
|
|
self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la)
|
|
v = self.Q_g[m] @ v2 @ self.Q_a[m].t()
|
|
|
|
v = v.view(p.grad.data.size())
|
|
updates[p] = v
|
|
|
|
vg_sum = 0
|
|
for p in self.model.parameters():
|
|
v = updates[p]
|
|
vg_sum += (v * p.grad.data * self.lr * self.lr).sum()
|
|
|
|
nu = min(1, math.sqrt(self.kl_clip / vg_sum))
|
|
|
|
for p in self.model.parameters():
|
|
v = updates[p]
|
|
p.grad.data.copy_(v)
|
|
p.grad.data.mul_(nu)
|
|
|
|
self.optim.step()
|
|
self.steps += 1
|