banner
Energy-based models
Huh?
#️⃣   ⌛  ~1.5 h 🤓  Intermediate
23.06.2023
upd:
#58

views-badgeviews-badge
banner
Energy-based models
Huh?
⌛  ~1.5 h
#58


🎓 79/167

This post is a part of the Generative models educational series from my free course. Please keep in mind that the correct sequence of posts is outlined on the course page, while it can be arbitrary in Research.

I'm also happy to announce that I've started working on standalone paid courses, so you could support my work and get cheap educational material. These courses will be of completely different quality, with more theoretical depth and niche focus, and will feature challenging projects, quizzes, exercises, video lectures and supplementary stuff. Stay tuned!


Energy-based models (EBMs) have seen a resurgence of interest in modern machine learning research thanks to their capacity for representing complex, high-dimensional probability distributions in a principled yet flexible manner. Although these models date back several decades — with roots in statistical physics and early neural network literature — many of their core ideas have reemerged recently with new theoretical insights and improved training methodologies. The enhanced computational power and deep architectures available today have unlocked possibilities that might have been infeasible in earlier eras.

The motivation for energy-based modeling stems from the fundamental need to learn probability distributions over richly structured data, such as images, videos, text, audio signals, and even multimodal combinations. In numerous tasks, what we really desire is a mechanism to assess how "likely" or "plausible" a given data point might be. Consider image generation: we want to sample new images from the same distribution as the training data (e.g., real photographs of faces or digits). Or consider anomaly detection, where we want to know if a novel data point lies outside the typical distribution of observations. Likewise, for tasks such as classification under uncertain conditions, an energy-based perspective can provide a unifying framework that goes beyond purely discriminative approaches.

Why not just learn a direct parametric probability density like a standard Gaussian or a more specialized model? For low-dimensional or structured data, simpler parametric methods might suffice. But for modern, large-scale tasks (such as natural images with tens of thousands or millions of pixels, or language with huge vocabularies), a flexible modeling approach is required. The classical approach is to pick a parametric family of densities with tractable normalization (for instance, a latent-variable model or an autoregressive model), but each approach comes with constraints. Energy-based modeling offers an alternative view: it transforms an arbitrary "energy function" (which can be any scalar output from a neural network) into a distribution by exponentiating and normalizing. This generality provides vast expressive power and unifies many previously distinct models under a single conceptual umbrella.

In certain respects, EBMs can be viewed as bridging the gap between purely discriminative networks (where the model typically outputs scores or logits) and fully generative methods (where the model tries to output or sample from a learned distribution). The hallmark of EBMs is that they do not strictly require the direct computation of a normalized probability density at training or inference time, as they lean heavily on sampling-based approximations. In practice, this means we can incorporate a wide range of neural architectures (feed-forward networks, convolutional or recurrent layers, etc.) without needing special restrictions on how the outputs are formed, other than ensuring that we can backpropagate to obtain gradients with respect to inputs.

However, with such flexibility come challenges. Training EBMs is notoriously tricky, often requiring Markov Chain Monte Carlo (MCMC) procedures inside each iteration of gradient descent. Without careful attention to hyperparameters — such as step sizes, number of sampling steps, noise levels, architectural constraints, and additional stability tricks — the training can easily diverge or collapse. Yet, for researchers and practitioners willing to invest in carefully tuning these methods, EBMs offer exciting prospects in generative modeling, anomaly detection, interpretability, compositionality, and more.

In this article, I will present a comprehensive overview of energy-based models, from the fundamental equations to advanced applications and stable training strategies. The aim is to thoroughly explore the concepts, math, algorithms, code, and some of the typical pitfalls and solutions one may encounter when deploying EBMs in practice. Throughout, I will highlight connections to canonical examples such as Boltzmann machines and restricted Boltzmann machines, as well as more recent deep architectures that leverage convolutional or residual networks. We will dive into sampling and training procedures like Contrastive Divergence (CD) and Stochastic Gradient Langevin Dynamics (SGLD), and we will walk through an implementation that can learn to generate realistic images (e.g., MNIST digits). By the end, I hope you will see how EBMs can form a powerful and unifying framework in modern machine learning research.


2. Fundamentals of energy-based models

2.1 defining an energy function

The central object in an energy-based model is the energy function, typically written as Eθ(x)E_\theta(x) for data point xx (for instance, an image or a vector of features) under parameters θ\theta. This function Eθ(x)E_\theta(x) maps each possible configuration of xx to a real value (often unconstrained in −∞ to +∞). Intuitively, if Eθ(x)E_\theta(x) is small (i.e., negative and large in magnitude), we want to interpret xx as having high "compatibility" or likelihood. If Eθ(x)E_\theta(x) is large (positive), we interpret xx as unlikely under the model.

In physical systems, energy is a measure of how stable or probable a configuration is. Low-energy states represent stable configurations that the physical system often "prefers". By analogy, an EBM tries to learn an energy landscape that assigns low energies to regions of data space observed frequently in the training set, and high energies to regions that are rare or absent. Put another way, the EBM has no inherent restriction on the functional form of Eθ(x)E_\theta(x). It just learns to shape an energy surface so that the training data points end up in low-energy "valleys" while less plausible points are located in high-energy "peaks".

Mathematically, Eθ:XRE_\theta: \mathcal{X} \rightarrow \mathbb{R} can be realized by any neural network architecture that outputs a single scalar. For example, a feed-forward network with fully connected layers might produce a single logit. Alternatively, for image data, a convolutional network can output one scalar per image; for text, a transformer or RNN encoder might produce an embedding followed by a linear transform that yields a scalar.

