AE 16: Classifying legislative texts with LLMs

Application exercise
Modified

October 31, 2024

Load packages

library(tidyverse)
library(tidyllm)
library(tidymodels)
library(arrow)
library(tictoc)
library(scales)
library(colorspace)

# preferred theme
theme_set(theme_minimal(base_size = 12, base_family = "Atkinson Hyperlegible"))
Warning

If you have not already completed the pre-class preparation to set up your API key, do this now.

Classifying legislative policy attention

# import data
leg <- read_parquet(file = "data/legislation.parquet")

Our goal is to classify these legislative bill descriptions into one of the Comparative Agendas Project’s major policy categories.

cap_codes <- leg |>
  distinct(policy, policy_lab) |>
  arrange(policy)
cap_codes

We previously attempted to complete this task by training shallow and deep learning models to predict the policy topic. In this AE we will use a large language model (LLM) to classify these descriptions into one of the policy categories.

Sub-sample and partition data

We start by ensuring we only classify distinct legislative descriptions. This eliminates duplicates and ensures a more efficient and reliable classification process.

Note

For our purposes today (and to avoid excessive costs), we will only classify a small sub-sample of the data. In a real-world scenario, you would start with a small sub-sample but eventually label all the observations.

set.seed(521)

leg_lite <- leg |>
  distinct(description, policy) |>
  slice_sample(n = 500L)
leg_lite

# distribution of sample
leg_lite |>
  count(policy)

Next, we divide the bills into training/test sets to evaluate the LLMs’ performance.

set.seed(783)

leg_split <- initial_split(leg_lite, prop = 0.8)

leg_train <- training(leg_split)
leg_test <- testing(leg_split)

dim(leg_train)
dim(leg_test)

Construct an initial classifier

Create a basic classifier function

Demonstration: To classify this data, we write a custom function that wraps llm_message(). This function sends each legislative description to an LLM and prompts it to assign one of the pre-defined policy codes.

