banner
Causal representation learning
Causality steps in
#️⃣   ⌛  ~1.5 h 📚  Advanced
04.02.2025
upd:
#147

views-badgeviews-badge
banner
Causal representation learning
Causality steps in
⌛  ~1.5 h
#147


🎓 153/167

This post is a part of the AI theory educational series from my free course. Please keep in mind that the correct sequence of posts is outlined on the course page, while it can be arbitrary in Research.

I'm also happy to announce that I've started working on standalone paid courses, so you could support my work and get cheap educational material. These courses will be of completely different quality, with more theoretical depth and niche focus, and will feature challenging projects, quizzes, exercises, video lectures and supplementary stuff. Stay tuned!


Causal representation learning, often abbreviated as CRL, stands as a rapidly advancing branch of machine learning research that aims to identify, disentangle, and utilize latent causal variables from observed data. Unlike conventional statistical learning approaches that concentrate almost exclusively on correlations and large-scale pattern finding, CRL emphasizes the intricate relationships that encode how and why specific factors interact to shape observed outcomes. My purpose in this extensive article is to walk you through the important concepts, theoretical underpinnings, practical frameworks, relevant literature, and outstanding future directions in CRL. By the time you complete reading, I hope you will have a profound understanding of how causal representations make machine learning models more transparent, robust to out-of-domain shifts, and better able to handle interventions, planning, and decision-making in dynamic environments.

It should be acknowledged that causal representation learning has been influenced by, and indeed builds upon, a variety of intellectual contributions spanning decades of research in statistics, artificial intelligence, philosophy of science, and computer vision. These influences include, but are not limited to: Pearl's structural causal models (Pearl, 2009), Rubin's potential outcomes framework, Spirtes–Glymour–Scheines's approaches to causal discovery, and more recent works in disentangled representation learning. Over time, the synergy between representation learning and causality has produced methods—like CITRIS (Lippe and gang, 2022a)—that explicitly leverage constraints arising from interventions, domain shifts, and temporal dependencies to recover meaningful latent factors.

1.1 Motivation and scope of causal representation learning (CRL)

One of the major drivers behind CRL is the recognition that machine learning models too often rely on superficial patterns in data. While these correlations might suffice for i.i.d. infoindependent and identically distributed tasks within relatively static environments, they can fall short when presented with a shift in domain conditions or when asked to reason about interventions. For instance, if you have a computer vision model that classifies objects under standard lighting conditions, the model might fail when the illumination changes dramatically. The reason is often that the model latched onto statistical regularities that do not correspond to fundamental causal relationships in the world.

This shortcoming has led many researchers to focus on the concept of causality, which is about more than just correlation. Causal representations provide a vantage point from which models can separate essential causal mechanisms from superficial confounders. CRL is motivated by the notion that by modeling the true underlying factors of variation, we can achieve greater robustness, interpretability, and adaptability. Once we have identified these causal factors, we can intervene on them—potentially in a controlled manner—and predict how the system will behave under those interventions.

The scope of CRL, then, is incredibly broad. It strives to be applicable to a wide range of tasks: from image recognition and time-series forecasting to sequential decision-making in reinforcement learning. CRL is not merely an abstract theoretical exercise; it brings profound implications for real-world settings where we want stable, generalizable, and interpretable models across varied conditions.

1.2 Authors and context (Angelos Nalmpantis, Danilo de Goede)

Before diving more deeply into the formalisms and frameworks of CRL, it is worth noting that much of the work on CITRIS and the TRIS (TempoRal Intervened Sequences) setting has been spearheaded by researchers such as Angelos Nalmpantis and Danilo de Goede, in close corroboration with other leading groups studying causality and representation learning. The impetus for their research emerges from an intersection of time-series data analysis, Bayesian networks, and advanced deep learning, with a particular interest in how minimal causal variables can be identified and leveraged for tasks requiring interventions over time.

Specifically, Nalmpantis and de Goede, alongside various collaborators, have also explored how the interplay of representation learning and causal reasoning can lead to new insights in robotics, healthcare, and even economic forecasting. Their focus on the synergy between interventions and temporal dependencies is a critical piece of the broader puzzle: Many real-world domains—like autonomous driving or industrial control systems—are dynamic, with sequential dependencies that cannot be neglected. It is in precisely these environments that CRL can demonstrate its full strength by generating representations that not only reflect what is happening now, but also how that might change if certain variables are intervened upon.

1.3 Overview of the tutorial notebooks and pre-trained models

To facilitate hands-on learning, there exist a variety of resources, including open-source GitHub repositories, Colab notebooks, and specialized workflows, that showcase how to implement and test CRL methods in practice. These tutorials often provide:

Pre-trained models that have already been optimized on certain tasks, saving researchers from the initial overhead of lengthy training.
• Jupyter/Colab notebooks illustrating the decomposition of data into causal factors, how interventions are encoded, and how to visualize results.
• Guidelines on hyperparameter tuning, best practices for handling large datasets, and stable training heuristics.

