AE 13: Predicting legislative policy attention

Suggested answers

Application exercise
Answers
Modified

October 18, 2024

Load packages

# load packages
library(tidyverse)
library(tidymodels)
library(arrow)
library(textdata)
library(textrecipes)
library(here)
library(themis)
library(tictoc)
library(vip)
library(tidytext)
library(hardhat)
library(glmnet)

# use parallel processing
library(future)
plan(multisession)

# preferred theme
theme_set(theme_minimal(base_size = 12, base_family = "Atkinson Hyperlegible"))

Import data

leg <- read_parquet(file = "data/legislation.parquet")

Split data for model fitting

# split into training and testing
set.seed(852)

# sample only 20000 observations
leg_samp <- slice_sample(.data = leg, n = 2e04)

# split into training and testing sets
leg_split <- initial_split(data = leg_samp, strata = policy_lab, prop = 0.9)

leg_train <- training(x = leg_split)
leg_test <- testing(x = leg_split)

# create 5-fold cross-validation sets
leg_folds <- vfold_cv(data = leg_train, v = 5)

Null model

Demonstration: Estimate a null model to serve as a baseline for evaluating performance. How does it perform?

# estimate a null model
null_spec <- null_model() |>
  set_mode("classification") |>
  set_engine("parsnip")

null_rs <- workflow() |>
  add_formula(policy_lab ~ id) |>
  add_model(null_spec) |>
  fit_resamples(
    resamples = leg_folds
  )

collect_metrics(null_rs)
# A tibble: 3 × 6
  .metric     .estimator  mean     n  std_err .config             
  <chr>       <chr>      <dbl> <int>    <dbl> <chr>               
1 accuracy    multiclass 0.129     5 0.00306  Preprocessor1_Model1
2 brier_class multiclass 0.467     5 0.000162 Preprocessor1_Model1
3 roc_auc     hand_till  0.5       5 0        Preprocessor1_Model1

Add response here. Poorly. Baseline accuracy is quite low and the Brier score is high. This is due to the imbalanced nature of the policy labels and the fact that there are 20 distinct labels. This is a hard problem to solve.

Lasso regression model

Estimate a simple lasso regression model using a bag-of-words representation of the legislative descriptions.

Define recipe

Your turn: Define a feature engineering recipe that uses the description to predict policy_lab. The recipe should:

  • Tokenize the description column.
  • Remove stopwords.
  • Keep tokens that appear more than 500 times in the corpus.
  • Convert the tokens to a term frequency-inverse document frequency (TF-IDF) representation.
  • Downsample using policy_lab to balance the classes.
Note

Step functions for text features are found in the {textrecipes} package.

glmnet_rec <- recipe(policy_lab ~ description, data = leg_train) |>
  # tokenize and prep text
  step_tokenize(description) |>
  step_stopwords(description) |>
  step_tokenfilter(description, max_tokens = 500) |>
  step_tfidf(description) |>
  step_downsample(policy_lab)
glmnet_rec
glmnet_rec |>
  prep() |>
  bake(new_data = NULL)
# A tibble: 4,960 × 501
   policy_lab     tfidf_description_1 tfidf_description_10 tfidf_description_18
   <fct>                        <dbl>                <dbl>                <dbl>
 1 Macroeconomics                   0                    0                    0
 2 Macroeconomics                   0                    0                    0
 3 Macroeconomics                   0                    0                    0
 4 Macroeconomics                   0                    0                    0
 5 Macroeconomics                   0                    0                    0
 6 Macroeconomics                   0                    0                    0
 7 Macroeconomics                   0                    0                    0
 8 Macroeconomics                   0                    0                    0
 9 Macroeconomics                   0                    0                    0
10 Macroeconomics                   0                    0                    0
# ℹ 4,950 more rows
# ℹ 497 more variables: tfidf_description_1930 <dbl>,
#   tfidf_description_1934 <dbl>, tfidf_description_1938 <dbl>,
#   tfidf_description_1949 <dbl>, tfidf_description_1954 <dbl>,
#   tfidf_description_1958 <dbl>, tfidf_description_1965 <dbl>,
#   tfidf_description_1970 <dbl>, tfidf_description_1974 <dbl>,
#   tfidf_description_1986 <dbl>, tfidf_description_2 <dbl>, …

Model specification

Demonstration: Specify a lasso regression model tuned over the penalty parameter.

# penalized logistic regression, tune over penalty - keep mixture = 1 (lasso regression)
glmnet_spec <- multinom_reg(penalty = tune(), mixture = 1) |>
  set_mode("classification") |>
  set_engine("glmnet")

Define workflow

Demonstration: Create the workflow for the model.

Sparse encoding

