This vignette is a guide to policy_eval() and some of
the associated S3 methods. The purpose of policy_eval is to
estimate (evaluate) the value of a user-defined policy or a policy
learning algorithm. For details on the methodology, see the associated
paper (Nordland and Holst 2023).
We consider a fixed two-stage problem as a general setup and simulate
data using sim_two_stage() and create a
policy_data object using policy_data():
d <- sim_two_stage(n = 2e3, seed = 1)
pd <- policy_data(d,
                  action = c("A_1", "A_2"),
                  baseline = c("B", "BB"),
                  covariates = list(L = c("L_1", "L_2"),
                                    C = c("C_1", "C_2")),
                  utility = c("U_1", "U_2", "U_3"))
pd## Policy data with n = 2000 observations and maximal K = 2 stages.
## 
##      action
## stage    0    1    n
##     1 1017  983 2000
##     2  819 1181 2000
## 
## Baseline covariates: B, BB
## State covariates: L, C
## Average utility: 0.84User-defined policies are created using policy_def(). In
this case we define a simple static policy always selecting action
'1':
As we want to apply the same policy function at both stages we set
reuse = TRUE.
policy_eval() implements three types of policy
evaluations: Inverse probability weighting estimation, outcome
regression estimation, and doubly robust (DR) estimation. As doubly
robust estimation is a combination of the two other types, we focus on
this approach. For details on the implementation see Algorithm 1 in
(Nordland and Holst 2023).
##                  Estimate Std.Err   2.5% 97.5%   P-value
## E[Z(d)]: d=(A=1)   0.8213  0.1115 0.6027  1.04 1.796e-13policy_eval() returns an object of type
policy_eval which prints like a lava::estimate
object. The policy value estimate and variance are available via
coef() and vcov():
## [1] 0.8213233##            [,1]
## [1,] 0.01244225policy_eval objectsThe policy_eval object behaves like an
lava::estimate object, which can also be directly accessed
using estimate().
estimate objects makes it easy to work with estimates
with an iid decomposition given by the influence curve/function, see the
estimate
vignette.
The influence curve is available via IC():
##            [,1]
## [1,]  2.5515875
## [2,] -5.6787782
## [3,]  4.9506000
## [4,]  2.0661524
## [5,]  0.7939672
## [6,] -2.2932160Merging estimate objects allow the user to get inference
for transformations of the estimates via the Delta method. Here we get
inference for the average treatment effect, both as a difference and as
a ratio:
p0 <- policy_def(policy_functions = 0, reuse = TRUE, name = "(A=0)")
pe0 <- policy_eval(policy_data = pd,
                   policy = p0,
                   type = "dr")
(est <- merge(pe0, pe1))##                  Estimate Std.Err    2.5%  97.5%   P-value
## E[Z(d)]: d=(A=0) -0.06123  0.0881 -0.2339 0.1114 4.871e-01
## ────────────────                                          
## E[Z(d)]: d=(A=1)  0.82132  0.1115  0.6027 1.0399 1.796e-13##                Estimate Std.Err   2.5% 97.5%  P-value
## ATE-difference   0.8825  0.1338 0.6203 1.145 4.25e-11##           Estimate Std.Err   2.5% 97.5% P-value
## ATE-ratio   -13.41    19.6 -51.83    25  0.4937So far we have relied on the default generalized linear models for
the nuisance g-models and Q-models. As default, a single g-model trained
across all stages using the state/Markov type history, see the
policy_data vignette. Use get_g_functions() to
get access to the fitted model:
## $all_stages
## $model
## 
## Call:  NULL
## 
## Coefficients:
## (Intercept)            L            C            B     BBgroup2     BBgroup3  
##     0.08285      0.03094      0.97993     -0.05753     -0.13970     -0.06122  
## 
## Degrees of Freedom: 3999 Total (i.e. Null);  3994 Residual
## Null Deviance:       5518 
## Residual Deviance: 4356  AIC: 4368
## 
## 
## attr(,"full_history")
## [1] FALSEThe g-functions can be used as input to a new policy evaluation:
##                  Estimate Std.Err    2.5%  97.5% P-value
## E[Z(d)]: d=(A=0) -0.06123  0.0881 -0.2339 0.1114  0.4871or we can get the associated predicted values:
## Key: <id, stage>
##       id stage        g_0        g_1
##    <int> <int>      <num>      <num>
## 1:     1     1 0.15628741 0.84371259
## 2:     1     2 0.08850558 0.91149442
## 3:     2     1 0.92994454 0.07005546
## 4:     2     2 0.92580890 0.07419110
## 5:     3     1 0.11184451 0.88815549
## 6:     3     2 0.08082666 0.91917334Similarly, we can inspect the Q-functions using
get_q_functions():
## $stage_1
## $model
## 
## Call:  NULL
## 
## Coefficients:
## (Intercept)           A1            L            C            B     BBgroup2  
##    0.232506     0.682422     0.454642     0.039021    -0.070152    -0.184704  
##    BBgroup3         A1:L         A1:C         A1:B  A1:BBgroup2  A1:BBgroup3  
##   -0.171734    -0.010746     0.938791     0.003772     0.157200     0.270711  
## 
## Degrees of Freedom: 1999 Total (i.e. Null);  1988 Residual
## Null Deviance:       7689 
## Residual Deviance: 3599  AIC: 6877
## 
## 
## $stage_2
## $model
## 
## Call:  NULL
## 
## Coefficients:
## (Intercept)           A1            L            C            B     BBgroup2  
##   -0.043324     0.147356     0.002376    -0.042036     0.005331    -0.001128  
##    BBgroup3         A1:L         A1:C         A1:B  A1:BBgroup2  A1:BBgroup3  
##   -0.108404     0.024424     0.962591    -0.059177    -0.102084     0.094688  
## 
## Degrees of Freedom: 1999 Total (i.e. Null);  1988 Residual
## Null Deviance:       3580 
## Residual Deviance: 1890  AIC: 5588
## 
## 
## attr(,"full_history")
## [1] FALSENote that a model is trained for each stage. Again, we can predict
from the Q-models using predict().
Usually, we want to specify the nuisance models ourselves using the
g_models and q_models arguments:
pe1 <- policy_eval(pd,
            policy = p1,
            g_models = list(
              g_sl(formula = ~ BB + L_1, SL.library = c("SL.glm", "SL.ranger")),
              g_sl(formula = ~ BB + L_1 + C_2, SL.library = c("SL.glm", "SL.ranger"))
            ),
            g_full_history = TRUE,
            q_models = list(
              q_glm(formula = ~ A * (B + C_1)), # including action interactions
              q_glm(formula = ~ A * (B + C_1 + C_2)) # including action interactions
            ),
            q_full_history = TRUE)## Loading required namespace: rangerHere we train a super learner g-model for each stage using the full
