AE 06: Predicting hotel price with boosting models

Suggested answers

Application exercise
Answers
Modified

September 19, 2024

Setup

# metrics
reg_metrics <- metric_set(mae, rsq)

# import data
data(hotel_rates)
set.seed(295)
hotel_rates <- hotel_rates |>
  sample_n(5000) |>
  arrange(arrival_date) |>
  select(-arrival_date) |>
  mutate(
    company = factor(as.character(company)),
    country = factor(as.character(country)),
    agent = factor(as.character(agent))
  )

# split into training/test sets
set.seed(421)
hotel_split <- initial_split(hotel_rates, strata = avg_price_per_room)

hotel_train <- training(hotel_split)
hotel_test <- testing(hotel_split)

# 10-fold CV
set.seed(531)
hotel_rs <- vfold_cv(hotel_train, strata = avg_price_per_room)

# feature engineering recipe
hash_rec <- recipe(avg_price_per_room ~ ., data = hotel_train) |>
  step_YeoJohnson(lead_time) |>
  # Defaults to 32 signed indicator columns
  step_dummy_hash(agent) |>
  step_dummy_hash(company) |>
  # Regular indicators for the others
  step_dummy(all_nominal_predictors()) |>
  step_zv(all_predictors())

Boosting model specification

hash_rec <- recipe(avg_price_per_room ~ ., data = hotel_train) |>
  step_YeoJohnson(lead_time) |>
  step_dummy_hash(agent, num_terms = tune("agent hash")) |>
  step_dummy_hash(company, num_terms = tune("company hash")) |>
  step_zv(all_predictors())
lgbm_spec <- boost_tree(trees = tune(), learn_rate = tune()) |>
  set_mode("regression") |>
  set_engine("lightgbm")

lgbm_wflow <- workflow(hash_rec, lgbm_spec)

Create a grid

Demonstration: Create a space filling grid for the boosting workflow.

set.seed(12)
grid <- lgbm_wflow |>
  extract_parameter_set_dials() |>
  grid_space_filling(size = 25)

grid
# A tibble: 25 × 4
   trees learn_rate `agent hash` `company hash`
   <int>      <dbl>        <int>          <int>
 1     1   7.50e- 6          574            574
 2    84   1.78e- 5         2048           2298
 3   167   5.62e-10         1824            912
 4   250   4.22e- 5         3250            512
 5   334   1.78e- 8          512           2896
 6   417   1.33e- 3          322           1625
 7   500   1   e- 1         1448           1149
 8   584   1   e- 7         1290            256
 9   667   2.37e-10          456            724
10   750   1.78e- 2          645            322
# ℹ 15 more rows

Your turn: Try creating a regular grid for the boosting workflow.

set.seed(12)
grid <- lgbm_wflow |>
  extract_parameter_set_dials() |>
  grid_regular(levels = 4)

grid
# A tibble: 256 × 4
   trees   learn_rate `agent hash` `company hash`
   <int>        <dbl>        <int>          <int>
 1     1 0.0000000001          256            256
 2   667 0.0000000001          256            256
 3  1333 0.0000000001          256            256
 4  2000 0.0000000001          256            256
 5     1 0.0000001             256            256
 6   667 0.0000001             256            256
 7  1333 0.0000001             256            256
 8  2000 0.0000001             256            256
 9     1 0.0001                256            256
10   667 0.0001                256            256
# ℹ 246 more rows

Your turn: What advantage would a regular grid have?

Add response here. Uniform spacing of parameter values. All combinations of parameters are explored.

Update parameter ranges

lgbm_param <- lgbm_wflow |>
  extract_parameter_set_dials() |>
  update(
    trees = trees(c(1L, 100L)),
    learn_rate = learn_rate(c(-5, -1))
  )

set.seed(712)
grid <- lgbm_param |>
  grid_space_filling(size = 25)

grid
# A tibble: 25 × 4
   trees learn_rate `agent hash` `company hash`
   <int>      <dbl>        <int>          <int>
 1     1  0.00147            574            574
 2     5  0.00215           2048           2298
 3     9  0.0000215         1824            912
 4    13  0.00316           3250            512
 5    17  0.0001             512           2896
 6    21  0.0147             322           1625
 7    25  0.1               1448           1149
 8    29  0.000215          1290            256
 9    34  0.0000147          456            724