Many references point to a well-documented pipeline (for instance, Lipschitz-limited architectures for CITRIS or normalizing flow-based expansions for iCITRIS). Even if you are new to the codebase, these tutorials are structured so you can gradually build your understanding of the entire system: from ingestion of observational data to the final step of performing manipulated predictions.

I recommend leveraging these resources as you progress through the rest of this article, because they will allow you to anchor the theoretical concepts presented here in practical code examples. By working through the notebooks, you can see exactly how CITRIS handles interventions in a temporal sequence, how the model's latent space representation emerges, and how to interpret the metrics that quantify disentanglement or causal identifiability.

2. Foundations of causal representation learning

Long before the era of large neural networks, causality researchers grappled with the question: "What does it mean to identify a cause-and-effect structure from data?". Early milestones in this field included the so-called "do-calculus" introduced by Judea Pearl, which provides a formal language for expressing interventions, and structural equation models that define explicit functional relationships among variables. However, as deep learning ascended to the forefront of machine learning, the combination of high-capacity representational power and rigorous causal inference frameworks became a natural focal point. Enter causal representation learning, which merges these two previously disparate paradigms.

2.1 Definition and importance of CRL

Causal representation learning (CRL) can be loosely defined as a subdiscipline that aims to discover or leverage latent causal factors in high-dimensional data, typically with the goal of enabling robust predictions, interventions, and interpretations. It transcends the purely correlation-based viewpoint to specifically target the variables that carry causal significance. In a typical scenario, you might have an observed dataset of images (each containing objects in different positions, colors, or under different lighting conditions) and you suspect that behind these images lie a handful of generative factors (shape, position, color, lighting) that combine to create the final observed outcome. CRL tries to invert this generative process by learning a representation that recovers precisely those factors.

The importance of CRL stems from the tangible benefits it brings to real-world scenarios:

  1. Out-of-domain generalization: If a model has learned a purely correlational relationship that only holds for the training distribution, any domain shift could cause performance degradation. Conversely, a model that recovers the underlying causal variables is better positioned to generalize because those variables remain invariant across domains, or at least they provide well-defined transformations for how the domain shift might affect the data.

  2. Interventions and counterfactuals: CRL also matters for scenarios in which you want to predict how the world would change if you took a specific action. For example, if you intervene on the hue of an object, how does that affect classification? A causal representation is precisely the tool you need to perform such manipulations reliably in your latent space without breaking the internal consistency of the generative factors.

  3. Interpretability and transparency: In many regulated industries—such as healthcare or finance—understanding why a model made a particular prediction is essential. Causal variables often come with real-world semantics, making the model's decision process far more interpretable than black-box correlation-based methods.

2.2 Statistical learning vs. causal learning

Statistical learning primarily focuses on patterns of association. If you see a large dataset of input–output pairs, you craft a function that tries to minimize some form of predictive error. Causal learning, however, requires that you understand which variables are drives of change versus those that are simply changing in tandem. Indeed, you can have two variables that are strongly correlated but have no causal relationship, perhaps due to a shared hidden parent or confounding factor. Identifying the correct directionality or incapacitating the confounding effect is precisely where causal inference steps in.

In the classical sense, suppose you have a dataset of variables X,Y,ZX, Y, Z; a purely statistical approach might find correlations like XYX \leftrightarrow Y is strong. But causal learning attempts to ascertain if XX causes YY or YY causes XX or whether a confounder ZZ is the real cause of both. The difference is crucial: if you are trying to design an intervention, you need to know which variables to push to obtain a desired effect.

From a representation learning standpoint, bridging statistical and causal learning means designing flexible neural architectures that can untangle underlying generative processes, while also systematically enforcing constraints that lead to identifiable causal structures. This can be accomplished through domain knowledge, controlled interventions, or temporal consistency, among other approaches.

2.3 Key challenges: out-of-domain generalization and planning

Despite the attractiveness of CRL, major challenges exist in bridging theory and practice:

  1. Out-of-domain generalization: A central ambition of CRL is to ensure that the learned representation remains valid under distribution shifts—where either the distribution of inputs changes or new causal variables appear in the environment. This is significantly more challenging than the typical i.i.d. assumption, which is standard in short-sighted machine learning paradigms.

  2. Planning: When latency or real-time decision needs are introduced—such as in robotic control tasks or dynamic scheduling systems—having the ability to plan forward, test hypothetical scenarios, and choose interventions that yield beneficial outcomes becomes paramount. CRL frameworks must be integrated with sequential decision-making architectures so that the model can, in effect, simulate different interventions in latent space and choose the best path.

