Files
Kolesnikov Sergey 7401266fe7 pytorch version
2017-11-15 22:18:46 +03:00

43 lines
1.3 KiB
Python

from collections import OrderedDict
from itertools import tee
import torch
import torch.nn as nn
from common.modules.LayerNorm import LayerNorm
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = tee(iterable)
next(b, None)
return zip(a, b)
class LinearNet(nn.Module):
def __init__(self, layers, activation=torch.nn.ELU,
layer_norm=False, linear_layer=nn.Linear):
super(LinearNet, self).__init__()
self.input_shape = layers[0]
self.output_shape = layers[-1]
if layer_norm:
layer_fn = lambda layer: [
("linear_{}".format(layer[0]), linear_layer(layer[1][0], layer[1][1])),
("layer_norm_{}".format(layer[0]), LayerNorm(layer[1][1])),
("act_{}".format(layer[0]), activation())]
else:
layer_fn = lambda layer: [
("linear_{}".format(layer[0]), linear_layer(layer[1][0], layer[1][1])),
("act_{}".format(layer[0]), activation())]
self.net = torch.nn.Sequential(
OrderedDict([
x for y in map(
lambda layer: layer_fn(layer),
enumerate(pairwise(layers))) for x in y]))
def forward(self, x):
x = self.net.forward(x)
return x