metahunt() predicts a function on the grid; conformal
routines can return a band at every grid point. Often, though, the
inferential target is a single number derived from that function:
For all of these, MetaHunt accepts a wrapper argument
that collapses the predicted function to a scalar
before any further calculation. The same wrapper is applied identically
to predictions and to calibration residuals, so conformal coverage
transfers directly to the scalar summary.
apply_wrapper(F_mat, wrapper, grid_weights) defines the
contract.
F_mat is an n-by-G_grid
numeric matrix; row j is one function on the grid.wrapper is NULL,
apply_wrapper() returns the weighted mean of each row using
grid_weights (uniform 1/G_grid by default),
divided by sum(grid_weights).wrapper is a function, apply_wrapper()
calls apply(F_mat, 1, wrapper), which means the
wrapper receives a single numeric vector of length
G_grid — one row of F_mat at a time —
and must return a single numeric value.The contract therefore is:
wrapper :: numeric vector of length G_grid -> numeric scalar
Any function satisfying that signature is a valid wrapper. The package then enforces post-hoc that the result is numeric and has exactly one entry per row.
grf::causal_forestWe simulate a multi-site clinical trial with m = 8
sites. Each site has its own individual-level data \((Y, X, T)\) where \(Y\) is a continuous outcome, \(X\) is a single patient covariate
(age), and \(T\) is binary
treatment. The site-level CATE function \(\tau^{(i)}(\text{age}) = E[Y(1) - Y(0) \mid
\text{age}, \text{site} = i]\) varies across sites in a way that
depends on the site’s metadata. Each site fits its own
grf::causal_forest on its individual-level data, and shares
only the fitted model — not the patient data — with us.
m <- 8
n_per_site <- 200
G <- 30
W <- data.frame(
year = sample(2010:2020, m, replace = TRUE),
pct_treated = round(runif(m, 0.3, 0.6), 2)
)
site_data_list <- lapply(seq_len(m), function(i) {
age <- runif(n_per_site, 30, 80)
T <- rbinom(n_per_site, 1, W$pct_treated[i])
site_eff <- (W$year[i] - 2015) / 5 # site-level shift in CATE
tau_age <- 0.02 * (age - 50) + site_eff
Y0 <- 0.01 * age + rnorm(n_per_site, sd = 0.5)
Y1 <- Y0 + tau_age
Y <- ifelse(T == 1, Y1, Y0)
data.frame(Y = Y, age = age, T = T)
})
grid <- data.frame(age = seq(30, 80, length.out = G))Each site fits its own causal_forest. We use
num.trees = 200 to keep the vignette fast; in practice you
would use the default 2000 or more.
cf_models <- lapply(site_data_list, function(d)
grf::causal_forest(X = matrix(d$age, ncol = 1),
Y = d$Y,
W = d$T,
num.trees = 200))We stack the per-site CATE estimates on the shared age
grid into the m-by-G matrix
F_hat. Here we pass an explicit predict_fn to
illustrate the general pattern; the dispatch table inside
f_hat_from_models() already knows how to call
causal_forest, so for users on standard
grf::causal_forest, the default predict_fn is
sufficient and you can omit the predict_fn argument.
cate_predict <- function(model, grid) {
as.numeric(stats::predict(model, newdata = matrix(grid$age, ncol = 1))$predictions)
}
F_hat <- f_hat_from_models(cf_models, grid, predict_fn = cate_predict)
dim(F_hat)
#> [1] 8 30We now fit metahunt() on (F_hat, W) and ask
for the predicted ATE at a hypothetical new site.
fit <- metahunt(F_hat, W, K = 3, dfspa_args = list(denoise = FALSE))
W_new <- data.frame(year = 2018, pct_treated = 0.45)
ate_pred <- predict(fit, newdata = W_new, wrapper = mean)
ate_pred
#> [1] 0.9247137The scalar ate_pred is the predicted average treatment
effect for a hypothetical new site with metadata
(year = 2018, pct_treated = 0.45), taking the unweighted
mean over the 30-point age grid.
Below are three short, self-contained wrappers, each illustrating a
different idea. All three are applied to the F_hat,
fit, and W_new constructed in the previous
section.
meanmean is already a function
numeric -> numeric, so it is a valid wrapper. With a
uniform grid this is just the unweighted average of the function over
the grid — i.e. the grid-uniform ATE.
Suppose we only credit treatment effects that are positive (for
example, in a cost-effectiveness setting). The wrapper averages
max(f(x), 0) over the grid:
restricted_pos_mean <- function(f) sum(pmax(f, 0)) / length(f)
predict(fit, newdata = W_new, wrapper = restricted_pos_mean)
#> [1] 0.9247137Because every row of F_mat is passed in turn,
f inside the wrapper is just a numeric vector of length
G_grid. length(f) is therefore the grid size,
and dividing by it gives a uniform-weighted average.
The difference f(x_G) - f(x_1) is a useful summary when
the grid is ordered (e.g. age, dose, or time). For our age grid it is
the gap in CATE between an 80-year-old and a 30-year-old patient at the
new site:
When you pass wrapper into
split_conformal() (or cross_conformal(), or
conformal_from_fit()), conformity scores are computed
after the wrapper, on a single shared
quantile. The interval covers the wrapped scalar with the
nominal level — not the underlying function pointwise.
With only m = 8 sites, we hold out a single site (the
8th) and use the other seven for training plus calibration. The
calibration set is small, so we use alpha = 0.1 rather than
0.05.
# Use 7 sites for training+calibration, predict for the held-out 8th
tr_cal <- 1:7; new <- 8
res <- split_conformal(
F_hat[tr_cal, , drop = FALSE],
W[tr_cal, , drop = FALSE],
W[new, , drop = FALSE],
K = 3, wrapper = mean, alpha = 0.1, cal_frac = 0.5, seed = 1,
dfspa_args = list(denoise = FALSE)
)
#> Warning in .build_conformal_output(obs_cal = F_cal, pred_cal = pred_cal, : With
#> n_cal = 3 and alpha = 0.1, the conformal quantile is infinite; intervals are
#> unbounded. Increase calibration size or use a larger `alpha`.
data.frame(prediction = res$prediction,
lower = res$lower,
upper = res$upper)
#> prediction lower upper
#> 1 -0.8307923 -Inf InfWith only 8 sites in this realistic example, an empirical-coverage
check on a single held-out site is not informative — for coverage
diagnostics, use a leave-one-out loop or simulate a larger study count.
See ?coverage for the helper function and the
conformal-prediction vignette for split-conformal at
scale.
| Aspect | Pointwise (wrapper = NULL) |
Scalar (wrapper supplied) |
|---|---|---|
| Output shape | nrow(W_new) x G_grid matrix |
length-nrow(W_new) numeric vector |
| Conformal quantile | one per grid point (length-G_grid) |
a single scalar |
| Coverage guarantee | per grid point, marginally (not joint over grid) | for the scalar summary, marginally |
| Best for | visualising the predicted function with a band | reporting a single number with a valid CI |
| Example call | split_conformal(F, W, W_new, K = 3) |
split_conformal(F, W, W_new, K = 3, wrapper = mean) |
A pointwise band is a visualisation aid; a scalar interval is the right object for an inferential claim about a specific functional. Pick the wrapper that matches the question you actually want to answer, and let the conformal machinery do the rest.
vignette("data-prep") — building F_hat
from per-site fitted models (including the
grf::causal_forest dispatch and the predict_fn
escape hatch used here).vignette("conformal-prediction") — split- and
cross-conformal routines at scale, including empirical-coverage
diagnostics that need more than a handful of held-out sites.