10    38  0.0464             645            322
# ℹ 15 more rows

Choose a parameter combination

show_best(lgbm_res, metric = "rsq")
# A tibble: 5 × 11
  trees min_n learn_rate `agent hash` `company hash` .metric .estimator  mean
  <int> <int>      <dbl>        <int>          <int> <chr>   <chr>      <dbl>
1  1890    10    0.0159           115            174 rsq     standard   0.950
2   774    12    0.0441            27             95 rsq     standard   0.949
3  1638    36    0.0409            15            120 rsq     standard   0.948
4   963    23    0.00556          157             13 rsq     standard   0.937
5   590     5    0.00320           85             73 rsq     standard   0.911
# ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>
show_best(lgbm_res, metric = "mae")
# A tibble: 5 × 11
  trees min_n learn_rate `agent hash` `company hash` .metric .estimator  mean
  <int> <int>      <dbl>        <int>          <int> <chr>   <chr>      <dbl>
1  1890    10    0.0159           115            174 mae     standard    9.80
2   774    12    0.0441            27             95 mae     standard    9.86
3  1638    36    0.0409            15            120 mae     standard   10.0 
4   963    23    0.00556          157             13 mae     standard   11.4 
5   590     5    0.00320           85             73 mae     standard   17.4 
# ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>
lgbm_best <- select_best(lgbm_res, metric = "mae")
lgbm_best
# A tibble: 1 × 6
  trees min_n learn_rate `agent hash` `company hash` .config              
  <int> <int>      <dbl>        <int>          <int> <chr>                
1  1890    10     0.0159          115            174 Preprocessor12_Model1

Checking calibration

lgbm_res |>
  collect_predictions(
    parameters = lgbm_best
  ) |>
  cal_plot_regression(
    truth = avg_price_per_room,
    estimate = .pred
  )

Tune on stop_iter

Your turn: Try early stopping: Set trees = 2000 and tune the stop_iter parameter!

Note that you will need to regenerate lgbm_param with your new workflow!

lgbm_spec <- boost_tree(
  trees = 2000, learn_rate = tune(),
  min_n = tune(), stop_iter = tune()
) |>
  set_mode("regression") |>
  set_engine("lightgbm")

lgbm_wflow <- workflow(hash_rec, lgbm_spec)

# Update the feature hash ranges (log-2 units)
lgbm_param <- lgbm_wflow |>
  extract_parameter_set_dials() |>
  update(
    `agent hash` = num_hash(c(3, 8)),
    `company hash` = num_hash(c(3, 8))
  )

# tune the model
lgbm_res <- lgbm_wflow |>
  tune_grid(
    resamples = hotel_rs,
    grid = 25,
    # The options below are not required by default
    param_info = lgbm_param,
    control = ctrl,
    metrics = reg_metrics
  )
autoplot(lgbm_res)

show_best(lgbm_res, metric = "mae")
# A tibble: 5 × 11
  min_n learn_rate stop_iter `agent hash` `company hash` .metric .estimator
  <int>      <dbl>     <int>        <int>          <int> <chr>   <chr>     
1     9    0.0712         12           61             28 mae     standard  
2    12    0.0180          6           13              9 mae     standard  
3    30    0.0409         13           37             44 mae     standard  
4    24    0.00495         4           92             28 mae     standard  
5    33    0.00200        14           23             11 mae     standard  
# ℹ 4 more variables: mean <dbl>, n <int>, std_err <dbl>, .config <chr>

Acknowledgments

sessioninfo::session_info()
─ Session info ───────────────────────────────────────────────────────────────
 setting  value
 version  R version 4.4.1 (2024-06-14)
 os       macOS Sonoma 14.6.1
 system   aarch64, darwin20
 ui       X11
 language (EN)
 collate  en_US.UTF-8
 ctype    en_US.UTF-8
 tz       America/New_York
 date     2024-09-20
 pandoc   3.3 @ /usr/local/bin/ (via rmarkdown)