Because Eθ(x)E_\theta(x) is so flexible, there is no fundamental structural constraint except that we need to be able to differentiate Eθ(x)E_\theta(x) with respect to xx (and of course with respect to θ\theta). This gradient with respect to xx is crucial in the typical sampling-based training process.

2.2 turning energy into probability (boltzmann distribution)

While an energy function alone does not necessarily define a probability distribution, we can turn it into one by employing the Boltzmann (or Gibbs) distribution from statistical physics. The probability of a configuration xx under parameters θ\theta is given by:

pθ(x)=exp(Eθ(x))Z(θ), p_\theta(x) = \frac{\exp\left(-E_\theta(x)\right)}{Z(\theta)},

where:

  • Eθ(x)E_\theta(x) is the energy.
  • Z(θ)Z(\theta) is the normalizing constant, also called the partition function:
Z(θ)=exp(Eθ(x))dx Z(\theta) = \int \exp\left(-E_\theta(x)\right)\, dx

in the continuous case, or a sum xexp(Eθ(x))\sum_x \exp(-E_\theta(x)) in the discrete case.

The presence of Z(θ)Z(\theta) ensures that pθ(x)p_\theta(x) integrates or sums to 1 over the entire domain X\mathcal{X}. Although we have introduced the negative sign inside the exponential for convention (low energy → large exponent → high probability), the exact sign is not crucial if we carefully keep track of how we interpret the energy.

2.3 partition function and normalization challenges

The partition function Z(θ)Z(\theta) poses one of the principal challenges of energy-based modeling. In high-dimensional spaces, computing Z(θ)Z(\theta) exactly is usually intractable because it entails integrating or summing over a vast range of xx. For small discrete systems or low-dimensional cases, one might compute Z(θ)Z(\theta) exactly, but in modern deep learning contexts (images, text, audio, etc.), it is effectively impossible.

This complication means that we typically cannot directly maximize pθ(x)p_\theta(x) in closed form. Instead, we resort to approximate methods, often involving sampling procedures (Markov Chain Monte Carlo is common) that produce samples from pθ(x)p_\theta(x) or, more precisely, approximate the gradient of the log-likelihood with respect to θ\theta. Contrastive Divergence (CD), introduced by Hinton, is one of the best-known sampling-based training methods for EBMs. Later sections will dive into the details.

When Z(θ)Z(\theta) is unknown, we typically aim to shape the energy landscape Eθ(x)E_\theta(x) so that xx from our dataset land in the low-energy regions while out-of-distribution xx land in high-energy basins. We do this by comparing energies of real data points against energies of points sampled from the model's own distribution. This ensures we do not need the exact value of Z(θ)Z(\theta) but only a scheme that manipulates the relative energies of real versus fake points.

2.4 canonical ensemble learning (cel) and the statistical physics connection

The original perspective on EBMs is closely aligned with the notion of a "canonical ensemble" in statistical physics. In physics, a system in thermal equilibrium at temperature TT follows the Boltzmann distribution:

p(x)exp(E(x)kBT), p(x) \propto \exp\left(-\frac{E(x)}{k_B T}\right),

where E(x)E(x) is energy, kBk_B is the Boltzmann constant, and TT is temperature. Drawing parallels to EBMs, we can see that Eθ(x)E_\theta(x) is a learned energy function, and the partition function normalizes the resulting exponentiated energy. In the machine learning context, we often set kBT=1k_B T=1 for simplicity, or treat the temperature as an additional hyperparameter that can modulate the sharpness or smoothness of the energy landscape.

This analogy is deeper than it appears: many methods for sampling from EBMs (like Langevin dynamics) are effectively simulating the diffusion or random walk of a physical system in its energy landscape. By carefully introducing noise (analogous to thermal fluctuations) and performing gradient-based steps (analogous to the system relaxing toward low-energy states), we can sample from pθ(x)p_\theta(x).


3. core principles in practice

3.1 energy vs. probability: a conceptual comparison

It may be useful to compare an EBM's approach to that of purely probabilistic models (e.g., normalizing flows or variational autoencoders). In a normalizing flow, for instance, one tries to construct a bijection that maps a base distribution (like a Gaussian) to the target distribution. The main challenge is to keep track of the log Jacobian determinant so that we can properly normalize. For EBMs, we do not bother with designing invertible transformations; we only define Eθ(x)E_\theta(x). The probability is implicitly derived via exponentiation and normalization, which, while powerful, leads to the computational difficulty of not knowing Z(θ)Z(\theta).

A purely probabilistic model also tries to approximate p(x)p(x) by specifying it in a parametric form from the get-go (for instance, a factorized distribution or a certain latent-variable model). In contrast, an EBM can represent highly multimodal distributions or distributions that do not factorize easily, because Eθ(x)E_\theta(x) can be any neural network. The cost of that generality is the need for specialized sampling and training protocols.

3.2 the role of latent variables in ebms

Another dimension of design is whether the EBM has latent variables hh in addition to visible variables xx. Boltzmann machines, for example, can contain hidden units that can shape the energy of the visible units in ways that capture underlying structure. In a latent-variable EBM, we might define:

Eθ(x)=minhHEθ(x,h), E_\theta(x) = \min_{h \in \mathcal{H}} E_\theta(x, h),

or we might define:

pθ(x)exp(Eθ(x,h))dh. p_\theta(x) \propto \int \exp\bigl(-E_\theta(x, h)\bigr)\, dh.

