Update improved_wgan.py

Correctly sets the gradient_penalty hyperparameter to match the paper…
This commit is contained in:
Michael Oliver
2017-04-25 16:17:01 -07:00
committed by GitHub
parent ccc1be8857
commit 2d56791066
+7 -5
View File
@@ -55,7 +55,7 @@ def wasserstein_loss(y_true, y_pred):
return K.mean(y_true * y_pred)
def gradient_penalty_loss(y_true, y_pred, averaged_samples):
def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):
"""Calculates the gradient penalty loss for a batch of "averaged" samples.
In Improved WGANs, the 1-Lipschitz constraint is enforced by adding a term to the loss function
@@ -74,7 +74,7 @@ def gradient_penalty_loss(y_true, y_pred, averaged_samples):
gradients = K.gradients(y_pred, averaged_samples)
gradients = K.concatenate([K.flatten(tensor) for tensor in gradients])
gradient_l2_norm = K.sqrt(K.sum(K.square(gradients)))
gradient_penalty = 2 * GRADIENT_PENALTY_WEIGHT * K.square(1 - gradient_l2_norm)
gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
return gradient_penalty
@@ -215,10 +215,12 @@ averaged_samples = RandomWeightedAverage()([real_samples, generated_samples_for_
# output for these samples - we're only running them to get the gradient norm for the gradient penalty loss.
averaged_samples_out = discriminator(averaged_samples)
# The gradient penalty loss function requires a list of trainable weights to get gradients on. However,
# The gradient penalty loss function requires the input averaged samples to get gradients. However,
# Keras loss functions can only have two arguments, y_true and y_pred. We get around this by making a partial()
# of the function with the discriminator's trainable weights here.
partial_gp_loss = partial(gradient_penalty_loss, averaged_samples=averaged_samples)
# of the function with the averaged samples here.
partial_gp_loss = partial(gradient_penalty_loss,
averaged_samples=averaged_samples,
gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
partial_gp_loss.__name__ = 'gradient_penalty' # Functions need names or Keras will throw an error
# Keras requires that inputs and outputs have the same number of samples. This is why we didn't concatenate the