mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Update ftml for Keras changes to optimizers (#137)
* Update ftml for Keras changes to optimizers * fix import
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
from keras.optimizers import Optimizer
|
||||
from .. import backend as K
|
||||
from keras.utils.generic_utils import get_custom_objects
|
||||
from keras.legacy import interfaces
|
||||
|
||||
|
||||
class FTML(Optimizer):
|
||||
@@ -30,7 +31,8 @@ class FTML(Optimizer):
|
||||
self.epsilon = epsilon
|
||||
self.inital_decay = decay
|
||||
|
||||
def get_updates(self, params, constraints, loss):
|
||||
@interfaces.legacy_get_updates_support
|
||||
def get_updates(self, loss, params):
|
||||
grads = self.get_gradients(loss, params)
|
||||
self.updates = [K.update_add(self.iterations, 1)]
|
||||
|
||||
@@ -61,10 +63,11 @@ class FTML(Optimizer):
|
||||
self.updates.append(K.update(d, d_t))
|
||||
|
||||
new_p = p_t
|
||||
# apply constraints
|
||||
if p in constraints:
|
||||
c = constraints[p]
|
||||
new_p = c(new_p)
|
||||
|
||||
# Apply constraints.
|
||||
if getattr(p, 'constraint', None) is not None:
|
||||
new_p = p.constraint(new_p)
|
||||
|
||||
self.updates.append(K.update(p, new_p))
|
||||
return self.updates
|
||||
|
||||
|
||||
Reference in New Issue
Block a user