Marginalizing out the hidden variable hh can again be complicated in high-dimensional scenarios. Restricted Boltzmann machines (RBMs) avoid some complexities by adopting a bipartite structure that permits partial factorization. However, for deep EBMs with complex latent spaces, we again rely on approximate sampling or gradient-based methods to handle the latent dimension.

3.3 free energy and marginalizing over latent variables

In EBM literature, one often sees references to free energy. The free energy functional can appear when we integrate out latent variables. If hh denotes hidden states, then:

Fθ(x)=logexp(Eθ(x,h))dh. F_\theta(x) = -\log \int \exp\bigl(-E_\theta(x, h)\bigr)\, dh.

This Fθ(x)F_\theta(x) is known as the free energy of xx. Minimizing free energy is effectively encouraging the existence of some hidden representation hh that yields low total energy. Indeed, for models like RBMs, training procedures often revolve around approximating gradients of the free energy. In advanced latent EBMs, free-energy-based training can also be used, but sampling or optimization in hh may again be nontrivial.

3.4 advantages over purely probabilistic approaches

EBMs present several advantages:

  • Flexibility: We can parametrize Eθ(x)E_\theta(x) with almost any neural network, unconstrained by invertibility or closed-form integrals.
  • Unified view: Both discriminative tasks (classification) and generative tasks (sampling) can be formulated under a single framework, often by combining the energy with additional terms or introducing class labels as part of the energy function.
  • Multimodality: Because the energy surface can be arbitrarily shaped, EBMs readily capture complicated, multimodal distributions.
  • Anomaly detection: As we will see, EBMs naturally produce a scalar that can serve as an anomaly score. Data that is truly out of distribution tends to yield higher energies.

However, these advantages come with the significant computational burden of not having direct access to Z(θ)Z(\theta), leading to the necessity of iterative sampling.


4. training

4.1 maximum likelihood and why it is difficult with ebms

When training a generative model, we typically maximize the log-likelihood of the observed data. For an EBM, the log-likelihood of a single data point xx is:

logpθ(x)=Eθ(x)logZ(θ). \log p_\theta(x) = - E_\theta(x) - \log Z(\theta).

Taking the gradient with respect to θ\theta gives:

θlogpθ(x)=θEθ(x)θlogZ(θ). \nabla_\theta \log p_\theta(x) = -\nabla_\theta E_\theta(x) - \nabla_\theta \log Z(\theta).

The gradient of the log partition function logZ(θ)\log Z(\theta) is:

