Lime的R版本可以用count:poisson目标函数解释xgboost模型吗?

2024-03-30

我使用 xgb.train 和“count:poisson”目标函数生成了一个模型,在尝试创建解释器时出现以下错误:

Error: Unsupported model type

当我用其他东西(例如 reg:logistic)替换目标时,Lime 就起作用了。

有没有办法解释 count:poisson 石灰? 谢谢

可重现的例子:

library(xgboost)
library(dplyr)
library(caret)
library(insuranceData) # example dataset https://cran.r-project.org/web/packages/insuranceData/insuranceData.pdf
library(lime) # Local Interpretable Model-Agnostic Explanations
set.seed(123)
data(dataCar)
mydb <- dataCar %>% select(clm, exposure, veh_value, veh_body,
                           veh_age, gender, area, agecat)

label_var <- "clm"  
offset_var <- "exposure"
feature_vars <- mydb %>% 
  select(-one_of(c(label_var, offset_var))) %>% 
  colnames()

#preparing data for xgboost (one hot encoding of categorical (factor) data
myformula <- paste0( "~", paste0( feature_vars, collapse = " + ") ) %>% as.formula()
dummyFier <- caret::dummyVars(myformula, data=mydb, fullRank = TRUE)
dummyVars.df <- predict(dummyFier,newdata = mydb)
mydb_dummy <- cbind(mydb %>% select(one_of(c(label_var, offset_var))), 
                    dummyVars.df)
rm(myformula, dummyFier, dummyVars.df)


feature_vars_dummy <-  mydb_dummy  %>% select(-one_of(c(label_var, offset_var))) %>% colnames()

xgbMatrix <- xgb.DMatrix(
  data = mydb_dummy %>% select(feature_vars_dummy) %>% as.matrix, 
  label = mydb_dummy %>% pull(label_var),
  missing = "NAN")


#model 1: this does not
myParam <- list(max.depth = 2,
                eta = .01,
                gamma = 0.001,
                objective = 'count:poisson',
                eval_metric = "poisson-nloglik")


booster <- xgb.train(
  params = myParam, 
  data = xgbMatrix, 
  nround = 50)

explainer <- lime(mydb_dummy %>% select(feature_vars_dummy), 
                  model = booster)

explanation <- explain(mydb_dummy %>% select(feature_vars_dummy) %>% head,
                       explainer,
                       n_labels = 1, 
                       n_features = 2)
#Error: Unsupported model type
#model 2 : this works
myParam2 <- list(max.depth = 2,
                eta = .01,
                gamma = 0.001,
                objective = 'reg:logistic',
                eval_metric = "logloss")


booster2 <- xgb.train(
  params = myParam2, 
  data = xgbMatrix, 
  nround = 50)

explainer <- lime(mydb_dummy %>% select(feature_vars_dummy), 
                  model = booster)

explanation <- explain(mydb_dummy %>% select(feature_vars_dummy) %>% head,
                       explainer,
                       n_features = 2)


plot_features(explanation)

