mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Merge pull request #181 from valentindey/master
Adds swish activation function
This commit is contained in:
@@ -236,3 +236,50 @@ class SReLU(Layer):
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
get_custom_objects().update({'SReLU': SReLU})
|
||||
|
||||
|
||||
class Swish(Layer):
|
||||
""" Swish (Ramachandranet al., 2017)
|
||||
|
||||
# Input shape
|
||||
Arbitrary. Use the keyword argument `input_shape`
|
||||
(tuple of integers, does not include the samples axis)
|
||||
when using this layer as the first layer in a model.
|
||||
|
||||
# Output shape
|
||||
Same shape as the input.
|
||||
|
||||
# Arguments
|
||||
beta: float >= 0. Scaling factor
|
||||
if set to 1 and trainable set to False (default), Swish equals the SiLU activation (Elfwing et al., 2017)
|
||||
trainable: whether to learn the scaling factor during training or not
|
||||
|
||||
# References
|
||||
- [Searching for Activation Functions](https://arxiv.org/abs/1710.05941)
|
||||
- [Sigmoid-weighted linear units for neural network function approximation in reinforcement learning](https://arxiv.org/abs/1702.03118)
|
||||
"""
|
||||
|
||||
def __init__(self, beta=1.0, trainable=False, **kwargs):
|
||||
super(Swish, self).__init__(**kwargs)
|
||||
self.supports_masking = True
|
||||
self.beta = beta
|
||||
self.trainable = trainable
|
||||
|
||||
def build(self, input_shape):
|
||||
self.scaling_factor = K.variable(self.beta,
|
||||
dtype=K.floatx(),
|
||||
name='scaling_factor')
|
||||
if self.trainable:
|
||||
self._trainable_weights.append(self.scaling_factor)
|
||||
super(Swish, self).build(input_shape)
|
||||
|
||||
def call(self, inputs, mask=None):
|
||||
return inputs * K.sigmoid(self.scaling_factor * inputs)
|
||||
|
||||
def get_config(self):
|
||||
config = {'beta': self.get_weights()[0] if self.trainable else self.beta,
|
||||
'trainable': self.trainable}
|
||||
base_config = super(Swish, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
get_custom_objects().update({'Swish': Swish})
|
||||
|
||||
@@ -26,5 +26,18 @@ def test_srelu_share():
|
||||
layer_test(advanced_activations.SReLU, kwargs={'shared_axes': 1},
|
||||
input_shape=(2, 3, 4))
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_swish_constant():
|
||||
layer_test(advanced_activations.Swish, kwargs={'beta': 1.0, 'trainable': False},
|
||||
input_shape=(2, 3, 4))
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_swish_trainable():
|
||||
layer_test(advanced_activations.Swish, kwargs={'beta': 1.0, 'trainable': True},
|
||||
input_shape=(2, 3, 4))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user