From a pipeline perspective, these two challenges demand specialized solutions: robust architectures, novel training regimes, and a synergy between knowledge-driven constraints and data-driven inference. In subsequent chapters, you'll see how the TRIS setting, CITRIS, and other frameworks attempt to address these points.

3. The tempoRal intervened sequences (TRIS) setting

Moving from fundamentals to a specialized scenario, we encounter the TRIS setting, short for TempoRal Intervened Sequences. This approach focuses on sequences of data—like videos or time-series logs—where one can apply interventions at specific time steps, thereby influencing future observations. In the real world, many phenomena unfold over time, so the ability to incorporate a temporal dimension into causal inference is crucial. The TRIS perspective ensures that changes induced by interventions are tracked as they propagate from one time step to the next.

3.1 Defining TRIS: temporal Bayesian networks and interventions

Arguably, the simplest conceptual mechanism for capturing temporal structure is a temporal Bayesian network. In the TRIS context, you have a set of latent causal variables ztz_t that evolve over time, and each ztz_t can be intervened upon to yield a new configuration at zt+1z_{t+1} or subsequent time steps. Formally, you can think of the generative process as:

zt+1=fθ(zt,at,ϵt) z_{t+1} = f_{\theta}(z_t, a_t, \epsilon_t)

where ztz_t denotes the state of the latent causal variables at time tt, ata_t is the applied intervention (or action), ϵt\epsilon_t is noise or exogenous variation, and fθf_{\theta} is a parametric function capturing the transition dynamics. Observed data xtx_t is then generated from ztz_t through some emission function gϕg_{\phi}. Thus:

xt=gϕ(zt) x_t = g_{\phi}(z_t)

The intervention ata_t effectively modifies the generative trajectory, letting you see how the system changes as a result. In standard supervised or unsupervised learning, you seldom have the capacity to systematically manipulate variables. TRIS, in contrast, sets up the scenario explicitly: you get data that show how the system evolves with and without interventions, giving the model strong cues about the causal structure.

3.2 Causal variables, intervention targets, and observational data

An essential concept in TRIS is causal variables. These are the distinct factors in the latent space that genuinely drive changes in the observed data. Intervention targets specify which causal variables are manipulated and how. Finally, observational data in this setting is the set of recorded sequences under various interventions. For instance:

• You might have a robot pushing objects in different directions. Each push is an intervention on position or velocity.
• You might have an environment where you systematically alter lighting conditions (an intervention on a lighting factor) and observe the impact on object appearance.

Crucially, the observational data alone (with no known interventions) is insufficient to unravel the complete causal structure. Only by pairing observational sequences with intervened data can the model separate correlation from genuine causation. This synergy between the observational dimension—what naturally happens over time—and the intervened dimension—how the system changes when you intentionally manipulate it—underpins the TRIS approach.

3.3 Multidimensional causal variables and minimal causal variables

In many real-world scenarios, each ztz_t is not a single scalar variable but a collection of factors—like zt1,zt2,,ztkz_t^1, z_t^2, \ldots, z_t^k—each controlling a different aspect of the observed phenomenon. For instance in a 3D scene: zt1z_t^1 might govern the horizontal position of an object, zt2z_t^2 might govern its color, zt3z_t^3 might control the background hue, and so on. Each dimension can be intervened upon, either individually or collectively.

A topic of intense inquiry is identifying minimal causal variables—the smallest set of variables that suffice to describe the system's causal dynamics. This is concerned with parsimony: we prefer a minimal set of factors, as that fosters better interpretability and reduces superfluous modeling overhead. Moreover, minimality is often tied to better identifiability properties. If the model lumps multiple distinct factors into a single dimension, it can obscure which interventions specifically affected the observed changes.

3.4 Connection to the broader CRL field

TRIS is not an isolated phenomenon, but rather one instance of a broader push within CRL to incorporate real-world complexities like time, interactions, and direct manipulations. Indeed, one might see parallels with frameworks such as structural recurrent neural networks, partially observable Markov decision processes, or other sequential generative models. What sets TRIS apart is the explicit notion that we do not just passively track data over time; we actively intervene and measure the consequences. This is why TRIS is such an attractive setting for methodologists seeking to push the boundaries of causal representation discovery.

4. Introducing CITRIS

Now that you have a grasp on how CRL differs from statistical learning and how the TRIS setting encodes temporal interventions, we move on to a specific framework known as CITRIS (infoCausal Identifiability from Temporal Intervened Sequences). Proposed in Lippe and gang (2022a), CITRIS takes a stance that by combining temporal consistency with strategic interventions, one can learn an invertible mapping from high-dimensional observations to stable causal factors. The name CITRIS is a direct nod to the fact that it is designed to exploit the TRIS setup in order to achieve identifiability of causal variables.

4.1 Motivation: learning causal variables from high-dimensional data