Recall that many of the cells in the prepared data frame contain 0s (e.g. token not used in a specific document). Regularized regression models with text features powered using set_engine("glmnet") can be more efficiently fit if we transform the data to a sparse matrix. This is done by specifying a non-default preprocessing blueprint using the {hardhat} package.

For more information, see this case study from SMLTAR.

# use sparse blueprint for more efficient model estimation
library(hardhat)
sparse_bp <- default_recipe_blueprint(composition = "dgCMatrix")

# define workflow
glmnet_wf <- workflow() |>
  add_recipe(glmnet_rec, blueprint = sparse_bp) |>
  add_model(glmnet_spec)

Tune the model

Demonstration: Create a tuning grid for the penalty parameter and tune the model.

# define tuning grid
penalty_grid <- grid_regular(penalty(range = c(-5, 0)), levels = 30)
# tune the model
set.seed(123)

tic()
glmnet_tune_rs <- tune_grid(
  object = glmnet_wf,
  resamples = leg_folds,
  grid = penalty_grid,
  control = control_grid(save_pred = TRUE, save_workflow = TRUE)
)
toc()
17.853 sec elapsed

Your turn: Examine the performance of the model, both overall and using the confusion matrix. Which misclassifications are most common? Why might the model have a hard time discriminating between these policy labels?

# view average metrics
autoplot(glmnet_tune_rs)

# identify best models based on assessment set
show_best(x = glmnet_tune_rs, metric = "roc_auc")
# A tibble: 5 × 7
  penalty .metric .estimator  mean     n std_err .config              
    <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
1 0.00259 roc_auc hand_till  0.911     5 0.00211 Preprocessor1_Model15
2 0.00386 roc_auc hand_till  0.911     5 0.00217 Preprocessor1_Model16
3 0.00174 roc_auc hand_till  0.910     5 0.00213 Preprocessor1_Model14
4 0.00574 roc_auc hand_till  0.907     5 0.00205 Preprocessor1_Model17
5 0.00117 roc_auc hand_till  0.906     5 0.00221 Preprocessor1_Model13
show_best(x = glmnet_tune_rs, metric = "accuracy")
# A tibble: 5 × 7
  penalty .metric  .estimator  mean     n std_err .config              
    <dbl> <chr>    <chr>      <dbl> <int>   <dbl> <chr>                
1 0.00386 accuracy multiclass 0.603     5 0.00529 Preprocessor1_Model16
2 0.00259 accuracy multiclass 0.601     5 0.00479 Preprocessor1_Model15
3 0.00574 accuracy multiclass 0.596     5 0.00373 Preprocessor1_Model17
4 0.00174 accuracy multiclass 0.594     5 0.00480 Preprocessor1_Model14
5 0.00853 accuracy multiclass 0.589     5 0.00311 Preprocessor1_Model18
# confusion matrix for best penalty value
conf_mat_resampled(
  x = glmnet_tune_rs,
  parameters = select_best(glmnet_tune_rs, metric = "roc_auc"),
  tidy = FALSE
) |>
  autoplot(type = "heatmap")

# most frequent misclassifications (average)
conf_mat_resampled(
  x = glmnet_tune_rs,
  parameters = select_best(glmnet_tune_rs, metric = "roc_auc")
) |>
  # filter out correct predictions and sort by frequency
  filter(Prediction != Truth) |>
  arrange(-Freq)
# A tibble: 380 × 3
   Prediction                                     Truth                     Freq
   <fct>                                          <fct>                    <dbl>
 1 Civil rights, minority issues, civil liberties Government operations     25.4
 2 Macroeconomics                                 Government operations     23.6
 3 Macroeconomics                                 Banking, finance, and d…  20.8
 4 International affairs and foreign aid          Government operations     19.4
 5 Labor and employment                           Government operations     19.2
 6 Defense                                        Government operations     18.8
 7 Public lands and water management              Government operations     18.2
 8 Environment                                    Public lands and water …  17.8
 9 Law, crime, family issues                      Government operations     17.8
10 International affairs and foreign aid          Defense                   16.6
# ℹ 370 more rows

Add response here. Overall the model has a high ROC AUC, but still a low accuracy. This is likely due to the number of possible policy labels and the complexity of the relationships between the descriptions and the labels. Many of the common misclassifications feel like reasonable mistakes even if a human being was coding the documents. For example, International affairs and Defense probably overlap in terms of common tokens and subject-matter.

Additionally, the most common misclassifications are for bills that are actually Government operations. Since it is the most type of legislation, there are simply more opportunities to incorrectly label them.

Variable importance

Your turn: Calculate which tokens are used by the model to predict each of the policy labels. Visualize the top 10 tokens for each policy label. Do these tokens make sense as useful predictors?

