Estimation and Inference of Heterogeneous Treatment Effects Using Random Forests

Causal Inference
Machine Learning
Random Forests
Treatment Effects
Breaking down Wager & Athey’s causal forest — how they turned random forests into a principled tool for estimating who benefits most from a treatment, complete with valid confidence intervals.
Author

Sean Lewis

Published

February 20, 2026

The Gist

Random forests are everywhere in applied machine learning, but they’ve always had an awkward relationship with statistics: great predictions, no confidence intervals. You can’t do inference with a standard random forest — there’s no valid way to say “this estimate is \(X \pm Y\) with 95% confidence.”

Wager and Athey (2018, Annals of Statistics) solve this problem for a specific and important target: the conditional average treatment effect (CATE). Their causal forest is a random forest adapted to estimate how treatment effects vary across individuals — and, crucially, it comes with asymptotically valid confidence intervals. This means you can not only predict who benefits most from a treatment, but quantify the uncertainty of that prediction with the same rigor as classical econometrics.

Why It Matters Now

Heterogeneous treatment effect estimation has become a cornerstone of modern data science. Tech companies use it to personalize interventions (which users should see which ad? which patients benefit from which drug?). Policy evaluators use it to target programs. The challenge has always been: how do you move from “the average effect is X” to “the effect for this person is X” without overfitting or losing inferential validity?

Before this paper, you had two unsatisfying options: parametric models (linear regressions with interactions) that were interpretable but likely misspecified, or flexible ML methods (boosted trees, neural nets) that could capture complex heterogeneity but offered no uncertainty quantification. Causal forests sit at the intersection — flexible enough to discover nonlinear heterogeneity, rigorous enough to provide valid standard errors.

The paper has over 3,000 citations and spawned the widely-used grf (Generalized Random Forests) R package, which has become the default tool for CATE estimation in applied economics and the social sciences.

The Key Insight: Honest Estimation

The central innovation is honesty. In a standard random forest, the same data that determines the tree splits also generates the predictions. This double-dipping biases the estimates and makes inference impossible. Wager and Athey’s fix: split each tree’s training sample in two.

  1. Splitting sample (I): Used only to determine the tree structure — where to partition the covariate space.
  2. Estimation sample (J): Used only to estimate the treatment effect within each leaf.

Because the estimation sample played no role in choosing the partition, the leaf-level estimates are unbiased conditional on the tree structure. This separation is what makes valid inference possible.

flowchart TD
    A["Training Data<br/>(X, Y, W)"] --> B["For each tree:<br/>Draw subsample of size s"]
    B --> C["Split subsample in half"]
    C --> D["Splitting Sample (I)<br/>Determines tree structure"]
    C --> E["Estimation Sample (J)<br/>Estimates leaf effects"]
    D --> F["Grow tree by maximizing<br/>treatment effect heterogeneity"]
    F --> G["Tree Structure Fixed"]
    G --> H["Drop estimation sample<br/>into fixed tree"]
    E --> H
    H --> I["Estimate τ(x) in each leaf<br/>using only J observations"]
    I --> J["Average across B trees<br/>→ Causal Forest prediction"]
    J --> K["Asymptotically Normal<br/>→ Valid confidence intervals"]

The Lineage: Where This Fits

The intellectual thread here braids together two traditions that rarely talk to each other:

From the ML side: Breiman’s random forests (2001) and the subsequent literature on ensemble methods. Random forests had strong empirical performance but essentially no inferential theory. Breiman himself was skeptical of the statistics community’s focus on models and inference, famously arguing in his “Two Cultures” paper (2001) that prediction accuracy was what mattered.

From the causal inference side: The Rubin causal model and the literature on treatment effect heterogeneity. Economists and biostatisticians cared deeply about uncertainty quantification but relied on parametric models that couldn’t capture complex heterogeneity without manual feature engineering.

The tension: ML had the flexibility but not the theory; econometrics had the theory but not the flexibility.

