Module IV·Article I·~3 min read
Generative Adversarial Networks (GAN)
Generative Models
Turn this article into a podcast
Pick voices, format, length — AI generates the audio
Generative Adversarial Networks (GAN)
GAN (Goodfellow et al., 2014) is one of the most creative architectural patterns in the history of machine learning. Yann LeCun called them “the most interesting idea in machine learning in the past 20 years.” They spawned an entire era of synthetic media: realistic faces, deepfakes, AI artists.
Principle of GAN: Game Setting
Two agents: Generator G: z → x̂ (generates synthetic data from noise z). Discriminator D: x → [0,1] (distinguishes real data from synthetic).
Minimax game:
min_G max_D V(D,G) = E_{xp_data}[log D(x)] + E_{zp_z}[log(1 − D(G(z)))]
Decoding: D maximizes the probability of correct classification of real (D(x)→1) and fake (D(G(z))→0) data. G minimizes the probability that D will detect the fake (D(G(z))→1).
Optimal D with fixed G: D*(x) = p_data(x)/(p_data(x) + p_G(x)). Substituting: max_D V = −log(4) + 2·JSD(p_data || p_G). JSD is the Jensen-Shannon divergence. At equilibrium: p_G = p_data → D* = 1/2 → V = −log(4).
Nash Equilibrium Theorem for GAN: The only Nash equilibrium point: G replicates p_data, D cannot distinguish (D*(x) = 0.5 everywhere).
GAN Training Problems
Mode collapse: G “finds” several successful modes and generates only them. Instead of diverse faces — several “safe” variants. Reason: G is not directly penalized for lack of diversity.
Instability: A fine balance — if D is too good, gradient for G ≈ 0 (D saturates). If D is too weak — G gets no useful signal. Training is unstable: G and D “chase” each other.
Vanishing gradients when D saturates: log(1 − D(G(z))) → 0 as D(G(z)) → 0. Practical solution: rephrase as min_G −E[log D(G(z))] (non-saturating loss). Now the gradient does not vanish even with a strong D.
Improvements: DCGAN, WGAN, StyleGAN
DCGAN (Radford et al., 2015): Recommendations for stable training: replace pooling with strided convolutions (in G — transposed, in D — strided). Batch Normalization everywhere (except G's output layer and D's input). LeakyReLU in D (α=0.2). tanh in G's output. Remove fully connected layers.
WGAN (Arjovsky et al., 2017): Replace JS divergence with Wasserstein distance (Earth Mover's Distance): W₁(p,q) = inf_{γ∈Π(p,q)} E_{(x,y)~γ}[||x−y||]. Kantorovich-Rubinstein theorem: W₁(p,q) = sup_{||f||_L≤1} [E_p[f] − E_q[f]]. Critic f (not a discriminator) must be 1-Lipschitz → gradient clipping (||W|| ≤ c) or gradient penalty. Training is much more stable, meaningful loss function.
StyleGAN (Karras et al., 2018/2019): “Style” is introduced in G via Adaptive Instance Normalization (AdaIN) at each network level: AdaIN(x, y) = y_s (x − μ(x))/σ(x) + y_b. y_s and y_b — scale and shift, computed from latent code w (via mapping network z → w → y). Different levels control different scales: coarse layers (pose, shape) → middle (features) → fine (texture, color). Thispersondoesnotexist.com is an example.
FID (Fréchet Inception Distance): Standard GAN quality metric. Use Inception-v3 as feature extractor. Real and synthetic images → feature vectors → approximate with Gaussians (μ_r, Σ_r) and (μ_g, Σ_g). FID = ||μ_r − μ_g||² + Tr(Σ_r + Σ_g − 2(Σ_r Σ_g)^{1/2}). Lower FID → better. StyleGAN2: FID=2.8 on FFHQ (vs 35 for the first GAN).
Numerical Example
DCGAN for MNIST generation (28×28, binary): G architecture: FC(100→256·7·7) → BN → Reshape(256,7,7) → ConvTranspose(128,4,2) → BN → ConvTranspose(1,4,2) → tanh. D architecture: Conv(64,4,2) → LeakyReLU → Conv(128,4,2) → BN → LeakyReLU → Flatten → FC → sigmoid.
Training for 50 epochs (batch=64, Adam lr=0.0002): after 5 epochs — blurry digits. After 20 — recognizable. After 50 — high-quality samples. FID(50 epochs) ≈ 12 (close to real data). Mode collapse was prevented thanks to label smoothing (real → 0.9, not 1.0).
Assignment: Implement DCGAN for MNIST. (1) G: Linear(100)→Unflatten→ConvTranspose×3→Tanh. D: Conv×3→LeakyReLU→Flatten→Linear→Sigmoid. (2) Train for 50 epochs. Visualize progress every 5 epochs (grid 8×8). (3) Compute FID on 1000 generated samples. (4) Implement interpolation in z-space: linearly interpolate between z₁ and z₂ (8 steps) → visualize “smooth” transition between digits.
§ Act · what next