None

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Lime的R版本可以用count:poisson目标函数解释xgboost模型吗? 的相关文章

  • data.table:从不存在的列到现有列的“get”失败,静默失败

    gt d lt data table x 1 5 gt d x 6 y get i 9 Error in get i 9 object i 9 not found gt d y 1 add a new column y gt d x 6 y
  • 使用 sapply 的列表和矩阵

    我有一个也许是基本的问题 我在网上搜索过 我在读取文件时遇到问题 尽管如此 我还是按照 Konrad的建议设法读取了我的文件 我很欣赏这一点 How to get R to read in files from multiple subdi
  • R 中的聚类分析:确定最佳聚类数

    如何选择最佳的聚类数量来进行 k 均值分析 绘制以下数据的子集后 多少个簇比较合适 如何进行聚类树突分析 n 1000 kk 10 x1 runif kk y1 runif kk z1 runif kk x4 sample x1 lengt
  • 如何导入 .tsv 文件

    我需要读取一个表 tsvR 中的文件 test lt read table file drug info tsv Error in scan file what nmax sep dec quote skip nlines na strin
  • 如何替换R中的“意外转义字符”

    当我尝试从 Facebook URL 的字符对象解析 JSON 时 我收到 fromJSON data 中的错误 位置 130 处出现意外的转义字符 o 看一下这个 library RCurl library rjson data lt g
  • 从频率表生成 data.frame

    我在 2 4 数组中有包含 500 个观察值的合成数据 datax array c 120 181 50 43 41 33 24 8 dim c 2 4 dimnames datax list gender c male female pu
  • 为 Linux 安装 R 包时出错

    我试图在 R 3 3 上安装一个名为 rgeos 的包 但是当我输入 install packages rgeos 但它返回给我以下错误 其他包也会发生同样的情况 但不是所有包 gt installing source package rg
  • 双向条形图,两侧带有正标签ggplot2

    我尝试在 ggplot 中创建一个双向条形图 其中轴上方和下方的轴标签和数据标签均为正值 例如 如果您的数据是 myData lt data frame category c yes yes no no month c Jan Feb Ja
  • R.scale() 和 sklearn.preprocessing.scale() 之间的区别

    我目前正在将数据分析从 R 转移到 Python 当在 R 中缩放数据集时 我将使用 R scale 根据我的理解 它将执行以下操作 x mean x sd x 为了替换该函数 我尝试使用 sklearn preprocessing sca
  • 如何在R中绘制仪表图表?

    如何在 R 中绘制以下图 Red 30 Yellow 40 Green 30 Needle at 52 所以这里有一个完整的ggplot解决方案 注意 从原始帖子中编辑 在仪表中断处添加数字指示器和标签 这似乎是OP在评论中所要求的 如果不
  • 如何使用 RODBC 将数据帧保存到数据库生成的主键表

    我想使用 R 脚本将数据框输入到数据库中的现有表中 并且希望数据库中的表具有顺序主键 我的问题是 RODBC 似乎不允许主键约束 这是创建我想要的表的 SQL CREATE TABLE dbo results ID INT IDENTITY
  • ggplot2 中的小时刻度

    我正在处理就寝时间和醒来时间 因此我想创建一个具有 24 小时 x 轴的图表 从第一天中午 12 点开始 到第二天中午 12 点结束 这意味着晚上 11 59 之后 它应该再次从 0 开始 同样的问题 仅涉及数字 我想创建一个从 10 到
  • 合并具有一个共同元素的集合 R

    我有一个这样的列表 lista list lista 1 c 1 2 4 6 8 9 10 11 12 19 32 34 35 36 37 38 lista 2 c 7 8 lista 3 c 13 14 16 26 27 28 29 30
  • dmvnorm MVN 密度 - RcppArmadillo 实现比 R 包慢,包括一些 Fortran

    The solution现已上线RCPP画廊 http gallery rcpp org articles dmvnorm arma 我从 RcppArmadillo 中的 mvtnorm 包重新实现了 dmvnorm 我有点喜欢犰狳 但我
  • 自定义 colorRampPalette 中的颜色条

    我定义了一个 colorRampPalette my colors colorRampPalette c light green yellow orange red 如何为其绘制颜色条 图例 项目 最好仅使用基本包 我正在寻找一个充满该颜色
  • 选择一个单元格内的最小值或最大值(分隔字符串)

    我有一个数据框 其中每个样本的列可以有多个值 例如 Gene Pvalue1 Pvalue2 Pvalue3 Beta Ace 0 0381 0 00357 0 01755 0 001385 0 0037 NA 0 039 0 03 1 1
  • 使用 RMySQL 会干扰 RPostgreSQL

    我有一个 R 脚本 我想从 MySQL 数据库中提取一些数据 然后从 PostgreSQL 数据库中提取一些数据 但是 从 RMySQL 加载 MySQL 驱动程序会阻止我从以下位置加载 PostgreSQL 驱动程序 PostgreSQL
  • 使用 purrr::map() 更改和分配新变量名称

    我刚刚开始掌握编写函数并使用 lapply purrr map 使我的代码更加简洁 但显然还没有完全理解它 在我当前的示例中 我想重命名 lm robust 对象的系数名称 然后更改 lm robust 对象以合并新名称 我目前这样做 li
  • 列槽不足

    当尝试为 data table 中的每个变量 108 个变量 创建 12 个滞后时 我收到一条错误 指出列槽不足 此操作应创建大约 1200 个变量或列 Data A as data table Datos A Varnames names
  • 按列分组的数据帧上 R 中的行之间的差异

    我希望通过 app name 获得不同版本的计数差异 我的数据集如下所示 app name version id count difference 这是数据集 data structure list app name structure c

随机推荐