The impetus behind CITRIS is straightforward if you consider the typical challenges of modern data. Real-world data—especially images or sensor readings—are massive in dimension. A single raw image might contain thousands of pixels, each with multiple color channels. Identifying the underlying causal factors from such unstructured data can feel akin to searching for a needle in a cosmic-scale haystack. CITRIS addresses this by imposing structure on the problem:

  1. Temporal consistency: Observations close in time should share consistent factors.
  2. Interventions: The moments when these factors are forcibly changed reveal which dimension is truly controlling which aspect in the observed domain.
  3. Invertibility: A structure in the latent space ensures that each dimension corresponds to a genuinely distinct causal factor.

4.2 CITRIS at a glance (Lippe and gang, 2022a)

CITRIS can be conceptualized as a family of generative or inference-based models that incorporate both a forward mapping from latent factors to observations and a backward (i.e., encoder) mapping from observations to latent factors. The forward pass can be described as:

xt=G(zt;θ) x_t = G(z_t; \theta)

where ztz_t is the vector of latent causal variables at time tt and θ\theta are the parameters. The backward pass involves an encoder:

qϕ(ztxt) q_{\phi}(z_t \mid x_t)

that attempts to invert GG and uncover the latent interacting components from the high-dimensional data. CITRIS extends these building blocks by layering in a transition prior that captures how ztz_t evolves under interventions, as well as constraints that enforce partial or full invertibility. Additionally:

• The approach exploits minimal causal variables to ensure a principled factorization.
• A target classifier might be tacked on to help ensure that the assigned factors are indeed relevant to the interventions or targets of interest.

4.3 CITRIS vs. earlier causal representation learning methods

Prior CRL methods—like the CausalVAE or other approaches that rely on factor disentanglement—often focus predominantly on observational data or assume that certain strong supervision signals are available. While those methods show promise, they sometimes fail to disentangle or identify the truly causal factors if the data do not exhibit enough variability or if the model does not have any direct means of seeing how the environment changes under interventions.

CITRIS, by contrast, leverages the TRIS strategy to provide strong supervision in the form of interventions. This leads to the significantly improved identifiability that was previously elusive. In essence, CITRIS stands on the shoulders of earlier works in the sense that it uses data-driven autoencoders or generative frameworks, but it augments them with explicit mechanisms to incorporate knowledge gleaned from manipulated transitions in time.

4.4 iCITRIS (brief note on Lippe and gang, 2022b)

Shortly after CITRIS was introduced, Lippe and gang (2022b) proposed an incremental extension named iCITRIS, where additional attention was paid to scaling the model up to more complex scenarios and refining the invertibility constraints even further. iCITRIS also dives deeper into the potential for more flexible normalizing flow-based transformations. In practice, iCITRIS can be seen as a continuing push to handle broader classes of interventions and data modalities, including multi-object scenes, videos with complex lighting, and partial or uncertain knowledge about which interventions were performed at each time step.

5. CITRIS framework in detail

After introducing CITRIS conceptually, let us now go under the hood to see what truly makes this framework tick. A hallmark of CITRIS is that it incorporates guarantees for identifying minimal causal variables, uses invertibility to ensure a one-to-one mapping between latent factors and generative processes, and trains using an interplay between temporal consistency constraints and explicit intervention-based constraints.

5.1 Minimal causal variables and invertible mappings

At the core of CITRIS, one finds the principle that each dimension in the latent space should correspond to precisely one causal variable, and that each causal variable shows up in the learned representation unequivocally. In simpler terms, CITRIS tries to avoid the phenomenon in which two separate latent dimensions are partially responsible for controlling the same factor. Achieving this property requires you to give up certain degrees of freedom in the representation—specifically, you impose invertibility constraints so that each dimension in the latent space can be inverted to yield a unique factor.

A common technique for achieving invertibility (a one-to-one mapping between latent space and data space) is to use a normalizing flow or various bijective transformations in the generative model. Another approach is to carefully design the neural networks so that they remain invertible by construction, an approach sometimes observed in RealNVP or other flow-based frameworks. By ensuring invertibility, we drastically reduce the risk of an arbitrary collapse of multiple latent factors into a single dimension.

5.2 Learning approach: combining temporal consistency with interventions

The CITRIS training regime typically has two main objectives:

  1. Temporal consistency: Observations xtx_t and xt+1x_{t+1} that are close in time share many of the same latent factors, unless an intervention ata_t modifies them. This is enforced by a prior or cost function that rewards the model for mapping successive frames to similar latent representations.

  2. Intervention-based constraints: CITRIS explicitly encodes knowledge about where and how an intervention took place. If an intervention modifies only factor ztkz_t^k, then the model is penalized if other factors ztjz_t^j (jkj \neq k) also move in latent space. This constraint effectively teaches the model that each dimension is only sensitive to specific interventions.

Conceptually, this can be expressed via a custom loss function. We might define:

