top of page

Improved Training of Wasserstein GANs

I really like the WGAN -GP paper. But I think the theory in the paper scared off a lot of people, which is a bit of a shame because it's quite cool.
This is my attempt to make the paper more accessible!

You may want to download the paper yourself, especially if you want more of the theoretical details. To aid anyone who takes me up on this, the section names in this post will match the ones in the paper.

Abstract

 

  • The recently proposed Wasserstein GAN (WGAN) makes progress toward stable training of GANs, but sometimes can still generate only poor samples or fail to converge. We find that these problems are often due to the use of weight clipping in WGAN to enforce a Lipschitz constraint on the critic, which can lead to undesired behavior. 

  • We propose an alternative to clipping weights: penalize the norm of gradient of the critic with respect to its input. Our proposed method performs better than standard WGAN and enables stable training of a wide variety of GAN architectures with almost no hyperparameter tuning.

Introduction

The Wasserstein distance produces a value function which has better theoretical properties than the original. WGAN requires that the discriminator (called the critic in that work) must lie within the space of 1-Lipschitz functions, which the authors enforce through weight clipping.

Our contributions are as follows:

  1. On toy datasets, we demonstrate how critic weight clipping can lead to undesired behavior.

  2. We propose a gradient penalty (WGAN-GP), which does not suffer from the same problems.

  3. We demonstrate stable training of varied GAN architectures, performance improvements over weight clipping, high-quality image generation, and a character-level GAN language model without any discrete sampling.

Background

ganformula1.png

To enforce the Lipschitz constraint on the critic, the weights of the critic are clipped to lie within a compact space [−c, c]. The set of functions satisfying this constraint is a subset of the k-Lipschitz functions for some k which depends on c and the critic architecture. In the following sections, we demonstrate some of the issues with this approach and propose an alternative.

Difficulties with weight constraints 

Lipschitz Continuity

1_ieyAKSxgJGqX9lktL_ujnA.png

For c = 0.1, as we progress from 13 layers to 1 layer, gradients are getting bigger. And for c = 0.01 it is diminishing. On the RHS you can see that the weights are being pushed to the extremes because they are such small values, every weight seems to be outside them.

 

The model performance is very sensitive to this hyperparameter. In the diagram above, when batch normalization is off, the discriminator moves from diminishing gradients to exploding gradients when c increases from 0.01 to 0.1.

Model Capacity

 

The weight clipping behaves as a weight regulation. It reduces the capacity of the model f and limits the capability to model complex functions. In the experiment below, the first row is the contour plot of the value function estimated by WGAN. The second row is estimated by WGAN-GP. The reduced capacity of WGAN fails to create a complex boundary to surround the modes (orange dots) of the model while the improved WGAN-GP can.

 

It essentially over simplifies whatever we have to generate.

ganformula2.png

So weight clipping is kind of either too hard on you (small c) or too loose (large c). It involves too much hyper parameter tuning.

Gradient Penalty

As we already know, the loss function for WGAN:

ganformula3.png

The critic needs to be 1-L continuous (i.e the norm of the gradient needs to be at most 1 at every single point of the function).


The gradient penalty is a softer way of implementing the 1-L condition than the weight clipping. Essentially, it just adds a regularization term to your loss function:

ganformula4.png

The ‘reg’ term essentially penalizes the critic when the gradient term is greater than 1. λ is just to decide how much to weigh this regularization (i.e how in-check you want the lipschitz condition to be).
 

So now we need to check the gradient norm of the critic at every possible point of the feature space. However, it is impossible to evaluate this function at all points in the input space. So they came up with an idea to only check interpolations of real and fake data. 

ganformula5.png

So you sample some real data and sample some fake data, combine them based on some parameter ϵ (between 0 and 1) and get a mixed image    .

So the critic looks at     , get the gradient of the critics prediction on      and take the norm. This norm should be maximum 1. Note that this kind of gradient penalty encourages gradients to be 1 rather than less than 1 because it's just making sure it's within 1.

xcap.png
xcap.png
xcap.png
ganformula6.png

So the WGAN-GP Value Function V(G,D) is:

ganformula7.png

The square just penalizes it extremely.
 

It does not strictly enforce 1-Lipschitz conditions, it just kind of encourages it.

ganformula8.png

Batch normalization is avoided for the critic. Batch normalization creates correlations between samples in the same batch. It impacts the effectiveness of the gradient penalty which is confirmed by experiments.

By design or not, some new cost functions add a gradient penalty to the cost function. Some is purely based on empirical observation that models misbehave when the gradient increases. However, gradient penalty adds computational complexity that may not be desirable but it does produce some higher-quality images.


Note: The proof of this can be seen in appendix of the paper.

Experiments

WGAN-GP enhances training stability.

 

Below is the inception score using different methods. The experiment from the WGAN-GP paper demonstrates better image quality and convergence comparing with WGAN. However, DCGAN demonstrates slightly better image quality and it converges faster. But the inception score for WGAN-GP is more stable when it starts converging. Although it might perform worse than DCGAN, it makes training more stable and therefore easier to train. As WGAN-GP helps models to converge better, we can use a more complex model like a deep ResNet for the generator and the discriminator.

ganformula9.png

Conclusion

Stability in training is the biggest contribution of the paper.

bottom of page