Wasserstein GAN
   HOME

TheInfoList



OR:

The Wasserstein Generative Adversarial Network (WGAN) is a variant of generative adversarial network (GAN) proposed in 2017 that aims to "improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning curves useful for debugging and hyperparameter searches". Compared with the original GAN discriminator, the Wasserstein GAN discriminator provides a better learning signal to the generator. This allows the training to be more stable when generator is learning distributions in very high dimensional spaces.


Motivation


The GAN game

The original GAN method is based on the GAN game, a
zero-sum game Zero-sum game is a mathematical representation in game theory and economic theory of a situation which involves two sides, where the result is an advantage for one side and an equivalent loss for the other. In other words, player one's gain is e ...
with 2 players: generator and discriminator. The game is defined over a
probability space In probability theory, a probability space or a probability triple (\Omega, \mathcal, P) is a mathematical construct that provides a formal model of a random process or "experiment". For example, one can define a probability space which models t ...
(\Omega, \mathcal B, \mu_), The generator's strategy set is the set of all probability measures \mu_G on (\Omega, \mathcal B), and the discriminator's strategy set is the set of measurable functions D: \Omega \to
, 1 The comma is a punctuation mark that appears in several variants in different languages. It has the same shape as an apostrophe or single closing quotation mark () in many typefaces, but it differs from them in being placed on the baseline o ...
/math>. The objective of the game isL(\mu_G, D) := \mathbb_ ln D(x)+ \mathbb_ ln (1-D(x)) The generator aims to minimize it, and the discriminator aims to maximize it. A basic theorem of the GAN game states that Repeat the GAN game many times, each time with the generator moving first, and the discriminator moving second. Each time the generator \mu_G changes, the discriminator must adapt by approaching the idealD^*(x) = \frac. Since we are really interested in \mu_, the discriminator function D is by itself rather uninteresting. It merely keeps track of the likelihood ratio between the generator distribution and the reference distribution. At equilibrium, the discriminator is just outputting \frac 12 constantly, having given up trying to perceive any difference. Concretely, in the GAN game, let us fix a generator \mu_G, and improve the discriminator step-by-step, with \mu_ being the discriminator at step t. Then we (ideally) haveL(\mu_G, \mu_) \leq L(\mu_G, \mu_) \leq \cdots \leq \max_ L(\mu_G, \mu_D) = 2D_(\mu_ \, \mu_G) - 2\ln 2,so we see that the discriminator is actually lower-bounding D_(\mu_ \, \mu_G).


Wasserstein distance

Thus, we see that the point of the discriminator is mainly as a critic to provide feedback for the generator, about "how far it is from perfection", where "far" is defined as Jensen–Shannon divergence. Naturally, this brings the possibility of using a different criteria of farness. There are many possible divergences to choose from, such as the
f-divergence In probability theory, an f-divergence is a function D_f(P\, Q) that measures the difference between two probability distributions P and Q. Many common divergences, such as KL-divergence, Hellinger distance, and total variation distance, are s ...
family, which would give the f-GAN. The Wasserstein GAN is obtained by using the
Wasserstein metric In mathematics, the Leonid Vaseršteĭn, Wasserstein distance or Leonid Kantorovich, Kantorovich–Gennadii Rubinstein, Rubinstein metric is a metric (mathematics), distance function defined between Probability distribution, probability distributi ...
, which satisfies a "dual representation theorem" that renders it highly efficient to compute: A proof can be found in the main page on Wasserstein metric.


Definition

By the Kantorovich-Rubenstein duality, the definition of Wasserstein GAN is clear: By the Kantorovich-Rubenstein duality, for any generator strategy \mu_G, the optimal reply by the discriminator is D^*, such that L_(\mu_G, D^*) = K \cdot W_1(\mu_G, \mu_).Consequently, if the discriminator is good, the generator would be constantly pushed to minimize W_1(\mu_G, \mu_), and the optimal strategy for the generator is just \mu_G = \mu_, as it should.


Comparison with GAN

