AE 14: Predicting legislative policy attention using a dense neural network

Suggested answers

Application exercise
Answers
Modified

October 29, 2024

Load packages

library(tidyverse)
library(tidymodels)
library(keras3)
library(textrecipes)
library(textdata)
library(arrow)
library(tidytext)

set.seed(123)

# preferred theme
theme_set(theme_minimal(base_size = 12, base_family = "Atkinson Hyperlegible"))

Import data

leg <- read_parquet(file = "data/legislation.parquet")

Split data for model fitting

# split into training and testing
set.seed(123)

# split into training and testing sets
leg_split <- initial_split(data = leg, strata = policy_lab, prop = 0.75)

legislation_train <- training(x = leg_split)
legislation_test <- testing(x = leg_split)
1
Since we’re using the entire dataset, there is no need to use cross-validation. Since deep learning models are trained on large (or greater) datasets, there should not be significant variation between the training and validation sets. It’s also generally too time-consuming to implement some form of resampling or cross-validation with deep learning models.

Typical length of legislative description

legislation_train |>
  unnest_tokens(output = word, input = description) |>
  anti_join(y = stop_words) |>
  count(id) |>
  ggplot(mapping = aes(x = n)) +
  geom_histogram(binwidth = 10, color = "white") +
  geom_rug() +
  labs(x = "Number of words per description",
       y = "Number of bills")
1
We need to know the typical length of the legislative descriptions since the neural network requires a fixed-length input (e.g. all descriptions must be the same length). This plot shows the number of words in each description. We can see that most descriptions are fewer than 50 words long once we remove common stop words.

Simple dense neural network

Preprocessing

max_words <- 3e4
max_length <- 50

# generate basic recipe for sequental one-hot encoding
leg_rec <- recipe(policy_lab ~ description, data = legislation_train) |>
  step_tokenize(description) |>
  step_stopwords(description) |>
  step_stem(description) |>
  step_tokenfilter(description, max_tokens = max_words) |>
  step_sequence_onehot(description, sequence_length = max_length) |>
  # fit policy_lab to numeric encoding
  step_integer(policy_lab, zero_based = TRUE)
leg_rec
1
Maximum number of words to include in the dictionary. Note this is much larger than our previous attempts using shallow ML techniques.
2
Maximum length of each sequence. Each document needs to be an identical length based on the number of tokens. Too long and it will take longer to train, plus we have to fill in the gaps with zeros which are non-informative. Too short and we lose information.
3
Convert the text into a sequence of integers. This is a requirement for the embedding layer in the neural network. Each token is assigned a unique integer. The sequence length is fixed to max_length and any tokens beyond this are truncated. Any sequences shorter than 100 tokens will be padded with zeros at the beginning of the sequence. 1 is reserved as a placeholder for “unknown” tokens. (Note this is not a problem in the current approach since we defined the entire dictionary, but it is used nonetheless as it becomes important later.)
4
Convert the policy labels to integers. This is required for {keras3} to understand the outcome. We also start counting at \(0\) since Keras is really a Python library and Python indexes from 0.
# prep and apply feature engineering recipe
leg_prep <- prep(leg_rec)

# what did we make?
leg_prep |> bake(new_data = slice(legislation_train, 1:5))

# bake to get outcome only
leg_train_outcome <- bake(leg_prep, new_data = NULL, starts_with("policy_lab")) 

# get weights for each policy topic
leg_train_weights <- count(leg_train_outcome, policy_lab) |>
  mutate(pct = n / sum(n)) |>
  select(-n) |>
  deframe() |>
  as.list()

# get outcome as vector of numeric integers for each topic
leg_train_outcome <- to_categorical(leg_train_outcome$policy_lab)
dim(leg_train_outcome)
head(leg_train_outcome)

# bake to get features only
leg_train <- bake(leg_prep, new_data = NULL, composition = "matrix", -starts_with("policy_lab"))
dim(leg_train)
head(leg_train)
1
Prepare the recipe for use. We aren’t using {tidymodels} to fit the model so we have to manually prep the recipe.
2
Calculate the weights for each policy topic. This is used to balance the classes in the model instead of something like step_undersample(). We still use all the observations in the training set, but Keras will weight the loss function (during training only) to pay more attention to observations from under-represented classes. We calculate the proportion of each policy topic in the training set and structure it as a list object.
3
Convert the outcome to a matrix (array/tensor) of integers. This is required for the model to understand the outcome.
4
Bake the recipe to get the features only. We will pass them separately to the model. Note we structure it as a matrix (array/tensor) for the model to understand.
# A tibble: 5 × 51
  policy_lab seq1hot_description_1 seq1hot_description_2 seq1hot_description_3
       <int>                 <int>                 <int>                 <int>
