Skip to contents

This function creates a visualization of policy tree results from a causal forest or multi-arm causal forest model. It generates two plots showing the relationships between the top three split variables, with points colored by the predicted optimal treatment. The function now includes label transformation options and provides informative CLI messages throughout the process.

Usage

margot_plot_policy_tree(
  mc_test,
  model_name,
  original_df = NULL,
  color_scale = NULL,
  point_alpha = 0.5,
  theme_function = theme_classic,
  title_size = 14,
  subtitle_size = 11,
  axis_title_size = 10,
  legend_title_size = 10,
  jitter_width = 0.3,
  jitter_height = 0.3,
  split_line_color = "darkgray",
  split_line_alpha = 0.7,
  split_line_type = "dashed",
  split_line_linewidth = 0.5,
  split_label_size = 10,
  split_label_color = "darkgray",
  custom_action_names = NULL,
  legend_position = "bottom",
  plot_selection = "both",
  remove_tx_prefix = TRUE,
  remove_z_suffix = TRUE,
  use_title_case = TRUE,
  remove_underscores = TRUE,
  label_mapping = NULL
)

Arguments

mc_test

A list containing the results from a multi-arm causal forest model.

model_name

A string specifying which model's results to plot.

original_df

Optional dataframe with untransformed variables, used to display split values on the data scale.

color_scale

An optional ggplot2 color scale. If NULL (default), a color scale based on the Okabe-Ito palette is created, with blue for the first level, orange for the second, and subsequent colors for additional levels.

point_alpha

Numeric value between 0 and 1 for point transparency. Default is 0.5.

theme_function

A ggplot2 theme function. Default is theme_classic.

title_size

Numeric value for the size of the plot title. Default is 14.

subtitle_size

Numeric value for the size of the plot subtitle. Default is 11.

axis_title_size

Numeric value for the size of axis titles. Default is 10.

legend_title_size

Numeric value for the size of the legend title. Default is 10.

jitter_width

Numeric value for the amount of horizontal jitter. Default is 0.3.

jitter_height

Numeric value for the amount of vertical jitter. Default is 0.3.

split_line_color

Color of the split lines. Default is "darkgray".

split_line_alpha

Numeric value between 0 and 1 for split line transparency. Default is 0.7.

split_line_type

Line type for split lines. Default is "dashed".

split_line_linewidth

Numeric value for the thickness of split lines. Default is 0.5.

split_label_size

Numeric value for the size of split value labels. Default is 10.

split_label_color

Color of split value labels. Default is "darkgray".

custom_action_names

Optional vector of custom names for the actions. Must match the number of actions in the policy tree.

legend_position

String specifying the position of the legend. Can be "top", "bottom", "left", or "right". Default is "bottom".

plot_selection

String specifying which plots to display: "both", "p1", or "p2". Default is "both".

remove_tx_prefix

Logical value indicating whether to remove the "tx_" prefix from labels. Default is TRUE.

remove_z_suffix

Logical value indicating whether to remove the "_z" suffix from labels. Default is TRUE.

use_title_case

Logical value indicating whether to convert labels to title case. Default is TRUE.

remove_underscores

Logical value indicating whether to remove underscores from labels. Default is TRUE.

label_mapping

Optional named list for custom label mappings. Keys should be original variable names (with or without "model_" prefix), and values should be the desired display labels. Default is NULL.

Value

A ggplot object containing the specified plot(s) of the policy tree results.

Examples

if (FALSE) { # \dontrun{
# Default (both plots, legend at bottom)
plot <- margot_plot_policy_tree(mc_test, "model_t2_belong_z")

# Only the first plot (p1)
plot <- margot_plot_policy_tree(mc_test, "model_t2_belong_z", plot_selection = "p1")

# Both plots with legend on the right
plot <- margot_plot_policy_tree(mc_test, "model_t2_belong_z", legend_position = "right")

# Custom color scale
custom_scale <- scale_colour_manual(values = c("red", "green", "blue"))
plot <- margot_plot_policy_tree(mc_test, "model_t2_belong_z", color_scale = custom_scale)

# Customize label transformations
plot <- margot_plot_policy_tree(mc_test, "model_t2_belong_z",
                                remove_tx_prefix = FALSE,
                                remove_z_suffix = FALSE,
                                use_title_case = FALSE,
                                remove_underscores = FALSE)

# Use custom label mapping and original_df for unstandardized values
label_mapping <- list(
  "t2_env_not_env_efficacy_z" = "Deny Personal Environmental Efficacy",
  "t2_env_not_climate_chg_real_z" = "Deny Climate Change Real"
)
plot <- margot_plot_policy_tree(mc_test, "model_t2_env_not_climate_chg_concern_z",
                                label_mapping = label_mapping,
                                original_df = original_df)
} # }