

🎓 59/167
This post is a part of the Probabilistic models & Bayesian methods educational series from my free course. Please keep in mind that the correct sequence of posts is outlined on the course page, while it can be arbitrary in Research.
I'm also happy to announce that I've started working on standalone paid courses, so you could support my work and get cheap educational material. These courses will be of completely different quality, with more theoretical depth and niche focus, and will feature challenging projects, quizzes, exercises, video lectures and supplementary stuff. Stay tuned!
Sampling lies at the heart of countless machine learning and statistical methods, providing a pathway for exploring stochastic latent variables and enabling many key techniques for training large-scale models. Within the realm of deep learning, discrete structures have traditionally presented challenges due to their inherently non-differentiable nature. For example, if I want to sample a discrete random variable (like a categorical variable or a subset) inside a neural network and then backpropagate a gradient signal through that sampling procedure, I immediately run into the dilemma that the argmax or discrete selection steps are non-differentiable. However, this barrier has been partially overcome by new families of continuous relaxations and reparameterization methods, which allow for approximate or exact gradient flow through random sampling steps.
These methods are particularly relevant when dealing with discrete latent variables in a deep network — for instance, when I use a variational autoencoder (VAE) with a discrete latent space. Sampling from discrete random variables used to involve non-differentiable draws, forcing the use of high-variance gradient estimators like REINFORCE or policy gradients. But with the introduction of the Gumbel-softmax reparameterization trick (Jang and gang, ICLR 2017; Maddison and gang, ICLR 2017), along with other creative continuous approximations of discrete operations, we now have an expanded toolkit that can drastically improve training stability and reduce estimator variance.
The broader theme of this article is the family of methods that revolve around and its close cousins, which facilitate reparameterizing discrete draws in a continuous fashion. We will explore the fundamental ideas of the Gumbel distribution, the Gumbel-max trick, the Gumbel-softmax distribution, Gumbel top-k sampling, and the Gumbel-Sinkhorn operator. These techniques allow for differentiable approximations of argmax, top-k, subsets, permutations, and even adjacency matrices for graphs. By carefully tuning temperature parameters or employing advanced gradient estimators, we can move seamlessly between discrete and continuous realms.
motivation for sampling in deep learning
Deep neural networks have opened new possibilities for generative modeling, reinforcement learning, and structured prediction. In many of these problems, I might want to represent a variable as discrete — for instance, a categorical latent variable indicating which cluster an observation belongs to, a subset of features in a combinatorial selection problem, or the ordering of items in a ranking or matching problem. Such discrete structures bring interpretability and can lead to simpler or more efficient solutions, but they also introduce non-differentiable operations.
The question then becomes: how can I train large neural networks that contain these discrete components end-to-end with backpropagation? The direct approach, involving discrete sampling, breaks the chain rule of calculus. This leads to gradient estimators like REINFORCE that can exhibit high variance, especially in high-dimensional settings.
Hence the need for reparameterization tricks. The general principle behind reparameterization is that I express a random variable with a distribution parameterized by as a deterministic function of a parameter-free noise variable . Formally, instead of
I try to write
,
where has a distribution that does not depend on . This style of reparameterization is easy to implement for continuous distributions like the Gaussian, where and . But for discrete distributions, we can't so simply separate out the random part from the parameters — at least not without some creativity.
This is where the Gumbel distribution steps in. The Gumbel-max trick shows that if I have a categorical distribution with class probabilities , I can sample exactly from that distribution by sampling i.i.d. and taking . That's an exact sampler, but the is not differentiable. So the Gumbel-softmax relaxation replaces the with a operation, thereby making the entire process differentiable.
discrete structures in machine learning models
Some common examples of discrete structures in ML:
-
Categorical variables (unstructured vectors). A straightforward scenario is a K-class latent variable . This arises in VAEs that impose a discrete latent space or in generative models that must pick from discrete sets of states.
-
Subsets. I might want to pick a subset of items (features, data samples, or other sets). This is a combinatorial selection problem, often associated with the top-k or top-p selection, or used in resource-constrained optimization.
-
Permutations. Another step up in complexity is a scenario where I want to sample permutations, for instance in ranking tasks, matching problems, or route planning. A permutation matrix is discrete and combinatorial, so direct backprop is not feasible.
-
Graph-based latent variables. If the structure of a graph (which edges exist) is unknown, sampling from a distribution over graphs is a discrete selection problem over edges.
Each of these discrete structures can be tackled with Gumbel-based methods or other advanced gradient estimators.
continuous relaxations and reparameterization tricks
The overarching idea is to turn each discrete sample or operation (argmax, top-k, constructing adjacency, etc.) into a smooth approximation so that the forward pass yields values that, at high temperature, approximate continuous distributions, and at low temperature, approach discrete distributions. Then I can backprop through these continuous approximations.
We'll dig deep into the Gumbel distribution and the softmax relaxation, paying attention to how temperature scheduling can gradually push the model from a continuous representation (easy to train but less discrete) toward a more discrete representation (closer to the final goal but potentially harder to optimize). I'll also highlight alternative methods (like REINFORCE, RELAX, REBAR) that do not require continuous relaxations but come with their own trade-offs in variance and bias.
This article aims to give you an in-depth look at how these sampling approaches work, how to implement them, and how they can be used in real-world models. Along the way, I'll provide references to relevant papers and highlight recent research trends.
gumbel-softmax reparameterization
the gumbel-max trick and categorical distributions
Suppose I have a categorical distribution over classes with probabilities , meaning . I want to sample from that distribution. One way is to do a typical discrete sampling approach: pick class with probability . But the Gumbel-max trick redefines the sampling procedure as follows:
- Sample i.i.d. from a Gumbel(0,1) distribution. A Gumbel(0,1) random variable can be sampled as where .
- Compute .
The random index follows the original categorical distribution with probability . Thus, is a discrete sample from the correct distribution.
However, is not differentiable with respect to , because is not a smooth function.
the argmax operation vs. a softmax relaxation
The Gumbel-softmax trick, also called the Concrete distribution by Jang and gang (ICLR 2017) and Maddison and gang (ICLR 2017), replaces with a \softmax operation:
where is a temperature parameter. If is very low, the \softmax is sharply peaked, closely approximating an . If is very high, the vector spreads out, approaching a uniform distribution.
Crucially, is differentiable with respect to the logits . Hence if is produced by some neural network (e.g., the encoder of a VAE), I can backpropagate from to .
I note that the vector is still a continuous vector in . I can interpret as a "soft" one-hot vector, with each component in but summing to 1.
temperature hyperparameter and annealing strategies
A central hyperparameter in Gumbel-softmax is the temperature . If is close to 0, the \softmax becomes extremely peaked, approaching a one-hot vector, but the gradients can become large (or vanish in certain regimes). If is large, the vector is more uniform, and the training signals flow more stably.
A common trick is to start with a relatively high temperature and then gradually reduce it during training toward a smaller value . This is called temperature annealing. It helps the model begin training with smoother approximations, and later it can hone in on more discrete solutions.
training categorical vaes with gumbel-softmax
Variational autoencoders with continuous latent variables rely on the reparameterization trick for Gaussians. But if I want to do a categorical VAE, I can't easily reparameterize a discrete distribution. The Gumbel-softmax trick offers a solution.
In a categorical VAE, the encoder outputs (logits), from which I derive by . Then to sample , I do:
- Sample for each class .
- Compute y = \softmax\bigl((\log \pi + G)/\tau\bigr) .
Now is the relaxed discrete latent variable. The reconstruction function can condition on . At the same time, I can compute a KL divergence term that approximates the difference between the approximate posterior and a prior , which might be a uniform categorical distribution.
the straight-through estimator
Another variant is the straight-through estimator. In this approach, I compute the argmax in the forward pass to get an actual discrete sample , but in the backward pass, I treat that argmax as if it was a \softmax of the same logits. This means I get a discrete sample for the forward pass (which is beneficial if I want strict category decisions), but I still propagate useful gradients back.
Concretely, if is the Gumbel-softmax output, I can create by taking across the classes, i.e. a true one-hot vector. Then the forward pass uses as input to the next layer, but for the backward pass, I manually replace the gradient with the gradient from . In code, this is often achieved by adding and subtracting from inside a stop-gradient operation.
comparisons to alternative gradient estimators (reinforce, rebar, relax)
If I do not use a continuous relaxation, I might resort to REINFORCE (Williams, 1992), which treats the discrete sample as a random node in the computation graph and uses the log-derivative trick . But REINFORCE can suffer from high variance, so it's common to incorporate baselines to reduce variance.
REBAR (Tucker and gang, ICLR 2017) and RELAX (Grathwohl and gang, ICLR 2018) combine a continuous relaxation with a control variate technique to reduce variance further. These methods can provide unbiased (or low-bias) gradient estimates that have lower variance than straightforward REINFORCE. However, they are also more complex to implement, requiring additional neural networks to approximate baseline functions.
Overall, the Gumbel-softmax approach is conceptually simpler to implement. It does introduce a small bias, because the sample is no longer strictly discrete, but in practice, it can work well and scale to large problems more straightforwardly.
categorical vae example
model overview and latent variable structure
Let's consider an example of a Variational Autoencoder with one or more categorical latent variables. For simplicity, imagine I have a single latent variable with categories. The generative process is:
- Sample from some prior .
- Generate observation from a conditional distribution (e.g., a neural network that outputs parameters of a Bernoulli or Gaussian distribution over , given a one-hot encoding of ).
In a normal continuous VAE, might be a vector from a Gaussian. Here, it's a single categorical variable. The goal is to learn and also an approximate posterior .
gumbel-softmax sampling within the encoder
In the encoder, I produce logits , from which I define . To sample in a differentiable manner, I do:
The vector is a relaxed discrete sample. Then the decoder takes as input, e.g. .
kl divergence and reconstruction loss
The training objective for a VAE is the Evidence Lower BOund (ELBO):
When is discrete, is a categorical distribution. If the prior is uniform, has a closed-form expression:
But note, I never directly sample from the discrete distribution. Instead, I sample the continuous relaxation . The reparameterization with Gumbel-softmax is used for the expectation term . The KL term is computed in closed form using the distribution .
code walk-through and implementation details
Below is a simplified code snippet illustrating how to implement a categorical VAE with Gumbel-softmax in Python (PyTorch-like pseudocode). The relevant steps are in the encoder, where I produce logits and sample them with the reparameterization trick.
import torch
import torch.nn as nn
import torch.nn.functional as F
class GumbelSoftmaxVAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_k, temp_init=1.0):
super().__init__()
self.latent_k = latent_k
self.temp = temp_init
# Encoder
self.encoder_fc = nn.Linear(input_dim, hidden_dim)
self.encoder_logits = nn.Linear(hidden_dim, latent_k)
# Decoder
self.decoder_fc = nn.Linear(latent_k, hidden_dim)
self.decoder_out = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
h = F.relu(self.encoder_fc(x))
logits = self.encoder_logits(h)
# Return logits for the categorical distribution
return logits
def gumbel_softmax_sample(self, logits, temperature):
# sample gumbel
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
# add to logits and apply softmax
y = F.softmax((logits + gumbel_noise) / temperature, dim=-1)
return y
def decode(self, y):
h = F.relu(self.decoder_fc(y))
out = torch.sigmoid(self.decoder_out(h))
return out
def forward(self, x):
logits = self.encode(x)
y = self.gumbel_softmax_sample(logits, self.temp)
reconstruction = self.decode(y)
return reconstruction, y, logits
Here, implements the Gumbel-softmax reparameterization. The temperature can be annealed over epochs by doing something like:
# Suppose we have an update rule
def update_temperature(model, epoch, start_temp=1.0, end_temp=0.1, total_epochs=100):
# Simple linear or exponential scheduling
# For example, exponential
ratio = min(float(epoch) / total_epochs, 1.0)
new_temp = start_temp * (end_temp / start_temp) ** ratio
model.temp = new_temp
practical considerations and tuning the temperature
In practice, selecting a starting temperature in and gradually decaying to or can be a good starting point. If the temperature becomes too low too fast, training can destabilize or get stuck. If it remains too high, the latent variables never become discrete, and the model might not leverage the categorical structure.
Additionally, watch out for gradient magnitudes. With extremely small , the softmax can saturate, leading to vanishing or exploding gradients. Some frameworks also provide built-in Gumbel-softmax or straight-through functionalities that can handle these details.
evaluation metrics for generative quality and latent utilization
Once trained, we can evaluate the generative quality by sampling from the learned model:
- Sample (or from the approximate posterior if evaluating reconstruction).
- Pass the corresponding one-hot or relaxed vector to the decoder.
We can measure how well the model reconstructs test data (reconstruction error), how well the latent classes are used (e.g., how many categories remain near zero probability in the prior or encoder distribution), and if the learned categories are semantically meaningful.
gumbel top-k sampling and subset selection
from gumbel-argmax to gumbel top-k for subsets
In many tasks, I don't just want to pick one element with ; I might want the top- elements out of . The Gumbel-max trick extends naturally to the top-k scenario:
- Sample for .
- Compute as the indices of the largest values of .
This yields an exact sample of a -subset from the distribution that picks subsets with probabilities proportional to the product of their parameters (under certain assumptions). But for differentiability, I can do an approximate version.
sampling without replacement and top-k relaxation
The top-k version of Gumbel sampling can be relaxed in multiple ways. One approach is to produce a -dimensional vector whose largest entries are significantly higher than the rest, but in a continuous manner. Another approach is iterative, where I pick the largest logit, remove it, and pick the next largest from the remaining subset, etc.
A well-known approach is to do something akin to:
y = \softmax\Bigl( ( \log \pi + G ) / \tau \Bigr),then only keep the top-k entries in by zeroing out others or applying some continuous threshold. This approach can be more complex, but the principle is to preserve as much differentiability as possible.
subsetoperator class: iterative softmax and temperature tuning
A pseudo-implementation might look like this:
import torch
import torch.nn.functional as F
def sample_topk_gumbel(logits, k, temperature):
# logits shape: (batch_size, K)
# we want to get a "soft" top-k selection
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
y = (logits + gumbel_noise) / temperature
# softmax
y_soft = F.softmax(y, dim=-1)
# potentially we can approximate top-k via a sharper distribution or partial threshold
# for demonstration, let's keep it simple
# we can also return the sorted indices or partial top-k distribution
return y_soft
One might refine that to an iterative or thresholding approach that ensures exactly items are chosen.
differentiable subset selection in k-nearest neighbor classification
An interesting application is making the choice of which neighbors are considered in a k-NN model differentiable. For example, I might embed points in a latent space and then pick the top- nearest neighbors as a soft subset based on Gumbel-based distances. Then backpropagation can adjust the embedding to improve classification or regression performance.
If I do a standard k-NN approach, the selection of neighbors is non-differentiable. But if I approximate it with a Gumbel top-k scheme, the entire pipeline can, in principle, be made end-to-end trainable, though this is still an area of active research.
empirical distribution checks and histogram comparison
When implementing Gumbel top-k or other discrete sampling approximations, I should always check that the empirical distribution of the selected subsets matches my intended distribution. For instance, if is uniform, I want each k-subset to be equally likely on average. In practice, numerical issues, finite sample sizes, or extreme temperature values may skew the distribution.
To verify correctness, one can:
- Sample a large number of subsets.
- Estimate the distribution of subsets.
- Compare it against the theoretically expected distribution.
For moderate and , a histogram can confirm that the sampling is unbiased or only mildly biased.
extensions to resource-constrained subset selection and combinatorial optimization
Beyond picking neighbors, I might have a large set of items from which I need to pick a subset that optimizes some cost function, subject to resource constraints. Traditional combinatorial optimization algorithms are not easily integrated into deep networks, but with Gumbel-based subset sampling, I can incorporate these constraints as a differentiable approximation and optimize end-to-end with gradient-based methods. Research on differentiable subset selection and deep combinatorial optimization is growing rapidly (see Mena and gang, 2018 and Kool and gang, 2019 for example approaches).
gumbel-sinkhorn networks for permutations
doubly stochastic matrices and linear assignment
A permutation matrix is a binary matrix with exactly one "1" in each row and each column, and zeros elsewhere. If I want to sample a permutation from some distribution, that is effectively a discrete structure. But I can approximate permutations with doubly stochastic matrices, which have nonnegative entries and each row and column sums to 1.
The Sinkhorn operator is a function that projects or normalizes any positive matrix into a doubly stochastic matrix via iterative row and column normalization.
relaxing permutations with the sinkhorn operator
The Gumbel-Sinkhorn trick (Mena and gang, ICLR 2018) extends the Gumbel-max trick to permutations. Instead of sampling across rows, I inject Gumbel noise into a matrix and then apply the Sinkhorn normalization repeatedly:
where is a matrix built as (for a distribution over permutations). The result is close to a permutation matrix but is differentiable with respect to .
gumbel-matching vs. gumbel-sinkhorn distributions
The original Gumbel-matching approach picks the single best matching with operations. Gumbel-Sinkhorn relaxes that to produce a doubly stochastic matrix. As , becomes closer to a true permutation matrix. For moderate or high temperatures, is a fuzzy or partial assignment.
implementation of sinkhorn and the matching function
The Sinkhorn operator can be implemented as:
import torch
def sinkhorn(log_alpha, n_iters=10):
# log_alpha: (batch_size, n, n)
# returns a (batch_size, n, n) doubly stochastic matrix
# exponentiate
alpha = torch.exp(log_alpha)
for _ in range(n_iters):
# row normalization
alpha = alpha / (torch.sum(alpha, dim=2, keepdim=True) + 1e-9)
# column normalization
alpha = alpha / (torch.sum(alpha, dim=1, keepdim=True) + 1e-9)
return alpha
Then, if I want to incorporate Gumbel noise, I can do:
def gumbel_sinkhorn_sample(logits, tau, n_iters=10):
# logits: (batch_size, n, n)
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
log_alpha = (logits + gumbel_noise) / tau
# apply Sinkhorn
S = sinkhorn(log_alpha, n_iters=n_iters)
return S
The matrix is a relaxed permutation. If I want to get a hard permutation, I can apply an row/column selection on .
sampling permutation matrices in a differentiable way
Since is continuous in the forward pass, I can backpropagate through it. As decreases, approaches a permutation matrix, but might cause sharper gradients. One must tune or anneal .
applications to sorting, ranking, and structured prediction tasks
The Gumbel-Sinkhorn approach has been used in tasks like:
- Neural sorting and ranking. Approximating the top positions in a ranking as a continuous operation.
- Linear assignment problems. If I want to solve a matching problem in a deep learning context, the Gumbel-Sinkhorn approach can embed that optimization in the network.
- Structured prediction. Many structured outputs (like permutations of tokens) can be handled with differentiable approximation methods.
learning latent permutations in practice
sinkhorn-based convolutional networks
Imagine a scenario where I have a convolutional neural network that's supposed to reorder or unscramble an input. For instance, I might take an image, split it into patches, shuffle the patches, and feed that scrambled image into a network that must unscramble it by predicting the correct permutation of patches.
the unscrambling mnist digits example
A canonical example is unscrambling an MNIST digit that has been cut into patches and permuted randomly. The network attempts to find the correct permutation so that the patches line up to reconstruct the original digit.
- Data: For each MNIST digit, generate random permutations of the patches.
- Model: Takes the scrambled patches as input, produces of shape .
- Apply Gumbel-Sinkhorn: Yields a doubly stochastic matrix .
- Reconstruct the unscrambled image using .
network architecture and training procedure
- Feature extractor: A small CNN or MLP that processes each patch or the entire scrambled image.
- Permutation predictor: Outputs a matrix of shape representing the cost/logits for assigning patch to location .
- Gumbel-Sinkhorn: Produces the continuous approximation .
- Reconstruction loss: After applying the (relaxed) permutation to the patches, measure the reconstruction error against the original unscrambled image.
In practice, you might do an MSE or cross-entropy reconstruction loss on the unscrambled image.
evaluating permutation accuracy (kendall-tau)
A typical metric for permutations is Kendall-Tau, which measures how many pairwise inversions differ from the ground truth permutation. Alternatively, we can measure the fraction of patches correctly placed.
During inference, if needed, we can take row-wise or column-wise on to obtain a discrete permutation. Then we measure how many patches match their correct positions.
sample results and visualization
One might visualize the scrambled images and their unscrambled counterparts to see if the model is truly learning a good ordering. This can be done by generating plots of the input, the predicted unscrambled output, and a difference overlay.