# fit the best model
glmnet_best <- glmnet_tune_rs |>
  fit_best()
# multiclass classification - need to manually extract coefficients for each
# outcome class
lasso_vip <- coef(extract_fit_engine(glmnet_best),
  s = select_best(x = glmnet_tune_rs, metric = "roc_auc")$penalty
) |>
  # need first element from each sublist
  map(\(x) x[, 1L]) |>
  # convert to a data frame and extract the relevant pieces of information
  enframe() |>
  unnest_longer(value) |>
  rename(
    token = value_id,
    importance = value,
    class = name
  ) |>
  # ignore the intercept
  filter(token != "(Intercept)") |>
  # clean up data to focus purely on magnitude of coefficients
  mutate(
    sign = ifelse(sign(importance) == 1, "POS", "NEG"),
    importance = abs(importance),
    token = str_remove_all(token, "tfidf_description_")
  ) |>
  # remove anything with importance of 0
  filter(importance != 0)

# visualize the top 10 tokens for each class
lasso_vip |>
  # keep the top 10 coefficients for each
  filter(sign == "POS") |>
  slice_max(n = 10, order_by = importance, by = class) |>
  # change order of token levels for plotting
  mutate(token = reorder_within(token, by = importance, within = class)) |>
  ggplot(mapping = aes(
    x = importance,
    y = token
  )) +
  geom_col() +
  scale_y_reordered() +
  facet_wrap(
    facets = vars(class),
    scales = "free",
    ncol = 4,
    labeller = labeller(class = label_wrap_gen(20))
  ) +
  labs(
    title = "Most important tokens for each policy class",
    x = "Importance",
    y = NULL
  )

Add response here. A lot of these are common words or terms one would associate with the substantive policy area. However many are numbers or other seemingly generic terms. This is where domain experience and knowing your data is quite helpful. Many of these numbers are likely part of references to different prior legislative acts or U.S. Code. Allowing for longer \(n\)-grams might allows us to use these in a semantically valid manner.

\(n\)-grams

This model will be similar to the first one, but we will now utilize \(n\)-grams to capture more complex relationships between words.

Define recipe

Your turn: Modify your previous feature engineering recipe to include \(n\)-grams, for \(n \in 1, 2, 3, 4\).

ngram_rec <- recipe(policy_lab ~ description, data = leg_train) |>
  # tokenize and prep text
  step_tokenize(description) |>
  step_stopwords(description) |>
  step_ngram(description, num_tokens = 4L, min_num_tokens = 1L) |>
  step_tokenfilter(description, max_tokens = 1000) |>
  step_tfidf(description) |>
  step_downsample(policy_lab)
ngram_rec
ngram_rec |>
  prep() |>
  bake(new_data = NULL)
# A tibble: 4,960 × 1,001
   policy_lab    tfidf_description_1 tfidf_description_10 tfidf_description_10…¹
   <fct>                       <dbl>                <dbl>                  <dbl>
 1 Macroeconomi…                   0                    0                      0
 2 Macroeconomi…                   0                    0                      0
 3 Macroeconomi…                   0                    0                      0
 4 Macroeconomi…                   0                    0                      0
 5 Macroeconomi…                   0                    0                      0
 6 Macroeconomi…                   0                    0                      0
 7 Macroeconomi…                   0                    0                      0
 8 Macroeconomi…                   0                    0                      0
 9 Macroeconomi…                   0                    0                      0
10 Macroeconomi…                   0                    0                      0
# ℹ 4,950 more rows
# ℹ abbreviated name: ¹​tfidf_description_10_united
# ℹ 997 more variables: tfidf_description_10_united_states <dbl>,
#   tfidf_description_10_united_states_code <dbl>, tfidf_description_15 <dbl>,
#   tfidf_description_18 <dbl>, tfidf_description_18_united <dbl>,
#   tfidf_description_18_united_states <dbl>,
#   tfidf_description_18_united_states_code <dbl>, …

Define workflow

Demonstration: Create the workflow for the model, reusing the previous model specification and a sparse blueprint.

# define workflow
ngram_wf <- workflow() |>
  add_recipe(ngram_rec, blueprint = sparse_bp) |>
  add_model(glmnet_spec)

Tune the model

Demonstration: Tune the model.

# tune the model
set.seed(123)

tic()
ngram_tune_rs <- tune_grid(
  object = ngram_wf,
  resamples = leg_folds,
  grid = penalty_grid,
  control = control_grid(save_pred = TRUE, save_workflow = TRUE)
)
toc()
20.829 sec elapsed

Your turn: Examine the performance of the model, both overall and using the confusion matrix. Which misclassifications are most common? Why might the model have a hard time discriminating between these policy labels?

