mirror of
https://github.com/wassname/Run-Skeleton-Run.git
synced 2026-06-27 19:45:48 +08:00
43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
from collections import OrderedDict
|
|
from itertools import tee
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from common.modules.LayerNorm import LayerNorm
|
|
|
|
|
|
def pairwise(iterable):
|
|
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
|
|
a, b = tee(iterable)
|
|
next(b, None)
|
|
return zip(a, b)
|
|
|
|
|
|
class LinearNet(nn.Module):
|
|
def __init__(self, layers, activation=torch.nn.ELU,
|
|
layer_norm=False, linear_layer=nn.Linear):
|
|
super(LinearNet, self).__init__()
|
|
self.input_shape = layers[0]
|
|
self.output_shape = layers[-1]
|
|
|
|
if layer_norm:
|
|
layer_fn = lambda layer: [
|
|
("linear_{}".format(layer[0]), linear_layer(layer[1][0], layer[1][1])),
|
|
("layer_norm_{}".format(layer[0]), LayerNorm(layer[1][1])),
|
|
("act_{}".format(layer[0]), activation())]
|
|
else:
|
|
layer_fn = lambda layer: [
|
|
("linear_{}".format(layer[0]), linear_layer(layer[1][0], layer[1][1])),
|
|
("act_{}".format(layer[0]), activation())]
|
|
|
|
self.net = torch.nn.Sequential(
|
|
OrderedDict([
|
|
x for y in map(
|
|
lambda layer: layer_fn(layer),
|
|
enumerate(pairwise(layers))) for x in y]))
|
|
|
|
def forward(self, x):
|
|
x = self.net.forward(x)
|
|
return x
|