---
title: "Estimating ATE and CATE using Causal Forests"
subtitle: ""
format:
html:
warnings: false
error: false
messages: false
code-overflow: scroll
highlight-style: Ayu
code-line-numbers: true
code-fold: false
code-tools:
source: true
toggle: false
html-math-method: katex
reference-location: margin
citation-location: margin
cap-location: margin
code-block-border-left: true
bibliography: /Users/joseph/GIT/templates/bib/references.bib
editor_options:
chunk_output_type: console
---
```{r}
#| label: setup
#| echo: false
#| include: false
#| eval: true
# ββ initialisation and setup ββββββββββββββββββββββββββββββββββββββββββββββββββ
# save this file in your project root (e.g. 'quarto/1_setup.R')
# load here to manage paths
dep <- requireNamespace("here", quietly = TRUE)
if (!dep) install.packages("here")
library(here)
# create required folders (these will likely already exist)
dirs <- c(
here("quarto"),
here("bibliography"),
here("save_directory"),
here("csl")
)
for (d in dirs) {
if (!dir.exists(d)) dir.create(d, recursive = TRUE)
}
# ensure tinytex for PDF rendering
if (!requireNamespace("tinytex", quietly = TRUE)) {
install.packages("tinytex")
tinytex::install_tinytex()
}
# ensure pacman for package management
if (!requireNamespace("pacman", quietly = TRUE)) {
install.packages("pacman")
}
# min version of margot
if (packageVersion("margot") < "1.0.233") {
stop(
"please install margot >= 1.0.233 for this workflow\n
run:
devtools::install_github('go-bayes/margot')
"
)
}
# call library
library("margot")
# check package version
packageVersion(pkg = "margot")
# load (and install if needed) all required packages
pacman::p_load(
boilerplate, shiny, tidyverse,
kableExtra, glue, patchwork, stringr,
ggplot2, ggeffects, parameters,
table1, knitr, extrafont, here, cli
)
# load fonts (requires prior extrafont::font_import())
if (requireNamespace("extrafont", quietly = TRUE)) {
extrafont::loadfonts(device = "all")
} else {
message("'extrafont' not installed; skipping font loading")
}
# reproducibility
set.seed(123)
# copy CSL and BibTeX into quarto folder
src_files <- list(
c(here("csl", "apa7.csl"), here("quarto", "apa7.csl")),
c(here("bibliography", "references.bib"), here("quarto", "references.bib"))
)
for (f in src_files) {
if (!file.exists(f[2]) && file.exists(f[1])) {
file.copy(f[1], f[2])
}
}
# ββ define paths and import data ββββββββββββββββββββββββββββββββββββββββββββββ
push_mods <- here::here("results")
master_path <- "/Users/joseph/GIT/templates/boilerplate/data/boilerplate_unified.json"
# set directory database path
unified_db <- boilerplate_import(data_path = master_path)
# ββ import data for visualisation βββββββββββββββββββββββββββββββββββββββββββββ
original_df <- readRDS(here::here("data", "religious_prosocial_data.rds"))
# data used in model fitting
df_grf <- margot::here_read('data_standardised', dir_path = here::here('data'))
# import
label_mapping = here_read("label_mapping", push_mods)
# # co-variates
# E <- margot::here_read('E', push_mods)
# # select covariates and drop numeric attributes
# X <- margot::remove_numeric_attributes(df_grf[E])
# ββ define nice names and regimes βββββββββββββββββββββββββββββββββββββββββββββ
# title
ate_title <- glue("ATE Effects of Belief in God on Cooperative outcomes")
# ββ set plot defaults for ate plots βββββββββββββββββββββββββββββββββββββββββββ
base_defaults_binary <- list(
type = "RD",
title = ate_title,
e_val_bound_threshold = 1.2,
colors = c(
"positive" = "#E69F00",
"not reliable" = "grey50",
"negative" = "#56B4E9"
),
x_offset = -.5,
x_lim_lo = -.5,
x_lim_hi = 1,
text_size = 5,
linewidth = 0.75,
estimate_scale = 1,
base_size = 20,
point_size = 4,
title_size = 20,
subtitle_size = 16,
legend_text_size = 10,
legend_title_size = 10,
include_coefficients = FALSE
)
# ββ create plot options for outcomes ββββββββββββββββββββββββββββββββββββββββββ
outcomes_options_all <- margot_plot_create_options(
title = ate_title,
base_defaults = base_defaults_binary,
subtitle = "",
filename_prefix = "grf"
)
# ββ load model results ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
models_binary_cate <- margot::here_read_qs("models_binary_cate", push_mods)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# CRITICAL DECISION POINT: FLIPPED OUTCOMES OR NOT
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# ********** READ THIS CAREFULLY **********
# if you DID NOT flip any outcomes:
# - set: use_flipped <- FALSE
# - this will use models_binary throughout
# - this will use label_mapping_all throughout
#
# if you DID flip outcomes:
# - set: use_flipped <- TRUE
# - ensure models_binary_flipped_all exists
# - ensure flip_outcomes vector exists
# - this will create label_mapping_all_flipped
use_flipped <- FALSE # <- MAKE TRUE IF YOU FLIPPED OUTCOMES, `FALSE` otherwise
# set up variables based on whether outcomes were flipped
if (use_flipped) {
# check that required objects exist
if (!exists("models_binary_flipped_all")) {
models_binary_flipped_all <- here_read_qs("models_binary_flipped_all", push_mods)
}
# try to read flip_outcomes and flipped_names
# use tryCatch to handle missing files gracefully
flip_outcomes <- tryCatch(
here_read("flip_outcomes"),
error = function(e) {
stop("flip_outcomes file not found. Please ensure it exists if use_flipped = TRUE")
}
)
flipped_names <- tryCatch(
here_read("flipped_names"),
error = function(e) {
stop("flipped_names file not found. Please ensure it exists if use_flipped = TRUE")
}
)
# create flipped label mapping
label_mapping_all_flipped <- margot_reversed_labels(label_mapping_all, flip_outcomes)
# use flipped models and labels
models_for_analysis <- models_binary_flipped_all
labels_for_analysis <- label_mapping_all_flipped
flipped_list <- paste(flipped_names, collapse = ", ")
} else {
# use standard models and labels (no flipping)
models_for_analysis <- models_binary_cate
labels_for_analysis <- label_mapping
# ensure flipped_names is a character vector, not a function
# remove any existing flipped_names object that might be a function
if (exists("flipped_names") && is.function(flipped_names)) {
rm(flipped_names)
}
flipped_names <- character(0) # empty character vector
flipped_list <- ""
}
devtools::load_all("/Users/joseph/GIT/margot/")
# ββ average treatment effects (ate) analysis ββββββββββββββββββββββββββββββββββ
ate_results <- margot_plot(
models_binary_cate$combined_table,
options = outcomes_options_all,
label_mapping = label_mapping, # always use standard labels for ate
include_coefficients = FALSE,
order = "evaluebound_asc",
original_df = original_df,
e_val_bound_threshold = 1.2,
rename_ate = TRUE,
adjust = "bonferroni",
alpha = 0.05
)
# check results table:
ate_results$transformed_table
# check interpretation:
cat(ate_results$interpretation)
# check plot data is correct
# plot_data <- ate_results$plot$data
# print(plot_data[plot_data$outcome == "Social Belonging", c("2.5 %", "97.5 %")])
# ββ make nice markdown table -----------------ββββββββββββββββββββββββββββββββββ
margot_bind_tables_markdown <- margot_bind_tables(
ate_results$transformed_table,
sort_E_val_bound = "desc",
e_val_bound_threshold = 1.2,
# β choose threshold
highlight_color = NULL,
bold = TRUE,
rename_cols = TRUE,
col_renames = list("E-Value" = "E_Value", "E-Value bound" = "E_Val_bound"),
rename_ate = TRUE,
threshold_col = "E_Val_bound",
output_format = "markdown",
kbl_args = list(
booktabs = TRUE,
caption = NULL,
align = NULL
)
)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# HETEROGENEITY ANALYSIS
# this section adapts based on whether outcomes were flipped
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# check package version for heterogeneity analysis
stopifnot(utils::packageVersion("margot") >= "1.0.233")
# helper function for printing tables
print_rate <- function(tbl) {
tbl |>
mutate(across(where(is.numeric), \(x) round(x, 2))) |>
kbl(format = "markdown")
}
# ββ 1. screen for heterogeneity (rate autoc + rate qini) βββββββββββββββββββββ
rate_results <- margot_rate(
models = models_for_analysis,
policy = "treat_best",
alpha = 0.20,
adjust = "fdr",
label_mapping = labels_for_analysis
)
# interpret rate results
if (use_flipped) {
rate_interp <- margot_interpret_rate(
rate_results,
flipped_outcomes = flipped_names,
adjust_positives_only = TRUE
)
} else {
rate_interp <- margot_interpret_rate(
rate_results,
flipped_outcomes = NULL, # no flipped outcomes
adjust_positives_only = TRUE
)
}
cat(rate_interp$comparison, "\n")
cli_h2("Analysis ready for Appendix β")
# organise model names by evidence strength
model_groups <- list(
autoc = rate_interp$autoc_model_names,
qini = rate_interp$qini_model_names,
either = rate_interp$either_model_names,
exploratory = rate_interp$not_excluded_either
)
# ββ 2. plot rate autoc curves (if any exist) βββββββββββββββββββββββββββββββββ
if (length(model_groups$autoc) > 0) {
autoc_plots <- margot_plot_rate_batch(
models = models_for_analysis,
label_mapping = labels_for_analysis,
model_names = model_groups$autoc
)
# store first autoc name if it exists
if (nrow(rate_results$rate_autoc) > 0) {
autoc_name_1 <- rate_results$rate_autoc$outcome[[1]]
}
} else {
autoc_plots <- list()
message("no significant rate autoc results found")
}
# ββ 3. qini curves + gain interpretation ββββββββββββββββββββββββββββββββββββββ
# define plot settings (move here so they exist even if not used)
policy_tree_defaults <- list(
point_alpha = 0.5,
title_size = 30,
subtitle_size = 25,
axis_title_size = 25,
legend_title_size = 18,
split_line_color = "red",
split_line_alpha = 0.8,
split_label_color = "red",
split_label_nudge_factor = 0.007
)
decision_tree_defaults <- list(
span_ratio = 0.1,
text_size = 4,
y_padding = 0.5,
edge_label_offset = 0.02,
border_size = 0.01
)
# run initial qini analysis
qini_results <- margot_policy(
models_for_analysis,
policy_tree_args = policy_tree_defaults,
model_names = names(models_for_analysis$results),
original_df = original_df,
label_mapping = labels_for_analysis,
qini_args = list(show_ci = "both"),
max_depth = 1L,
output_objects = c("qini_plot", "diff_gain_summaries")
)
# interpret qini results
qini_gain <- margot_interpret_qini(
qini_results,
label_mapping = labels_for_analysis
)
print_rate(qini_gain$summary_table)
cat(qini_gain$qini_explanation, "\n")
reliable_ids <- qini_gain$reliable_model_ids
# ββ 4. policy trees (only if reliable models exist) ββββββββββββββββββββββββββ
if (length(reliable_ids) > 0) {
# recompute for reliable models only
qini_results_valid <- margot_policy(
models_for_analysis,
decision_tree_args = decision_tree_defaults,
policy_tree_args = policy_tree_defaults,
model_names = reliable_ids,
original_df = original_df,
qini_args = list(show_ci = "both", ylim = c(0,1)),
label_mapping = labels_for_analysis,
max_depth = 1L,
output_objects = c("qini_plot", "diff_gain_summaries")
)
qini_plots <- map(qini_results_valid, ~ .x$qini_plot)
qini_names <- margot_get_labels(reliable_ids, labels_for_analysis)
# compute policy trees
policy_results_2L <- margot_policy(
models_for_analysis,
decision_tree_args = decision_tree_defaults,
policy_tree_args = policy_tree_defaults,
model_names = reliable_ids,
max_depth = 1L,
original_df = original_df,
label_mapping = labels_for_analysis,
output_objects = c("combined_plot")
)
policy_plots <- map(policy_results_2L, ~ .x$combined_plot)
# generate plain language interpretation
policy_text <- margot_interpret_policy_batch(
models = models_for_analysis,
original_df = original_df,
model_names = reliable_ids,
label_mapping = labels_for_analysis,
max_depth = 1L
)
cat(policy_text, "\n")
} else {
qini_plots <- list()
policy_plots <- list()
qini_names <- character(0)
policy_text <- "No reliable heterogeneous treatment effects found."
message("no reliable qini models found - skipping policy tree analysis")
}
cli::cli_h1("Finished: heterogeneity analysis complete β")
# run this code, copy and paste contents into your text
global_vars <- list(
name_exposure_variable = "Belief in God",
n_total = as.character(20,000),
ate_adjustment = "bonferroni",
ate_alpha = "0.1",
cate_adjustment = "bonferroni",
cate_alpha = "0.1",
sample_ratio_policy = "50/50",
n_participants = 20000,
# n_censored = n_censored,
n_iterations = as.character(1000),
stability_threshold = as.character(10),
appendix_outcomes = as.character(2),
# eligibility_criteria = eligibility_criteria,
# exposure_variable = name_exposure,
# name_exposure_lower = name_exposure_lower,
# name_control_regime_lower = name_control_regime_lower,
name_outcome_variables = "Cooperation Outcomes", # <- adjust to your study
# name_outcomes_lower = name_outcomes_lower,
# name_exposure_capfirst = nice_name_exposure,
# measures_exposure = measures_exposure,
# value_exposure_regime = value_exposure_regime,
# value_control_regime = value_control_regime,
flipped_list = flipped_list,
appendix_explain_grf = "4",
appendix_assumptions_grf = "5",
name_exposure_threshold = "1",
name_control_threshold = "0",
appendix_measures = "2",
# value_control = value_control,
# value_exposure = value_exposure,
appendix_positivity = "3",
appendix_rate = "4",
appendix_qini_curve = "4",
appendix_explain_grf = "6",
train_proportion_decision_tree = "50 percent",
training_proportion = "50 percent",
sample_split = "50/50",
sample_ratio_policy = "50/50",
# baseline_wave = baseline_wave,
# exposure_waves = exposure_waves,
# outcome_wave = outcome_wave,
protocol_url = "https://osf.io/ce4t9/", # if used
appendix_timeline = "S1" # if used
)
```
{{< pagebreak >}}
## Introduction
The following briefly walks through the results of ATE and CATE estimation and reporting.
## Method - Example Report
```{r, results='asis'}
#| eval: false # <- set to false, copy, delete, modify, and extend text as needed
cat(boilerplate_generate_text(
category = "template",
sections = c("grf.simple"),
# global_vars = list(
# name_outcomes_lower = name_outcomes_lower,
# name_exposure_lower = name_exposure_lower,
# name_exposure_capfirst = name_exposure_variable
# ),
db = unified_db
))
```
{{< pagebreak >}}
## Results
### Average Treatment Effects
::: {.column-screen}
```{r}
#| label: fig-ate
#| fig-cap: "Average Treatment Effects on Multi-dimensional Wellbeing"
#| eval: true
#| fig-height: 12
#| fig-width: 8
ate_results$plot
```
:::
{{< pagebreak >}}
```{r}
#| label: tbl-outcomes
#| tbl-cap: "Average Treatment Effects on Multi-dimensional Wellbeing"
#| eval: true # - set to false/ copy and change
#| echo: false
#| include: true
margot_bind_tables_markdown
```
```{r, results = 'asis'}
#| eval: true # - set to false/ copy and change
#| echo: false
#| include: true
cat(ate_results$interpretation)
```
{{< pagebreak >}}
### Heterogeneous Treatment Effects {#results-qini-curve}
We begin by examining the distribution of individual treatment effects (Οα΅’) across our sample. @fig-tau-distribution presents the estimated treatment effects for each individual, revealing substantial variability in how people respond to {name_exposure_lower}.
::: {.column-screen}
```{r}
#| label: fig-tau-distribution
#| fig-cap: "Distribution of Individual Treatment Effects (Οα΅’) Across Outcomes"
#| eval: true
#| echo: false
# create tau plots showing individual treatment effect distributions
tau_plots <- margot_plot_tau(
models_for_analysis,
label_mapping = labels_for_analysis
)
# display the plot
tau_plots
```
:::
The histograms above show considerable heterogeneity in treatment effects across individuals in the charitable giving condition. To determine whether this variability is systematic (i.e., predictable based on individual characteristics) rather than random noise, we employ two complementary approaches: Qini curves to assess the reliability of heterogeneous effects, and policy trees to identify subgroups with differential treatment responses.
```{r, results='asis'}
#| eval: true # <- copy and modify text as needed
#| echo: false
# copy and paste into your text
cat(
boilerplate::boilerplate_generate_text(
category = "results",
sections = c("grf.interpretation_qini"),
global_vars = global_vars,
db = unified_db
)
)
```
##### RATE AUTOC RESULTS
```{r, results = 'asis'}
# only show if there are autoc results
if (length(model_groups$autoc) > 0) {
cat(rate_interp$autoc_results)
} else {
cat("No significant RATE AUTOC results were found.")
}
```
```{r}
#| label: fig-rate-autoc
#| fig-cap: "RATE AUTOC Curves"
#| eval: false # <- set to true if you have autoc plots
#| echo: false
#| fig-height: 16
#| fig-width: 12
# display autoc plots if they exist
if (length(autoc_plots) > 0) {
# create blank plot for spacing
blank_plot <- plot_spacer()
# determine grid layout (2 columns preferred)
n_plots <- length(autoc_plots)
n_cols <- 2
n_rows <- ceiling(n_plots / n_cols)
# create list of plots including blank spacers for even grid
plot_list <- autoc_plots
n_blanks_needed <- (n_rows * n_cols) - n_plots
# add blank plots to fill the grid
if (n_blanks_needed > 0) {
for (i in 1:n_blanks_needed) {
plot_list <- append(plot_list, list(blank_plot))
}
}
# combine plots in a grid
combined_autoc <- wrap_plots(
plot_list,
ncol = n_cols,
nrow = n_rows
) +
plot_layout(guides = "collect") +
plot_annotation(
title = "RATE AUTOC Curves for Heterogeneous Effects",
subtitle = paste("Outcomes with significant autocorrelation (n =",
n_plots, ")")
) &
theme(
legend.position = "bottom",
plot.title = element_text(hjust = 0.5),
plot.subtitle = element_text(hjust = 0.5)
)
print(combined_autoc)
} else {
message("no autoc plots to display")
}
```
```{r, results = 'asis'}
#| eval: true # <- set to true and use if you have reliable results
# only use if you have reliable qini results
if (length(reliable_ids) > 0) {
cat(qini_gain$qini_explanation)
} else {
cat("No significant heterogeneous treatment effects were detected using Qini curve analysis.")
}
```
```{r}
#| tbl-cap: "Qini Curve Results"
#| eval: true # <- set to true if you have qini results
# only use if you have multiple qini results
if (length(reliable_ids) > 0) {
knitr::kable(
qini_gain$summary_table |>
mutate(across(where(is.numeric), ~ round(., 2))),
format = "markdown",
caption = "Qini Curve Results"
)
} else {
cat("*Note: Qini curve table only displayed when multiple significant results are found.*")
}
```
::: {.column-screen}
```{r}
#| label: fig-qini-combined
#| fig-cap: "Qini Curves for Heterogeneous Treatment Effects"
#| eval: true # <- set to true if you have qini plots
#| echo: false
#| fig-height: 12
#| fig-width: 12
# only run if you have qini plots
if (length(qini_plots) > 0) {
# create blank plot for spacing
blank_plot <- plot_spacer()
# determine grid layout (2 columns preferred)
n_plots <- length(qini_plots)
n_cols <- 2
n_rows <- ceiling(n_plots / n_cols)
# create list of plots including blank spacers for even grid
plot_list <- qini_plots
n_blanks_needed <- (n_rows * n_cols) - n_plots
# add blank plots to fill the grid
if (n_blanks_needed > 0) {
for (i in 1:n_blanks_needed) {
plot_list <- append(plot_list, list(blank_plot))
}
}
# combine plots in a grid
combined_qini <- wrap_plots(
plot_list,
ncol = n_cols,
nrow = n_rows
) +
plot_layout(guides = "collect") +
plot_annotation(
title = "Qini Curves for Reliable Heterogeneous Effects",
subtitle = paste("Models with significant treatment effect heterogeneity (n =",
n_plots, ")")
) &
theme(
legend.position = "bottom",
plot.title = element_text(hjust = 0.5),
plot.subtitle = element_text(hjust = 0.5)
)
print(combined_qini)
} else {
message("no qini plots to display")
}
```
:::
### Decision Rules (Who is Most Sensitive to Treatment?)
```{r, results='asis'}
#| eval: true # <- set to true and modify text as needed
cat(
boilerplate::boilerplate_generate_text(
category = "results",
sections = c("grf.interpretation_policy_tree"),
global_vars = global_vars,
db = unified_db
)
)
```
The following pages present policy trees for each outcome with reliable heterogeneous effects. Each tree shows: (1) the decision rules for treatment assignment, (2) the distribution of treatment effects across subgroups, and (3) visual representation of how covariates split the population into groups with differential treatment responses.
```{r}
#| label: fig-policy-trees
#| fig-cap: "Policy Trees for Treatment Assignment"
#| eval: true # <- set to true if you have policy trees
#| echo: false
#| fig-height: 12
#| fig-width: 14
# display policy trees if they exist - one per page
if (length(policy_plots) > 0) {
# iterate through each policy tree
for (i in seq_along(policy_plots)) {
# add page break before each plot except the first
if (i > 1) {
cat("\n\n{{< pagebreak >}}\n\n")
}
# create individual caption for each tree
cat(paste0("\n\n#### Policy Tree ", i, ": ", qini_names[[i]], "\n\n"))
# print the policy tree
print(policy_plots[[i]])
# add some space after
cat("\n\n")
}
} else {
message("no policy trees to display")
}
```
```{r, results = 'asis'}
#| eval: true # <- copy and paste text a insert jsust below graph
# use this text below your decision tree graphs
if (length(reliable_ids) > 0) {
cat(policy_text, "\n")
}
```
## Discussion
```{r, results='asis'}
#| eval: false # <- set to true and modify text as needed
#| echo: false
cat(boilerplate_generate_text(
category = "discussion",
sections = c(
"student_ethics",
"student_data",
"student_authors_statement"
),
global_vars = list(
exposure_variable = "belief in God"
),
db = unified_db
))
```
{{< pagebreak >}}
## S1: Estimating and Interpreting Heterogeneous Treatment Effects with GRF {#appendix-explain-grf}
```{r, results='asis'}
#| eval: false # <- set to true as needed
#| echo: false
cat(
boilerplate::boilerplate_generate_text(
category = "appendix",
sections = c("explain.grf_short"),
global_vars = global_vars,
db = unified_db
)
)
```
#### Qini Curves
The Qini curve shows the cumulative **gain** as we expand a targeting rule down the CATE ranking.
* **Beneficial exposure:** we add individuals from the top positive CATEs downward; the baseline is 'expose everyone.'
* **Detrimental exposure:** we first flip outcome direction (so higher values represent **more harm**; see ), then *add* the exposure starting with individuals whose CATEs show the **greated harm**, gradually including those predicted to be more resistant to harm; the baseline is 'expose everyone.' The curve therefore quantifies the harm by when those most suceptible to harm are exposed.
## S2: Strengths and Limitations of Causal Forests {#appendix-strengths}
```{r, results='asis'}
#| eval: true # <- set to true and modify text as needed
cat(
boilerplate::boilerplate_generate_text(
category = "discussion",
sections = c("strengths.strengths_grf_short"),
global_vars = global_vars,
db = unified_db
)
)
```
## References {.appendix-refs}