如何保存 Tidymodels Lightgbm 模型以供重复使用

2024-01-11

我有以下代码用于创建tidymodels工作流程与lightgbm模型。但是,当我尝试保存到.rds对象和预测

library(AmesHousing)
library(treesnip)
library(lightgbm)
library(tidymodels)
tidymodels_prefer()

### Model ###

# data
data <- make_ames() %>%
  janitor::clean_names()

data <- subset(data, select = c(sale_price, bedroom_abv_gr, bsmt_full_bath, bsmt_half_bath, enclosed_porch, fireplaces,
                                full_bath, half_bath, kitchen_abv_gr, garage_area, garage_cars, gr_liv_area, lot_area,
                                lot_frontage, year_built, year_remod_add, year_sold))

data$id <- c(1:nrow(data))

data <- data %>%
  mutate(id = as.character(id)) %>%
  select(id, everything())

# model specification

lgbm_model <- boost_tree(
  mtry = 7,
  trees = 347,
  min_n = 10,
  tree_depth = 12,
  learn_rate = 0.0106430579211173,
  loss_reduction = 0.000337948798058139,
) %>%
  set_mode("regression") %>%
  set_engine("lightgbm", objective = "regression")

# recipe and workflow

lgbm_recipe <- recipe(sale_price ~., data = data) %>%
  update_role(id, new_role = "ID") %>%
  step_corr(all_predictors(), threshold = 0.7) %>%
  prep()

lgbm_workflow <- workflow() %>% 
  add_recipe(lgbm_recipe) %>%
  add_model(lgbm_model)  
  
# fit workflow

fit_lgbm_workflow <- lgbm_workflow %>%
  fit(data = data)

# predict

data_predict <- subset(data, select = -c(sale_price))
predict(fit_lgbm_workflow, new_data = data_predict)


### CASE 1: Save the workflow with SaveRDS()

saveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")
new_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")

# Predict - error: Attempting to use a Booster which no longer exists

predict(new_lgbm_workflow, new_data = data_predict)



### CASE 2: Save the workflow and the fitted model separately

fitted_model <- (fit_lgbm_workflow %>% extract_fit_parsnip())$fit
saveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")
lightgbm::saveRDS.lgb.Booster(object = fitted_model, file = "lgbm_model.rds")


new_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")
new_lgbm_model <- lightgbm::readRDS.lgb.Booster(file = "lgbm_model.rds")
new_lgbm_workflow$fit$fit <- new_lgbm_model


# Predict - error: cannot predict on data of class ‘tbl_df’‘tbl’‘data.frame’

predict(new_lgbm_workflow, new_data = data_predict)

仅工作流程lightgbm模型好像有这个问题。对于其他类型的模型(随机森林、xgboost、glm 等),我可以使用以下命令保存拟合的工作流程saveRDS(), 阅读与readRDS(),并使用新数据进行预测就好了

对于情况 2,显然底层预测函数将更改为predict.lgb.Booster(),这需要一个matrix作为输入。但我的 id 变量有character格式,而 a 中的所有列matrix必须具有相同的格式

有没有办法保存整个workflow以供将来使用?


经过大量挖掘,我找到了解决方案已关闭的问题 https://github.com/tidymodels/bonsai/issues/44#issuecomment-1208463927.

library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.2.1
#> Warning: package 'broom' was built under R version 4.2.1
#> Warning: package 'scales' was built under R version 4.2.1
#> Warning: package 'infer' was built under R version 4.2.1
#> Warning: package 'modeldata' was built under R version 4.2.1
#> Warning: package 'parsnip' was built under R version 4.2.1
#> Warning: package 'rsample' was built under R version 4.2.1
#> Warning: package 'tibble' was built under R version 4.2.1
#> Warning: package 'workflows' was built under R version 4.2.1
#> Warning: package 'workflowsets' was built under R version 4.2.1
library(bonsai)
library(lightgbm)
#> Warning: package 'lightgbm' was built under R version 4.2.1
#> Loading required package: R6
#> 
#> Attaching package: 'lightgbm'
#> The following object is masked from 'package:dplyr':
#> 
#>     slice