Key predecessors in bridging this gap include Athey and Imbens (2016) on “honest” recursive partitioning for causal effects (the single-tree version), the broader literature on semiparametric efficiency theory (Hahn 1998, Hirano-Imbens-Ridder 2003), and the random forest consistency results by Biau (2012) and Scornet, Biau, and Vert (2015).

Method Flexibility Valid Inference Heterogeneity
OLS with interactions Low Yes Manual specification
LASSO/Ridge Medium Approximate Linear only
Bayesian CART (BART) High Bayesian credible intervals Yes
Standard random forest High No Yes (point estimates)
Causal forest High Yes (asymptotic) Yes (automatic)
Neural net (e.g., TARNet) Very high No (without extra work) Yes

How It Works: The Technical Core

The Estimand

For individual \(i\) with covariates \(X_i\), treatment \(W_i \in \{0, 1\}\), and outcome \(Y_i\), the target is:

\[ \tau(x) = E[Y_i(1) - Y_i(0) \mid X_i = x] \]

This is the conditional average treatment effect (CATE) — how much the treatment helps someone with characteristics \(x\).

The Forest Estimator

A causal forest estimates \(\tau(x)\) as a weighted average of observed treatment effects:

\[ \hat{\tau}(x) = \sum_{i=1}^{n} \alpha_i(x) \cdot Y_i \]

where the weights \(\alpha_i(x)\) are determined by how often observation \(i\) falls in the same leaf as the target point \(x\) across trees. Observations that frequently “neighbor” \(x\) in the forest get high weights.

Critically, these weights come only from the estimation samples — the observations that did not participate in building the tree structures. This is the honesty property at work.

The Splitting Rule

Standard random forests split to minimize prediction error (e.g., MSE of \(Y\)). Causal forests instead split to maximize heterogeneity of treatment effects. At each candidate split, the algorithm asks: does dividing the data here create child nodes with substantially different average treatment effects?

The paper proposes maximizing:

\[ \Delta(\text{split}) = \left(\hat{\tau}_{\text{left}} - \hat{\tau}_{\text{right}}\right)^2 \]

weighted by the number of observations on each side. This ensures the tree partitions the covariate space along the dimensions where treatment effect heterogeneity is greatest.

The Subsampling Trick

Standard random forests use bootstrap sampling (sample \(n\) observations with replacement). Causal forests instead use subsampling without replacement — drawing \(s \ll n\) observations per tree. This is crucial for the asymptotic theory: it ensures each tree is built on a vanishing fraction of the data, which controls the correlation between trees and enables the central limit theorem to kick in.

The subsample size \(s\) must satisfy \(s/n \to 0\) and \(s \to \infty\) as \(n \to \infty\). In practice, the grf package uses \(s = n^{0.7}\) or similar.

Asymptotic Normality

The main theoretical result: under regularity conditions (honesty, random splitting, subsampling, and a smoothness condition on \(\tau(x)\)), the causal forest estimator is asymptotically normal:

\[ \frac{\hat{\tau}(x) - \tau(x)}{\sqrt{\text{Var}[\hat{\tau}(x)]}} \xrightarrow{d} N(0, 1) \]

The variance can be consistently estimated using the infinitesimal jackknife, giving valid confidence intervals without bootstrap.

Key Results

The paper validates causal forests on simulations designed to test whether the method finds heterogeneity where it exists and avoids finding it where it doesn’t.

Simulation Setup True Heterogeneity Causal Forest Finding
Constant treatment effect None Correctly finds no heterogeneity
Effect varies with 1 covariate Strong, 1-D Recovers the pattern
Effect varies nonlinearly Complex Captures nonlinear CATE
Many irrelevant covariates Sparse Focuses on relevant dimensions
Observational data (confounded) Varies Works with propensity score adjustment

The paper also applies causal forests to the National Job Training Partnership Act (JTPA) study — a randomized experiment evaluating job training programs. The forest finds meaningful heterogeneity: the program helped some subgroups substantially while having near-zero effects on others.