# view average metrics
autoplot(ngram_tune_rs)

# identify best models based on assessment set
show_best(x = ngram_tune_rs, metric = "roc_auc")
# A tibble: 5 × 7
  penalty .metric .estimator  mean     n std_err .config              
    <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
1 0.00386 roc_auc hand_till  0.917     5 0.00230 Preprocessor1_Model16
2 0.00574 roc_auc hand_till  0.917     5 0.00244 Preprocessor1_Model17
3 0.00259 roc_auc hand_till  0.916     5 0.00217 Preprocessor1_Model15
4 0.00853 roc_auc hand_till  0.915     5 0.00222 Preprocessor1_Model18
5 0.00174 roc_auc hand_till  0.913     5 0.00225 Preprocessor1_Model14
show_best(x = ngram_tune_rs, metric = "accuracy")
# A tibble: 5 × 7
  penalty .metric  .estimator  mean     n std_err .config              
    <dbl> <chr>    <chr>      <dbl> <int>   <dbl> <chr>                
1 0.00386 accuracy multiclass 0.627     5 0.00498 Preprocessor1_Model16
2 0.00259 accuracy multiclass 0.626     5 0.00500 Preprocessor1_Model15
3 0.00574 accuracy multiclass 0.624     5 0.00467 Preprocessor1_Model17
4 0.00174 accuracy multiclass 0.618     5 0.00415 Preprocessor1_Model14
5 0.00853 accuracy multiclass 0.618     5 0.00424 Preprocessor1_Model18
# confusion matrix for best penalty value
conf_mat_resampled(
  x = ngram_tune_rs,
  parameters = select_best(ngram_tune_rs, metric = "roc_auc"),
  tidy = FALSE
) |>
  autoplot(type = "heatmap")

# most frequent misclassifications (average)
conf_mat_resampled(
  x = ngram_tune_rs,
  parameters = select_best(glmnet_tune_rs, metric = "roc_auc")
) |>
  # filter out correct predictions and sort by frequency
  filter(Prediction != Truth) |>
  arrange(-Freq)
# A tibble: 380 × 3
   Prediction                                     Truth                     Freq
   <fct>                                          <fct>                    <dbl>
 1 Macroeconomics                                 Government operations     24.4
 2 Macroeconomics                                 Banking, finance, and d…  24  
 3 Civil rights, minority issues, civil liberties Government operations     22.8
 4 Law, crime, family issues                      Government operations     19  
 5 Environment                                    Public lands and water …  17  
 6 International affairs and foreign aid          Defense                   16.2
 7 Public lands and water management              Government operations     15  
 8 International affairs and foreign aid          Government operations     14.8
 9 Defense                                        Government operations     14.6
10 Labor and employment                           Government operations     14.4
# ℹ 370 more rows

Add response here. Not a significant improvement over the unigram model. It still appears to make many of the same types of mistakes.

Variable importance

Your turn: Calculate which tokens are used by the model to predict each of the policy labels. Visualize the top 10 tokens for each policy label. Do these tokens make sense as useful predictors?

# fit the best model
ngram_best <- ngram_tune_rs |>
  fit_best()
# multiclass classification - need to manually extract coefficients for each
# outcome class
ngram_vip <- coef(extract_fit_engine(ngram_best),
  s = select_best(x = ngram_tune_rs, metric = "roc_auc")$penalty
) |>
  # need first element from each sublist
  map(\(x) x[, 1L]) |>
  # convert to a data frame and extract the relevant pieces of information
  enframe() |>
  unnest_longer(value) |>
  rename(
    token = value_id,
    importance = value,
    class = name
  ) |>
  # ignore the intercept
  filter(token != "(Intercept)") |>
  # clean up data to focus purely on magnitude of coefficients
  mutate(
    sign = ifelse(sign(importance) == 1, "POS", "NEG"),
    importance = abs(importance),
    token = str_remove_all(token, "tfidf_description_")
  ) |>
  # remove anything with importance of 0
  filter(importance != 0)

# visualize the top 10 tokens for each class
ngram_vip |>
  # keep the top 10 coefficients for each
  filter(sign == "POS") |>
  slice_max(n = 10, order_by = importance, by = class) |>
  # change order of token levels for plotting
  mutate(token = reorder_within(token, by = importance, within = class)) |>
  ggplot(mapping = aes(
    x = importance,
    y = token
  )) +
  geom_col() +
  scale_y_reordered() +
  facet_wrap(
    facets = vars(class),
    scales = "free",
    ncol = 4,
    labeller = labeller(class = label_wrap_gen(20))
  ) +
  labs(
    title = "Most important tokens for each policy class",
    x = "Importance",
    y = NULL
  )