L(ϕ,θ)=Lrec(x,z)+λtempLtemp(zt,zt+1)+λintLint(zt,at) \mathcal{L}( \phi, \theta ) = \mathcal{L}_{\mathrm{rec}}( x, z ) + \lambda_{\mathrm{temp}} \mathcal{L}_{\mathrm{temp}}( z_t, z_{t+1} ) + \lambda_{\mathrm{int}} \mathcal{L}_{\mathrm{int}}( z_t, a_t )

Here, Lrec\mathcal{L}_{\mathrm{rec}} is a reconstruction loss that ensures the model reconstructs the observed data well, Ltemp\mathcal{L}_{\mathrm{temp}} enforces temporal consistency, and Lint\mathcal{L}_{\mathrm{int}} penalizes incorrect changes in latent factors during interventions. The hyperparameters λtemp\lambda_{\mathrm{temp}} and λint\lambda_{\mathrm{int}} control how strongly the model weighs each constraint.

5.3 Transition prior in latent space

One of CITRIS's distinguishing features is its usage of a transition prior. Rather than letting ztz_t float freely through time, CITRIS posits a model:

zt+1=hψ(zt,at)+ηt z_{t+1} = h_\psi(z_t, a_t) + \eta_t

where hψh_\psi is a trainable function capturing how the latent state changes in response to the intervention ata_t. The residual noise ηt\eta_t accounts for unobserved variations or measurement errors. By modeling the transition explicitly, CITRIS ensures consistent evolution of latent factors, thereby reinforcing the view that these latent factors constitute true states of the system over time.

5.4 Identifiability guarantees (intuitive explanation)

When a model is said to be "identifiable", it means that, given enough data and certain mild assumptions, the model will recover the "correct" underlying factors rather than an arbitrary transformation of them. In the CRL context, identifiability is paramount because you want to be certain that the factor labeled "position" in your latent space is indeed the position in the real world, and not some mixture of position and color.

In CITRIS, identifiability arises from the interplay of three factors: (i) invertible mappings ensure that no factor gets lost; (ii) knowledge of interventions ensures that the model has unambiguous signals about how each factor can be altered; and (iii) temporal consistency ensures that consecutive observations remain aligned except for the factor that was manipulated. Under fairly general conditions spelled out in Lippe and gang (2022a), these constraints collectively pin down a unique mapping from data to latent factors, making CITRIS one of the first frameworks with robust identifiability claims in a practical domain.

6. Practical implementations

With the conceptual architecture of CITRIS in mind, the next question is: How do we implement it concretely? This section dives into real-world CITRIS frameworks that—for clarity—are grouped under titles like CITRIS-VAE and CITRIS-NF, each with slightly different design choices. Additionally, a target classifier can be integrated to further enhance the method's ability to capture meaningful causal factors.

6.1 CITRIS-VAE

CITRIS-VAE draws on the variational autoencoder (VAE) paradigm. VAEs are a mainstay in modern representation learning: they marry an encoder qϕ(zx)q_\phi(z \mid x) and a decoder pθ(xz)p_\theta(x \mid z) while optimizing a variational objective known as the ELBO (Evidence Lower BOund). In CITRIS-VAE, the standard VAE structure is augmented with the CITRIS constraints around temporal transitions and interventions.

6.1.1 Encoder-decoder structure

The encoder translates the input xtx_t into ztz_t, while the decoder attempts to reconstruct xtx_t from ztz_t. However, CITRIS-VAE's encoder is mindful of the fact that ztz_t evolves from zt1z_{t-1} under the transition prior, and also that certain factors in ztz_t might have been directly intervened upon. Plotting out the architecture might show multiple branches: one for forward transitions, one for bridging intervention information, and one for standard VAE analysis.

6.1.2 ELBO objective and KL divergence under the transition prior

In a classical VAE, the ELBO objective is:

LELBO(ϕ,θ)=Eqϕ(zx)[logpθ(xz)]βDKL(qϕ(zx)p(z)) \mathcal{L}_{\text{ELBO}}(\phi,\theta) = \mathbb{E}_{q_\phi(z \mid x)} \big[ \log p_\theta(x \mid z) \big] - \beta \, D_{\mathrm{KL}}\big(q_\phi(z \mid x) \parallel p(z)\big)

where DKLD_{\mathrm{KL}} denotes the Kullback–Leibler divergence and β\beta is a hyperparameter that can be used to control the emphasis on the KL term. In CITRIS-VAE, however, the prior p(z)p(z) is replaced by a transition prior p(ztzt1,at1)p(z_{t} \mid z_{t-1}, a_{t-1}). Consequently, the KL term becomes:

DKL(qϕ(ztxt)    p(ztzt1,at1)). D_{\mathrm{KL}}\Big(q_\phi(z_t \mid x_t) \;\Big\|\; p\big(z_t \mid z_{t-1}, a_{t-1}\big)\Big).

This means the model must respect not just reconstruction fidelity at each time step, but also the correct trajectory of latent states across time according to the known or learned transitions.

