mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
ebd568283a
* add ftml.py and fix __init__.py * fix PEP8
82 lines
2.8 KiB
Python
82 lines
2.8 KiB
Python
from __future__ import absolute_import
|
|
from keras.optimizers import Optimizer
|
|
from .. import backend as K
|
|
|
|
|
|
class FTML(Optimizer):
|
|
"""FTML optimizer.
|
|
|
|
# Arguments
|
|
lr: float >= 0. Learning rate.
|
|
beta_1: float, 0 < beta < 1. Generally close to 0.5.
|
|
beta_2: float, 0 < beta < 1. Generally close to 1.
|
|
epsilon: float >= 0. Fuzz factor.
|
|
decay: float >= 0. Learning rate decay over each update.
|
|
|
|
# References
|
|
- [FTML - Follow the Moving Leader in Deep Learning](http://www.cse.ust.hk/~szhengac/papers/icml17.pdf)
|
|
"""
|
|
|
|
def __init__(self, lr=0.0025, beta_1=0.6, beta_2=0.999,
|
|
epsilon=1e-8, decay=0., **kwargs):
|
|
super(FTML, self).__init__(**kwargs)
|
|
self.__dict__.update(locals())
|
|
self.iterations = K.variable(0)
|
|
self.lr = K.variable(lr)
|
|
self.beta_1 = K.variable(beta_1)
|
|
self.beta_2 = K.variable(beta_2)
|
|
self.decay = K.variable(decay)
|
|
self.inital_decay = decay
|
|
|
|
def get_updates(self, params, constraints, loss):
|
|
grads = self.get_gradients(loss, params)
|
|
self.updates = [K.update_add(self.iterations, 1)]
|
|
|
|
lr = self.lr
|
|
if self.inital_decay > 0:
|
|
lr *= (1. / (1. + self.decay * self.iterations))
|
|
|
|
t = self.iterations + 1
|
|
|
|
lr_t = lr / (1. - K.pow(self.beta_1, t))
|
|
|
|
shapes = [K.get_variable_shape(p) for p in params]
|
|
zs = [K.zeros(shape) for shape in shapes]
|
|
vs = [K.zeros(shape) for shape in shapes]
|
|
ds = [K.zeros(shape) for shape in shapes]
|
|
self.weights = [self.iterations] + zs + vs + ds
|
|
|
|
for p, g, z, v, d in zip(params, grads, zs, vs, ds):
|
|
v_t = self.beta_2 * v + (1. - self.beta_2) * K.square(g)
|
|
d_t = (K.sqrt(v_t / (1. - K.pow(self.beta_2, t))) + self.epsilon) / lr_t
|
|
sigma_t = d_t - self.beta_1 * d
|
|
z_t = self.beta_1 * z + (1. - self.beta_1) * g - sigma_t * p
|
|
|
|
p_t = - z_t / d_t
|
|
|
|
self.updates.append(K.update(z, z_t))
|
|
self.updates.append(K.update(v, v_t))
|
|
self.updates.append(K.update(d, d_t))
|
|
|
|
new_p = p_t
|
|
# apply constraints
|
|
if p in constraints:
|
|
c = constraints[p]
|
|
new_p = c(new_p)
|
|
self.updates.append(K.update(p, new_p))
|
|
return self.updates
|
|
|
|
def get_config(self):
|
|
config = {'lr': float(K.get_value(self.lr)),
|
|
'beta_1': float(K.get_value(self.beta_1)),
|
|
'beta_2': float(K.get_value(self.beta_2)),
|
|
'decay': float(K.get_value(self.decay)),
|
|
'epsilon': self.epsilon}
|
|
base_config = super(FTML, self).get_config()
|
|
return dict(list(base_config.items()) + list(config.items()))
|
|
|
|
|
|
# Aliases.
|
|
|
|
ftml = FTML
|