AE 06: Predicting hotel price with boosting models
Suggested answers
Application exercise
Answers
Setup
# metrics
reg_metrics <- metric_set(mae, rsq)
# import data
data(hotel_rates)
set.seed(295)
hotel_rates <- hotel_rates |>
sample_n(5000) |>
arrange(arrival_date) |>
select(-arrival_date) |>
mutate(
company = factor(as.character(company)),
country = factor(as.character(country)),
agent = factor(as.character(agent))
)
# split into training/test sets
set.seed(421)
hotel_split <- initial_split(hotel_rates, strata = avg_price_per_room)
hotel_train <- training(hotel_split)
hotel_test <- testing(hotel_split)
# 10-fold CV
set.seed(531)
hotel_rs <- vfold_cv(hotel_train, strata = avg_price_per_room)
# feature engineering recipe
hash_rec <- recipe(avg_price_per_room ~ ., data = hotel_train) |>
step_YeoJohnson(lead_time) |>
# Defaults to 32 signed indicator columns
step_dummy_hash(agent) |>
step_dummy_hash(company) |>
# Regular indicators for the others
step_dummy(all_nominal_predictors()) |>
step_zv(all_predictors())
Boosting model specification
hash_rec <- recipe(avg_price_per_room ~ ., data = hotel_train) |>
step_YeoJohnson(lead_time) |>
step_dummy_hash(agent, num_terms = tune("agent hash")) |>
step_dummy_hash(company, num_terms = tune("company hash")) |>
step_zv(all_predictors())
lgbm_spec <- boost_tree(trees = tune(), learn_rate = tune()) |>
set_mode("regression") |>
set_engine("lightgbm")
lgbm_wflow <- workflow(hash_rec, lgbm_spec)
Create a grid
Demonstration: Create a space filling grid for the boosting workflow.
set.seed(12)
grid <- lgbm_wflow |>
extract_parameter_set_dials() |>
grid_space_filling(size = 25)
grid
# A tibble: 25 × 4
trees learn_rate `agent hash` `company hash`
<int> <dbl> <int> <int>
1 1 7.50e- 6 574 574
2 84 1.78e- 5 2048 2298
3 167 5.62e-10 1824 912
4 250 4.22e- 5 3250 512
5 334 1.78e- 8 512 2896
6 417 1.33e- 3 322 1625
7 500 1 e- 1 1448 1149
8 584 1 e- 7 1290 256
9 667 2.37e-10 456 724
10 750 1.78e- 2 645 322
# ℹ 15 more rows
Your turn: Try creating a regular grid for the boosting workflow.
set.seed(12)
grid <- lgbm_wflow |>
extract_parameter_set_dials() |>
grid_regular(levels = 4)
grid
# A tibble: 256 × 4
trees learn_rate `agent hash` `company hash`
<int> <dbl> <int> <int>
1 1 0.0000000001 256 256
2 667 0.0000000001 256 256
3 1333 0.0000000001 256 256
4 2000 0.0000000001 256 256
5 1 0.0000001 256 256
6 667 0.0000001 256 256
7 1333 0.0000001 256 256
8 2000 0.0000001 256 256
9 1 0.0001 256 256
10 667 0.0001 256 256
# ℹ 246 more rows
Your turn: What advantage would a regular grid have?
Add response here. Uniform spacing of parameter values. All combinations of parameters are explored.
Update parameter ranges
lgbm_param <- lgbm_wflow |>
extract_parameter_set_dials() |>
update(
trees = trees(c(1L, 100L)),
learn_rate = learn_rate(c(-5, -1))
)
set.seed(712)
grid <- lgbm_param |>
grid_space_filling(size = 25)
grid
# A tibble: 25 × 4
trees learn_rate `agent hash` `company hash`
<int> <dbl> <int> <int>
1 1 0.00147 574 574
2 5 0.00215 2048 2298
3 9 0.0000215 1824 912
4 13 0.00316 3250 512
5 17 0.0001 512 2896
6 21 0.0147 322 1625
7 25 0.1 1448 1149
8 29 0.000215 1290 256
9 34 0.0000147 456 724
10 38 0.0464 645 322
# ℹ 15 more rows
Grid search
Let’s take our previous model and tune more parameters:
lgbm_spec <- boost_tree(trees = tune(), learn_rate = tune(), min_n = tune()) |>
set_mode("regression") |>
set_engine("lightgbm")
lgbm_wflow <- workflow(hash_rec, lgbm_spec)
# Update the feature hash ranges (log-2 units)
lgbm_param <- lgbm_wflow |>
extract_parameter_set_dials() |>
update(
`agent hash` = num_hash(c(3, 8)),
`company hash` = num_hash(c(3, 8))
)
Run the grid search:
set.seed(9)
ctrl <- control_grid(save_pred = TRUE, verbose = FALSE)
lgbm_res <- lgbm_wflow |>
tune_grid(
resamples = hotel_rs,
grid = 25,
# The options below are not required by default
param_info = lgbm_param,
control = ctrl,
metrics = reg_metrics
)
Inspect results:
autoplot(lgbm_res)
collect_metrics(lgbm_res)
# A tibble: 50 × 11
trees min_n learn_rate `agent hash` `company hash` .metric .estimator mean
<int> <int> <dbl> <int> <int> <chr> <chr> <dbl>
1 298 19 4.15e- 9 222 36 mae standard 53.5
2 298 19 4.15e- 9 222 36 rsq standard 0.816
3 1394 5 5.82e- 6 28 21 mae standard 53.2
4 1394 5 5.82e- 6 28 21 rsq standard 0.817
5 774 12 4.41e- 2 27 95 mae standard 9.86
6 774 12 4.41e- 2 27 95 rsq standard 0.949
7 1342 7 6.84e-10 71 17 mae standard 53.5
8 1342 7 6.84e-10 71 17 rsq standard 0.816
9 669 39 8.62e- 7 141 145 mae standard 53.5
10 669 39 8.62e- 7 141 145 rsq standard 0.817
# ℹ 40 more rows
# ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>
collect_metrics(lgbm_res, summarize = FALSE)
# A tibble: 500 × 10
id trees min_n learn_rate `agent hash` `company hash` .metric .estimator
<chr> <int> <int> <dbl> <int> <int> <chr> <chr>
1 Fold01 298 19 4.15e-9 222 36 mae standard
2 Fold01 298 19 4.15e-9 222 36 rsq standard
3 Fold02 298 19 4.15e-9 222 36 mae standard
4 Fold02 298 19 4.15e-9 222 36 rsq standard
5 Fold03 298 19 4.15e-9 222 36 mae standard
6 Fold03 298 19 4.15e-9 222 36 rsq standard
7 Fold04 298 19 4.15e-9 222 36 mae standard
8 Fold04 298 19 4.15e-9 222 36 rsq standard
9 Fold05 298 19 4.15e-9 222 36 mae standard
10 Fold05 298 19 4.15e-9 222 36 rsq standard
# ℹ 490 more rows
# ℹ 2 more variables: .estimate <dbl>, .config <chr>
Choose a parameter combination
show_best(lgbm_res, metric = "rsq")
# A tibble: 5 × 11
trees min_n learn_rate `agent hash` `company hash` .metric .estimator mean
<int> <int> <dbl> <int> <int> <chr> <chr> <dbl>
1 1890 10 0.0159 115 174 rsq standard 0.950
2 774 12 0.0441 27 95 rsq standard 0.949
3 1638 36 0.0409 15 120 rsq standard 0.948
4 963 23 0.00556 157 13 rsq standard 0.937
5 590 5 0.00320 85 73 rsq standard 0.911
# ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>
show_best(lgbm_res, metric = "mae")
# A tibble: 5 × 11
trees min_n learn_rate `agent hash` `company hash` .metric .estimator mean
<int> <int> <dbl> <int> <int> <chr> <chr> <dbl>
1 1890 10 0.0159 115 174 mae standard 9.80
2 774 12 0.0441 27 95 mae standard 9.86
3 1638 36 0.0409 15 120 mae standard 10.0
4 963 23 0.00556 157 13 mae standard 11.4
5 590 5 0.00320 85 73 mae standard 17.4
# ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>
lgbm_best <- select_best(lgbm_res, metric = "mae")
lgbm_best
# A tibble: 1 × 6
trees min_n learn_rate `agent hash` `company hash` .config
<int> <int> <dbl> <int> <int> <chr>
1 1890 10 0.0159 115 174 Preprocessor12_Model1
Checking calibration
lgbm_res |>
collect_predictions(
parameters = lgbm_best
) |>
cal_plot_regression(
truth = avg_price_per_room,
estimate = .pred
)
Tune on stop_iter
Your turn: Try early stopping: Set trees = 2000
and tune the stop_iter
parameter!
Note that you will need to regenerate lgbm_param
with your new workflow!
lgbm_spec <- boost_tree(
trees = 2000, learn_rate = tune(),
min_n = tune(), stop_iter = tune()
) |>
set_mode("regression") |>
set_engine("lightgbm")
lgbm_wflow <- workflow(hash_rec, lgbm_spec)
# Update the feature hash ranges (log-2 units)
lgbm_param <- lgbm_wflow |>
extract_parameter_set_dials() |>
update(
`agent hash` = num_hash(c(3, 8)),
`company hash` = num_hash(c(3, 8))
)
# tune the model
lgbm_res <- lgbm_wflow |>
tune_grid(
resamples = hotel_rs,
grid = 25,
# The options below are not required by default
param_info = lgbm_param,
control = ctrl,
metrics = reg_metrics
)
autoplot(lgbm_res)
show_best(lgbm_res, metric = "mae")
# A tibble: 5 × 11
min_n learn_rate stop_iter `agent hash` `company hash` .metric .estimator
<int> <dbl> <int> <int> <int> <chr> <chr>
1 9 0.0712 12 61 28 mae standard
2 12 0.0180 6 13 9 mae standard
3 30 0.0409 13 37 44 mae standard
4 24 0.00495 4 92 28 mae standard
5 33 0.00200 14 23 11 mae standard
# ℹ 4 more variables: mean <dbl>, n <int>, std_err <dbl>, .config <chr>
Acknowledgments
- Materials derived in part from Machine learning with {tidymodels} and licensed under a Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA) License.
Session information
sessioninfo::session_info()
─ Session info ───────────────────────────────────────────────────────────────
setting value
version R version 4.4.1 (2024-06-14)
os macOS Sonoma 14.6.1
system aarch64, darwin20
ui X11
language (EN)
collate en_US.UTF-8
ctype en_US.UTF-8
tz America/New_York
date 2024-09-20
pandoc 3.3 @ /usr/local/bin/ (via rmarkdown)
─ Packages ───────────────────────────────────────────────────────────────────
package * version date (UTC) lib source
backports 1.5.0 2024-05-23 [1] CRAN (R 4.4.0)
bonsai * 0.2.1 2022-11-29 [1] CRAN (R 4.3.0)
broom * 1.0.6 2024-05-17 [1] CRAN (R 4.4.0)
class 7.3-22 2023-05-03 [1] CRAN (R 4.4.0)
cli 3.6.3 2024-06-21 [1] CRAN (R 4.4.0)
codetools 0.2-20 2024-03-31 [1] CRAN (R 4.4.1)
colorspace 2.1-1 2024-07-26 [1] CRAN (R 4.4.0)
crayon 1.5.3 2024-06-20 [1] CRAN (R 4.4.0)
data.table 1.15.4 2024-03-30 [1] CRAN (R 4.3.1)
dials * 1.3.0 2024-07-30 [1] CRAN (R 4.4.0)
DiceDesign 1.10 2023-12-07 [1] CRAN (R 4.3.1)
digest 0.6.35 2024-03-11 [1] CRAN (R 4.3.1)
doFuture 1.0.1 2023-12-20 [1] CRAN (R 4.3.1)
dplyr * 1.1.4 2023-11-17 [1] CRAN (R 4.3.1)
evaluate 0.24.0 2024-06-10 [1] CRAN (R 4.4.0)
fansi 1.0.6 2023-12-08 [1] CRAN (R 4.3.1)
farver 2.1.2 2024-05-13 [1] CRAN (R 4.3.3)
fastmap 1.2.0 2024-05-15 [1] CRAN (R 4.4.0)
float 0.3-2 2023-12-10 [1] CRAN (R 4.3.1)
foreach 1.5.2 2022-02-02 [1] CRAN (R 4.3.0)
furrr 0.3.1 2022-08-15 [1] CRAN (R 4.3.0)
future * 1.33.2 2024-03-26 [1] CRAN (R 4.3.1)
future.apply 1.11.2 2024-03-28 [1] CRAN (R 4.3.1)
generics 0.1.3 2022-07-05 [1] CRAN (R 4.3.0)
ggplot2 * 3.5.1 2024-04-23 [1] CRAN (R 4.3.1)
globals 0.16.3 2024-03-08 [1] CRAN (R 4.3.1)
glue 1.7.0 2024-01-09 [1] CRAN (R 4.3.1)
gower 1.0.1 2022-12-22 [1] CRAN (R 4.3.0)
GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.3.0)
gtable 0.3.5 2024-04-22 [1] CRAN (R 4.3.1)
hardhat 1.4.0 2024-06-02 [1] CRAN (R 4.4.0)
here 1.0.1 2020-12-13 [1] CRAN (R 4.3.0)
htmltools 0.5.8.1 2024-04-04 [1] CRAN (R 4.3.1)
htmlwidgets 1.6.4 2023-12-06 [1] CRAN (R 4.3.1)
infer * 1.0.7 2024-03-25 [1] CRAN (R 4.3.1)
ipred 0.9-14 2023-03-09 [1] CRAN (R 4.3.0)
iterators 1.0.14 2022-02-05 [1] CRAN (R 4.3.0)
jsonlite 1.8.8 2023-12-04 [1] CRAN (R 4.3.1)
knitr 1.47 2024-05-29 [1] CRAN (R 4.4.0)
labeling 0.4.3 2023-08-29 [1] CRAN (R 4.3.0)
lattice 0.22-6 2024-03-20 [1] CRAN (R 4.4.0)
lava 1.8.0 2024-03-05 [1] CRAN (R 4.3.1)
lgr 0.4.4 2022-09-05 [1] CRAN (R 4.3.0)
lhs 1.1.6 2022-12-17 [1] CRAN (R 4.3.0)
lifecycle 1.0.4 2023-11-07 [1] CRAN (R 4.3.1)
lightgbm 4.4.0 2024-06-15 [1] CRAN (R 4.4.0)
listenv 0.9.1 2024-01-29 [1] CRAN (R 4.3.1)
lubridate 1.9.3 2023-09-27 [1] CRAN (R 4.3.1)
magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.0)
MASS 7.3-61 2024-06-13 [1] CRAN (R 4.4.0)
Matrix 1.7-0 2024-03-22 [1] CRAN (R 4.4.0)
mgcv 1.9-1 2023-12-21 [1] CRAN (R 4.4.0)
mlapi 0.1.1 2022-04-24 [1] CRAN (R 4.3.0)
modeldata * 1.4.0 2024-06-19 [1] CRAN (R 4.4.0)
munsell 0.5.1 2024-04-01 [1] CRAN (R 4.3.1)
nlme 3.1-165 2024-06-06 [1] CRAN (R 4.4.0)
nnet 7.3-19 2023-05-03 [1] CRAN (R 4.4.0)
parallelly 1.37.1 2024-02-29 [1] CRAN (R 4.3.1)
parsnip * 1.2.1 2024-03-22 [1] CRAN (R 4.3.1)
pillar 1.9.0 2023-03-22 [1] CRAN (R 4.3.0)
pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.3.0)
probably * 1.0.3 2024-02-23 [1] CRAN (R 4.3.1)
prodlim 2023.08.28 2023-08-28 [1] CRAN (R 4.3.0)
purrr * 1.0.2 2023-08-10 [1] CRAN (R 4.3.0)
R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.0)
Rcpp 1.0.12 2024-01-09 [1] CRAN (R 4.3.1)
recipes * 1.0.10 2024-02-18 [1] CRAN (R 4.3.1)
RhpcBLASctl 0.23-42 2023-02-11 [1] CRAN (R 4.3.0)
rlang 1.1.4 2024-06-04 [1] CRAN (R 4.3.3)
rmarkdown 2.27 2024-05-17 [1] CRAN (R 4.4.0)
rpart 4.1.23 2023-12-05 [1] CRAN (R 4.4.0)
rprojroot 2.0.4 2023-11-05 [1] CRAN (R 4.3.1)
rsample * 1.2.1 2024-03-25 [1] CRAN (R 4.3.1)
rsparse 0.5.1 2022-09-11 [1] CRAN (R 4.3.0)
rstudioapi 0.16.0 2024-03-24 [1] CRAN (R 4.3.1)
scales * 1.3.0 2023-11-28 [1] CRAN (R 4.4.0)
sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.3.0)
sfd 0.1.0 2024-01-08 [1] CRAN (R 4.4.0)
survival 3.7-0 2024-06-05 [1] CRAN (R 4.4.0)
text2vec 0.6.4 2023-11-09 [1] CRAN (R 4.3.1)
textrecipes * 1.0.6 2023-11-15 [1] CRAN (R 4.3.1)
tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.3.0)
tidymodels * 1.2.0 2024-03-25 [1] CRAN (R 4.3.1)
tidyr * 1.3.1 2024-01-24 [1] CRAN (R 4.3.1)
tidyselect 1.2.1 2024-03-11 [1] CRAN (R 4.3.1)
timechange 0.3.0 2024-01-18 [1] CRAN (R 4.3.1)
timeDate 4032.109 2023-12-14 [1] CRAN (R 4.3.1)
tune * 1.2.1 2024-04-18 [1] CRAN (R 4.3.1)
utf8 1.2.4 2023-10-22 [1] CRAN (R 4.3.1)
vctrs 0.6.5 2023-12-01 [1] CRAN (R 4.3.1)
withr 3.0.1 2024-07-31 [1] CRAN (R 4.4.0)
workflows * 1.1.4 2024-02-19 [1] CRAN (R 4.3.1)
workflowsets * 1.1.0 2024-03-21 [1] CRAN (R 4.3.1)
xfun 0.45 2024-06-16 [1] CRAN (R 4.4.0)
yaml 2.3.8 2023-12-11 [1] CRAN (R 4.3.1)
yardstick * 1.3.1 2024-03-21 [1] CRAN (R 4.3.1)
[1] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
──────────────────────────────────────────────────────────────────────────────