θlogZ(θ)=1Z(θ)θZ(θ)=1Z(θ)θexp(Eθ(x))dx. \nabla_\theta \log Z(\theta) = \frac{1}{Z(\theta)} \nabla_\theta Z(\theta) = \frac{1}{Z(\theta)} \nabla_\theta \int \exp\bigl(- E_\theta(x')\bigr)\, dx'.

We can rewrite it as an expectation under pθ(x)p_\theta(x'):

θlogZ(θ)=Expθ(x)[θEθ(x)]. \nabla_\theta \log Z(\theta) = \mathbb{E}_{x' \sim p_\theta(x')}\bigl[\nabla_\theta E_\theta(x')\bigr].

Thus,

θlogpθ(x)=θEθ(x)+Expθ(x)[θEθ(x)]. \nabla_\theta \log p_\theta(x) = -\nabla_\theta E_\theta(x) + \mathbb{E}_{x' \sim p_\theta(x')}\bigl[\nabla_\theta E_\theta(x')\bigr].

Implementing this gradient in a naive way would require sampling from pθp_\theta, which itself is unknown unless we run an expensive Markov chain that relies on computing gradients of Eθ(x)E_\theta(x) with respect to xx. This is why maximum-likelihood training of EBMs is typically done through approximate methods such as Monte Carlo sampling.

4.2 contrastive divergence (cd): the fundamental training objective

Contrastive Divergence (CD), introduced by Hinton, is arguably the most common method for training EBMs (at least historically, especially in the context of restricted Boltzmann machines). The idea is to approximate the gradient of the log-likelihood by short-run MCMC chains initialized at training data or from some buffer of previously generated samples. Specifically, the CD-k algorithm starts from observed data, runs k steps of Gibbs sampling (or another MCMC method) to produce a negative sample, and then computes:

θCDkθEθ(data)+θEθ(neg sample), \nabla_\theta \text{CD}_k \approx - \nabla_\theta E_\theta(\text{data}) + \nabla_\theta E_\theta(\text{neg sample}),

where the negative sample is drawn from the short-run Markov chain. The hope is that with enough training iterations, the chain states approximate pθ(x)p_\theta(x) sufficiently for learning to succeed.

4.3 detailed derivation of contrastive divergence

A more formal derivation frames the maximum likelihood gradient as:

θEθ(x)+pθ(x)θEθ(x)dx. -\nabla_\theta E_\theta(x) + \int p_\theta(x') \nabla_\theta E_\theta(x')\, dx'.

Since we cannot compute the integral exactly, we approximate pθ(x)p_\theta(x') by a distribution q(x)q(x') that we sample from in a short-run chain. The simplest approach is to start from the data point xx itself and take only a few MCMC steps. This yields a sample x(k)x^{(k)} that hopefully is close to the modes of the distribution but is cheaper to compute than running a long chain. We then replace pθ(x)dx\int p_\theta(x') \ldots dx' with the single-sample approximation from x(k)x^{(k)}. The difference Eθ(x(k))Eθ(x)E_\theta(x^{(k)}) - E_\theta(x) is the core of the training signal: we want to push down the energy of real data, while pushing up the energy of samples from the model.

4.4 intuition behind "pulling up" and "pushing down" energies

One of the most intuitive explanations of the CD objective is that we have two forces:

  • Pull down the energy of real data so these points reside in low-energy basins (i.e., are assigned high probability).
  • Push up the energy of negatively sampled points so that the model does not erroneously assign them too high probability.

This combination of attractive and repulsive forces sculpts the energy landscape in a way that matches the data distribution. Imagine each real data point creates a "well" around itself, whereas each negative sample pushes up the energy in some region, flattening it out or creating a barrier that prevents the distribution from spreading too widely in unobserved areas.

4.5 stochastic gradient langevin dynamics (sgld)

One popular variant for sampling negative examples during training is Stochastic Gradient Langevin Dynamics (SGLD). Langevin dynamics is an MCMC approach that treats updates in xx-space as a gradient descent on Eθ(x)E_\theta(x) plus a noise term that ensures exploration of other modes. Specifically, the update for xx can look like:

xt+1=xtαxEθ(xt)+2αηt, x_{t+1} = x_t - \alpha \nabla_x E_\theta(x_t) + \sqrt{2\alpha}\,\eta_t,

where ηt\eta_t is Gaussian noise and α\alpha is a small step size. In the limit of infinitely small steps and infinite sampling time, this procedure samples exactly from the Boltzmann distribution pθ(x)p_\theta(x). In practice, we only take a finite number of steps (like 10 to 100), so the result is an approximate sample.

SGLD is often favored for neural-network-based EBMs because of its relative simplicity to implement and its strong theoretical grounding in the connection to continuous-time diffusion processes.

4.6 other mcmc approaches (metropolis–hastings, gibbs sampling)

Besides Langevin dynamics, classical MCMC methods like Metropolis–Hastings and Gibbs sampling are also widely used, especially in restricted Boltzmann machines. However, for continuous high-dimensional data like images, Metropolis–Hastings can be cumbersome to tune, and conditional distributions for Gibbs sampling might be intractable. When feasible, partial or block Gibbs updates can help, but in many modern EBM setups, short-run Langevin dynamics or variants of gradient-based MCMC are more common.


5. implementation and tricks

5.1 basic neural network parameterizations of energy functions

Since Eθ(x)E_\theta(x) is unconstrained, we can choose a straightforward architecture for images, such as a multi-layer convolutional network that reduces an input image to a single scalar:


import torch
import torch.nn as nn

class EnergyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*8*8, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.features(x).squeeze()  # scalar

You can see that the output is a single scalar for each input xx. We interpret this as Eθ(x)E_\theta(x). Of course, this can be extended with deeper residual blocks, skip connections, or attention mechanisms, but the principle remains: an arbitrary function that yields a scalar.

5.2 sampling buffer or replay buffer for stabilized training

A widely used trick in modern EBM training is maintaining a "sampling buffer" or "replay buffer" from which we initialize new Markov chains. Rather than starting each chain from random noise, which can require many MCMC steps to yield a decent sample, we re-initialize a portion of the negative samples from a buffer that holds previously generated samples. Because these samples are often somewhat close to the modes, fewer MCMC steps are needed to refine them. This approach speeds up training substantially and can mitigate instability.

5.3 step sizes, noise schedules, and gradient clipping

Tuning the hyperparameters of the MCMC procedure is critical:

  • Step size α\alpha: If the step size is too large, chains may diverge or skip over important modes. If it is too small, we need too many steps for mixing or mode exploration.
  • Noise variance: The random noise we add in Langevin or Metropolis updates is crucial for exploration. It can be constant or gradually decreased. Some advanced schedules gradually reduce noise as we move closer to equilibrium.
  • Gradient clipping: Because neural networks can produce large gradients, it is common to clamp gradients within a certain range (e.g., [0.03,0.03][-0.03, 0.03]). This helps avoid explosive updates that push the sample out of the data manifold.

5.4 regularization terms and bias constraints

Without additional constraints, the output of Eθ(x)E_\theta(x) can drift arbitrarily. For instance, the model might systematically push energies of all points downward by adding a constant negative offset, messing up the normalization. One common remedy is to add a penalty on the magnitude of the network output — for instance, αEθ(x)2\alpha \|E_\theta(x)\|^2 — so that energies remain near zero for real samples. Some authors also apply weight decay or other stabilizers that limit the scale of the parameters themselves.

5.5 monitoring training progress: loss terms and sample quality

Because typical losses (like CD) can sometimes appear to converge even while the model distribution is not well-formed, it is crucial to track more than just the final loss. Researchers often monitor:

  • The average energy assigned to training data.
  • The average energy assigned to generated samples.
  • Metrics on sample quality, such as how visually coherent or diverse they are.
  • Additional metrics like the FID (Fréchet Inception Distance) or Inception Score for image models, or perplexity for text.

5.6 code snippets

Below is a partial snippet illustrating how one might integrate the MCMC sampling step into a training loop using PyTorch. This snippet is adapted from a more complete code base:


def langevin_sampling(model, x_init, steps=60, step_size=10, noise_std=0.005):
    x = x_init.clone().detach().requires_grad_(True)
    for i in range(steps):
        # random noise
        noise = torch.randn_like(x) * noise_std
        x = x + noise
        # clamp input range
        x = x.clamp(-1., 1.)

        # compute energy
        energy = model(x)
        # we want to descend the gradient of E, so we do x <- x - grad(E)
        energy.sum().backward()
        with torch.no_grad():
            x = x - step_size * x.grad
        x.grad.zero_()

    return x.detach()

# In a typical training step:
# 1. get real data
# 2. sample from model using replay buffer or random initialization
# 3. compute E(real) and E(sampled)
# 4. compute contrastive divergence loss
# 5. backprop and update

6. canonical examples and architectures

6.1 boltzmann machines and restricted boltzmann machines

Historically, Boltzmann machines (BMs) played a key role in popularizing EBMs in the neural network community. A BM is an undirected graphical model with visible units xx and possibly hidden units hh. The energy function is often bilinear in xx and hh:

Eθ(x,h)=xTWhbTxcTh, E_\theta(x, h) = - x^T W h - b^T x - c^T h,

with θ={W,b,c}\theta = \{W, b, c\}. If we do not impose any restrictions on the connectivity, sampling can be extremely difficult because xx and hh are heavily coupled. A Restricted Boltzmann Machine (RBM) uses bipartite connections between the visible and hidden layers, but no connections among visible units or among hidden units themselves. This structure permits a faster block Gibbs sampling approach and simpler updates for contrastive divergence.

RBMs found numerous applications (like pretraining of deep networks in the late 2000s) and remain a classic demonstration of energy-based modeling. However, in modern practice, deep convolutional or residual-based EBMs are typically used when dealing with high-resolution images or other complex data sources.

6.2 deep energy-based models using convolutional networks

Recent years have seen a surge of interest in deep EBMs that employ convolutional networks as the backbone for Eθ(x)E_\theta(x). This is useful for images or volumetric data (3D). The idea is straightforward: pass xx through multiple convolutional layers, possibly with residual or skip connections, to produce a single scalar. The architecture can be reminiscent of discriminators from Generative Adversarial Networks (GANs), except that we interpret the scalar as an energy instead of a logit for real vs. fake classification.

6.3 residual architectures and hybrid ebms

Some advanced EBMs adopt ResNet-like blocks for the energy function to ease gradient flow and stabilize training. Because we rely heavily on gradient-based sampling, the smoothness of the neural network is crucial. Activation choices can matter; for instance, some researchers find that smoother nonlinearities (like softplus or Swish) can lead to more stable MCMC sampling.

A hybrid EBM might combine an energy-based objective with additional discriminative or generative losses. For example, one could design a multi-task network that outputs both a classification score for xx and an energy. The energy can then be used for anomaly detection, while the classification head is used for supervised learning. This synergy sometimes helps calibrate energies in a more stable way.

6.4 overview of joint energy-based models (jem)

Joint Energy-based Models (JEM), introduced by Grathwohl and gang (ICLR 2020), unify classification and energy-based modeling by letting Eθ(x,y)E_\theta(x, y) represent the joint energy of data xx and label yy. Classification amounts to finding y^=argminyEθ(x,y)\hat{y} = \arg\min_y E_\theta(x, y), while generative modeling arises from sampling xx from exp(Eθ(x,y))\exp(-E_\theta(x, y)). This approach showed that a single model could match the performance of strong discriminative classifiers while also acting as a generative model. However, training stability remains a concern, requiring careful MCMC parameter tuning and additional regularization terms.


7. generative modeling applications

7.1 image generation

One of the most common demonstrations of EBMs is image synthesis. By training an EBM on a dataset of images (MNIST, CIFAR-10, etc.), we can eventually sample new images by running MCMC in pixel space. The model attempts to shift random noise toward the modes in the distribution. For instance, if we train on handwritten digits, random initial images gradually morph into something resembling digits (0 through 9).

Unlike certain other generative methods (e.g., VAEs or GANs), EBMs do not necessarily require a separate "generator" network. Instead, we rely on iterative sampling to produce new samples. This iterative approach is computationally heavier, but it can potentially model data more flexibly.

7.2 training setup and hyperparameter considerations

To obtain high-quality samples, you typically need:

  1. A properly sized model with enough capacity to represent the complexity of the data distribution.
  2. A carefully tuned MCMC procedure (step size, number of steps, noise level).
  3. Possibly a replay buffer to accelerate mixing.
  4. Additional regularization (e.g., penalizing large energy values).
  5. Sufficient training time (and perhaps reinitializing from stable checkpoints if divergence occurs).

For example, with MNIST (28×28 grayscale images), a small convolutional EBM can suffice. But for larger images (like ImageNet, 128×128 or higher resolution), you may need a deep ResNet or even more advanced architectures, plus careful multi-GPU training protocols.

7.3 evaluating generation quality (fid, inception score)

As with any generative model, we often rely on statistics like the Frechet Inception Distance (FID) or Inception Score to measure the realism and diversity of generated images. While EBMs can achieve competitive results with enough tuning, they are sometimes outperformed in practice by specialized methods like GANs or normalizing flows on certain benchmarks. Nonetheless, some studies (Du and Mordatch, 2019) have shown that with improved training, deep EBMs can match or surpass strong GAN baselines on image generation tasks.

7.4 video and 3d data generation

In principle, EBMs can handle any dimensional input, including video frames or 3D voxel grids/point clouds, by letting Eθ(x)E_\theta(x) be defined on these structured inputs. The main difference is that sampling becomes even more computationally intensive due to the higher dimensionality. Researchers have proposed specialized architectures for 3D data or for spatiotemporal correlation in video. Many of these remain active areas of investigation, where advanced EBMs, potentially combined with latent representations, may unlock more tractable training.

7.5 comparison with gans, vaes, and normalizing flows

  • GANs: Provide fast sampling at test time (just a forward pass of the generator) but can suffer from mode collapse. EBMs do iterative sampling (slower) but can represent multiple modes more naturally. Training EBMs can be more stable in some respects (no discriminator–generator two-player game) but can diverge in others if MCMC hyperparameters are not set carefully.
  • VAEs: Provide a lower-bound optimization approach, often yield smooth latent spaces, and have straightforward sampling. However, the approximate posterior might be restrictive or cause issues like posterior collapse. EBMs skip the explicit latent posterior in the simplest formulation.
  • Normalizing flows: Guarantee exact log-likelihood evaluation with direct sampling. But they require invertible transformations that can limit architecture design or memory usage. EBMs are unconstrained in that sense but pay the price with approximate partition functions.

7.6 mode collapse vs. multi-modal distributions

GANs often face the infamous problem of mode collapse, where the generator only learns a subset of all possible modes. While EBMs can also fail to discover all modes if the MCMC does not explore thoroughly, the approach is, in principle, less prone to strict collapse. Because we are pushing energies up or down in the entire space, multiple modes can remain. Nonetheless, in practice, short-run MCMC or poor hyperparameters can inadvertently fail to capture modes, so in that sense, partial collapses or slow mixing can occur.


8. other applications of energy-based models

8.1 out-of-distribution detection (anomaly detection)

One of the most promising applications for EBMs is out-of-distribution (OOD) detection. Because Eθ(x)E_\theta(x) is smaller (more negative) for typical in-distribution samples and larger (positive) for anomalies, a simple threshold on energy can serve as an OOD detector. In contrast, purely discriminative methods might produce high confidence for OOD inputs.

8.2 classification and object recognition reinterpreted as ebms

As mentioned, a classifier can be recast as an EBM on (x,y)(x, y). The probability of label yy given xx might be:

pθ(yx)=exp(Eθ(x,y))yexp(Eθ(x,y)). p_\theta(y|x) = \frac{\exp\left(-E_\theta(x, y)\right)} {\sum_{y'} \exp\left(-E_\theta(x, y')\right)}.

By training Eθ(x,y)E_\theta(x, y) with maximum likelihood or with an approximate scheme, we effectively do supervised learning. The advantage is that we can then also try to sample from pθ(x,y)p_\theta(x, y) to generate synthetic (x,y)(x, y) pairs or do novelty detection for label–input mismatches.

8.3 denoising and image reconstruction tasks

In tasks like image denoising or inpainting, we want to find xx close to a noisy observation xnoisyx_{\text{noisy}} but also lying in the manifold of possible clean images. An EBM can solve:

minxEθ(x)+λd(x,xnoisy), \min_x \, E_\theta(x) + \lambda \, d\bigl(x, x_{\text{noisy}}\bigr),

where dd is some distance metric (e.g., L2L_2). In practice, a gradient-based routine can do an iterative refinement, pushing xx to low energy while remaining close to the observed data.

8.4 natural language processing perspectives

Text data is discrete, which complicates gradient-based MCMC. Nonetheless, if we define an EBM in the latent continuous space of a neural text encoder, we can attempt to do sampling in that representation. Another approach is to apply Gumbel-softmax reparameterizations to handle discrete tokens. Although less common than, say, transformers for language modeling, EBM-based text generation has been explored in specialized research. It remains a challenging domain, partially because of how to incorporate language structure into the energy function and sampling procedure.


9. handling training instabilities

9.1 common pitfalls: divergence scenarios and local maxima

A hallmark of EBM training is that it can diverge if hyperparameters are poorly tuned or if the network can trivially push down energies for both real and negative samples. Divergence may manifest as the sampling chain producing meaningless noise while the model's energies saturate at negative values. Once diverged, the model can be difficult to recover without reverting to an earlier checkpoint.

Another potential pitfall is that the MCMC chain gets trapped in local minima or spurious modes, never exploring the true distribution. This results in inaccurate negative samples that do not provide the correct push-back in the CD training objective.

9.2 hyperparameter sensitivities (learning rate, noise, step count)

Practitioners typically note that EBM training is more sensitive to hyperparameters than other deep models. The learning rate for model updates, the number of MCMC steps per iteration, the step size inside each MCMC step, and the noise level are all crucial. A small mismatch can lead to either an overly smoothed or an extremely rugged energy surface.

Tactics include:

  • Start with smaller step sizes for MCMC and gradually increase.
  • Introduce progressive noise schedules or cyclical approaches where the noise is varied in each epoch.
  • Carefully tune the ratio of real data to negative data or the frequency with which the replay buffer is updated.

9.3 checkpoint reloading strategies for recovering from divergence

When your EBM starts to diverge, you might see a sudden spike in energies or the negative sample energies saturating. As a fallback, one can revert to a stable checkpoint from a few epochs earlier. This strategy can salvage training runs that are otherwise irrecoverable, though it is essentially a manual fix. In some pipelines, automated heuristics watch for divergence signals and revert the model automatically to a stable checkpoint, perhaps with adjusted hyperparameters.

9.4 regularization and architectural constraints for stability

Adding an explicit penalty on Eθ(x)\|E_\theta(x)\| or bounding the network's last layer can help limit runaways. Certain architectures that produce smoother energy surfaces might also be more stable in practice. For instance, networks that avoid ReLU dead zones and rely on activations that are differentiable at zero (like Swish or softplus) can yield more stable sampling updates. Additionally, spectral normalization of the convolution weights, commonly used in GAN discriminators, can help keep the EBM's gradients within a bounded range.


10. advanced topics and extensions

10.1 compositional ebms (product of experts)

EBMs are naturally compositional because energies can be added. The product of experts framework (e.g., multiple RBMs combined) states that:

Ecombined(x)=Eθ1(x)+Eθ2(x), E_{\text{combined}}(x) = E_{\theta_1}(x) + E_{\theta_2}(x),

thus leading to pcombined(x)=pθ1(x)pθ2(x)/Zp_{\text{combined}}(x) = p_{\theta_1}(x) \, p_{\theta_2}(x) / Z. This can incorporate multiple knowledge sources or constraints into a single EBM, each of which shapes the final distribution.

10.2 hybrid monte carlo and advanced sampling methods

Hybrid/Hamiltonian Monte Carlo (HMC) uses Hamiltonian dynamics to propose new states more efficiently in high dimensions, often resulting in better mixing than simpler methods. For EBMs with well-behaved gradients, HMC can be beneficial, though it is more complex to implement (requiring leapfrog steps, momentum variables, acceptance–rejection, etc.). Another direction is tempered transitions, in which intermediate distributions bridge from a broad distribution (high temperature) to the target distribution (temperature 1), mitigating the risk of being stuck in a single mode.

10.3 energy-based interpretation of classifiers (jem revisited)

JEM's approach can be generalized to many classification tasks: each label yy is associated with a sub-landscape in xx-space, and the relative energies define pθ(yx)p_\theta(y|x). One can also attempt open-set recognition by checking the absolute energy Eθ(x)E_\theta(x). If it is too large, then xx might be out of distribution, even if argminyEθ(x,y)\arg\min_y E_\theta(x,y) yields some label. This gives classifiers a built-in OOD rejection feature, at least in theory.

10.4 incorporating domain knowledge or symbolic constraints

Because EBMs revolve around shaping an energy function, domain knowledge can be injected as an additional term in the energy. For example, if we know certain constraints on xx (like geometric constraints, or physical constraints for robotics states), we can add a penalty to Eθ(x)E_\theta(x) that enforces or encourages them. This can be done by a hand-designed potential or an auxiliary network that encodes domain-specific knowledge. The synergy between learned features and hand-crafted constraints can be quite powerful.


11. implementation walkthrough (example code)

In this section, let's go step by step through a code example of training an EBM for MNIST digit generation. While we focus on MNIST for demonstration, similar principles extend to other image datasets or even other modalities, with changes mainly in the network architecture and hyperparameters.

11.1 data loading and normalization steps

Below is a typical snippet (PyTorch-based) for loading MNIST, normalizing pixel values from -1 to 1 for convenience:


from torchvision.datasets import MNIST
from torchvision import transforms
import torch.utils.data as data

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # scale to [-1, 1]
])

train_set = MNIST(root="data", train=True, transform=transform, download=True)
test_set = MNIST(root="data", train=False, transform=transform, download=True)

train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4)

11.2 model architecture (simple cnn or resnet)

We define a small CNN to output a single scalar for each 28×28 input image:


import torch
import torch.nn as nn
import torch.nn.functional as F

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class CNNModel(nn.Module):
    def __init__(self, hidden_features=32):
        super().__init__()
        c_hid1 = hidden_features // 2
        c_hid2 = hidden_features
        c_hid3 = hidden_features * 2
        self.cnn_layers = nn.Sequential(
            nn.Conv2d(1, c_hid1, kernel_size=5, stride=2, padding=4),
            Swish(),
            nn.Conv2d(c_hid1, c_hid2, kernel_size=3, stride=2, padding=1),
            Swish(),
            nn.Conv2d(c_hid2, c_hid3, kernel_size=3, stride=2, padding=1),
            Swish(),
            nn.Conv2d(c_hid3, c_hid3, kernel_size=3, stride=2, padding=1),
            Swish(),
            nn.Flatten(),
            nn.Linear(c_hid3*4, c_hid3),
            Swish(),
            nn.Linear(c_hid3, 1)
        )

    def forward(self, x):
        return self.cnn_layers(x).squeeze(dim=-1)

Here, Eθ(x)E_\theta(x) is effectively output-\text{output} if we follow the negative sign convention. In practice, you can store Eθ(x)E_\theta(x) as the direct output and interpret the sign carefully in the training objective.

11.3 contrastive divergence training loop: real vs. fake images, buffer sampling strategy, loss function details

We now define a replay buffer for negative samples:


import random
import numpy as np

class Sampler:
    def __init__(self, model, img_shape, sample_size=128, max_len=8192):
        self.model = model
        self.img_shape = img_shape
        self.sample_size = sample_size
        self.max_len = max_len
        self.examples = [(torch.rand((1,) + img_shape) * 2 - 1) for _ in range(self.sample_size)]

    def sample_new_exmps(self, steps=60, step_size=10):
        n_new = np.random.binomial(self.sample_size, 0.05)
        rand_imgs = torch.rand((n_new,) + self.img_shape) * 2 - 1
        old_imgs = torch.cat(random.choices(self.examples, k=self.sample_size-n_new), dim=0)
        inp_imgs = torch.cat([rand_imgs, old_imgs], dim=0).detach()
        inp_imgs = inp_imgs.to(next(self.model.parameters()).device)

        inp_imgs = self.generate_samples(self.model, inp_imgs, steps=steps, step_size=step_size)

        self.examples = list(inp_imgs.cpu().chunk(self.sample_size, dim=0)) + self.examples
        self.examples = self.examples[:self.max_len]
        return inp_imgs

    @staticmethod
    def generate_samples(model, inp_imgs, steps=60, step_size=10, return_img_per_step=False):
        is_training = model.training
        model.eval()
        for p in model.parameters():
            p.requires_grad = False
        inp_imgs.requires_grad = True

        had_gradients_enabled = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

        noise = torch.randn(inp_imgs.shape, device=inp_imgs.device)
        imgs_per_step = []

        for _ in range(steps):
            noise.normal_(0, 0.005)
            inp_imgs.data.add_(noise.data)
            inp_imgs.data.clamp_(-1.0, 1.0)

            out_imgs = -model(inp_imgs)  # negative sign for convenience
            out_imgs.sum().backward()
            inp_imgs.grad.data.clamp_(-0.03, 0.03)
            inp_imgs.data.add_(-step_size * inp_imgs.grad.data)
            inp_imgs.grad.detach_()
            inp_imgs.grad.zero_()
            inp_imgs.data.clamp_(-1.0, 1.0)

            if return_img_per_step:
                imgs_per_step.append(inp_imgs.clone().detach())

        for p in model.parameters():
            p.requires_grad = True
        model.train(is_training)
        torch.set_grad_enabled(had_gradients_enabled)

        if return_img_per_step:
            return torch.stack(imgs_per_step, dim=0)
        else:
            return inp_imgs

Next, we combine everything into a training loop using PyTorch Lightning for convenience:


import pytorch_lightning as pl
import torch.optim as optim

class DeepEnergyModel(pl.LightningModule):
    def __init__(self, img_shape, batch_size, alpha=0.1, lr=1e-4, beta1=0.0):
        super().__init__()
        self.save_hyperparameters()
        self.cnn = CNNModel()
        self.sampler = Sampler(self.cnn, img_shape=img_shape, sample_size=batch_size)
        self.example_input_array = torch.zeros(1, *img_shape)

    def forward(self, x):
        return self.cnn(x)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, betas=(self.hparams.beta1, 0.999))
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.97)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        real_imgs, _ = batch
        noise = torch.randn_like(real_imgs) * 0.005
        real_imgs = (real_imgs + noise).clamp(-1., 1.)

        fake_imgs = self.sampler.sample_new_exmps(steps=60, step_size=10)
        inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
        real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0)

        # alpha * (real_out^2 + fake_out^2) is the regularization
        reg_loss = self.hparams.alpha * (real_out**2 + fake_out**2).mean()
        cdiv_loss = fake_out.mean() - real_out.mean()
        loss = reg_loss + cdiv_loss

        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        real_imgs, _ = batch
        rand_imgs = torch.rand_like(real_imgs) * 2 - 1
        inp_imgs = torch.cat([real_imgs, rand_imgs], dim=0)
        real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0)
        cdiv = fake_out.mean() - real_out.mean()
        self.log('val_cdiv', cdiv)