What to Watch Out For

  1. Overlap is still required: Like all methods under unconfoundedness, causal forests need overlap — every individual must have a nonzero probability of receiving either treatment. If some subgroups are never treated, no method can estimate their treatment effect.

  2. Subsample size tuning: The asymptotic theory requires \(s/n \to 0\), but in finite samples, choosing \(s\) matters. Too small loses power; too large breaks the theory. The grf package handles this automatically, but be aware it’s a tuning choice.

  3. Honesty costs statistical power: Splitting the sample in half for each tree means each piece is smaller. You’re trading some efficiency for inferential validity — a worthwhile trade in most applications, but worth noting.

  4. Not a silver bullet for observational data: The paper assumes unconfoundedness. If there are unobserved confounders, the CATE estimates will be biased regardless of the forest’s flexibility. Causal forests estimate effects under the assumption, not test the assumption.

  5. Confidence intervals are pointwise, not uniform: The asymptotic normality result is for a fixed point \(x\). If you’re testing many points simultaneously (e.g., plotting a confidence band across all \(x\)), you need a multiple testing correction.

Takeaways

Wager and Athey did something genuinely difficult: they took one of the most successful ML algorithms ever invented and gave it a rigorous inferential foundation for causal questions. The practical implication is that you can now use a single, well-understood tool to answer both “who benefits?” and “how confident are we?” — questions that previously required separate methodologies. The grf package makes this accessible with a clean API and sensible defaults.


Reproduction & Implementation

Environment Setup

# R (primary implementation)
install.packages("grf")       # Generalized Random Forests (>= 2.3.0)
install.packages("policytree") # Optimal policy learning from GRF

# R version >= 4.1
# grf >= 2.3.0
# Python alternative
pip install econml        # Microsoft's EconML (>=0.14)
pip install scikit-learn  # For comparison baselines

# Python 3.10+
# econml >= 0.14
# scikit-learn >= 1.3

Core Algorithm: Causal Forest in R

library(grf)

# --- Simulated data ---
n <- 2000; p <- 10
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 0.5)                     # Random treatment
tau_true <- pmax(X[,1], 0)                  # True CATE: depends on X1
Y <- tau_true * W + rnorm(n)               # Observed outcome

# --- Fit causal forest ---
cf <- causal_forest(X, Y, W,
  num.trees = 2000,
  honesty = TRUE,              # Honest splitting (default)
  honesty.fraction = 0.5       # 50/50 split
)

# --- Predict CATEs with confidence intervals ---
preds <- predict(cf, estimate.variance = TRUE)
tau_hat <- preds$predictions
sigma_hat <- sqrt(preds$variance.estimates)

# 95% confidence intervals
ci_lower <- tau_hat - 1.96 * sigma_hat
ci_upper <- tau_hat + 1.96 * sigma_hat

# --- Test for heterogeneity ---
test_calibration(cf)

# --- Variable importance ---
variable_importance(cf)

# --- Average treatment effect ---
average_treatment_effect(cf, target.sample = "all")

Python Equivalent (EconML)

from econml.dml import CausalForestDML
import numpy as np

np.random.seed(42)
n, p = 2000, 10
X = np.random.randn(n, p)
W = np.random.binomial(1, 0.5, n)
tau_true = np.maximum(X[:, 0], 0)
Y = tau_true * W + np.random.randn(n)

# Fit causal forest
cf = CausalForestDML(
    n_estimators=2000,
    min_samples_leaf=5,
    random_state=42
)
cf.fit(Y, W, X=X)

# Point estimates
tau_hat = cf.effect(X)

# Confidence intervals
ci = cf.effect_interval(X, alpha=0.05)
lower, upper = ci[0], ci[1]

# Summary
print(f"Mean CATE: {tau_hat.mean():.3f}")
print(f"Mean CI width: {(upper - lower).mean():.3f}")