Advanced Sampling in Diffusion Models
This post continues our exploration of diffusion model sampling, diving into numerical samplers (Euler, Heun, Runge–Kutta, DPM-Solver), classifier-free guidance (CFG) and its interpretations, and latent diffusion. We connect practical algorithms to theoretical principles, providing a rigorous yet intuitive understanding.
1. Numerical Solvers for SDEs and ODEs
Sampling from diffusion models amounts to integrating either:
- Stochastic differential equations (SDEs) — DDPM-style sampling
- Deterministic ODEs (probability flow ODEs) — DDIM-style sampling
These can be integrated using numerical methods:
- Euler–Maruyama (for SDEs): explicit, fast but low accuracy.
- Heun’s method (2nd-order Runge–Kutta): corrects Euler with a predictor–corrector step.
- DPM-Solver: an efficient high-order sampler designed for score-based diffusion ODEs. See the paper DPM-Solver for the full derivation.
The key benefit of advanced solvers like DPM-Solver is few-step sampling with high fidelity, thanks to treating the score model as a time-dependent vector field and integrating it as an ODE.
2. Classifier-Free Guidance (CFG)
In conditional generation, we want to sample from \( p(\mathbf{x} \mid \mathbf{y}) \). One option is classifier guidance (Dhariwal & Nichol, 2021), using a pretrained classifier \( \nabla_{\mathbf{x}} \log p(\mathbf{y} \mid \mathbf{x}) \). But CFG takes a simpler approach:
- Train a single conditional score model \( s_\theta(\mathbf{x}, t, \mathbf{y}) \),
- Use null-conditioning \( \varnothing \) during training with some probability (e.g., 10–20%)
- At inference, guide with: \[ \tilde{s}_\theta(\mathbf{x}, t, \mathbf{y}) = (1 + w) s_\theta(\mathbf{x}, t, \mathbf{y}) - w s_\theta(\mathbf{x}, t, \varnothing) \] where \( w > 0 \) is the guidance scale.
Interpretation 1: Directional Control
This linear interpolation increases the gradient magnitude in the direction of the conditional score, amplifying the drift toward high-likelihood samples under \( p(\mathbf{x} \mid \mathbf{y}) \).
Interpretation 2: Annealed Denoising
As noted in GLIDE and subsequent works, CFG acts as an annealing mechanism: the model gradually shifts from unconditional to conditional generation.
Interpretation 3: A Model vs. Its Noisy Self
CFG can also be seen as the model comparing itself to a "bad version" — the unconditional variant — and refining its outputs based on that discrepancy.
CFG is especially powerful in text-to-image models like Stable Diffusion, where nuanced conditioning (e.g., CLIP embeddings) allows for fine-grained generation control.
3. Latent Diffusion Models (LDM)
Training a diffusion model in pixel space is costly. Latent Diffusion Models (Rombach et al., 2022) propose:
- Train an autoencoder \( E(\mathbf{x}) = \mathbf{z} \), \( D(\mathbf{z}) = \hat{\mathbf{x}} \)
- Train a diffusion model in the latent space \( \mathbf{z} \in \mathbb{R}^d \) with \( d \ll HW \)
- Sample in latent space and decode to image
Formally, diffusion is applied to \( \mathbf{z}_0 \sim E(p_{\text{data}}) \), and the generative model is:
\[ p(\mathbf{x}) = \int p(\mathbf{x} \mid \mathbf{z}_0) p(\mathbf{z}_0) \, d\mathbf{z}_0 \]The benefit is computational: both training and sampling are dramatically cheaper in the latent space, and high-frequency details are preserved via decoder upsampling.
Note that conditioning (e.g., text prompts) is often passed through CLIP or T5 embeddings, and cross-attended during sampling steps.
4. Summary
- Numerical samplers control quality–speed tradeoffs; DPM-Solver is among the most effective few-step solvers
- Classifier-free guidance provides flexible and powerful conditioning with minimal overhead
- Latent diffusion enables scalable generation by learning in compressed semantic spaces
These ideas enable state-of-the-art image, audio, and video synthesis and are central to models like Stable Diffusion, Imagen, Midjourney, and Sora.
In the next post, we may look at guided sampling techniques like classifier guidance, score distillation, and semantic control using additional modalities (e.g., keypoints, depth maps).