mirror of
https://github.com/wassname/NALU-pytorch.git
synced 2026-06-27 16:00:06 +08:00
70 lines
1.9 KiB
Python
70 lines
1.9 KiB
Python
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.init as init
|
|
import torch.nn.functional as F
|
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
class NeuralAccumulatorCell_exact(nn.Module):
|
|
"""A Neural Accumulator (NAC) cell [1].
|
|
|
|
Attributes:
|
|
in_dim: size of the input sample.
|
|
out_dim: size of the output sample.
|
|
|
|
Sources:
|
|
[1]: https://arxiv.org/abs/1808.00508
|
|
"""
|
|
def __init__(self, in_dim, out_dim):
|
|
super().__init__()
|
|
self.in_dim = in_dim
|
|
self.out_dim = out_dim
|
|
self.gate = nn.Linear(in_dim, out_dim)
|
|
self.transform = nn.Linear(in_dim, out_dim)
|
|
|
|
init.kaiming_uniform_(self.gate.weight, a=math.sqrt(5))
|
|
init.kaiming_uniform_(self.transform.weight, a=math.sqrt(5))
|
|
|
|
def forward(self, input):
|
|
x = self.transform(input)
|
|
g = F.sigmoid(self.gate(input))
|
|
return (1 - g) * x - g * x
|
|
|
|
def extra_repr(self):
|
|
return 'in_dim={}, out_dim={}'.format(
|
|
self.in_dim, self.out_dim
|
|
)
|
|
|
|
|
|
class NAC_exact(nn.Module):
|
|
"""A stack of NAC layers.
|
|
|
|
Attributes:
|
|
num_layers: the number of NAC layers.
|
|
in_dim: the size of the input sample.
|
|
hidden_dim: the size of the hidden layers.
|
|
out_dim: the size of the output.
|
|
"""
|
|
def __init__(self, num_layers, in_dim, hidden_dim, out_dim):
|
|
super().__init__()
|
|
self.num_layers = num_layers
|
|
self.in_dim = in_dim
|
|
self.hidden_dim = hidden_dim
|
|
self.out_dim = out_dim
|
|
|
|
layers = []
|
|
for i in range(num_layers):
|
|
layers.append(
|
|
NeuralAccumulatorCell_exact(
|
|
hidden_dim if i > 0 else in_dim,
|
|
hidden_dim if i < num_layers - 1 else out_dim,
|
|
)
|
|
)
|
|
self.model = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
out = self.model(x)
|
|
return out
|