11.4 logging and visualization (tensorboard, metrics)

We might add callbacks to log images or negative samples from the replay buffer at each epoch, visualize their progression over MCMC steps, or track how out-of-distribution inputs are scored. These details are typically handled via standard PyTorch Lightning callback mechanisms or custom logging code.

11.5 example results: mnist digit generation

Trained with appropriate hyperparameters (e.g., 60 epochs, lr=1e4lr = 1e-4, etc.), this model begins to generate digits that look like MNIST samples. Below is a conceptual example of how one might generate images from scratch:


def generate_digits(model, num_steps=256, batch_size=8):
    with torch.no_grad():
        start_imgs = torch.rand((batch_size, 1, 28, 28)) * 2 - 1
    imgs = Sampler.generate_samples(model.cnn, start_imgs, steps=num_steps, step_size=10, return_img_per_step=True)
    return imgs

model = DeepEnergyModel((1,28,28), batch_size=128)
# Suppose the model is trained or loaded from checkpoint

gen_imgs_per_step = generate_digits(model)
# We can now visualize gen_imgs_per_step to see the evolution from random noise to digit-like patterns.

12. conclusion

Energy-based models provide a powerful and elegant framework for learning flexible probability distributions by defining a scalar energy function Eθ(x)E_\theta(x). Through exponentiation and normalization, these models describe a probability distribution that is, in principle, capable of capturing complex, multimodal phenomena. The ability to unify discriminative and generative modeling is attractive, as is the natural anomaly detection capability derived from the scalar energy.

