经过大量挖掘,我找到了解决方案已关闭的问题 https://github.com/tidymodels/bonsai/issues/44#issuecomment-1208463927.
library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.2.1
#> Warning: package 'broom' was built under R version 4.2.1
#> Warning: package 'scales' was built under R version 4.2.1
#> Warning: package 'infer' was built under R version 4.2.1
#> Warning: package 'modeldata' was built under R version 4.2.1
#> Warning: package 'parsnip' was built under R version 4.2.1
#> Warning: package 'rsample' was built under R version 4.2.1
#> Warning: package 'tibble' was built under R version 4.2.1
#> Warning: package 'workflows' was built under R version 4.2.1
#> Warning: package 'workflowsets' was built under R version 4.2.1
library(bonsai)
library(lightgbm)
#> Warning: package 'lightgbm' was built under R version 4.2.1
#> Loading required package: R6
#>
#> Attaching package: 'lightgbm'
#> The following object is masked from 'package:dplyr':
#>
#> slice
# data
data <- modeldata::ames %>%
janitor::clean_names()
data <- subset(data, select = c(sale_price, bedroom_abv_gr, bsmt_full_bath, bsmt_half_bath, enclosed_porch, fireplaces,
full_bath, half_bath, kitchen_abv_gr, garage_area, garage_cars, gr_liv_area, lot_area,
lot_frontage, year_built, year_remod_add, year_sold))
data$id <- c(1:nrow(data))
data <- data %>%
mutate(id = as.character(id)) %>%
select(id, everything())
# model specification
lgbm_model <- boost_tree(
mtry = 7,
trees = 347,
min_n = 10,
tree_depth = 12,
learn_rate = 0.0106430579211173,
loss_reduction = 0.000337948798058139,
) %>%
set_mode("regression") %>%
set_engine("lightgbm", objective = "regression")
# recipe and workflow
lgbm_recipe <- recipe(sale_price ~., data = data) %>%
update_role(id, new_role = "ID") %>%
step_corr(all_predictors(), threshold = 0.7)
lgbm_workflow <- workflow(preprocessor = lgbm_recipe,
spec = lgbm_model)
# fit workflow
fit_lgbm_workflow <- lgbm_workflow %>%
fit(data = data)
# predict
data_predict <- subset(data, select = -c(sale_price))
predict(fit_lgbm_workflow, new_data = data_predict)
#> # A tibble: 2,930 × 1
#> .pred
#> <dbl>
#> 1 201911.
#> 2 124695.
#> 3 138983.
#> 4 221095.
#> 5 198972.
#> 6 188613.
#> 7 198730.
#> 8 170893.
#> 9 243899.
#> 10 196875.
#> # … with 2,920 more rows
# save the trained workflow and lgb.booster object separately
saveRDS(fit_lgbm_workflow, "lgbm_wflw.rds")
saveRDS.lgb.Booster(extract_fit_engine(fit_lgbm_workflow), "lgbm_booster.rds")
# load trained workflow and merge it with lgb.booster
new_lgbm_wflow <- readRDS("lgbm_wflw.rds")
new_lgbm_wflow$fit$fit$fit <- readRDS.lgb.Booster("lgbm_booster.rds")
predict(new_lgbm_wflow, data_predict)
#> # A tibble: 2,930 × 1
#> .pred
#> <dbl>
#> 1 201911.
#> 2 124695.
#> 3 138983.
#> 4 221095.
#> 5 198972.
#> 6 188613.
#> 7 198730.
#> 8 170893.
#> 9 243899.
#> 10 196875.
#> # … with 2,920 more rows
Created on 2022-09-07 with reprex v2.0.2 https://reprex.tidyverse.org
在上面的 reprex 中,我使用了适合的工作流程。如果您使用防风草对象进行拟合,请改用以下方法:
saveRDS(bonsai_fit, path1)
saveRDS.lgb.Booster(extract_fit_engine(bonsai_fit), path2)
bonsai_fit_read <- readRDS(path1)
bonsai_fit_engine_read <- readRDS.lgb.Booster(path2)
bonsai_fit_read$fit <- bonsai_fit_engine_read
参考这条评论 https://github.com/tidymodels/bonsai/issues/44#issuecomment-1206816030更多细节。
The 一线希望 https://github.com/tidymodels/bonsai/issues/44#issuecomment-1208489161 is:
Just want to add to this conversation that since December 2021, {lightgbm}'s development version has supported using readsRDS() / saveRDS() directly for {lightgbm} models