mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
adds swish activation function
This commit is contained in:
@@ -235,4 +235,48 @@ class SReLU(Layer):
|
||||
base_config = super(SReLU, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
get_custom_objects().update({'SReLU': SReLU})
|
||||
class Swish(Layer):
|
||||
""" Swish (Ramachandranet al., 2017)
|
||||
|
||||
# Input shape
|
||||
Arbitrary. Use the keyword argument `input_shape`
|
||||
(tuple of integers, does not include the samples axis)
|
||||
when using this layer as the first layer in a model.
|
||||
|
||||
# Output shape
|
||||
Same shape as the input.
|
||||
|
||||
# Arguments
|
||||
beta: float >= 0. Scaling factor
|
||||
if set to 1 and trainable set to False (default), Swish equals the SiLU activation (Elfwing et al., 2017)
|
||||
trainable: whether to learn the scaling factor during training or not
|
||||
|
||||
# References
|
||||
- [Searching for Activation Functions](https://arxiv.org/abs/1710.05941)
|
||||
- [Sigmoid-weighted linear units for neural network function approximation in reinforcement learning](https://arxiv.org/abs/1702.03118)
|
||||
"""
|
||||
|
||||
def __init__(self, beta=1.0, trainable=False, **kwargs):
|
||||
super(Swish, self).__init__(**kwargs)
|
||||
self.supports_masking = True
|
||||
self.beta = beta
|
||||
self.trainable = trainable
|
||||
|
||||
def build(self, input_shape):
|
||||
self.scaling_factor = K.variable(self.beta,
|
||||
dtype=K.floatx(),
|
||||
name='scaling_factor')
|
||||
if self.trainable:
|
||||
self._trainable_weights.append(self.scaling_factor)
|
||||
super(Swish, self).build(input_shape)
|
||||
|
||||
def call(self, inputs, mask=None):
|
||||
return inputs * K.sigmoid(self.scaling_factor * inputs)
|
||||
|
||||
def get_config(self):
|
||||
config = {'beta': self.get_weights()[0] if self.trainable else self.beta,
|
||||
'trainable': self.trainable}
|
||||
base_config = super(Swish, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
get_custom_objects().update({'SReLU': SReLU, 'Swish': Swish})
|
||||
|
||||
Reference in New Issue
Block a user