In the Wasserstein GAN game, the discriminator provides a better gradient than in the GAN game. Consider for example a game on the real line where both \mu_G and \mu_ are Gaussian. Then the optimal Wasserstein critic D_ and the optimal GAN discriminator D are plotted as below: For fixed discriminator, the generator needs to minimize the following objectives: * For GAN, \mathbb E_ ln(1-D(x))/math>. * For Wasserstein GAN, \mathbb E_ _(x)/math>. Let \mu_G be parametrized by \theta, then we can perform stochastic gradient descent by using two
unbiased estimator In statistics, the bias of an estimator (or bias function) is the difference between this estimator's expected value and the true value of the parameter being estimated. An estimator or decision rule with zero bias is called ''unbiased''. In sta ...
s of the gradient:\nabla_ \mathbb E_ ln(1-D(x))= \mathbb E_ ln(1-D(x))\cdot \nabla_ \ln\rho_(x)/math>\nabla_ \mathbb E_ _(x)= \mathbb E_ _(x)\cdot \nabla_ \ln\rho_(x)/math>where we used the reparametrization trick. As shown, the generator in GAN is motivated to let its \mu_G "slide down the peak" of \ln(1-D(x)). Similarly for the generator in Wasserstein GAN. For Wasserstein GAN, D_ has gradient 1 almost everywhere, while for GAN, \ln(1-D) has flat gradient in the middle, and steep gradient elsewhere. As a result, the variance for the estimator in GAN is usually much larger than that in Wasserstein GAN. See also Figure 3 of. The problem with D_ is much more severe in actual machine learning situations. Consider training a GAN to generate ImageNet, a collection of photos of size 256-by-256. The space of all such photos is \R^, and the distribution of ImageNet pictures, \mu_, concentrates on a manifold of much lower dimension in it. Consequently, any generator strategy \mu_G would almost surely be entirely disjoint from \mu_, making D_(\mu_G \, \mu_) = +\infty. Thus, a good discriminator can almost perfectly distinguish \mu_ from \mu_G, as well as any \mu_G' close to \mu_G. Thus, the gradient \nabla_ L(\mu_G, D) \approx 0, creating no learning signal for the generator. Detailed theorems can be found in.


Training Wasserstein GANs

Training the generator in Wasserstein GAN is just
gradient descent In mathematics, gradient descent (also often called steepest descent) is a first-order iterative optimization algorithm for finding a local minimum of a differentiable function. The idea is to take repeated steps in the opposite direction of the ...
, the same as in GAN (or most deep learning methods), but training the discriminator is different, as the discriminator is now restricted to have bounded Lipschitz norm. There are several methods for this.


Upper-bounding the Lipschitz norm