─ Packages ───────────────────────────────────────────────────────────────────
 package      * version    date (UTC) lib source
 backports      1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
 bonsai       * 0.2.1      2022-11-29 [1] CRAN (R 4.3.0)
 broom        * 1.0.6      2024-05-17 [1] CRAN (R 4.4.0)
 class          7.3-22     2023-05-03 [1] CRAN (R 4.4.0)
 cli            3.6.3      2024-06-21 [1] CRAN (R 4.4.0)
 codetools      0.2-20     2024-03-31 [1] CRAN (R 4.4.1)
 colorspace     2.1-1      2024-07-26 [1] CRAN (R 4.4.0)
 crayon         1.5.3      2024-06-20 [1] CRAN (R 4.4.0)
 data.table     1.15.4     2024-03-30 [1] CRAN (R 4.3.1)
 dials        * 1.3.0      2024-07-30 [1] CRAN (R 4.4.0)
 DiceDesign     1.10       2023-12-07 [1] CRAN (R 4.3.1)
 digest         0.6.35     2024-03-11 [1] CRAN (R 4.3.1)
 doFuture       1.0.1      2023-12-20 [1] CRAN (R 4.3.1)
 dplyr        * 1.1.4      2023-11-17 [1] CRAN (R 4.3.1)
 evaluate       0.24.0     2024-06-10 [1] CRAN (R 4.4.0)
 fansi          1.0.6      2023-12-08 [1] CRAN (R 4.3.1)
 farver         2.1.2      2024-05-13 [1] CRAN (R 4.3.3)
 fastmap        1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
 float          0.3-2      2023-12-10 [1] CRAN (R 4.3.1)
 foreach        1.5.2      2022-02-02 [1] CRAN (R 4.3.0)
 furrr          0.3.1      2022-08-15 [1] CRAN (R 4.3.0)
 future       * 1.33.2     2024-03-26 [1] CRAN (R 4.3.1)
 future.apply   1.11.2     2024-03-28 [1] CRAN (R 4.3.1)
 generics       0.1.3      2022-07-05 [1] CRAN (R 4.3.0)
 ggplot2      * 3.5.1      2024-04-23 [1] CRAN (R 4.3.1)
 globals        0.16.3     2024-03-08 [1] CRAN (R 4.3.1)
 glue           1.7.0      2024-01-09 [1] CRAN (R 4.3.1)
 gower          1.0.1      2022-12-22 [1] CRAN (R 4.3.0)
 GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.3.0)
 gtable         0.3.5      2024-04-22 [1] CRAN (R 4.3.1)
 hardhat        1.4.0      2024-06-02 [1] CRAN (R 4.4.0)
 here           1.0.1      2020-12-13 [1] CRAN (R 4.3.0)
 htmltools      0.5.8.1    2024-04-04 [1] CRAN (R 4.3.1)
 htmlwidgets    1.6.4      2023-12-06 [1] CRAN (R 4.3.1)
 infer        * 1.0.7      2024-03-25 [1] CRAN (R 4.3.1)
 ipred          0.9-14     2023-03-09 [1] CRAN (R 4.3.0)
 iterators      1.0.14     2022-02-05 [1] CRAN (R 4.3.0)
 jsonlite       1.8.8      2023-12-04 [1] CRAN (R 4.3.1)
 knitr          1.47       2024-05-29 [1] CRAN (R 4.4.0)
 labeling       0.4.3      2023-08-29 [1] CRAN (R 4.3.0)
 lattice        0.22-6     2024-03-20 [1] CRAN (R 4.4.0)
 lava           1.8.0      2024-03-05 [1] CRAN (R 4.3.1)
 lgr            0.4.4      2022-09-05 [1] CRAN (R 4.3.0)
 lhs            1.1.6      2022-12-17 [1] CRAN (R 4.3.0)
 lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.3.1)
 lightgbm       4.4.0      2024-06-15 [1] CRAN (R 4.4.0)
 listenv        0.9.1      2024-01-29 [1] CRAN (R 4.3.1)
 lubridate      1.9.3      2023-09-27 [1] CRAN (R 4.3.1)
 magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.3.0)
 MASS           7.3-61     2024-06-13 [1] CRAN (R 4.4.0)
 Matrix         1.7-0      2024-03-22 [1] CRAN (R 4.4.0)
 mgcv           1.9-1      2023-12-21 [1] CRAN (R 4.4.0)
 mlapi          0.1.1      2022-04-24 [1] CRAN (R 4.3.0)
 modeldata    * 1.4.0      2024-06-19 [1] CRAN (R 4.4.0)
 munsell        0.5.1      2024-04-01 [1] CRAN (R 4.3.1)
 nlme           3.1-165    2024-06-06 [1] CRAN (R 4.4.0)
 nnet           7.3-19     2023-05-03 [1] CRAN (R 4.4.0)
 parallelly     1.37.1     2024-02-29 [1] CRAN (R 4.3.1)
 parsnip      * 1.2.1      2024-03-22 [1] CRAN (R 4.3.1)
 pillar         1.9.0      2023-03-22 [1] CRAN (R 4.3.0)
 pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.3.0)
 probably     * 1.0.3      2024-02-23 [1] CRAN (R 4.3.1)
 prodlim        2023.08.28 2023-08-28 [1] CRAN (R 4.3.0)
 purrr        * 1.0.2      2023-08-10 [1] CRAN (R 4.3.0)
 R6             2.5.1      2021-08-19 [1] CRAN (R 4.3.0)
 Rcpp           1.0.12     2024-01-09 [1] CRAN (R 4.3.1)
 recipes      * 1.0.10     2024-02-18 [1] CRAN (R 4.3.1)
 RhpcBLASctl    0.23-42    2023-02-11 [1] CRAN (R 4.3.0)
 rlang          1.1.4      2024-06-04 [1] CRAN (R 4.3.3)
 rmarkdown      2.27       2024-05-17 [1] CRAN (R 4.4.0)
 rpart          4.1.23     2023-12-05 [1] CRAN (R 4.4.0)
 rprojroot      2.0.4      2023-11-05 [1] CRAN (R 4.3.1)
 rsample      * 1.2.1      2024-03-25 [1] CRAN (R 4.3.1)
 rsparse        0.5.1      2022-09-11 [1] CRAN (R 4.3.0)
 rstudioapi     0.16.0     2024-03-24 [1] CRAN (R 4.3.1)
 scales       * 1.3.0      2023-11-28 [1] CRAN (R 4.4.0)
 sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.3.0)
 sfd            0.1.0      2024-01-08 [1] CRAN (R 4.4.0)
 survival       3.7-0      2024-06-05 [1] CRAN (R 4.4.0)
 text2vec       0.6.4      2023-11-09 [1] CRAN (R 4.3.1)
 textrecipes  * 1.0.6      2023-11-15 [1] CRAN (R 4.3.1)
 tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.3.0)
 tidymodels   * 1.2.0      2024-03-25 [1] CRAN (R 4.3.1)
 tidyr        * 1.3.1      2024-01-24 [1] CRAN (R 4.3.1)
 tidyselect     1.2.1      2024-03-11 [1] CRAN (R 4.3.1)
 timechange     0.3.0      2024-01-18 [1] CRAN (R 4.3.1)
 timeDate       4032.109   2023-12-14 [1] CRAN (R 4.3.1)
 tune         * 1.2.1      2024-04-18 [1] CRAN (R 4.3.1)
 utf8           1.2.4      2023-10-22 [1] CRAN (R 4.3.1)
 vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.3.1)
 withr          3.0.1      2024-07-31 [1] CRAN (R 4.4.0)
 workflows    * 1.1.4      2024-02-19 [1] CRAN (R 4.3.1)
 workflowsets * 1.1.0      2024-03-21 [1] CRAN (R 4.3.1)
 xfun           0.45       2024-06-16 [1] CRAN (R 4.4.0)
 yaml           2.3.8      2023-12-11 [1] CRAN (R 4.3.1)
 yardstick    * 1.3.1      2024-03-21 [1] CRAN (R 4.3.1)

 [1] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library

──────────────────────────────────────────────────────────────────────────────