1          0                     0                     0                     0
2         14                     0                     0                     0
3         14                     0                     0                     0
4         14                     0                     0                     0
5         14                     0                     0                     0
# ℹ 47 more variables: seq1hot_description_4 <int>,
#   seq1hot_description_5 <int>, seq1hot_description_6 <int>,
#   seq1hot_description_7 <int>, seq1hot_description_8 <int>,
#   seq1hot_description_9 <int>, seq1hot_description_10 <int>,
#   seq1hot_description_11 <int>, seq1hot_description_12 <int>,
#   seq1hot_description_13 <int>, seq1hot_description_14 <int>,
#   seq1hot_description_15 <int>, seq1hot_description_16 <int>, …
[1] 287989     20
     [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13] [,14]
[1,]    1    0    0    0    0    0    0    0    0     0     0     0     0     0
[2,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0
[3,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0
[4,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0
[5,]    0    0    0    0    0    0    0    0    0     0     0     0     0     0
[6,]    0    0    0    0    1    0    0    0    0     0     0     0     0     0
     [,15] [,16] [,17] [,18] [,19] [,20]
[1,]     0     0     0     0     0     0
[2,]     1     0     0     0     0     0
[3,]     1     0     0     0     0     0
[4,]     1     0     0     0     0     0
[5,]     1     0     0     0     0     0
[6,]     0     0     0     0     0     0
[1] 287989     50
     seq1hot_description_1 seq1hot_description_2 seq1hot_description_3
[1,]                     0                     0                     0
[2,]                     0                     0                     0
[3,]                     0                     0                     0
[4,]                     0                     0                     0
[5,]                     0                     0                     0
[6,]                     0                     0                     0
     seq1hot_description_4 seq1hot_description_5 seq1hot_description_6
[1,]                     0                     0                     0
[2,]                     0                     0                     0
[3,]                     0                     0                     0
[4,]                     0                     0                     0
[5,]                     0                     0                     0
[6,]                     0                     0                     0
     seq1hot_description_7 seq1hot_description_8 seq1hot_description_9
[1,]                     0                     0                     0
[2,]                     0                     0                     0
[3,]                     0                     0                     0
[4,]                     0                     0                     0
[5,]                     0                     0                     0
[6,]                     0                     0                     0
     seq1hot_description_10 seq1hot_description_11 seq1hot_description_12
[1,]                      0                      0                      0
[2,]                      0                      0                      0
[3,]                      0                      0                      0
[4,]                      0                      0                      0
[5,]                      0                      0                      0
[6,]                      0                      0                      0
     seq1hot_description_13 seq1hot_description_14 seq1hot_description_15
[1,]                      0                      0                      0
[2,]                      0                      0                      0
[3,]                      0                      0                      0
[4,]                      0                      0                      0
[5,]                      0                      0                      0
[6,]                      0                      0                      0
     seq1hot_description_16 seq1hot_description_17 seq1hot_description_18
[1,]                      0                      0                      0
[2,]                      0                      0                      0
[3,]                      0                      0                      0
[4,]                      0                      0                      0
[5,]                      0                      0                      0
[6,]                      0                      0                      0
     seq1hot_description_19 seq1hot_description_20 seq1hot_description_21
[1,]                      0                      0                      0
[2,]                      0                      0                      0
[3,]                      0                      0                      0
[4,]                      0                      0                      0
[5,]                      0                      0                      0
[6,]                      0                      0                      0
     seq1hot_description_22 seq1hot_description_23 seq1hot_description_24
[1,]                      0                      0                      0
[2,]                      0                      0                      0
[3,]                      0                      0                      0
[4,]                      0                      0                      0
[5,]                      0                      0                      0
[6,]                      0                      0                      0
     seq1hot_description_25 seq1hot_description_26 seq1hot_description_27
[1,]                      0                      0                      0
[2,]                      0                      0                      0
[3,]                      0                      0                   5310
[4,]                      0                      0                      0
[5,]                      0                      0                      0
[6,]                      0                      0                      0
     seq1hot_description_28 seq1hot_description_29 seq1hot_description_30
[1,]                      0                      0                      0
[2,]                      0                      0                      0
[3,]                   6063                  14612                  19173
[4,]                      0                      0                      0
[5,]                      0                      0                      0
[6,]                      0                      0                      0
     seq1hot_description_31 seq1hot_description_32 seq1hot_description_33
[1,]                      0                      0                      0
[2,]                      0                      0                      0
[3,]                   4584                   1416                  23929
[4,]                      0                      0                      0
[5,]                      0                      0                      0
[6,]                      0                      0                      0
     seq1hot_description_34 seq1hot_description_35 seq1hot_description_36
[1,]                      0                      0                      0
[2,]                      0                      0                      0
[3,]                  15024                   5883                  12853
[4,]                      0                      0                      0
[5,]                      0                      0                      0
[6,]                      0                      0                      0
     seq1hot_description_37 seq1hot_description_38 seq1hot_description_39
[1,]                      0                      0                      0
[2,]                   5310                   6063                  14612
[3,]                   9842                  27916                  19173
[4,]                      0                      0                      0
[5,]                      0                      0                      0
[6,]                      0                   5310                  25090
     seq1hot_description_40 seq1hot_description_41 seq1hot_description_42
[1,]                      0                      0                      0
[2,]                  19173                   4584                   1416
[3,]                   8567                  19363                   7585
[4,]                      0                  24773                  16633
[5,]                      0                      0                      0
[6,]                  25667                   4584                   1399
     seq1hot_description_43 seq1hot_description_44 seq1hot_description_45
[1,]                      0                      0                      0
[2,]                  24789                  23289                   8567
[3,]                  24789                   7585                  18049
[4,]                  11979                  29236                  25774
[5,]                  20004                   5924                  13421
[6,]                  24789                  23782                   7044
     seq1hot_description_46 seq1hot_description_47 seq1hot_description_48
[1,]                  25335                  17395                  17291
[2,]                  23782                  27916                  19173
[3,]                   4584                  26525                  25309
[4,]                  18278                  25384                  18278
[5,]                  21842                   8723                  29255
[6,]                  25628                  21369                   5684
     seq1hot_description_49 seq1hot_description_50
[1,]                  27794                  23782
[2,]                   7585                  24896
[3,]                  28183                  24896
[4,]                  18871                  23316
[5,]                  24377                  23311
[6,]                  23861                  11585

Simple flattened dense network

dense_model <- keras_model_sequential(input_shape = c(max_length)) |>
  layer_embedding(input_dim = max_words + 1,
                  output_dim = 12) |>
  layer_flatten() |>
  layer_dense(units = 64, activation = "relu") |>
  layer_dense(units = 20, activation = "softmax")

summary(dense_model)
1
Declare the model as a sequential model. This is the simplest type of neural network where each layer feeds into the next.
2
Add an embedding layer. This is a dense layer that converts the integer-encoded tokens into dense vectors. The output dimension is the number of dimensions for the dense embedding. This is a hyperparameter that needs to be “tuned”.
3
Flatten the output of the embedding layer. This is required to feed the output into a dense layer.
4
Add a dense layer with 64 units and a ReLU activation function. Number of units is a tuning parameter.
5
Add a dense layer with 20 units and a softmax activation function. This is the output layer and the number of units is the number of policy topics.
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                      ┃ Output Shape             ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ embedding (Embedding)             │ (None, 50, 12)           │       360,012 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ flatten (Flatten)                 │ (None, 600)              │             0 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense (Dense)                     │ (None, 64)               │        38,464 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_1 (Dense)                   │ (None, 20)               │         1,300 │
└───────────────────────────────────┴──────────────────────────┴───────────────┘
 Total params: 399,776 (1.53 MB)
 Trainable params: 399,776 (1.53 MB)
 Non-trainable params: 0 (0.00 B)
dense_model |>
  compile(
    optimizer = "adam",
    loss = "categorical_crossentropy",
    metrics = c("accuracy", "auc")
  )
1
Compile the model with the Adam optimizer. This is a popular optimizer for neural networks.
2
Use categorical cross-entropy as the loss function. This is the loss function for multi-class classification problems. Cross entropy loss is a measure of the distance between the log of the predicted probability and the actual outcome.
3
Use accuracy and AUC as metrics to monitor during training. There are lots of other metrics available through Keras (see the metric_*() functions).
dense_history <- dense_model |>
  fit(
    x = leg_train,
    y = leg_train_outcome,
    batch_size = 1024,
    epochs = 10,
    validation_split = 0.25,
    class_weight = leg_train_weights
  )

plot(dense_history)
1
Fit the model to the training data. We pass the features and the outcome separately.
2
Set the batch size to 1024. This is the number of observations to pass through the model before updating the weights. Many optimization procedures for neural networks work more efficiently by processing the sample in smaller batches rather than all observations at once. The larger the batch, the longer it takes to update the weights for each batch (but also requires fewer batches). It can be used to control how fast the model learns from the training set.
3
Set the number of epochs to 10. This is the number of times the model will see the entire training set. Each epoch is a complete pass through the training set. The model will update the weights after each batch. There is no hard and fast rule for the number of epochs to use. As epochs increase, the model’s performance can improve but at a certain point it will begin to overfit the training set.
4
Set the validation split to 0.25. This is the proportion of the training set to use as a validation set. The model will not be trained on this data but will use it to evaluate performance. This is a way to monitor overfitting. We can do this using Keras rather than manually partitioning the data using {tidymodels}, but we can also do it manually if we want to ensure the exact same observations are used for validation to compare models.
5
Pass the class weights to the model. This is used to balance the classes in the model instead of something like step_undersample(). We still use all the observations in the training set, but Keras will weight the loss function (during training only) to pay more attention to observations from under-represented classes.
6
We typically plot the loss and evaluation metrics over the epochs to see when overfitting becomes a problem.

Epoch 1/10
211/211 - 2s - 9ms/step - accuracy: 0.3219 - auc: 0.7839 - loss: 0.1243 - val_accuracy: 0.4799 - val_auc: 0.8755 - val_loss: 1.8742
Epoch 2/10
211/211 - 1s - 4ms/step - accuracy: 0.6451 - auc: 0.9394 - loss: 0.0606 - val_accuracy: 0.6273 - val_auc: 0.9358 - val_loss: 1.3572
Epoch 3/10
211/211 - 1s - 5ms/step - accuracy: 0.7296 - auc: 0.9639 - loss: 0.0459 - val_accuracy: 0.6751 - val_auc: 0.9492 - val_loss: 1.1889
Epoch 4/10
211/211 - 1s - 5ms/step - accuracy: 0.7739 - auc: 0.9734 - loss: 0.0389 - val_accuracy: 0.7012 - val_auc: 0.9553 - val_loss: 1.0972
Epoch 5/10
211/211 - 1s - 5ms/step - accuracy: 0.8026 - auc: 0.9786 - loss: 0.0344 - val_accuracy: 0.7140 - val_auc: 0.9575 - val_loss: 1.0592
Epoch 6/10
211/211 - 1s - 5ms/step - accuracy: 0.8220 - auc: 0.9816 - loss: 0.0314 - val_accuracy: 0.7273 - val_auc: 0.9600 - val_loss: 1.0160
Epoch 7/10
211/211 - 1s - 5ms/step - accuracy: 0.8358 - auc: 0.9837 - loss: 0.0290 - val_accuracy: 0.7368 - val_auc: 0.9609 - val_loss: 0.9947
Epoch 8/10
211/211 - 1s - 5ms/step - accuracy: 0.8469 - auc: 0.9852 - loss: 0.0271 - val_accuracy: 0.7502 - val_auc: 0.9631 - val_loss: 0.9547
Epoch 9/10
211/211 - 1s - 5ms/step - accuracy: 0.8545 - auc: 0.9865 - loss: 0.0256 - val_accuracy: 0.7406 - val_auc: 0.9608 - val_loss: 0.9966
Epoch 10/10
211/211 - 1s - 5ms/step - accuracy: 0.8611 - auc: 0.9875 - loss: 0.0243 - val_accuracy: 0.7490 - val_auc: 0.9618 - val_loss: 0.9764

Pre-trained word embeddings

Prep embeddings

# download and extract embeddings
##### uncomment if you are running RStudio on your personal computer
# extract 100 dimensions from GLoVE
glove6b <- embedding_glove6b(dimensions = 100)

# ##### uncomment if you are running RStudio on the Workbench
# # hacky way to make it work on RStudio Workbench
# glove6b <- read_delim(
#     file = "/rstudio-files/glove6b/glove.6B.100d.txt",
#     delim = " ",
#     quote = "",
#     col_names = c(
#       "token",
#       paste0("d", seq_len(100))
#     ),
#     col_types = paste0(
#       c(
#         "c",
#         rep("d", 100)
#       ),
#       collapse = ""
#     )
#   )

# filter to only include tokens in policy description vocab
glove6b_matrix <- tidy(leg_prep, 5) |>
  select(token) |>
  left_join(glove6b) |>
  # add row to capture all tokens not in GLoVE
  add_row(.before = 1) |>
  mutate(across(.cols = starts_with("d"), .fns = \(x) replace_na(data = x, replace = 0))) |>
  select(-token) |>
  as.matrix()
1
Add a row to the top of the matrix to capture all tokens not in the GLoVE embeddings. This is required for the embedding layer in the neural network.
2
Replace any missing values with 0. This is required for the embedding layer in the neural network.
3
Convert the data frame to a matrix. This is required for the embedding layer in the neural network.

Train model

# declare model specification
dense_model_pte <- keras_model_sequential(input_shape = c(max_length)) |>
  layer_embedding(input_dim = max_words + 1,
                  output_dim = ncol(glove6b_matrix),
                  weights = glove6b_matrix,
                  trainable = FALSE) |>
  layer_flatten() |>
  layer_dense(units = 64, activation = "relu") |>
  layer_dense(units = 20, activation = "softmax")
summary(dense_model_pte)
1
Set the output dimension to the number of columns in the GLoVE matrix.
2
Pass the GLoVE matrix to the embedding layer. It will prepopulate the layer with the weights from GLoVE rather than having to calculate them from scratch.
3
Set the weights to be non-trainable. This is required to keep the weights from GLoVE fixed during training. We don’t want to adjust the weights in the embedding layer since they are pre-trained, and it will save computation time.
Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━┓
┃ Layer (type)                  ┃ Output Shape           ┃     Param # ┃ Trai… ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━┩
│ embedding_1 (Embedding)       │ (None, 50, 100)        │   3,000,100 │   N   │
├───────────────────────────────┼────────────────────────┼─────────────┼───────┤
│ flatten_1 (Flatten)           │ (None, 5000)           │           0 │   -   │
├───────────────────────────────┼────────────────────────┼─────────────┼───────┤
│ dense_2 (Dense)               │ (None, 64)             │     320,064 │   Y   │
├───────────────────────────────┼────────────────────────┼─────────────┼───────┤
│ dense_3 (Dense)               │ (None, 20)             │       1,300 │   Y   │
└───────────────────────────────┴────────────────────────┴─────────────┴───────┘
 Total params: 3,321,464 (12.67 MB)
 Trainable params: 321,364 (1.23 MB)
 Non-trainable params: 3,000,100 (11.44 MB)
dense_model_pte |>
  compile(
    optimizer = "adam",
    loss = "categorical_crossentropy",
    metrics = c("accuracy", "auc")
  )
dense_history_pte <- dense_model_pte |>
  fit(
    x = leg_train,
    y = leg_train_outcome,
    batch_size = 1024,
    epochs = 10,
    validation_split = 0.25,
    class_weight = leg_train_weights
  )
Epoch 1/10
211/211 - 3s - 12ms/step - accuracy: 0.4580 - auc: 0.8710 - loss: 0.1021 - val_accuracy: 0.4917 - val_auc: 0.8922 - val_loss: 1.7991
Epoch 2/10
211/211 - 2s - 8ms/step - accuracy: 0.6072 - auc: 0.9357 - loss: 0.0721 - val_accuracy: 0.5318 - val_auc: 0.9120 - val_loss: 1.6399
Epoch 3/10
211/211 - 2s - 8ms/step - accuracy: 0.6535 - auc: 0.9489 - loss: 0.0637 - val_accuracy: 0.5600 - val_auc: 0.9182 - val_loss: 1.5747
Epoch 4/10
211/211 - 2s - 8ms/step - accuracy: 0.6834 - auc: 0.9557 - loss: 0.0582 - val_accuracy: 0.5622 - val_auc: 0.9186 - val_loss: 1.5813
Epoch 5/10
211/211 - 2s - 8ms/step - accuracy: 0.7045 - auc: 0.9603 - loss: 0.0542 - val_accuracy: 0.5807 - val_auc: 0.9224 - val_loss: 1.5380
Epoch 6/10
211/211 - 2s - 8ms/step - accuracy: 0.7215 - auc: 0.9635 - loss: 0.0511 - val_accuracy: 0.5760 - val_auc: 0.9211 - val_loss: 1.5607
Epoch 7/10
211/211 - 2s - 8ms/step - accuracy: 0.7340 - auc: 0.9662 - loss: 0.0485 - val_accuracy: 0.5797 - val_auc: 0.9205 - val_loss: 1.5763
Epoch 8/10
211/211 - 2s - 8ms/step - accuracy: 0.7463 - auc: 0.9682 - loss: 0.0463 - val_accuracy: 0.5951 - val_auc: 0.9231 - val_loss: 1.5342
Epoch 9/10
211/211 - 2s - 8ms/step - accuracy: 0.7565 - auc: 0.9698 - loss: 0.0445 - val_accuracy: 0.5851 - val_auc: 0.9192 - val_loss: 1.5976
Epoch 10/10
211/211 - 2s - 7ms/step - accuracy: 0.7642 - auc: 0.9713 - loss: 0.0429 - val_accuracy: 0.5917 - val_auc: 0.9200 - val_loss: 1.5903
plot(dense_history_pte)

Allow weights to adjust

# declare model specification
dense_model_pte2 <- keras_model_sequential(input_shape = c(max_length)) |>
  layer_embedding(input_dim = max_words + 1,
                  output_dim = ncol(glove6b_matrix),
                  weights = glove6b_matrix) |>
  layer_flatten() |>
  layer_dense(units = 64, activation = "relu") |>
  layer_dense(units = 20, activation = "softmax")
summary(dense_model_pte2)
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                      ┃ Output Shape             ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ embedding_2 (Embedding)           │ (None, 50, 100)          │     3,000,100 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ flatten_2 (Flatten)               │ (None, 5000)             │             0 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_4 (Dense)                   │ (None, 64)               │       320,064 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_5 (Dense)                   │ (None, 20)               │         1,300 │
└───────────────────────────────────┴──────────────────────────┴───────────────┘
 Total params: 3,321,464 (12.67 MB)
 Trainable params: 3,321,464 (12.67 MB)
 Non-trainable params: 0 (0.00 B)
dense_model_pte2 |>
  compile(
    optimizer = "adam",
    loss = "categorical_crossentropy",
    metrics = c("accuracy", "auc")
  )
dense_pte2_history <- dense_model_pte2 |>
  fit(
    x = leg_train,
    y = leg_train_outcome,
    batch_size = 1024,
    epochs = 10,
    validation_split = 0.25,
    class_weight = leg_train_weights
  )
Epoch 1/10
211/211 - 4s - 20ms/step - accuracy: 0.5288 - auc: 0.9006 - loss: 0.0848 - val_accuracy: 0.6305 - val_auc: 0.9398 - val_loss: 1.3194
Epoch 2/10
211/211 - 3s - 16ms/step - accuracy: 0.7503 - auc: 0.9711 - loss: 0.0435 - val_accuracy: 0.6974 - val_auc: 0.9571 - val_loss: 1.0823
Epoch 3/10
211/211 - 3s - 16ms/step - accuracy: 0.8077 - auc: 0.9811 - loss: 0.0336 - val_accuracy: 0.7239 - val_auc: 0.9611 - val_loss: 1.0043
Epoch 4/10
211/211 - 3s - 16ms/step - accuracy: 0.8381 - auc: 0.9855 - loss: 0.0282 - val_accuracy: 0.7368 - val_auc: 0.9629 - val_loss: 0.9710
Epoch 5/10
211/211 - 4s - 17ms/step - accuracy: 0.8589 - auc: 0.9882 - loss: 0.0244 - val_accuracy: 0.7323 - val_auc: 0.9603 - val_loss: 1.0057
Epoch 6/10
211/211 - 3s - 16ms/step - accuracy: 0.8744 - auc: 0.9901 - loss: 0.0215 - val_accuracy: 0.7495 - val_auc: 0.9619 - val_loss: 0.9693
Epoch 7/10
211/211 - 3s - 16ms/step - accuracy: 0.8866 - auc: 0.9915 - loss: 0.0191 - val_accuracy: 0.7556 - val_auc: 0.9623 - val_loss: 0.9625
Epoch 8/10
211/211 - 3s - 16ms/step - accuracy: 0.8972 - auc: 0.9927 - loss: 0.0171 - val_accuracy: 0.7566 - val_auc: 0.9606 - val_loss: 0.9877
Epoch 9/10
211/211 - 4s - 17ms/step - accuracy: 0.9057 - auc: 0.9936 - loss: 0.0155 - val_accuracy: 0.7581 - val_auc: 0.9596 - val_loss: 1.0033
Epoch 10/10
211/211 - 3s - 16ms/step - accuracy: 0.9128 - auc: 0.9945 - loss: 0.0141 - val_accuracy: 0.7571 - val_auc: 0.9574 - val_loss: 1.0430
plot(dense_pte2_history)

sessioninfo::session_info()
─ Session info ───────────────────────────────────────────────────────────────
 setting  value
 version  R version 4.4.1 (2024-06-14)
 os       macOS Sonoma 14.6.1
 system   aarch64, darwin20
 ui       X11
 language (EN)
 collate  en_US.UTF-8
 ctype    en_US.UTF-8
 tz       America/New_York
 date     2024-10-29
 pandoc   3.4 @ /usr/local/bin/ (via rmarkdown)

─ Packages ───────────────────────────────────────────────────────────────────
 package      * version     date (UTC) lib source
 arrow        * 17.0.0      2024-09-18 [1] https://apache.r-universe.dev (R 4.4.1)
 assertthat     0.2.1       2019-03-21 [1] CRAN (R 4.3.0)
 backports      1.5.0       2024-05-23 [1] CRAN (R 4.4.0)
 base64enc      0.1-3       2015-07-28 [1] CRAN (R 4.3.0)
 bit            4.0.5       2022-11-15 [1] CRAN (R 4.3.0)
 bit64          4.0.5       2020-08-30 [1] CRAN (R 4.3.0)
 broom        * 1.0.6       2024-05-17 [1] CRAN (R 4.4.0)
 class          7.3-22      2023-05-03 [1] CRAN (R 4.4.0)
 cli            3.6.3       2024-06-21 [1] CRAN (R 4.4.0)
 codetools      0.2-20      2024-03-31 [1] CRAN (R 4.4.1)
 colorspace     2.1-1       2024-07-26 [1] CRAN (R 4.4.0)
 data.table     1.15.4      2024-03-30 [1] CRAN (R 4.3.1)
 dials        * 1.3.0       2024-07-30 [1] CRAN (R 4.4.0)
 DiceDesign     1.10        2023-12-07 [1] CRAN (R 4.3.1)
 digest         0.6.35      2024-03-11 [1] CRAN (R 4.3.1)
 dplyr        * 1.1.4       2023-11-17 [1] CRAN (R 4.3.1)
 ellipsis       0.3.2       2021-04-29 [1] CRAN (R 4.3.0)
 evaluate       0.24.0      2024-06-10 [1] CRAN (R 4.4.0)
 fansi          1.0.6       2023-12-08 [1] CRAN (R 4.3.1)
 farver         2.1.2       2024-05-13 [1] CRAN (R 4.3.3)
 fastmap        1.2.0       2024-05-15 [1] CRAN (R 4.4.0)
 forcats      * 1.0.0       2023-01-29 [1] CRAN (R 4.3.0)
 foreach        1.5.2       2022-02-02 [1] CRAN (R 4.3.0)
 fs             1.6.4       2024-04-25 [1] CRAN (R 4.4.0)
 furrr          0.3.1       2022-08-15 [1] CRAN (R 4.3.0)
 future         1.33.2      2024-03-26 [1] CRAN (R 4.3.1)
 future.apply   1.11.2      2024-03-28 [1] CRAN (R 4.3.1)
 generics       0.1.3       2022-07-05 [1] CRAN (R 4.3.0)
 ggplot2      * 3.5.1       2024-04-23 [1] CRAN (R 4.3.1)
 globals        0.16.3      2024-03-08 [1] CRAN (R 4.3.1)
 glue           1.8.0       2024-09-30 [1] CRAN (R 4.4.1)
 gower          1.0.1       2022-12-22 [1] CRAN (R 4.3.0)
 GPfit          1.0-8       2019-02-08 [1] CRAN (R 4.3.0)
 gtable         0.3.5       2024-04-22 [1] CRAN (R 4.3.1)
 hardhat        1.4.0       2024-06-02 [1] CRAN (R 4.4.0)
 here           1.0.1       2020-12-13 [1] CRAN (R 4.3.0)
 hms            1.1.3       2023-03-21 [1] CRAN (R 4.3.0)
 htmltools      0.5.8.1     2024-04-04 [1] CRAN (R 4.3.1)
 htmlwidgets    1.6.4       2023-12-06 [1] CRAN (R 4.3.1)
 infer        * 1.0.7       2024-03-25 [1] CRAN (R 4.3.1)
 ipred          0.9-14      2023-03-09 [1] CRAN (R 4.3.0)
 iterators      1.0.14      2022-02-05 [1] CRAN (R 4.3.0)
 janeaustenr    1.0.0       2022-08-26 [1] CRAN (R 4.3.0)
 jsonlite       1.8.9       2024-09-20 [1] CRAN (R 4.4.1)
 keras3       * 1.2.0       2024-09-05 [1] CRAN (R 4.4.1)
 knitr          1.47        2024-05-29 [1] CRAN (R 4.4.0)
 labeling       0.4.3       2023-08-29 [1] CRAN (R 4.3.0)
 lattice        0.22-6      2024-03-20 [1] CRAN (R 4.4.0)
 lava           1.8.0       2024-03-05 [1] CRAN (R 4.3.1)
 lhs            1.1.6       2022-12-17 [1] CRAN (R 4.3.0)
 lifecycle      1.0.4       2023-11-07 [1] CRAN (R 4.3.1)
 listenv        0.9.1       2024-01-29 [1] CRAN (R 4.3.1)
 lubridate    * 1.9.3       2023-09-27 [1] CRAN (R 4.3.1)
 magrittr       2.0.3       2022-03-30 [1] CRAN (R 4.3.0)
 MASS           7.3-61      2024-06-13 [1] CRAN (R 4.4.0)
 Matrix         1.7-0       2024-03-22 [1] CRAN (R 4.4.0)
 mgcv           1.9-1       2023-12-21 [1] CRAN (R 4.4.0)
 modeldata    * 1.4.0       2024-06-19 [1] CRAN (R 4.4.0)
 munsell        0.5.1       2024-04-01 [1] CRAN (R 4.3.1)
 nlme           3.1-165     2024-06-06 [1] CRAN (R 4.4.0)
 nnet           7.3-19      2023-05-03 [1] CRAN (R 4.4.0)
 parallelly     1.37.1      2024-02-29 [1] CRAN (R 4.3.1)
 parsnip      * 1.2.1       2024-03-22 [1] CRAN (R 4.3.1)
 pillar         1.9.0       2023-03-22 [1] CRAN (R 4.3.0)
 pkgconfig      2.0.3       2019-09-22 [1] CRAN (R 4.3.0)
 png            0.1-8       2022-11-29 [1] CRAN (R 4.3.0)
 prodlim        2023.08.28  2023-08-28 [1] CRAN (R 4.3.0)
 purrr        * 1.0.2       2023-08-10 [1] CRAN (R 4.3.0)
 R6             2.5.1       2021-08-19 [1] CRAN (R 4.3.0)
 rappdirs       0.3.3       2021-01-31 [1] CRAN (R 4.3.0)
 Rcpp           1.0.13      2024-07-17 [1] CRAN (R 4.4.0)
 readr        * 2.1.5       2024-01-10 [1] CRAN (R 4.3.1)
 recipes      * 1.0.10      2024-02-18 [1] CRAN (R 4.3.1)
 reticulate     1.39.0      2024-09-05 [1] CRAN (R 4.4.1)
 rlang          1.1.4       2024-06-04 [1] CRAN (R 4.3.3)
 rmarkdown      2.27        2024-05-17 [1] CRAN (R 4.4.0)
 rpart          4.1.23      2023-12-05 [1] CRAN (R 4.4.0)
 rprojroot      2.0.4       2023-11-05 [1] CRAN (R 4.3.1)
 rsample      * 1.2.1       2024-03-25 [1] CRAN (R 4.3.1)
 rstudioapi     0.17.0      2024-10-16 [1] CRAN (R 4.4.1)
 scales       * 1.3.0       2023-11-28 [1] CRAN (R 4.4.0)
 sessioninfo    1.2.2       2021-12-06 [1] CRAN (R 4.3.0)
 SnowballC      0.7.1       2023-04-25 [1] CRAN (R 4.3.0)
 stopwords      2.3         2021-10-28 [1] CRAN (R 4.3.0)
 stringi        1.8.4       2024-05-06 [1] CRAN (R 4.3.1)
 stringr      * 1.5.1       2023-11-14 [1] CRAN (R 4.3.1)
 survival       3.7-0       2024-06-05 [1] CRAN (R 4.4.0)
 tensorflow     2.16.0.9000 2024-10-18 [1] Github (rstudio/tensorflow@2ed2029)
 textdata     * 0.4.5       2024-05-28 [1] CRAN (R 4.4.0)
 textrecipes  * 1.0.6       2023-11-15 [1] CRAN (R 4.3.1)
 tfruns         1.5.3       2024-04-19 [1] CRAN (R 4.4.0)
 tibble       * 3.2.1       2023-03-20 [1] CRAN (R 4.3.0)
 tidymodels   * 1.2.0       2024-03-25 [1] CRAN (R 4.3.1)
 tidyr        * 1.3.1       2024-01-24 [1] CRAN (R 4.3.1)
 tidyselect     1.2.1       2024-03-11 [1] CRAN (R 4.3.1)
 tidytext     * 0.4.2       2024-04-10 [1] CRAN (R 4.4.0)
 tidyverse    * 2.0.0       2023-02-22 [1] CRAN (R 4.3.0)
 timechange     0.3.0       2024-01-18 [1] CRAN (R 4.3.1)
 timeDate       4032.109    2023-12-14 [1] CRAN (R 4.3.1)
 tokenizers     0.3.0       2022-12-22 [1] CRAN (R 4.3.0)
 tune         * 1.2.1       2024-04-18 [1] CRAN (R 4.3.1)
 tzdb           0.4.0       2023-05-12 [1] CRAN (R 4.3.0)
 utf8           1.2.4       2023-10-22 [1] CRAN (R 4.3.1)
 vctrs          0.6.5       2023-12-01 [1] CRAN (R 4.3.1)
 whisker        0.4.1       2022-12-05 [1] CRAN (R 4.3.0)
 withr          3.0.1       2024-07-31 [1] CRAN (R 4.4.0)
 workflows    * 1.1.4       2024-02-19 [1] CRAN (R 4.3.1)
 workflowsets * 1.1.0       2024-03-21 [1] CRAN (R 4.3.1)
 xfun           0.45        2024-06-16 [1] CRAN (R 4.4.0)
 yaml           2.3.10      2024-07-26 [1] CRAN (R 4.4.0)
 yardstick    * 1.3.1       2024-03-21 [1] CRAN (R 4.3.1)
 zeallot        0.1.0       2018-01-28 [1] CRAN (R 4.3.0)

 [1] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library

─ Python configuration ───────────────────────────────────────────────────────
 python:         /Users/soltoffbc/.virtualenvs/r-keras/bin/python
 libpython:      /Users/soltoffbc/.pyenv/versions/3.10.15/lib/libpython3.10.dylib
 pythonhome:     /Users/soltoffbc/.virtualenvs/r-keras:/Users/soltoffbc/.virtualenvs/r-keras
 version:        3.10.15 (main, Oct 21 2024, 09:00:09) [Clang 15.0.0 (clang-1500.3.9.4)]
 numpy:          /Users/soltoffbc/.virtualenvs/r-keras/lib/python3.10/site-packages/numpy
 numpy_version:  1.26.4
 keras:          /Users/soltoffbc/.virtualenvs/r-keras/lib/python3.10/site-packages/keras
 
 NOTE: Python version was forced by RETICULATE_PYTHON

──────────────────────────────────────────────────────────────────────────────