Stability Analysis for Policy Trees
Source:R/margot_policy_tree_stability.R
margot_policy_tree_stability.Rd
Performs stability analysis of policy trees to assess robustness and generate consensus trees. By default, varies random seeds to create different train/test splits to assess stability. Optionally supports bootstrap resampling for traditional bootstrap analysis. Uses memory-efficient streaming approach to handle large datasets.
Usage
margot_policy_tree_stability(
model_results,
model_names = NULL,
custom_covariates = NULL,
exclude_covariates = NULL,
covariate_mode = c("original", "custom", "add", "all"),
depth = 2,
n_iterations = 300,
vary_type = c("split_only", "bootstrap", "both"),
consensus_threshold = 0.5,
train_proportion = 0.5,
vary_train_proportion = FALSE,
train_proportions = c(0.4, 0.5, 0.6, 0.7),
label_mapping = NULL,
return_consensus_trees = TRUE,
metaseed = 12345,
parallel = FALSE,
n_cores = NULL,
verbose = TRUE,
seed = 12345,
tree_method = c("fastpolicytree", "policytree"),
n_bootstrap = NULL
)
Arguments
- model_results
List returned by margot_causal_forest() or margot_flip_forests()
- model_names
Character vector of model names to analyze. NULL = all models.
- custom_covariates
Character vector of covariate names to use for policy trees. If NULL, uses the original top variables from the model.
- exclude_covariates
Character vector of covariate names or patterns to exclude. Supports exact matches and regex patterns (e.g., "_log" excludes all variables containing "_log").
- covariate_mode
Character string specifying how to handle covariates: "original" (use original top variables), "custom" (use only custom_covariates), "add" (add custom to existing), "all" (use all available covariates).
- depth
Numeric or character specifying which depth(s) to compute: 1 for single split, 2 for two splits (default), or "both" for both depths.
- n_iterations
Integer. Number of stability iterations (default 300).
- vary_type
Character. Type of variation: "split_only" (vary train/test split via seeds), "bootstrap" (bootstrap resample), "both" (resample + split). Default is "split_only".
- consensus_threshold
Numeric. Minimum inclusion frequency for consensus (default 0.5).
- train_proportion
Numeric. Train/test split when vary_train_proportion = FALSE (default 0.5).
- vary_train_proportion
Logical. Whether to vary train proportion (default FALSE).
- train_proportions
Numeric vector. Proportions to cycle through when vary_train_proportion = TRUE (default c(0.4, 0.5, 0.6, 0.7)).
- label_mapping
Named character vector for converting variable names to readable labels.
- return_consensus_trees
Logical. Return fitted consensus trees (default TRUE).
- metaseed
Integer. Master seed for reproducibility (default 12345).
- parallel
Logical. Use parallel processing (default FALSE).
- n_cores
Integer. Number of cores for parallel processing.
- verbose
Logical. Print progress messages (default TRUE).
- seed
Integer. Additional seed parameter for compatibility (default 12345).
- tree_method
Character string specifying the package to use: "fastpolicytree" (default) or "policytree". The fastpolicytree package provides ~10x faster computation, which is particularly beneficial for stability analysis. Falls back to policytree if fastpolicytree is not installed.
- n_bootstrap
Deprecated. Use n_iterations instead.
Value
Object of class "margot_stability_policy_tree" containing:
results: List with consensus trees and stability metrics per model
summary_metrics: Variable importance and convergence diagnostics
metadata: Analysis parameters and seeds used
Details
The function uses a memory-efficient approach:
Processes one tree at a time
Extracts only essential split information
Accumulates statistics without storing all trees
Reconstructs single consensus trees for compatibility
By default, the function varies random seeds to create different train/test splits for each iteration. This assesses the stability of the policy tree structure without the computational overhead and statistical assumptions of bootstrap resampling. True bootstrap resampling can be enabled with vary_type = "bootstrap".
Theoretical Background
Policy trees inherit the instability of decision trees, where small changes in the data can lead to completely different tree structures (Breiman, 1996). This instability is particularly pronounced when predictors are correlated, as the tree can arbitrarily choose between similar variables at split points. Athey and Wager's (2021) policy learning framework acknowledges these challenges while providing methods to extract robust insights despite the instability.
The stability analysis helps distinguish between:
Fundamental instability due to weak or absent treatment effect heterogeneity
Apparent instability due to correlated predictors that capture similar information
Robust patterns that emerge consistently across different data samples
Important: While stability is desirable, research shows that depth-2 trees typically outperform both uniform treatment assignment and more stable depth-1 trees, even when the depth-2 trees show high variability. The goal is to understand and quantify instability, not necessarily to eliminate it.
Use the companion functions `margot_assess_variable_correlation()` and `margot_stability_diagnostics()` to better understand the sources of instability.
Three types of variation are supported:
"split_only": Fixed sample, only varies train/test split (default)
"bootstrap": Bootstrap resampling with replacement
"both": Varies both bootstrap sampling and train/test splits
Complete Workflow Example
# 1. Run causal forest (save data for correlation analysis)
cf_results <- margot_causal_forest(
data = your_data,
outcome_vars = c("outcome1", "outcome2"),
save_data = TRUE # Important for correlation analysis
)
# 2. Run stability analysis to assess robustness
stability_results <- margot_policy_tree_stability(
cf_results,
n_iterations = 300,
tree_method = "fastpolicytree" # 10x faster if available
)
# 3. Check variable correlations
cor_analysis <- margot_assess_variable_correlation(
cf_results, # Use original results, NOT stability_results
"model_outcome1"
)
# 4. Identify clusters of correlated variables
clusters <- margot_identify_variable_clusters(cor_analysis)
# 5. Get comprehensive diagnostics
diagnostics <- margot_stability_diagnostics(
stability_results = stability_results,
model_results = cf_results,
model_name = "model_outcome1"
)
# 6. Interpret results
interpretation <- margot_interpret_stability(
stability_results,
"model_outcome1",
include_theory = TRUE # Include theoretical context
)
References
Athey, S., & Wager, S. (2021). Policy learning with observational data. Econometrica, 89(1), 133-161.
Breiman, L. (1996). Bagging predictors. Machine Learning, 24(2), 123-140.
Zhou, Z., Athey, S., & Wager, S. (2023). Offline multi-action policy learning: Generalization and optimization. Operations Research, 71(1), 148-183.
See also
margot_policy_tree
for computing policy trees without stability analysis
margot_assess_variable_correlation
for correlation analysis
margot_stability_diagnostics
for comprehensive diagnostics
Examples
if (FALSE) { # \dontrun{
# Basic stability analysis with fixed train proportion
stability_results <- margot_policy_tree_stability(
causal_forest_results,
n_iterations = 300
)
# Vary train proportion with default values
stability_results <- margot_policy_tree_stability(
causal_forest_results,
vary_train_proportion = TRUE
)
# Custom train proportions
stability_results <- margot_policy_tree_stability(
causal_forest_results,
vary_train_proportion = TRUE,
train_proportions = c(0.3, 0.5, 0.7)
)
# Use bootstrap resampling instead of seed variation
stability_results <- margot_policy_tree_stability(
causal_forest_results,
vary_type = "bootstrap",
n_iterations = 300
)
# Plot consensus tree
margot_plot_policy_tree(stability_results, "model_anxiety")
# Get stability summary
summary(stability_results)
# Interpret results with theoretical context
interpretation <- margot_interpret_stability(
stability_results,
"model_anxiety",
format = "text"
)
# Assess variable correlations (using original causal forest results)
cor_analysis <- margot_assess_variable_correlation(
causal_forest_results, # NOT stability_results
"model_anxiety"
)
# Identify variable clusters
clusters <- margot_identify_variable_clusters(cor_analysis)
# Run comprehensive stability diagnostics
diagnostics <- margot_stability_diagnostics(
stability_results = stability_results,
model_results = causal_forest_results,
model_name = "model_anxiety"
)
} # }