Files
keras-contrib/keras_contrib/optimizers/ftml.py
T
Shuai Zheng ebd568283a FTML optimizer (#110)
* add ftml.py and fix __init__.py

* fix PEP8
2017-06-16 11:28:43 -07:00

82 lines
2.8 KiB
Python

from __future__ import absolute_import
from keras.optimizers import Optimizer
from .. import backend as K
class FTML(Optimizer):
"""FTML optimizer.
# Arguments
lr: float >= 0. Learning rate.
beta_1: float, 0 < beta < 1. Generally close to 0.5.
beta_2: float, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor.
decay: float >= 0. Learning rate decay over each update.
# References
- [FTML - Follow the Moving Leader in Deep Learning](http://www.cse.ust.hk/~szhengac/papers/icml17.pdf)
"""
def __init__(self, lr=0.0025, beta_1=0.6, beta_2=0.999,
epsilon=1e-8, decay=0., **kwargs):
super(FTML, self).__init__(**kwargs)
self.__dict__.update(locals())
self.iterations = K.variable(0)
self.lr = K.variable(lr)
self.beta_1 = K.variable(beta_1)
self.beta_2 = K.variable(beta_2)
self.decay = K.variable(decay)
self.inital_decay = decay
def get_updates(self, params, constraints, loss):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
lr = self.lr
if self.inital_decay > 0:
lr *= (1. / (1. + self.decay * self.iterations))
t = self.iterations + 1
lr_t = lr / (1. - K.pow(self.beta_1, t))
shapes = [K.get_variable_shape(p) for p in params]
zs = [K.zeros(shape) for shape in shapes]
vs = [K.zeros(shape) for shape in shapes]
ds = [K.zeros(shape) for shape in shapes]
self.weights = [self.iterations] + zs + vs + ds
for p, g, z, v, d in zip(params, grads, zs, vs, ds):
v_t = self.beta_2 * v + (1. - self.beta_2) * K.square(g)
d_t = (K.sqrt(v_t / (1. - K.pow(self.beta_2, t))) + self.epsilon) / lr_t
sigma_t = d_t - self.beta_1 * d
z_t = self.beta_1 * z + (1. - self.beta_1) * g - sigma_t * p
p_t = - z_t / d_t
self.updates.append(K.update(z, z_t))
self.updates.append(K.update(v, v_t))
self.updates.append(K.update(d, d_t))
new_p = p_t
# apply constraints
if p in constraints:
c = constraints[p]
new_p = c(new_p)
self.updates.append(K.update(p, new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'decay': float(K.get_value(self.decay)),
'epsilon': self.epsilon}
base_config = super(FTML, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# Aliases.
ftml = FTML