Let the discriminator function D to be implemented by a
multilayer perceptron A multilayer perceptron (MLP) is a fully connected class of feedforward artificial neural network (ANN). The term MLP is used ambiguously, sometimes loosely to mean ''any'' feedforward ANN, sometimes strictly to refer to networks composed of mul ...
:D = D_n \circ D_ \circ \cdots \circ D_1where D_i(x) = h(W_i x), and h:\R \to \R is a fixed activation function with \sup_x , h'(x), \leq 1. For example, the
hyperbolic tangent function In mathematics, hyperbolic functions are analogues of the ordinary trigonometric functions, but defined using the hyperbola rather than the circle. Just as the points form a circle with a unit radius, the points form the right half of the u ...
h = \tanh satisfies the requirement. Then, for any x, let x_i = (D_i \circ D_ \circ \cdots \circ D_1)(x), we have by the
chain rule In calculus, the chain rule is a formula that expresses the derivative of the composition of two differentiable functions and in terms of the derivatives of and . More precisely, if h=f\circ g is the function such that h(x)=f(g(x)) for every , ...
:d D(x) = diag(h'(W_n x_)) \cdot W_n \cdot diag(h'(W_ x_)) \cdot W_ \cdots diag(h'(W_1 x)) \cdot W_1 \cdot dxThus, the Lipschitz norm of D is upper-bounded by\, D \, _L \leq \sup_\, diag(h'(W_n x_)) \cdot W_n \cdot diag(h'(W_ x_)) \cdot W_ \cdots diag(h'(W_1 x)) \cdot W_1\, _Fwhere \, \cdot\, _s is the
operator norm In mathematics, the operator norm measures the "size" of certain linear operators by assigning each a real number called its . Formally, it is a norm defined on the space of bounded linear operators between two given normed vector spaces. Introd ...
of the matrix, that is, the largest
singular value In mathematics, in particular functional analysis, the singular values, or ''s''-numbers of a compact operator T: X \rightarrow Y acting between Hilbert spaces X and Y, are the square roots of the (necessarily non-negative) eigenvalues of the self ...
of the matrix, that is, the
spectral radius In mathematics, the spectral radius of a square matrix is the maximum of the absolute values of its eigenvalues. More generally, the spectral radius of a bounded linear operator is the supremum of the absolute values of the elements of its spectru ...
of the matrix (these concepts are the same for matrices, but different for general
linear operators In mathematics, and more specifically in linear algebra, a linear map (also called a linear mapping, linear transformation, vector space homomorphism, or in some contexts linear function) is a mapping V \to W between two vector spaces that pre ...
). Since \sup_x , h'(x), \leq 1, we have \, diag(h'(W_i x_))\, _s = \max_j , h'(W_i x_), \leq 1, and consequently the upper bound:\, D \, _L \leq \prod_^n \, W_i \, _sThus, if we can upper-bound operator norms \, W_i\, _s of each matrix, we can upper-bound the Lipschitz norm of D.


Weight clipping

Since for any m\times l matrix W, let c = \max_ , W_, , we have\, W\, _s^2 = \sup_\, W x\, _2^2 = \sup_\sum_\left(\sum_j W_ x_j\right)^2 = \sup_\sum_W_W_x_jx_k \leq c^2 ml^2by clipping all entries of W to within some interval c, c/math>, we have can bound \, W\, _s. This is the weight clipping method, proposed by the original paper.


Spectral normalization

The spectral radius can be efficiently computed by the following algorithm: By reassigning W_i \leftarrow \frac after each update of the discriminator, we can upper bound \, W_i\, _s \leq 1, and thus upper bound \, D \, _L. The algorithm can be further accelerated by memoization: At step t, store x^*_i(t). Then at step t+1, use x^*_i(t) as the initial guess for the algorithm. Since W_i(t+1) is very close to W_i(t), so is x^*_i(t) close to x^*_i(t+1), so this allows rapid convergence. This is the spectral normalization method.


Gradient penalty

Instead of strictly bounding \, D\, _L, we can simply add a "gradient penalty" term for the discriminator, of form\mathbb_ \nabla D(x)\, _2 - a)^2/math>where \hat \mu is a fixed distribution used to estimate how much the discriminator has violated the Lipschitz norm requirement. The discriminator, in attempting to minimize the new loss function, would naturally bring \nabla D(x) close to a everywhere, thus making \, D\, _L \approx a. This is the gradient penalty method.


Further reading


From GAN to WGAN

Wasserstein GAN and the Kantorovich-Rubinstein Duality

Depth First Learning: Wasserstein GAN


See also

*
Generative adversarial network A generative adversarial network (GAN) is a class of machine learning frameworks designed by Ian Goodfellow and his colleagues in June 2014. Two neural networks contest with each other in the form of a zero-sum game, where one agent's gain is a ...
*
Wasserstein metric In mathematics, the Leonid Vaseršteĭn, Wasserstein distance or Leonid Kantorovich, Kantorovich–Gennadii Rubinstein, Rubinstein metric is a metric (mathematics), distance function defined between Probability distribution, probability distributi ...
*
Earth mover's distance In statistics, the earth mover's distance (EMD) is a measure of the distance between two probability distributions over a region ''D''. In mathematics, this is known as the Wasserstein metric. Informally, if the distributions are interpreted ...
* Transportation theory


References


Notes

{{Differentiable computing Neural network architectures Cognitive science Unsupervised learning