mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
move constraints and initializers
This commit is contained in:
@@ -1,25 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from . import backend as K
|
||||
from keras.constraints import *
|
||||
|
||||
|
||||
class Clip(Constraint):
|
||||
"""Clips weights to [-c, c].
|
||||
|
||||
# Arguments
|
||||
c: Clipping parameter.
|
||||
"""
|
||||
|
||||
def __init__(self, c=0.01):
|
||||
self.c = c
|
||||
|
||||
def __call__(self, p):
|
||||
return K.clip(p, -self.c, self.c)
|
||||
|
||||
def get_config(self):
|
||||
return {'name': self.__class__.__name__,
|
||||
'c': self.c}
|
||||
|
||||
# Aliases.
|
||||
|
||||
clip = Clip
|
||||
@@ -1,106 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
from . import backend as K
|
||||
from keras.initializers import *
|
||||
from keras.initializers import _compute_fans
|
||||
|
||||
|
||||
class ConvolutionAware(Initializer):
|
||||
"""
|
||||
Initializer that generates orthogonal convolution filters in the fourier
|
||||
space. If this initializer is passed a shape that is not 3D or 4D,
|
||||
orthogonal intialization will be used.
|
||||
# Arguments
|
||||
eps_std: Standard deviation for the random normal noise used to break
|
||||
symmetry in the inverse fourier transform.
|
||||
seed: A Python integer. Used to seed the random generator.
|
||||
# References
|
||||
Armen Aghajanyan, https://arxiv.org/abs/1702.06295
|
||||
"""
|
||||
|
||||
def __init__(self, eps_std=0.05, seed=None):
|
||||
self.eps_std = eps_std
|
||||
self.seed = seed
|
||||
self.orthogonal = Orthogonal()
|
||||
|
||||
def __call__(self, shape):
|
||||
rank = len(shape)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
|
||||
fan_in, fan_out = _compute_fans(shape, K.image_data_format())
|
||||
variance = 2 / fan_in
|
||||
|
||||
if rank == 3:
|
||||
row, stack_size, filters_size = shape
|
||||
|
||||
transpose_dimensions = (2, 1, 0)
|
||||
kernel_shape = (row,)
|
||||
correct_ifft = lambda shape, s=[None]: np.fft.irfft(shape, s[0])
|
||||
correct_fft = np.fft.rfft
|
||||
|
||||
elif rank == 4:
|
||||
row, column, stack_size, filters_size = shape
|
||||
|
||||
transpose_dimensions = (2, 3, 0, 1)
|
||||
kernel_shape = (row, column)
|
||||
correct_ifft = np.fft.irfft2
|
||||
correct_fft = np.fft.rfft2
|
||||
|
||||
elif rank == 5:
|
||||
x, y, z, stack_size, filters_size = shape
|
||||
|
||||
transpose_dimensions = (3, 4, 0, 1, 2)
|
||||
kernel_shape = (x, y, z)
|
||||
correct_fft = np.fft.rfftn
|
||||
correct_ifft = np.fft.irfftn
|
||||
else:
|
||||
return K.variable(self.orthogonal(shape), dtype=K.floatx())
|
||||
|
||||
kernel_fourier_shape = correct_fft(np.zeros(kernel_shape)).shape
|
||||
|
||||
init = []
|
||||
for i in range(filters_size):
|
||||
basis = self._create_basis(
|
||||
stack_size, np.prod(kernel_fourier_shape))
|
||||
basis = basis.reshape((stack_size,) + kernel_fourier_shape)
|
||||
|
||||
filters = [correct_ifft(x, kernel_shape) +
|
||||
np.random.normal(0, self.eps_std, kernel_shape) for
|
||||
x in basis]
|
||||
|
||||
init.append(filters)
|
||||
|
||||
# Format of array is now: filters, stack, row, column
|
||||
init = np.array(init)
|
||||
init = self._scale_filters(init, variance)
|
||||
return init.transpose(transpose_dimensions)
|
||||
|
||||
def _create_basis(self, filters, size):
|
||||
if size == 1:
|
||||
return np.random.normal(0.0, self.eps_std, (filters, size))
|
||||
|
||||
nbb = filters // size + 1
|
||||
li = []
|
||||
for i in range(nbb):
|
||||
a = np.random.normal(0.0, 1.0, (size, size))
|
||||
a = self._symmetrize(a)
|
||||
u, _, v = np.linalg.svd(a)
|
||||
li.extend(u.T.tolist())
|
||||
p = np.array(li[:filters], dtype=K.floatx())
|
||||
return p
|
||||
|
||||
def _symmetrize(self, a):
|
||||
return a + a.T - np.diag(a.diagonal())
|
||||
|
||||
def _scale_filters(self, filters, variance):
|
||||
c_var = np.var(filters)
|
||||
p = np.sqrt(variance / c_var)
|
||||
return filters * p
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'eps_std': self.eps_std,
|
||||
'seed': self.seed
|
||||
}
|
||||
Reference in New Issue
Block a user