Add response here. These are improvements over the unigram model. Now we see extended phrases such as “Title 38”, “28 United States” (most likely referring to “28 United States Code”), and “labor standards”. These are more likely to be unique to the policy areas and thus more useful for prediction.

Word embeddings

Finally we will fit a lasso regression model using word embeddings extracted from the GLoVE model to represent each description.

Define recipe

Your turn: Import the GLoVE 6b embeddings for 100 dimensions. Define a recipe that uses these embeddings to predict policy_lab.

##### uncomment if you are running RStudio on your personal computer
# extract 100 dimensions from GLoVE
glove6b <- embedding_glove6b(dimensions = 100)

# ##### uncomment if you are running RStudio on the Workbench
# # hacky way to make it work on RStudio Workbench
# glove6b <- read_delim(
#     file = "/rstudio-files/glove6b/glove.6B.100d.txt",
#     delim = " ",
#     quote = "",
#     col_names = c(
#       "token",
#       paste0("d", seq_len(100))
#     ),
#     col_types = paste0(
#       c(
#         "c",
#         rep("d", 100)
#       ),
#       collapse = ""
#     )
#   )
# initialize recipe
embed_rec <- recipe(policy_lab ~ description, data = leg_train) |>
  # tokenize
  step_tokenize(description) |>
  # convert to word embeddings
  step_word_embeddings(description, embeddings = glove6b) |>
  # normalize the columns
  step_zv(all_predictors()) |>
  step_normalize(all_predictors()) |>
  # downsample to keep same number of rows for each policy focus
  step_downsample(policy_lab)
embed_rec |>
  prep() |>
  bake(new_data = NULL)
# A tibble: 4,960 × 101
   policy_lab     wordembed_description_d1 wordembed_description_d2
   <fct>                             <dbl>                    <dbl>
 1 Macroeconomics                    0.727                   -0.458
 2 Macroeconomics                   -0.497                   -0.685
 3 Macroeconomics                   -0.409                   -0.832
 4 Macroeconomics                   -0.461                   -0.325
 5 Macroeconomics                    0.605                    1.13 
 6 Macroeconomics                   -0.486                   -0.882
 7 Macroeconomics                   -1.47                     0.234
 8 Macroeconomics                    0.679                   -0.283
 9 Macroeconomics                   -0.213                   -0.702
10 Macroeconomics                   -0.381                   -0.185
# ℹ 4,950 more rows
# ℹ 98 more variables: wordembed_description_d3 <dbl>,
#   wordembed_description_d4 <dbl>, wordembed_description_d5 <dbl>,
#   wordembed_description_d6 <dbl>, wordembed_description_d7 <dbl>,
#   wordembed_description_d8 <dbl>, wordembed_description_d9 <dbl>,
#   wordembed_description_d10 <dbl>, wordembed_description_d11 <dbl>,
#   wordembed_description_d12 <dbl>, wordembed_description_d13 <dbl>, …

Tune the penalized regression model

Demonstration: Tune the model.

Note

Since this model uses word embeddings, we will not use the sparse blueprint for the recipe. The embeddings are already in a dense format.

# define workflow
embed_wf <- workflow() |>
  add_recipe(embed_rec) |>
  add_model(glmnet_spec)

tic()
embed_tune_rs <- tune_grid(
  object = embed_wf,
  resamples = leg_folds,
  grid = penalty_grid,
  control = control_grid(save_pred = TRUE, save_workflow = TRUE)
)
toc()
29.58 sec elapsed

Your turn: Examine the performance of the model, both overall and using the confusion matrix. Which misclassifications are most common? Why might the model have a hard time discriminating between these policy labels?

# view average metrics
autoplot(embed_tune_rs)

# identify best models based on assessment set
show_best(x = embed_tune_rs, metric = "roc_auc")
# A tibble: 5 × 7
   penalty .metric .estimator  mean     n  std_err .config              
     <dbl> <chr>   <chr>      <dbl> <int>    <dbl> <chr>                
1 0.000788 roc_auc hand_till  0.943     5 0.000581 Preprocessor1_Model12
2 0.000530 roc_auc hand_till  0.943     5 0.000371 Preprocessor1_Model11
3 0.00117  roc_auc hand_till  0.942     5 0.000820 Preprocessor1_Model13
4 0.000356 roc_auc hand_till  0.942     5 0.000298 Preprocessor1_Model10
5 0.000240 roc_auc hand_till  0.941     5 0.000274 Preprocessor1_Model09
show_best(x = embed_tune_rs, metric = "accuracy")
# A tibble: 5 × 7
   penalty .metric  .estimator  mean     n std_err .config              
     <dbl> <chr>    <chr>      <dbl> <int>   <dbl> <chr>                
