mirror of
https://github.com/wassname/keras-contrib.git
synced 2026-06-27 16:10:11 +08:00
Update improved_wgan.py
Correctly sets the gradient_penalty hyperparameter to match the paper…
This commit is contained in:
@@ -55,7 +55,7 @@ def wasserstein_loss(y_true, y_pred):
|
||||
return K.mean(y_true * y_pred)
|
||||
|
||||
|
||||
def gradient_penalty_loss(y_true, y_pred, averaged_samples):
|
||||
def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):
|
||||
"""Calculates the gradient penalty loss for a batch of "averaged" samples.
|
||||
|
||||
In Improved WGANs, the 1-Lipschitz constraint is enforced by adding a term to the loss function
|
||||
@@ -74,7 +74,7 @@ def gradient_penalty_loss(y_true, y_pred, averaged_samples):
|
||||
gradients = K.gradients(y_pred, averaged_samples)
|
||||
gradients = K.concatenate([K.flatten(tensor) for tensor in gradients])
|
||||
gradient_l2_norm = K.sqrt(K.sum(K.square(gradients)))
|
||||
gradient_penalty = 2 * GRADIENT_PENALTY_WEIGHT * K.square(1 - gradient_l2_norm)
|
||||
gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
@@ -215,10 +215,12 @@ averaged_samples = RandomWeightedAverage()([real_samples, generated_samples_for_
|
||||
# output for these samples - we're only running them to get the gradient norm for the gradient penalty loss.
|
||||
averaged_samples_out = discriminator(averaged_samples)
|
||||
|
||||
# The gradient penalty loss function requires a list of trainable weights to get gradients on. However,
|
||||
# The gradient penalty loss function requires the input averaged samples to get gradients. However,
|
||||
# Keras loss functions can only have two arguments, y_true and y_pred. We get around this by making a partial()
|
||||
# of the function with the discriminator's trainable weights here.
|
||||
partial_gp_loss = partial(gradient_penalty_loss, averaged_samples=averaged_samples)
|
||||
# of the function with the averaged samples here.
|
||||
partial_gp_loss = partial(gradient_penalty_loss,
|
||||
averaged_samples=averaged_samples,
|
||||
gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
|
||||
partial_gp_loss.__name__ = 'gradient_penalty' # Functions need names or Keras will throw an error
|
||||
|
||||
# Keras requires that inputs and outputs have the same number of samples. This is why we didn't concatenate the
|
||||
|
||||
Reference in New Issue
Block a user