Module IV·Article II·~4 min read

Stochastic Optimization and Modern Methods

Convex Optimization for ML

Turn this article into a podcast

Pick voices, format, length — AI generates the audio

Stochastic Optimization in Deep Learning

Modern neural networks are trained on billions of examples and have billions of parameters. Computing the exact gradient is impossible — stochastic methods are required. Understanding their theoretical properties allows us to tune training and diagnose problems.

Stochastic Gradient Descent: Theory

Problem statement: min_θ f(θ) = (1/n) Σᵢ fᵢ(θ). Stochastic gradient: gₜ = ∇fᵢₜ(θₜ), where iₜ is chosen at random. Key properties: E[gₜ] = ∇f(θₜ) (unbiasedness), Var[gₜ] = σ² (finite variance).

Optimal learning rate schedules:

  • Decaying: αₜ = α₀/√t → convergence O(σ/√T) (convex case)
  • Constant: αₜ = α → convergence to a neighborhood, but not to the optimum
  • Decaying for SC: αₜ = 2/(μ(t+1)) → O(σ²/(μT)) (strongly convex)

Mini-batch: gₜ = (1/|B|)Σᵢ∈Bₜ ∇fᵢ(θₜ). Variance decreases: Var[gₜ] = σ²/|B|. Linear acceleration up to the critical batch size B_crit ≈ σ²/||∇f||² — beyond this, parallelism helps only in terms of time, not iterations.

Adam: Theoretical Analysis

Adam (Kingma & Ba, 2014) is the de facto standard for training neural networks:

mₜ = β₁ mₜ₋₁ + (1−β₁) gₜ (smoothed mean of gradient) vₜ = β₂ vₜ₋₁ + (1−β₂) gₜ² (smoothed mean of squared gradient) m̂ₜ = mₜ/(1−β₁ᵗ), v̂ₜ = vₜ/(1−β₂ᵗ) (bias correction) θₜ₊₁ = θₜ − α · m̂ₜ/(√v̂ₜ + ε)

Explanation: m̂ₜ/√v̂ₜ ≈ sgn(gₜ) in the stationary regime — Adam takes steps of fixed size in the direction of the gradient sign, adapting the lr for each parameter. Parameters with a history of large gradients receive a smaller lr.

Theoretical problems: Reddi et al. (2018) constructed an example where Adam fails to converge even for convex functions. The reason: v̂ₜ may "forget" information about previous large gradients.

AMSGrad (Reddi, 2018): uses the maximum of v̂: v̂ₜᵐᵃˣ = max(v̂ₜ₋₁ᵐᵃˣ, v̂ₜ), updates θ through v̂ᵐᵃˣ. Convergence is guaranteed.

AdamW (Loshchilov & Hutter, 2019): Adam + proper weight decay. Standard Adam applies L2 regularization to the gradient (via m̂/√v̂), which differs from weight decay. AdamW: θ ← θ(1−αλ) − α·m̂/√v̂. De facto standard for transformers.

Variance Reduction: SVRG and SARAH

SGD problem: variance σ² does not tend to zero near the optimum → oscillation, cannot use large lr.

SVRG (Johnson & Zhang, 2013): periodically (every m steps) computes the full gradient ∇f(x̃). Refined stochastic gradient:

gₜ = ∇fᵢ(xₜ) − ∇fᵢ(x̃) + ∇f(x̃)

Variance → 0 as xₜ → x* (both parts converge to the same value). Result: linear convergence O(exp(−t)) for L-smooth μ-SC problem — as with deterministic GD!

SARAH (Nguyen et al., 2017): recursive variance reduction: gₜ = ∇fᵢ(xₜ) − ∇fᵢ(xₜ₋₁) + gₜ₋₁. Theoretically even better than SVRG.

Practice: SVRG/SARAH are more efficient than SGD for convex problems with many sums (logistic regression, SVM). For neural networks, Adam with lr scheduling is more practical (nonlinearity breaks theoretical guarantees).

Federated Learning

Motivation: data on user devices cannot be centralized (privacy). We want to train a global model without access to raw data.

FedAvg (McMahan et al., 2017):

  1. Server sends model θ to clients K
  2. Each client k trains E epochs on local data: θₖ ← θ − α·∇L_k(θ)
  3. Server averages: θ ← (1/K) Σₖ θₖ

Communication efficiency: E local epochs instead of 1 → reducing the number of rounds E-fold.

Problems: Data heterogeneity (non-IID): if clients have different data distributions — FedAvg diverges. FedProx adds regularization: the client minimizes L_k(θ) + μ/2||θ−θ_global||². Differential privacy: add Gaussian noise to gradients before sending → (ε,δ)-DP guarantees.

Numerical Example

Training BERT-base (110M parameters) on A100 GPU (FP16):

  • Batch size = 256, lr = 2e-4, warmup = 10000 steps
  • Adam: β₁=0.9, β₂=0.999, ε=1e-8, weight decay=0.01
  • After 1M steps (~10 days on 8×A100): val perplexity = 3.8

When increasing batch size up to 2048 (linear scaling rule: lr = 8·2e-4 = 1.6e-3): same results in fewer iterations (8×), acceleration ≈4× (not 8× due to communication).

Assignment: Compare SGD, Adam, SVRG on MNIST with logistic regression (10 classes, L2 λ=0.001). For SGD and Adam: find optimal lr via grid search. For SVRG: m=n/10 (frequency of full gradient). Plot: val accuracy vs number of gradient evaluations (fair comparison). Implement FedAvg for MNIST: 10 clients with non-IID data (each sees only 2 classes). How does quality degrade compared to centralized training?

§ Act · what next