1 0.000530 accuracy multiclass 0.608     5 0.00314 Preprocessor1_Model11
2 0.000356 accuracy multiclass 0.608     5 0.00342 Preprocessor1_Model10
3 0.000240 accuracy multiclass 0.608     5 0.00345 Preprocessor1_Model09
4 0.000788 accuracy multiclass 0.605     5 0.00195 Preprocessor1_Model12
5 0.000161 accuracy multiclass 0.603     5 0.00308 Preprocessor1_Model08
# confusion matrix for best penalty value
conf_mat_resampled(
  x = embed_tune_rs,
  parameters = select_best(embed_tune_rs, metric = "roc_auc"),
  tidy = FALSE
) |>
  autoplot(type = "heatmap")

# most frequent misclassifications (average)
conf_mat_resampled(
  x = embed_tune_rs,
  parameters = select_best(glmnet_tune_rs, metric = "roc_auc")
) |>
  # filter out correct predictions and sort by frequency
  filter(Prediction != Truth) |>
  arrange(-Freq)
# A tibble: 380 × 3
   Prediction                                     Truth                     Freq
   <fct>                                          <fct>                    <dbl>
 1 Labor and employment                           Government operations     36.4
 2 Civil rights, minority issues, civil liberties Government operations     36  
 3 Defense                                        Government operations     29.2
 4 Environment                                    Public lands and water …  28.4
 5 Law, crime, family issues                      Government operations     24  
 6 Macroeconomics                                 Banking, finance, and d…  23.6
 7 International affairs and foreign aid          Defense                   21.6
 8 Macroeconomics                                 Government operations     21.4
 9 Banking, finance, and domestic commerce        Government operations     20.8
10 Civil rights, minority issues, civil liberties Law, crime, family issu…  20.4
# ℹ 370 more rows

Add response here. It took a bit longer to train this model even though it has fewer features, most likely because of its dense structure hence we could not take advantage of structuring the data as a sparse matrix. We see a \(0.03\) increase in ROC AUC compared to the first two models. Still similar types of mistakes compared to the earlier models.

Variable importance

Your turn: Calculate which dimensions are used by the model to predict each of the policy labels. Visualize the top 10 dimensions for each policy label. How useful is this analysis?

# fit the best model
embed_best <- embed_tune_rs |>
  fit_best()
# multiclass classification - need to manually extract coefficients for each
# outcome class
embed_vip <- coef(extract_fit_engine(embed_best),
  s = select_best(x = embed_tune_rs, metric = "roc_auc")$penalty
) |>
  # need first element from each sublist
  map(\(x) x[, 1L]) |>
  # convert to a data frame and extract the relevant pieces of information
  enframe() |>
  unnest_longer(value) |>
  rename(
    token = value_id,
    importance = value,
    class = name
  ) |>
  # ignore the intercept
  filter(token != "(Intercept)") |>
  # clean up data to focus purely on magnitude of coefficients
  mutate(
    sign = ifelse(sign(importance) == 1, "POS", "NEG"),
    importance = abs(importance),
    token = str_remove_all(token, "wordembed_description_")
  ) |>
  # remove anything with importance of 0
  filter(importance != 0)

# visualize the top 10 tokens for each class
embed_vip |>
  # keep the top 10 coefficients for each
  filter(sign == "POS") |>
  slice_max(n = 10, order_by = importance, by = class) |>
  # change order of token levels for plotting
  mutate(token = reorder_within(token, by = importance, within = class)) |>
  ggplot(mapping = aes(
    x = importance,
    y = token
  )) +
  geom_col() +
  scale_y_reordered() +
  facet_wrap(
    facets = vars(class),
    scales = "free",
    ncol = 4,
    labeller = labeller(class = label_wrap_gen(20))
  ) +
  labs(
    title = "Most important dimensions for each policy class",
    x = "Importance",
    y = NULL
  )

Add response here. Unfortunately this form of analysis is now useless. There is no intuitive interpretation of each of the word embeddings dimensions since they don’t have inherent semantic meaning, and there are far too many of them to interpret even if they did. If we want to identify the most important or useful tokens, we have to use another method.

sessioninfo::session_info()
─ Session info ───────────────────────────────────────────────────────────────
 setting  value
 version  R version 4.4.1 (2024-06-14)
 os       macOS Sonoma 14.6.1
 system   aarch64, darwin20
 ui       X11
 language (EN)
 collate  en_US.UTF-8
 ctype    en_US.UTF-8
 tz       America/New_York
 date     2024-10-18
 pandoc   3.4 @ /usr/local/bin/ (via rmarkdown)