6.1.3 Assignment function via Gumbel-Softmax

Another critical piece of CITRIS-VAE is the notion of discrete variable assignment if the interventions can only target a subset of factors or if the representation needs to partition zz into categories. A popular technique for dealing with discrete variables in neural networks—and for enabling backpropagation through sampling steps—is the Gumbel-Softmax trick. In essence, Gumbel-Softmax reparametrizes the categorical sampling process to produce a differentiable approximation, making it possible to optimize the assignment of factors to different intervention targets in an end-to-end manner.

Below is a small code snippet illustrating how one might incorporate a Gumbel-Softmax approach in PyTorch:

<Code text={`
import torch
import torch.nn.functional as F

def gumbel_softmax_sample(logits, temperature=1.0):
    # Sample Gumbel noise
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
    # Combine logits with Gumbel noise
    y = logits + gumbel_noise
    # Apply softmax with temperature
    return F.softmax(y / temperature, dim=-1)

# Example usage
logits = torch.randn((16, 10))  # 16 samples, 10 categories
samples = gumbel_softmax_sample(logits, temperature=0.5)
`}/>

In CITRIS-VAE, these discrete assignments can track which latent factor was manipulated by an intervention, or which factor belongs to which dimension if partial factorization is used.

6.2 CITRIS-NF

While CITRIS-VAE leverages the convenience of variational inference, a complementary approach is CITRIS-NF, which stands for CITRIS using normalizing flows. Normalizing flows have gained popularity thanks to their ability to produce invertible transformations with tractable exact likelihoods.

6.2.1 Motivation for normalizing flows

In a standard VAE, the decoder function pθ(xz)p_\theta(x \mid z) might be flexible, but it is not necessarily invertible. If invertibility is a cornerstone of your CRL approach, normalizing flows are a natural fit. They ensure a one-to-one mapping that not only provides powerful density estimation capabilities, but also fosters a strong disentanglement of factors.

6.2.2 Autoencoder + flow decomposition

In practice, CITRIS-NF might look like an autoencoder where, after an initial compression of xx into zz, you apply a sequence of invertible transformations f1,f2,,fLf_1, f_2, \ldots, f_L. Each fif_i is designed such that it can be inverted analytically. A typical normalizing flow uses coupling layers, split transformations, or masked autoregressive flows that ensure invertibility with relatively low computational overhead.

zt(L)=fLfL1f1(zt(0)). z_t^{(L)} = f_L \circ f_{L-1} \circ \ldots \circ f_1\big(z_t^{(0)}\big).

Because these transformations are fully invertible, each dimension in zt(L)z_t^{(L)} can correspond to a unique causal factor. One simply inverts the flow to map that dimension back to the original data.

6.2.3 Invertibility and disentanglement benefits

The advantage of using flows within CITRIS is that you have a built-in mechanism for invertibility—no ad-hoc constraints require you to approximate inverses. Disentanglement of causal variables is consequently more direct: if you discover that dimension kk of zt(L)z_t^{(L)} is responsible for object size, you can easily intervene on that dimension, run the inverse flow, and see how the entire image changes.

6.3 Target classifier (optional component)

An additional module sometimes included in CITRIS is a target classifier. This classifier is trained (or co-trained) to predict which intervention was applied based on the latent representation. By maximizing the mutual information between the latent factors and the intervention targets, you inject an even stronger signal that the representation should separate out the manipulated dimension from the unaffected ones.

6.3.1 Purpose and mutual information with intervention targets

Intuitively, if interventions are labeled (e.g., you know exactly which factor was changed), the target classifier can help supervise the representation. The model learns that dimension ziz^i is a strong predictor of whether a color intervention happened, while dimension zjz^j might reflect a shape intervention. This synergy between CITRIS's generative or flow-based structure and a classifier is key to reinforcing factor disentanglement.

6.3.2 Example configurations

Depending on how your dataset is set up, you might have multiple classifiers each focusing on a particular subset of interventions. Alternatively, you might combine them into a single multi-class network. Either way, the presence of these classifiers can significantly speed up convergence, reduce identifiability ambiguities, and allow for easy debugging of which latent dimension is corresponding to which real-world factor.

7. Experiments and evaluations

The true litmus test of any CRL framework, especially CITRIS-based methods, lies in how effectively they disentangle and identify causal factors under controlled or semi-controlled settings. While real-world data is the ultimate domain of interest, standardized datasets—like Causal3DIdent—have been curated to test systematically whether a model can discover genuine causal variables and remain robust to interventions.

7.1 The Causal3DIdent dataset

A widely cited toy dataset for CRL experiments is Causal3DIdent. It provides rendered 3D objects under various transformations—like changes in rotation, color, background, or lighting. The dataset is specifically structured to incorporate known interventions, making it an excellent test bed for CITRIS.

