mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
committed by
Michael Oliver
parent
a466050b5b
commit
ebd568283a
@@ -0,0 +1 @@
|
||||
from .ftml import FTML
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
from __future__ import absolute_import
|
||||
from keras.optimizers import Optimizer
|
||||
from .. import backend as K
|
||||
|
||||
|
||||
class FTML(Optimizer):
|
||||
"""FTML optimizer.
|
||||
|
||||
# Arguments
|
||||
lr: float >= 0. Learning rate.
|
||||
beta_1: float, 0 < beta < 1. Generally close to 0.5.
|
||||
beta_2: float, 0 < beta < 1. Generally close to 1.
|
||||
epsilon: float >= 0. Fuzz factor.
|
||||
decay: float >= 0. Learning rate decay over each update.
|
||||
|
||||
# References
|
||||
- [FTML - Follow the Moving Leader in Deep Learning](http://www.cse.ust.hk/~szhengac/papers/icml17.pdf)
|
||||
"""
|
||||
|
||||
def __init__(self, lr=0.0025, beta_1=0.6, beta_2=0.999,
|
||||
epsilon=1e-8, decay=0., **kwargs):
|
||||
super(FTML, self).__init__(**kwargs)
|
||||
self.__dict__.update(locals())
|
||||
self.iterations = K.variable(0)
|
||||
self.lr = K.variable(lr)
|
||||
self.beta_1 = K.variable(beta_1)
|
||||
self.beta_2 = K.variable(beta_2)
|
||||
self.decay = K.variable(decay)
|
||||
self.inital_decay = decay
|
||||
|
||||
def get_updates(self, params, constraints, loss):
|
||||
grads = self.get_gradients(loss, params)
|
||||
self.updates = [K.update_add(self.iterations, 1)]
|
||||
|
||||
lr = self.lr
|
||||
if self.inital_decay > 0:
|
||||
lr *= (1. / (1. + self.decay * self.iterations))
|
||||
|
||||
t = self.iterations + 1
|
||||
|
||||
lr_t = lr / (1. - K.pow(self.beta_1, t))
|
||||
|
||||
shapes = [K.get_variable_shape(p) for p in params]
|
||||
zs = [K.zeros(shape) for shape in shapes]
|
||||
vs = [K.zeros(shape) for shape in shapes]
|
||||
ds = [K.zeros(shape) for shape in shapes]
|
||||
self.weights = [self.iterations] + zs + vs + ds
|
||||
|
||||
for p, g, z, v, d in zip(params, grads, zs, vs, ds):
|
||||
v_t = self.beta_2 * v + (1. - self.beta_2) * K.square(g)
|
||||
d_t = (K.sqrt(v_t / (1. - K.pow(self.beta_2, t))) + self.epsilon) / lr_t
|
||||
sigma_t = d_t - self.beta_1 * d
|
||||
z_t = self.beta_1 * z + (1. - self.beta_1) * g - sigma_t * p
|
||||
|
||||
p_t = - z_t / d_t
|
||||
|
||||
self.updates.append(K.update(z, z_t))
|
||||
self.updates.append(K.update(v, v_t))
|
||||
self.updates.append(K.update(d, d_t))
|
||||
|
||||
new_p = p_t
|
||||
# apply constraints
|
||||
if p in constraints:
|
||||
c = constraints[p]
|
||||
new_p = c(new_p)
|
||||
self.updates.append(K.update(p, new_p))
|
||||
return self.updates
|
||||
|
||||
def get_config(self):
|
||||
config = {'lr': float(K.get_value(self.lr)),
|
||||
'beta_1': float(K.get_value(self.beta_1)),
|
||||
'beta_2': float(K.get_value(self.beta_2)),
|
||||
'decay': float(K.get_value(self.decay)),
|
||||
'epsilon': self.epsilon}
|
||||
base_config = super(FTML, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
# Aliases.
|
||||
|
||||
ftml = FTML
|
||||
Reference in New Issue
Block a user