mirror of
https://github.com/wassname/NALU-pytorch.git
synced 2026-06-27 16:00:06 +08:00
38 lines
900 B
Python
38 lines
900 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
def str2act(s):
|
|
if s is 'none':
|
|
return None
|
|
elif s is 'hardtanh':
|
|
return nn.Hardtanh()
|
|
elif s is 'sigmoid':
|
|
return nn.Sigmoid()
|
|
elif s is 'relu6':
|
|
return nn.ReLU6()
|
|
elif s is 'tanh':
|
|
return nn.Tanh()
|
|
elif s is 'tanhshrink':
|
|
return nn.Tanhshrink()
|
|
elif s is 'hardshrink':
|
|
return nn.Hardshrink()
|
|
elif s is 'leakyrelu':
|
|
return nn.LeakyReLU()
|
|
elif s is 'softshrink':
|
|
return nn.Softshrink()
|
|
elif s is 'softsign':
|
|
return nn.Softsign()
|
|
elif s is 'relu':
|
|
return nn.ReLU()
|
|
elif s is 'prelu':
|
|
return nn.PReLU()
|
|
elif s is 'softplus':
|
|
return nn.Softplus()
|
|
elif s is 'elu':
|
|
return nn.ELU()
|
|
elif s is 'selu':
|
|
return nn.SELU()
|
|
else:
|
|
raise ValueError("[!] Invalid activation function.")
|