available history and a generalized linear model for the Q-models. The
formula argument is used to construct the model frame
passed to the model for training (and prediction). The valid formula
terms depending on g_full_history and
q_full_history are available via
get_history_names():
## [1] "L"  "C"  "B"  "BB"## [1] "L_1" "C_1" "B"   "BB"## [1] "A_1" "L_1" "L_2" "C_1" "C_2" "B"   "BB"Remember that the action variable at the current stage is always
named A. Some models like glm require
interactions to be specified via the model frame. Thus, for some models,
it is important to include action interaction terms for the
Q-models.
The value of a learned policy is an important performance measure,
and policy_eval() allow for direct evaluation of a given
policy learning algorithm. For details, see Algorithm 4 in (Nordland and Holst 2023).
In polle, policy learning algorithms are specified using
policy_learn(), see the associated vignette. These
functions can be directly evaluated in policy_eval():
##               Estimate Std.Err  2.5% 97.5%   P-value
## E[Z(d)]: d=ql    1.306 0.06641 1.176 1.437 3.783e-86In the above example we evaluate the policy estimated via Q-learning.
Alternatively, we can first learn the policy and then pass it to
policy_eval():
p_ql <- policy_learn(type = "ql")(pd, q_models = q_glm())
policy_eval(pd,
            policy = get_policy(p_ql))##               Estimate Std.Err  2.5% 97.5%   P-value
## E[Z(d)]: d=ql    1.306 0.06641 1.176 1.437 3.783e-86A key feature of policy_eval() is that it allows for
easy cross-fitting of the nuisance models as well the learned policy.
Here we specify two-fold cross-fitting via the M
argument:
Specifically, both the nuisance models and the optimal policy are fitted on each training fold. Subsequently, the doubly robust value score is calculated on the validation folds.
The policy_eval object now consists of a list of
policy_eval objects associated with each fold:
## [1]  3  4  5  7  8 10##               Estimate Std.Err  2.5% 97.5%   P-value
## E[Z(d)]: d=ql    1.261 0.09456 1.075 1.446 1.538e-40In order to save memory, particularly when cross-fitting, it is
possible not to save the nuisance models via the
save_g_functions and save_q_functions
arguments.
future.applyIt is easy to parallelize the cross-fitting procedure via the
future.apply package:
## R version 4.4.1 (2024-06-14)
## Platform: aarch64-apple-darwin23.5.0
## Running under: macOS Sonoma 14.6.1
## 
## Matrix products: default
## BLAS:   /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRblas.dylib 
## LAPACK: /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRlapack.dylib;  LAPACK version 3.12.0
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## time zone: Europe/Copenhagen
## tzcode source: internal
## 
## attached base packages:
## [1] splines   stats     graphics  grDevices utils     datasets  methods  
## [8] base     
## 
## other attached packages:
## [1] ggplot2_3.5.1       data.table_1.15.4   polle_1.5          
## [4] SuperLearner_2.0-29 gam_1.22-4          foreach_1.5.2      
## [7] nnls_1.5           
## 
## loaded via a namespace (and not attached):
##  [1] sass_0.4.9          utf8_1.2.4          future_1.33.2      
##  [4] lattice_0.22-6      listenv_0.9.1       digest_0.6.36      
##  [7] magrittr_2.0.3      evaluate_0.24.0     grid_4.4.1         
## [10] iterators_1.0.14    mvtnorm_1.2-5       policytree_1.2.3   
## [13] fastmap_1.2.0       jsonlite_1.8.8      Matrix_1.7-0       
## [16] survival_3.6-4      fansi_1.0.6         scales_1.3.0       
## [19] numDeriv_2016.8-1.1 codetools_0.2-20    jquerylib_0.1.4    
## [22] lava_1.8.0          cli_3.6.3           rlang_1.1.4        
## [25] mets_1.3.4          parallelly_1.37.1   future.apply_1.11.2
## [28] munsell_0.5.1       withr_3.0.0         cachem_1.1.0       
## [31] yaml_2.3.8          tools_4.4.1         parallel_4.4.1     
## [34] colorspace_2.1-0    ranger_0.16.0       globals_0.16.3     
## [37] vctrs_0.6.5         R6_2.5.1            lifecycle_1.0.4    
## [40] pkgconfig_2.0.3     timereg_2.0.5       progressr_0.14.0   
## [43] bslib_0.7.0         pillar_1.9.0        gtable_0.3.5       
## [46] Rcpp_1.0.13         glue_1.7.0          xfun_0.45          
## [49] tibble_3.2.1        highr_0.11          knitr_1.47         
## [52] farver_2.1.2        htmltools_0.5.8.1   rmarkdown_2.27     
## [55] labeling_0.4.3      compiler_4.4.1