classify_policy <- function(description) {
  # output what the model is currently doing to the console
  str_glue("Classifying: {description}\n") |> message()

  # generate the prompt
  prompt <- str_glue("
      Classify this legislative description from the U.S. Congress: {description}

      Pick one of the following numerical codes from this list.
      Respond only with the code!
      1 = Macroeconomics
      2 = Civil rights, minority issues, civil liberties
      3 = Health
      4 = Agriculture
      5 = Labor and employment
      6 = Education
      7 = Environment
      8 = Energy
      9 = Immigration
      10 = Transportation
      12 = Law, crime, family issues
      13 = Social welfare
      14 = Community development and housing issues
      15 = Banking, finance, and domestic commerce
      16 = Defense
      17 = Space, technology, and communications
      18 = Foreign trade
      19 = International affairs and foreign aid
      20 = Government operations
      21 = Public lands and water management")

  # list of valid codes as strings
  valid_codes <- as.character(cap_codes$policy)

  # attempt to classify the description
  classification <- tryCatch(
    {
      # get the assistant's reply
      assistant_reply <- llm_message(prompt) |>
        openai(.model = "gpt-4o-mini", .temperature = 0) |>
        last_reply() |>
        str_squish()

      # validate the assistant's reply
      if (assistant_reply %in% valid_codes) {
        as.integer(assistant_reply)
      } else {
        # if the reply is not a valid code, set code 98
        98L
      }
    },
    error = function(e) {
      # if there's an error with the model, set code 97
      97L
    }
  )

  # output a tibble
  return(
    tibble(
      description = description,
      .pred = classification
    )
  )
}

Let’s test the function on a single observation:

classify_policy("To amend the Internal Revenue Code of 1986 to provide for the treatment of certain direct primary care service arrangements as medical care.")

Your turn: Apply the classifier to the entire training set iteratively using an appropriate purrr::map_*() function, and collapse the results into a single tibble for further analysis.

tic()
leg_train_pred <- leg_train$description |>
  TODO
toc()

Evaluate the performance

Your turn: Examine the performance of these predictions. What are appropriate metrics to use? How does the classifier perform?

# add labels for API-generated codes
cap_codes_api <- cap_codes |>
  bind_rows(
    tribble(
      ~policy, ~policy_lab,
      97, "API-connection failure",
      98, "Invalid response",
      99, "Missing (no clear policy)"
    )
  )

# combine predictions with true values
leg_train_pred_labels <- leg_train |>
  bind_cols(.pred = leg_train_pred$.pred) |>
  # convert truth and estimates to factors for evaluating performance
  mutate(across(
    .cols = c(policy, .pred),
    .fns = \(x) factor(
      x,
      levels = cap_codes_api$policy,
      labels = cap_codes_api$policy_lab
    )
  ))
leg_train_pred_labels

# choose class-based metrics
llm_metrics <- metric_set(TODO)

leg_train_pred_labels |>
  llm_metrics(truth = policy, estimate = .pred)

# confusion matrix
leg_train_pred_labels |>
  conf_mat(truth = policy, estimate = .pred) |>
  autoplot(type = "heatmap") +
  scale_fill_continuous_sequential() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

Add response here. TODO

Evaluate multiple classifiers

Demonstration: To test different prompts and models systematically we need to allow for a more flexible classifier function that can handle different prompts or models. For this we take the prompt-building logic out of the function and allow for different api-functions and models as function arguments:

# numeric code list for reuse
cap_code_list <- c("
      1 = Macroeconomics
      2 = Civil rights, minority issues, civil liberties
      3 = Health
      4 = Agriculture
      5 = Labor and employment
      6 = Education
      7 = Environment
      8 = Energy
      9 = Immigration
      10 = Transportation
      12 = Law, crime, family issues
      13 = Social welfare
      14 = Community development and housing issues
      15 = Banking, finance, and domestic commerce
      16 = Defense
      17 = Space, technology, and communications
      18 = Foreign trade
      19 = International affairs and foreign aid
      20 = Government operations
      21 = Public lands and water management
      ")

# classification function that accepts a prompt, api_function, and model,
# as well as the true value to pass through as arguments
classify_policy_compare <- function(description,
                                    policy,
                                    prompt,
                                    prompt_id,
                                    api_function,
                                    model) {
  # print (message) what the model is currently doing to the console
  str_glue("Classifying: {model} - {prompt_id} - {description}\n") |> message()

  # list of valid codes as strings
  valid_codes <- as.character(cap_codes$policy)

  # attempt to classify the description
  classification <- tryCatch(
    {
      # get the assistant's reply using the dynamically provided API function and model
      assistant_reply <- llm_message(prompt) |>
        api_function(.model = model, .temperature = 0) |>
        last_reply() |>
        str_squish()

      # validate the assistant's reply
      if (assistant_reply %in% valid_codes) {
        as.integer(assistant_reply)
      } else {
        98L # return 98 for invalid responses
      }
    },
    error = function(e) {
      97L # return 97 in case of an error (e.g., API failure)
    }
  )

  # Return a tibble containing the original occupation description and classification result
  return(
    tibble(
      description = description,
      .pred = classification,
      .truth = policy,
      model = model,
      prompt_id = prompt_id
    )
  )
}

Defining the prompt and model grid

We’ll define a set of prompts and models that we want to test. This will allow us to apply the classifier across different configurations and compare results. Here’s how the prompts and models are set up:

  1. Prompts:
  • Prompt 1: A detailed prompt explaining the purpose of this classification task and what policy attention means.
  • Prompt 2: Explicitly ask it to avoid making guesses by returning a special code (99) when the LLM is unsure.
  • Prompt 3: A shorter, more concise version to test whether the model performs similarly with less detailed instructions.
  1. Models:1

1 This would also be an appropriate time to evaluate models from other providers, such as Anthropic, Mistral, or open models running locally via Ollama.

  • GPT 4o-mini
  • GPT 4o

We set up a grid combining all the prompts and models. The expand_grid() function is a useful tool here to create every possible combination of prompts and models, which we will use to evaluate the classifier.

Your turn: Develop appropriate prompts that meet the requirements above.

prompts <- tibble(
  prompt =
    c( # original prompt
      "TODO",
      # more explanation about the task
      "TODO",
      # shorter prompt
      "TODO"
    ),
  prompt_id = 1:3
)

grid <- expand_grid(
  leg_train,
  prompts,
  model = c("gpt-4o-mini", "gpt-4o")
) |>
  arrange(model) |>
  rowwise() |> # glue together prompts and occupation row-by-row
  mutate(prompt = str_glue(prompt)) |>
  ungroup() |> # ungroup after the rowwise operation
  select(description, policy, prompt, prompt_id, model)

Generate predictions

Demonstration: To run the classification across the entire grid, we use pmap() from the {purrr} package, which allows us to iterate over multiple arguments simultaneously. Each combination of legislative description, prompt, and model is passed into the classify_policy_compare() function, and the results are concatenated into a single tibble:

grid_results <- grid |>
  pmap(classify_policy_compare, api_function = openai) |>
  list_rbind()

Assess performance

Your turn: Evaluate the performance of the classifiers across different prompts and models. How does each perform?

grid_factors <- grid_results |>
  mutate(across(
    .cols = c(.pred, .truth),
    .fns = \(x) factor(x, levels = cap_codes_api$policy, labels = cap_codes_api$policy_lab)
  ))
grid_factors |>
  mutate(prompt_id = factor(prompt_id,
    labels = c(
      "Detailed", "No guessing",
      "Shortest instructions"
    )
  )) |>
  group_by(prompt_id, model) |>
  llm_metrics(truth = .truth, estimate = .pred) |>
  ggplot(mapping = aes(x = prompt_id, y = .estimate, fill = model)) +
  geom_col(position = "dodge") +
  scale_fill_discrete_qualitative() +
  scale_x_discrete(labels = label_wrap(width = 15)) +
  facet_wrap(facets = vars(.metric)) +
  labs(
    title = "Accuracy by prompt and model",
    x = "Prompt",
    y = "Metric value",
    fill = "Model"
  ) +
  theme(legend.position = "top")

Add response here.

Which policy topics cause the most confusion?

grid_factors |>
  filter(prompt_id == 1, model == "gpt-4o-mini") |>
  conf_mat(truth = .truth, estimate = .pred) |>
  autoplot(type = "heatmap") +
  scale_fill_continuous_sequential() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

Add response here.

Multistep chain-of-thought prompting

For this we need a major change in the classification function, because we send two messages to the model. A first one elicits a reasoning step, the second one asks for the final code, based on the answer to the first message.

classify_policy_cot <- function(description,
                                policy,
                                prompt,
                                prompt_id,
                                api_function,
                                model,
                                stream = FALSE) {
  # Output what the model is currently doing to the console
  str_glue("Classifying with CoT: {model} - {description}\n") |> message()

  # Step 1: Ask the model to think through the problem
  prompt_reasoning <- str_glue('
    Think about which of the following policy topic codes would best describe this legislative description from the U.S. Congress: "{description}"

    {cap_code_list}

    Explain your reasoning for the 3 top candidate codes step by step. Then evaluate which seems best.
  ')

  reasoning_response <- tryCatch(
    {
      conversation <<- llm_message(prompt_reasoning) |>
        api_function(.model = model, .temperature = 0, .stream = stream)

      conversation |>
        last_reply()
    },
    error = function(e) {
      conversation <<- llm_message("Please classify this policy topic: {description}")
      "Error in reasoning step."
    }
  )

  # Step 2: Ask the model to provide the final answer
  prompt_final <- str_glue("
    Based on your reasoning, which code do you pick? Answer only with a numerical code!

  ")

  final_response <- tryCatch(
    {
      conversation |>
        llm_message(prompt_final) |>
        api_function(.model = model, .temperature = 0, .stream = stream) |>
        last_reply() |>
        str_squish()
    },
    error = function(e) {
      "97"
    }
  )

  # Validate the model's final response
  valid_codes <- as.character(cap_codes$policy)

  classification <- if (final_response %in% valid_codes) {
    as.integer(final_response)
  } else {
    98L # Return 98 for invalid responses
  }

  # Return a tibble containing the original occupation description and classification result
  tibble(
    description = description,
    .pred = classification,
    .truth = policy,
    model = str_glue("{model}_cot"),
    reasoning = reasoning_response,
    final_response = final_response
  )
}

Let’s run this function with GPT 4o:

results_cot <- grid |>
  filter(model == "gpt-4o", prompt_id == 1) |>
  select(-prompt, -prompt_id) |>
  pmap(classify_policy_cot, api_function = openai, stream = FALSE) |>
  list_rbind()

Your turn: Evaluate the performance of the chain-of-thought method compared to the earlier classifiers. How does it perform?

results_cot |>
  mutate(prompt_id = 4) |>
  bind_rows(grid_results) |>
  mutate(across(
    .cols = c(.pred, .truth),
    .fns = \(x) factor(x, levels = cap_codes_api$policy, labels = cap_codes_api$policy_lab)
  )) |>
  mutate(prompt_id = factor(prompt_id,
    labels = c(
      "Detailed", "No guessing",
      "Shortest instructions", "Chain-of-thought"
    )
  )) |>
  group_by(prompt_id, model) |>
  accuracy(truth = .truth, estimate = .pred) |>
  ggplot(mapping = aes(x = prompt_id, y = .estimate, fill = model)) +
  geom_col(position = "dodge") +
  scale_fill_discrete_qualitative() +
  scale_y_continuous(labels = label_percent()) +
  labs(
    title = "Accuracy by prompt and model",
    x = "Prompt",
    y = "Accuracy",
    fill = "Model"
  ) +
  theme(legend.position = "top")
results_cot |>
  mutate(across(
    .cols = c(.pred, .truth),
    .fns = \(x) factor(x, levels = cap_codes_api$policy, labels = cap_codes_api$policy_lab)
  )) |>
  conf_mat(truth = .truth, estimate = .pred) |>
  autoplot(type = "heatmap") +
  scale_fill_continuous_sequential() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

Add response here.

Acknowledgments

Additional resources