Part IV Β· Chapter 12

Autoencoders & Variational Autoencoders

Autoencoders learn compact representations by compressing data into a latent code and reconstructing it. Variational Autoencoders (VAEs) add a probabilistic prior on the latent space, enabling principled generation of new samples. We derive the VAE objective from the evidence lower bound, prove the reparameterisation trick, and compute the KL divergence in closed form.

1. Vanilla Autoencoder

An autoencoder is an encoder-decoder pair trained to reconstruct its input:

\[ \mathbf{z} = f_\phi(\mathbf{x}), \qquad \hat{\mathbf{x}} = g_\theta(\mathbf{z}), \qquad \mathcal{L}_\mathrm{AE} = \|\mathbf{x} - \hat{\mathbf{x}}\|^2 \]

The bottleneck (latent dimension \(d \ll D\)) forces the encoder to learn a compressed, informative representation. The decoder must reconstruct the input from this code.

A denoising autoencoder corrupts the input during training:\(\tilde{\mathbf{x}} = \mathbf{x} + \boldsymbol{\epsilon},\;\boldsymbol{\epsilon} \sim \mathcal{N}(0,\sigma^2\mathbf{I})\), then minimises \(\|\mathbf{x} - g_\theta(f_\phi(\tilde{\mathbf{x}}))\|^2\). This encourages learning robust features that are invariant to small perturbations.

2. Variational Autoencoder β€” Full Derivation

The VAE (Kingma & Welling 2014) is a latent variable model with prior \(p(\mathbf{z}) = \mathcal{N}(\mathbf{0},\mathbf{I})\)and likelihood \(p_\theta(\mathbf{x}|\mathbf{z})\). We want to maximise the log-likelihood\(\log p_\theta(\mathbf{x})\), but this requires integrating over all \(\mathbf{z}\):

\[ \log p_\theta(\mathbf{x}) = \log\int p_\theta(\mathbf{x}|\mathbf{z})\,p(\mathbf{z})\,d\mathbf{z} \]

This integral is intractable. Introduce an approximate posterior \(q_\phi(\mathbf{z}|\mathbf{x})\)(the encoder) and apply the identity:

ELBO decomposition

\[ \log p_\theta(\mathbf{x}) = \underbrace{\mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}\!\left[\log p_\theta(\mathbf{x}|\mathbf{z})\right] - \mathrm{KL}\!\left(q_\phi(\mathbf{z}|\mathbf{x})\,\big\|\,p(\mathbf{z})\right)}_{\text{ELBO}} + \mathrm{KL}\!\left(q_\phi(\mathbf{z}|\mathbf{x})\,\big\|\,p(\mathbf{z}|\mathbf{x})\right) \]

Proof: multiply and divide inside the log by \(q_\phi(\mathbf{z}|\mathbf{x})\):

\[ \log p_\theta(\mathbf{x}) = \log \mathbb{E}_{q_\phi}\!\left[\frac{p_\theta(\mathbf{x}|\mathbf{z})\,p(\mathbf{z})}{q_\phi(\mathbf{z}|\mathbf{x})}\right] \geq \mathbb{E}_{q_\phi}\!\left[\log \frac{p_\theta(\mathbf{x}|\mathbf{z})\,p(\mathbf{z})}{q_\phi(\mathbf{z}|\mathbf{x})}\right] \]

by Jensen's inequality (log is concave). The lower bound equals \(\log p_\theta(\mathbf{x})\) iff \(q_\phi = p(\mathbf{z}|\mathbf{x})\).

ELBO rewritten

\[ \mathcal{L}_\mathrm{ELBO} = \underbrace{\mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}\!\left[\log p_\theta(\mathbf{x}|\mathbf{z})\right]}_{\text{reconstruction}} - \underbrace{\mathrm{KL}\!\left(q_\phi(\mathbf{z}|\mathbf{x})\,\big\|\,p(\mathbf{z})\right)}_{\text{regularisation}} \]

The reconstruction term encourages the decoder to reproduce the input. The KL term regularises the approximate posterior toward the prior, preventing overfitting to individual data points.

Gaussian encoder: \(q_\phi(\mathbf{z}|\mathbf{x}) = \mathcal{N}(\boldsymbol{\mu}_\phi(\mathbf{x}), \mathrm{diag}(\boldsymbol{\sigma}^2_\phi(\mathbf{x})))\)