An image was requested, but the frog was found.
Alt: "Example of unscrambled MNIST digits"
Caption: "An illustration of applying Gumbel-Sinkhorn to unscramble image patches."
Error type: missing path
graph sampling and neural relational inference
overview of graph-based latent variable models
Many systems, from social networks to physical systems, revolve around interactions among entities represented as graphs. In some settings, the graph is not known a priori; we want to learn the connectivity structure. For instance, in Neural Relational Inference (Kipf and gang, ICML 2018), we observe the motions of particles that might be connected by unseen springs, and we want to discover the underlying interaction graph.
encoding and decoding graph structures in vaes
One approach is to treat the adjacency matrix of the graph as a latent variable. If there are nodes, is an matrix of 0/1 entries. We can place a Bernoulli distribution on each edge, or a categorical distribution if there are multiple edge types.
We'd then define:
which might be a physics simulation or a message-passing network that uses the adjacency to define interactions. Then the posterior can be approximated by a neural network. But sampling discrete adjacency leads to the same differentiability issue.
using gumbel-softmax for edge sampling
By applying a Gumbel-softmax approach to each pair of nodes, we can produce a continuous relaxation of the adjacency matrix. For an adjacency, each entry is a Bernoulli or multi-class variable. Then the Gumbel-softmax reparameterization can be used to generate a smoothed adjacency matrix.
learning interaction networks (springs example)
In the springs example, we have particles connected by unknown springs. We see their positions over time. The adjacency matrix tells us which pairs of particles are connected. By using the Gumbel-softmax adjacency, the neural network can infer which pairs are connected (or strongly interacting) in a fully differentiable manner. This forms the basis of a Neural Relational Inference VAE-like system.
Kipf and gang (ICML 2018) show that this approach can learn the correct adjacency in many synthetic physics simulations and also generalize to unseen conditions.
visualizing discovered graphs and analyzing results
A valuable step is to visualize the learned adjacency matrix, especially in a small example with a handful of particles. You can compare it to the ground truth adjacency to see if the model picks out the correct structure. Sometimes the model might produce soft adjacency values that reflect partial confidence in certain edges.
beyond gumbel: other discrete reparameterization approaches
binary concrete (continuous bernoulli) distributions for binary latent variables
If the discrete variable is binary ( or ), the binary concrete distribution is a specialized version of Gumbel-softmax for . Another option is sigmoid-based relaxations, which can also work but might have different gradient properties.
reinforce-style estimators with variance reduction (baseline, input-dependent, etc.)
REINFORCE is a classic alternative that does not rely on relaxing the distribution. Instead, we compute for discrete . It's an unbiased estimator but can have large variance. Adding a baseline or a learned value function can help.
relax and rebar: advanced variance-reduced gradient estimators
REBAR (Tucker and gang, ICLR 2017) and RELAX (Grathwohl and gang, ICLR 2018) attempt to combine the best of both worlds: they use the continuous Gumbel-softmax variable as a control variate and an additional function to correct for bias, yielding a gradient estimator with lower variance than pure REINFORCE while still being theoretically grounded.
trade-offs in bias vs. variance and model complexity
When choosing among Gumbel-softmax, REINFORCE, RELAX, or other gradient estimators, you weigh:
- Bias: Gumbel-softmax introduces a bias since the sample is never truly discrete. But in the limit of , the bias becomes negligible.
- Variance: REINFORCE can be high-variance, but sophisticated control variates can reduce it. Gumbel-softmax typically has moderate variance.
- Implementation complexity: Gumbel-softmax is straightforward to implement. RELAX or REBAR are more advanced.
practical guidelines on choosing an estimator
- Scale of the problem: Gumbel-softmax can handle moderate or large-scale discrete problems easily, but might degrade in extremely large combinatorial spaces.
- Exactness vs. efficiency: If a small bias is acceptable, Gumbel-based relaxations are often simpler. If you need unbiased estimates, consider REINFORCE or REBAR/RELAX with advanced variance-reduction.
- Computational cost: Some methods require additional neural networks to approximate baselines or local expectations.
future directions and conclusion
summary of key takeaways
Throughout this article, I've dived into how Gumbel-based reparameterization can open the door to end-to-end differentiable training for discrete random variables. The fundamental pipeline is:
- Express a discrete draw as with .
- Replace or or with a continuous relaxation that can be backpropagated through.
- Use temperature to balance the trade-off between discreteness and smoothness.
This approach generalizes to:
- Categorical variables (Gumbel-softmax).
- Subsets (top-k Gumbel).
- Permutations (Gumbel-Sinkhorn).
- Graphs (edge sampling with Gumbel-softmax).
advanced topics: reinforcement learning, policy gradients, and discrete controls
Discrete sampling arises in reinforcement learning for selecting actions. Although Gumbel-based methods can be used to approximate policy gradients, the standard approach in RL is to use policy gradient estimators or Q-learning-based methods. However, there is ongoing work that merges Gumbel reparameterization with policy optimization to reduce variance or to incorporate discrete structures in more complex environments.
potential improvements and research trends
- Better temperature schedules: Learning or adapting the temperature automatically is a hot research direction, so that the model can choose the right level of discreteness.
- Hybrid methods: Combining Gumbel-softmax with a baseline network (as in RELAX) might yield lower variance.
- Structured compositional tasks: New research addresses hierarchical subsets or permutations, e.g. building trees with Gumbel-based sampling.
- Larger combinatorial spaces: Ongoing research explores parallelization and specialized approximations for extremely large discrete sets (like sequences).
final remarks and further reading
In sum, sampling in deep learning must often address the question of how to deal with discrete variables in a differentiable way. The Gumbel approach, along with related reparameterization techniques, has become a foundational tool for training discrete latent variable models. While no single solution works best in all scenarios, Gumbel-based methods are widely used due to their relative simplicity, efficiency, and flexibility.
Those interested in further details may consult:
- Jang and gang, "Categorical Reparameterization with Gumbel-Softmax" (ICLR 2017).
- Maddison and gang, "The Concrete Distribution" (ICLR 2017).
- Mena and gang, "Learning Latent Permutations with Gumbel-Sinkhorn" (ICLR 2018).
- Tucker and gang, "REBAR: Low-variance, unbiased gradient estimates for discrete latent variable models" (ICLR 2017).
- Grathwohl and gang, "Backpropagation through the void: Optimizing control variates for black-box gradient estimation" (ICLR 2018).

An image was requested, but the frog was found.
Alt: "Generic schematic of Gumbel-based sampling"
Caption: "Illustration of Gumbel noise being added to logits, then passed through a continuous function (like a softmax or sinkhorn normalization)."
Error type: missing path
The ability to incorporate discrete structures while still retaining end-to-end trainability unlocks many new directions in generative modeling, combinatorial optimization, and structured prediction. With these methods in hand, you can design neural architectures that tackle discrete decisions at scale without sacrificing the advantages of gradient-based learning.