However, EBMs also bring computational challenges. The partition function is intractable, forcing us to rely on approximate gradient-based sampling. Training loops generally incorporate techniques such as Contrastive Divergence, replay buffers, and carefully tuned hyperparameters for MCMC steps (step sizes, noise levels, gradient clipping, etc.). Without these stabilizations, EBMs can diverge.

Despite these difficulties, a growing body of research (Du and Mordatch, 2019; Nijkamp and gang, 2019; Grathwohl and gang, 2020, among others) demonstrates that modern deep networks used as energy functions can rival or exceed performance of more widely used generative methods under certain conditions. They also open interesting avenues for tasks like out-of-distribution detection, compositional modeling (e.g., product of experts), and integrated classifier–generator systems (JEM).

In practice, learning to harness EBMs can be a rewarding endeavor: the conceptual clarity of shaping an energy landscape is appealing, and the direct interpretability of energies (especially for anomaly detection and reconstructions) can be enlightening. If you are willing to invest in hyperparameter tuning, advanced sampling procedures, and possibly fallback checkpoint reloading, energy-based deep models can prove to be an invaluable addition to a modern machine learning toolkit.

The next time you face a problem where you need both generative modeling and a robust measure of sample plausibility (and you are not entirely satisfied with typical VAEs or GANs), remember that energy-based models are a strong contender. As research continues to develop improved training and inference methods, it is likely that EBMs will remain an influential pillar in both theoretical and applied machine learning for years to come.

kofi_logopaypal_logopatreon_logobtc-logobnb-logoeth-logo
kofi_logopaypal_logopatreon_logobtc-logobnb-logoeth-logo