# data

data <- modeldata::ames %>%
  janitor::clean_names()

data <- subset(data, select = c(sale_price, bedroom_abv_gr, bsmt_full_bath, bsmt_half_bath, enclosed_porch, fireplaces,
                                full_bath, half_bath, kitchen_abv_gr, garage_area, garage_cars, gr_liv_area, lot_area,
                                lot_frontage, year_built, year_remod_add, year_sold))

data$id <- c(1:nrow(data))

data <- data %>%
  mutate(id = as.character(id)) %>%
  select(id, everything())

# model specification

lgbm_model <- boost_tree(
  mtry = 7,
  trees = 347,
  min_n = 10,
  tree_depth = 12,
  learn_rate = 0.0106430579211173,
  loss_reduction = 0.000337948798058139,
) %>%
  set_mode("regression") %>%
  set_engine("lightgbm", objective = "regression")

# recipe and workflow

lgbm_recipe <- recipe(sale_price ~., data = data) %>%
  update_role(id, new_role = "ID") %>%
  step_corr(all_predictors(), threshold = 0.7)

lgbm_workflow <- workflow(preprocessor = lgbm_recipe,
                          spec = lgbm_model)

# fit workflow

fit_lgbm_workflow <- lgbm_workflow %>%
  fit(data = data)

# predict

data_predict <- subset(data, select = -c(sale_price))
predict(fit_lgbm_workflow, new_data = data_predict)
#> # A tibble: 2,930 × 1
#>      .pred
#>      <dbl>
#>  1 201911.
#>  2 124695.
#>  3 138983.
#>  4 221095.
#>  5 198972.
#>  6 188613.
#>  7 198730.
#>  8 170893.
#>  9 243899.
#> 10 196875.
#> # … with 2,920 more rows

# save the trained workflow and lgb.booster object separately

saveRDS(fit_lgbm_workflow, "lgbm_wflw.rds")
saveRDS.lgb.Booster(extract_fit_engine(fit_lgbm_workflow), "lgbm_booster.rds")

# load trained workflow and merge it with lgb.booster

new_lgbm_wflow <- readRDS("lgbm_wflw.rds")
new_lgbm_wflow$fit$fit$fit <- readRDS.lgb.Booster("lgbm_booster.rds")

predict(new_lgbm_wflow, data_predict)
#> # A tibble: 2,930 × 1
#>      .pred
#>      <dbl>
#>  1 201911.
#>  2 124695.
#>  3 138983.
#>  4 221095.
#>  5 198972.
#>  6 188613.
#>  7 198730.
#>  8 170893.
#>  9 243899.
#> 10 196875.
#> # … with 2,920 more rows

Created on 2022-09-07 with reprex v2.0.2 https://reprex.tidyverse.org

在上面的 reprex 中,我使用了适合的工作流程。如果您使用防风草对象进行拟合,请改用以下方法:


saveRDS(bonsai_fit, path1)
saveRDS.lgb.Booster(extract_fit_engine(bonsai_fit), path2)
bonsai_fit_read <- readRDS(path1)
bonsai_fit_engine_read <- readRDS.lgb.Booster(path2)
bonsai_fit_read$fit <- bonsai_fit_engine_read

参考这条评论 https://github.com/tidymodels/bonsai/issues/44#issuecomment-1206816030更多细节。

The 一线希望 https://github.com/tidymodels/bonsai/issues/44#issuecomment-1208489161 is:

Just want to add to this conversation that since December 2021, {lightgbm}'s development version has supported using readsRDS() / saveRDS() directly for {lightgbm} models

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何保存 Tidymodels Lightgbm 模型以供重复使用 的相关文章

随机推荐