The KL between two Gaussians has a closed form. For diagonal \(q = \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\sigma}^2\mathbf{I})\)and \(p = \mathcal{N}(\mathbf{0}, \mathbf{I})\):

\[ \mathrm{KL}(q \| p) = \frac{1}{2}\sum_{j=1}^d \left(\sigma_j^2 + \mu_j^2 - 1 - \log\sigma_j^2\right) \]

Derivation: \(\mathrm{KL}(q\|p) = \mathbb{E}_q[\log q - \log p] = -\frac{1}{2}\sum_j(1 + \log\sigma_j^2 - \mu_j^2 - \sigma_j^2)\)using the entropy of a Gaussian and the log-normaliser of \(p\).

Reparameterisation trick β€” why it enables gradients

We need to backpropagate through the sampling operation \(\mathbf{z} \sim q_\phi(\mathbf{z}|\mathbf{x})\). NaΓ―ve Monte Carlo sampling is not differentiable w.r.t. \(\phi\) because the distribution itself depends on \(\phi\).

Reparameterisation: write \(\mathbf{z} = \boldsymbol{\mu}_\phi(\mathbf{x}) + \boldsymbol{\sigma}_\phi(\mathbf{x}) \odot \boldsymbol{\varepsilon}\)where \(\boldsymbol{\varepsilon} \sim \mathcal{N}(\mathbf{0},\mathbf{I})\).

\[ \mathbf{z} = \boldsymbol{\mu}_\phi(\mathbf{x}) + \boldsymbol{\sigma}_\phi(\mathbf{x}) \odot \boldsymbol{\varepsilon}, \qquad \boldsymbol{\varepsilon} \sim \mathcal{N}(\mathbf{0},\mathbf{I}) \]

Now \(\mathbf{z}\) is a deterministic function of \((\phi, \boldsymbol{\varepsilon})\). Gradients \(\partial \mathbf{z}/\partial \boldsymbol{\mu}_\phi = \mathbf{I}\) and\(\partial \mathbf{z}/\partial \boldsymbol{\sigma}_\phi = \mathrm{diag}(\boldsymbol{\varepsilon})\)flow through \(\mathbf{z}\) to \(\phi\) via the chain rule.

VAE Training Objective (maximise)

\[ \mathcal{L}_\mathrm{VAE}(\phi,\theta;\mathbf{x}) = \mathbb{E}_{\boldsymbol{\varepsilon}\sim\mathcal{N}(0,\mathbf{I})}\!\left[\log p_\theta(\mathbf{x}|\mathbf{z})\right] - \frac{1}{2}\sum_j\!\left(\sigma_{\phi,j}^2 + \mu_{\phi,j}^2 - 1 - \log\sigma_{\phi,j}^2\right) \]

For Gaussian likelihood \(p_\theta(\mathbf{x}|\mathbf{z}) = \mathcal{N}(\hat{\mathbf{x}}, \mathbf{I})\), the reconstruction term becomes \(-\|\mathbf{x}-\hat{\mathbf{x}}\|^2/2\) (MSE loss up to constants).

3. VAE Architecture Diagram

InputxEncoderq_phi(z|x)\u03bc_philog\u03c3\u00b2_phi\u03b5 ~ N(0,I)z =\u03bc+\u03c3\u22c5\u03b5Decoderp_theta(x|z)OutputxΜ‚Reparameterisation: z = \u03bc + \u03c3 \u2299 \u03b5 (enables backprop)

The encoder outputs \(\boldsymbol{\mu}_\phi\) and \(\log\boldsymbol{\sigma}^2_\phi\). The reparameterisation node samples \(\mathbf{z}\) differentiably, enabling gradients to flow from the decoder loss all the way back through to the encoder parameters.

4. Python: VAE on Synthetic 2D Data

Full NumPy VAE implementation with manual backpropagation through the reparameterisation trick. Trained on a 2D three-cluster dataset; we visualise the latent space, reconstructions, and new samples decoded from the prior \(\mathcal{N}(\mathbf{0},\mathbf{I})\).

Python
script.py232 lines

Click Run to execute the Python code

Code will be executed with Python 3 on the server