7.1.1 Variables, interventions, and dataset structure

In Causal3DIdent, each image is generated by sampling from a set of causal factors. For example:

  1. Object shape (e.g., sphere, cube).
  2. Object color (various possible hues).
  3. Background hue.
  4. Lighting angle (spotlight rotation).
  5. Object position in the 2D plane.
  6. Object scale (sometimes included in extended versions).

Interventions might specify that only the object color was changed, or that the background color was held constant while the object was rotated. The dataset includes sequences, so one can see how applying an intervention modifies a particular factor over time.

7.1.2 Loading, sampling, and visualization

The dataset is typically available in a .npy or .h5 format, where each entry is a sequence of images, possibly with meta-information about which interventions were taken at each time step. A typical CITRIS-based experiment might:

• Load the dataset into memory, with each entry consisting of (xt,at)(x_t, a_t) pairs.
• Feed these pairs into the CITRIS encoder to produce ztz_t.
• Evaluate how well the model reconstructs xtx_t and how precisely it identifies the factor impacted by ata_t.

For quick debugging, you might visualize random samples of images before and after interventions to confirm that the dataset was loaded correctly and that the model sees meaningful differences between time steps.

7.2 Triplet evaluation

One interesting method used in some CRL experiments—and specifically for CITRIS-based tasks—is a triplet evaluation approach. The rationale is to measure whether the model can encode and recombine factors in a way that respects the real-world causal structure.

7.2.1 Constructing image triplets and masks

In the triplet approach, you take three images sampled from the dataset:

  1. xax_a: reference image
  2. xbx_b: an image that differs from xax_a in exactly one factor (e.g., background color)
  3. xcx_c: an image that might differ from xax_a in another factor (e.g., object shape)

You then encode these images into za,zb,zcz_a, z_b, z_c and produce a latent mask that indicates which dimensions differ. The expectation is that if xax_a and xbx_b differ in background color only, then the difference zbzaz_b - z_a should show up in precisely the dimension that controls background color—and not in dimensions controlling shape or lighting.

7.2.2 Encoding and recombining latent factors

After obtaining these encodings, you can try to combine them to synthesize new images that mix factors from xax_a and xbx_b. For instance, you might keep the shape from xax_a but adopt the background color from xbx_b. The newly generated image is then inspected to see if it matches the expected causal manipulation. If successful, this is strong evidence that the model has correctly disentangled each factor.

7.2.3 Qualitative and quantitative metrics

Qualitatively, you can see if the newly generated images look consistent when you attempt to combine or swap factors. Quantitatively, you can measure how well each latent dimension lines up with explicit ground-truth factors. Simple correlation or classification metrics might suffice, or you can compute metrics like the Adjusted Rand Index or Mutual Information Gap. The key takeaway is whether CITRIS is truly isolating each factor, thereby enabling easy recombination without cross-talk.

7.3 Performing interventions via latent space

Another hallmark experiment is to perform direct interventions in latent space and examine how the generated observations change. Because CITRIS is built for interventions, this step verifies its capacity to handle manipulations consistent with the real-world notion of causality.

7.3.1 Object rotation example

Consider controlling an object's rotation angle in the dataset. If zrotationz_{\text{rotation}} is indeed capturing rotation, then you should be able to fix all other dimensions in zz while systematically stepping through different values of zrotationz_{\text{rotation}}. Decoding these values using G(z)G(z) should yield images of the object rotating, with no other changes in color, background, scale, or position.

To implement this:

<Code text={`
# Hypothetical code snippet
import torch

def intervene_on_rotation(z, rotation_index, values):
    # z is the latent representation
    # rotation_index is the dimension controlling rotation
    # values is a list of angles or latent codes to test
    outputs = []
    for val in values:
        z_mod = z.clone()
        z_mod[:, rotation_index] = val
        outputs.append(decode(z_mod))
    return outputs
`}/>

You can then visually inspect or compare these output images to see if the rotation is indeed the only change.

7.3.2 Randomizing individual causal factors

Besides rotating an object, you can selectively randomize certain latent factors while freezing the others. If the model has learned a factor zcolorz_{\text{color}} that controls hue, randomizing just that factor should produce images in which only color changes. This procedure is a robust way to check for so-called infoconditional independence among latent factors: if color is truly separate, you can vary it without inadvertently affecting shape or position.

7.3.3 Visualizing the intervention results

To confirm that everything is working properly, it helps to produce side-by-side images showing zz with baseline factor settings and zz with the altered factor. You might place these images in a grid that highlights each factor's different levels. For instance:

mysterious_frog

An image was requested, but the frog was found.

Alt: "Visualization of rotating object"

Caption: "A grid of images where each row corresponds to a different shape, and each column to a different rotation angle."

Error type: missing path

Such visualizations provide immediate qualitative evidence of whether the interventions are performing as expected.