─ Packages ───────────────────────────────────────────────────────────────────
 package      * version    date (UTC) lib source
 arrow        * 17.0.0     2024-09-18 [1] https://apache.r-universe.dev (R 4.4.1)
 assertthat     0.2.1      2019-03-21 [1] CRAN (R 4.3.0)
 backports      1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
 bit            4.0.5      2022-11-15 [1] CRAN (R 4.3.0)
 bit64          4.0.5      2020-08-30 [1] CRAN (R 4.3.0)
 broom        * 1.0.6      2024-05-17 [1] CRAN (R 4.4.0)
 class          7.3-22     2023-05-03 [1] CRAN (R 4.4.0)
 cli            3.6.3      2024-06-21 [1] CRAN (R 4.4.0)
 codetools      0.2-20     2024-03-31 [1] CRAN (R 4.4.1)
 colorspace     2.1-1      2024-07-26 [1] CRAN (R 4.4.0)
 data.table     1.15.4     2024-03-30 [1] CRAN (R 4.3.1)
 dials        * 1.3.0      2024-07-30 [1] CRAN (R 4.4.0)
 DiceDesign     1.10       2023-12-07 [1] CRAN (R 4.3.1)
 digest         0.6.35     2024-03-11 [1] CRAN (R 4.3.1)
 doFuture       1.0.1      2023-12-20 [1] CRAN (R 4.3.1)
 dplyr        * 1.1.4      2023-11-17 [1] CRAN (R 4.3.1)
 ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.3.0)
 evaluate       0.24.0     2024-06-10 [1] CRAN (R 4.4.0)
 fansi          1.0.6      2023-12-08 [1] CRAN (R 4.3.1)
 farver         2.1.2      2024-05-13 [1] CRAN (R 4.3.3)
 fastmap        1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
 forcats      * 1.0.0      2023-01-29 [1] CRAN (R 4.3.0)
 foreach        1.5.2      2022-02-02 [1] CRAN (R 4.3.0)
 fs             1.6.4      2024-04-25 [1] CRAN (R 4.4.0)
 furrr          0.3.1      2022-08-15 [1] CRAN (R 4.3.0)
 future       * 1.33.2     2024-03-26 [1] CRAN (R 4.3.1)
 future.apply   1.11.2     2024-03-28 [1] CRAN (R 4.3.1)
 generics       0.1.3      2022-07-05 [1] CRAN (R 4.3.0)
 ggplot2      * 3.5.1      2024-04-23 [1] CRAN (R 4.3.1)
 glmnet       * 4.1-8      2023-08-22 [1] CRAN (R 4.3.0)
 globals        0.16.3     2024-03-08 [1] CRAN (R 4.3.1)
 glue           1.8.0      2024-09-30 [1] CRAN (R 4.4.1)
 gower          1.0.1      2022-12-22 [1] CRAN (R 4.3.0)
 GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.3.0)
 gtable         0.3.5      2024-04-22 [1] CRAN (R 4.3.1)
 hardhat      * 1.4.0      2024-06-02 [1] CRAN (R 4.4.0)
 here         * 1.0.1      2020-12-13 [1] CRAN (R 4.3.0)
 hms            1.1.3      2023-03-21 [1] CRAN (R 4.3.0)
 htmltools      0.5.8.1    2024-04-04 [1] CRAN (R 4.3.1)
 htmlwidgets    1.6.4      2023-12-06 [1] CRAN (R 4.3.1)
 infer        * 1.0.7      2024-03-25 [1] CRAN (R 4.3.1)
 ipred          0.9-14     2023-03-09 [1] CRAN (R 4.3.0)
 iterators      1.0.14     2022-02-05 [1] CRAN (R 4.3.0)
 janeaustenr    1.0.0      2022-08-26 [1] CRAN (R 4.3.0)
 jsonlite       1.8.9      2024-09-20 [1] CRAN (R 4.4.1)
 knitr          1.47       2024-05-29 [1] CRAN (R 4.4.0)
 labeling       0.4.3      2023-08-29 [1] CRAN (R 4.3.0)
 lattice        0.22-6     2024-03-20 [1] CRAN (R 4.4.0)
 lava           1.8.0      2024-03-05 [1] CRAN (R 4.3.1)
 lhs            1.1.6      2022-12-17 [1] CRAN (R 4.3.0)
 lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.3.1)
 listenv        0.9.1      2024-01-29 [1] CRAN (R 4.3.1)
 lubridate    * 1.9.3      2023-09-27 [1] CRAN (R 4.3.1)
 magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.3.0)
 MASS           7.3-61     2024-06-13 [1] CRAN (R 4.4.0)
 Matrix       * 1.7-0      2024-03-22 [1] CRAN (R 4.4.0)
 modeldata    * 1.4.0      2024-06-19 [1] CRAN (R 4.4.0)
 munsell        0.5.1      2024-04-01 [1] CRAN (R 4.3.1)
 nnet           7.3-19     2023-05-03 [1] CRAN (R 4.4.0)
 parallelly     1.37.1     2024-02-29 [1] CRAN (R 4.3.1)
 parsnip      * 1.2.1      2024-03-22 [1] CRAN (R 4.3.1)
 pillar         1.9.0      2023-03-22 [1] CRAN (R 4.3.0)
 pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.3.0)
 prodlim        2023.08.28 2023-08-28 [1] CRAN (R 4.3.0)
 purrr        * 1.0.2      2023-08-10 [1] CRAN (R 4.3.0)
 R6             2.5.1      2021-08-19 [1] CRAN (R 4.3.0)
 rappdirs       0.3.3      2021-01-31 [1] CRAN (R 4.3.0)
 Rcpp           1.0.13     2024-07-17 [1] CRAN (R 4.4.0)
 readr        * 2.1.5      2024-01-10 [1] CRAN (R 4.3.1)
 recipes      * 1.0.10     2024-02-18 [1] CRAN (R 4.3.1)
 rlang          1.1.4      2024-06-04 [1] CRAN (R 4.3.3)
 rmarkdown      2.27       2024-05-17 [1] CRAN (R 4.4.0)
 ROSE           0.0-4      2021-06-14 [1] CRAN (R 4.3.0)
 rpart          4.1.23     2023-12-05 [1] CRAN (R 4.4.0)
 rprojroot      2.0.4      2023-11-05 [1] CRAN (R 4.3.1)
 rsample      * 1.2.1      2024-03-25 [1] CRAN (R 4.3.1)
 rstudioapi     0.16.0     2024-03-24 [1] CRAN (R 4.3.1)
 scales       * 1.3.0      2023-11-28 [1] CRAN (R 4.4.0)
 sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.3.0)
 shape          1.4.6.1    2024-02-23 [1] CRAN (R 4.3.1)
 SnowballC      0.7.1      2023-04-25 [1] CRAN (R 4.3.0)
 stopwords      2.3        2021-10-28 [1] CRAN (R 4.3.0)
 stringi        1.8.4      2024-05-06 [1] CRAN (R 4.3.1)
 stringr      * 1.5.1      2023-11-14 [1] CRAN (R 4.3.1)
 survival       3.7-0      2024-06-05 [1] CRAN (R 4.4.0)
 textdata     * 0.4.5      2024-05-28 [1] CRAN (R 4.4.0)
 textrecipes  * 1.0.6      2023-11-15 [1] CRAN (R 4.3.1)
 themis       * 1.0.2      2023-08-14 [1] CRAN (R 4.3.0)
 tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.3.0)
 tictoc       * 1.2.1      2024-03-18 [1] CRAN (R 4.4.0)
 tidymodels   * 1.2.0      2024-03-25 [1] CRAN (R 4.3.1)
 tidyr        * 1.3.1      2024-01-24 [1] CRAN (R 4.3.1)
 tidyselect     1.2.1      2024-03-11 [1] CRAN (R 4.3.1)
 tidytext     * 0.4.2      2024-04-10 [1] CRAN (R 4.4.0)
 tidyverse    * 2.0.0      2023-02-22 [1] CRAN (R 4.3.0)
 timechange     0.3.0      2024-01-18 [1] CRAN (R 4.3.1)
 timeDate       4032.109   2023-12-14 [1] CRAN (R 4.3.1)
 tokenizers     0.3.0      2022-12-22 [1] CRAN (R 4.3.0)
 tune         * 1.2.1      2024-04-18 [1] CRAN (R 4.3.1)
 tzdb           0.4.0      2023-05-12 [1] CRAN (R 4.3.0)
 utf8           1.2.4      2023-10-22 [1] CRAN (R 4.3.1)
 vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.3.1)
 vip          * 0.4.1      2023-08-21 [1] CRAN (R 4.3.0)
 withr          3.0.1      2024-07-31 [1] CRAN (R 4.4.0)
 workflows    * 1.1.4      2024-02-19 [1] CRAN (R 4.3.1)
 workflowsets * 1.1.0      2024-03-21 [1] CRAN (R 4.3.1)
 xfun           0.45       2024-06-16 [1] CRAN (R 4.4.0)
 yaml           2.3.10     2024-07-26 [1] CRAN (R 4.4.0)
 yardstick    * 1.3.1      2024-03-21 [1] CRAN (R 4.3.1)

 [1] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library

──────────────────────────────────────────────────────────────────────────────