7.4 Analyzing the latent space

Last but not least, we come to the analytical tools that let you see how CITRIS organizes the latent space under the constraints of identifiability and minimality.

7.4.1 Tracking changes to position, spotlight rotation, and background hue

One approach is to track how each latent dimension correlates to a known ground-truth factor. For instance, you can measure the correlation between zkz^k and the horizontal position of the object across many images. If zkz^k is indeed controlling horizontal position, the correlation should be near 1, and you should find minimal correlation with other factors.

7.4.2 Bar plots of latent dimensions per causal variable

A convenient visualization is to create a bar plot that, for each ground-truth factor, shows the magnitude of correlation with each latent dimension. If CITRIS is working perfectly, you would see each factor strongly correlated with exactly one latent dimension, and near-zero correlation with all others. The bar plot might reveal a neat diagonal structure, confirming a near-perfect disentanglement.

7.4.3 Insights on disentanglement and factor independence

Ultimately, the best-case scenario is that each dimension in zz lines up with a single real-world cause. In such a scenario, CITRIS has successfully delivered on the promise of causal representation learning: you can manipulate a dimension confident that it corresponds to a meaningful causal property, and every other dimension remains untouched by that intervention.

8. Conclusion and future directions

CITRIS stands as one of the more compelling frameworks for bridging the gap between theoretical desiderata—like identifiability and disentanglement—and the messy reality of real-world data. By explicitly leveraging the TRIS assumption, CITRIS obtains stronger causal learning signals than purely observational or purely generative models can. The result is a method that does not merely capture correlations; it identifies minimal causal variables that can be manipulated, recombined, and tested across different domains or tasks.

8.1 Key takeaways from CITRIS

Several central insights emerge from the CITRIS blueprint:

Temporal consistency: By aligning consecutive frames unless intervened upon, CITRIS can infer stable causal factors over time.
Interventions: Observing how interventions selectively change certain latent dimensions is a powerful way to disambiguate correlation from real causation.
Invertibility: Enforcing a one-to-one mapping between latent codes and data fosters disentanglement and helps CITRIS secure identifiability guarantees.
Scalability: Although CITRIS is far from trivial to implement, certain versions (like iCITRIS and CITRIS-NF) demonstrate that normalizing flows and advanced autoencoder structures can scale to moderately complex data and interventions.

8.2 Strengths and limitations of current CRL methods

Like any new research domain, CRL, including CITRIS, is a work in progress:

Strengths: Substantial progress has been made in formulating theoretical guarantees for identifiability, in designing frameworks that handle interventions elegantly, and in structuring networks that remain faithful to causal constraints.
Limitations: Real-world data rarely offers neat, labeled interventions. In many domains, interventions might be partial, noisy, or confounded by unobserved external factors. Additionally, scaling CRL methods to extremely large or multi-object scenes (like entire cityscapes) remains an ongoing challenge. Finally, perfect invertibility assumptions can be strong, and approximate solutions might be needed in practice.

8.3 Potential applications and next steps

The future of CRL holds a wealth of opportunities. In reinforcement learning contexts, CRL could enable an agent to systematically plan interventions and interpret them as desired outcomes, leading to more sample-efficient training. In computer vision, CRL can help disentangle factors of variation for tasks such as scene editing, domain adaptation, and semantic manipulation of images and videos. In robotics, it can help decipher which aspects of a system's state truly matter for control and which aspects are merely ephemeral byproducts.

For ongoing research, questions like how to handle partial observability, how to combine CRL with large-scale foundation models, and how to systematically evaluate real vs. spurious causal factors remain frontiers. Another open-ended inquiry is bridging CITRIS with more rigorous Bayesian frameworks, thereby combining interpretability and uncertainty quantification. Researchers have begun to explore expansions of CITRIS that incorporate more flexible priors, more advanced normalizing flow architectures, or additional structured interventions.

Ultimately, the quest to discover and manipulate true causal variables in data aligns closely with the overarching ambition of AI: to build models that do not just reflect the world as is, but that can robustly reason about—and act within—it. As CRL matures, expect to see it increasingly embedded in real-world applications, from automated medical diagnosis to interactive content-generation pipelines, forging a path toward machine learning models that are stable, interpretable, and capable of informed decision-making under interventions.

I encourage you to revisit the tutorial notebooks and pre-trained models mentioned earlier. Attempt the recommended tasks, replicate the CITRIS experiments on Causal3DIdent, manipulate latent factors, and evaluate the results. Reflect on how well the method generalizes across diverse settings and whether it lives up to theoretical claims. With a solid grounding in these key ideas, you will be poised to explore the wide-open horizon of causal representation learning research, fueled by frameworks like CITRIS and its successors.

kofi_logopaypal_logopatreon_logobtc-logobnb-logoeth-logo
kofi_logopaypal_logopatreon_logobtc-logobnb-logoeth-logo