caret:结合分层的 createMultiFolds (repeatedCV) 和 groupKFold

2023-12-19

我的问题与中提出的问题非常相似插入符号:结合 createResample 和 groupKFold https://stackoverflow.com/questions/48142617/caret-combine-createresample-and-groupkfold?answertab=active#tab-top

唯一的区别:我需要在分组后创建分层折叠(也重复 10 次),而不是引导重新采样(据我所知,这不是分层的),以便将其与插入符的 trainControl 一起使用。 以下代码适用于 10 倍重复的 CV,但我无法包含基于“ID”的数据分组(df$ID).

# creating indices
cv.10.folds <- createMultiFolds(rf_label, k = 10, times = 10)
# creating folds    
ctrl.10fold <- trainControl(method = "repeatedcv", number = 10, repeats = 10, index = cv.10.folds)
# train
rf.ctrl10 <- train(rf_train, y = rf_label, method = "rf", tuneLength = 6,
                       ntree = 1000, trControl = ctrl.10fold, importance = TRUE)

这是我的实际问题:我的数据包含许多组,每个组由 20 个实例组成,具有相同的“ID”。因此,当使用 10 倍 CV 重复 10 次时,我在训练中得到了一些组的实例,在验证集中得到了一些组的实例。我想避免这种情况,但总的来说,我需要对预测值进行分层分区(df$Label)。 (具有相同“ID”的所有实例也具有相同的预测/标签值。)

在上面的链接提供和接受的答案中(参见下面的部分)我想我必须修改folds2包含分层 10 倍 CV 而不是引导的行

folds <- groupKFold(x)
folds2 <- lapply(folds, function(x) lapply(1:10, function(i) sample(x, size = length(x), replace = TRUE)))

但不幸的是我不知道到底是怎么回事。你能帮我吗?


这是一种通过阻塞执行分层重复 K 倍 CV 的方法。

library(caret)
library(tidyverse)

一些假数据,其中 id 将成为阻塞因素:

id <- sample(1:55, size = 1000, replace = T)
y <- rnorm(1000)
x <- matrix(rnorm(10000), ncol = 10)
df <- data.frame(id, y, x)

按分块因子总结观察结果:

df %>%
  group_by(id) %>%
  summarise(mean = mean(y)) %>%
  ungroup() -> groups1 

根据分组数据创建分层折叠:

folds <- createMultiFolds(groups1$mean, 10, 3)

返回将原始 df 连接到组数据并获取 df 行 id

folds <- lapply(folds, function(i){
  data.frame(id = i) %>%
    left_join(df %>%
                rowid_to_column()) %>%
    pull(rowid) 
})

检查测试中的数据 ID 是否不在火车中:

lapply(folds, function(i){
  sum(df[i,1] %in% df[-i,1])
})

输出是一堆零,这意味着测试折叠中的 id 不存在于训练折叠中。

如果您的组 ID 不是数字,有两种方法可以实现此目的:
1 将它们转换为数字:

首先一些数据

id <- sample(1:55, size = 1000, replace = T)
y <- rnorm(1000)
x <- matrix(rnorm(10000), ncol = 10)
df <- data.frame(id = paste0("id_", id), y, x) #factor id's

df %>%
  mutate(id = as.numeric(id)) %>% #convert to numeric
  group_by(id) %>%
  summarise(mean = mean(y)) %>%
  ungroup() -> groups1 

folds <- createMultiFolds(groups1$mean, 10, 3)

folds <- lapply(folds, function(i){
  data.frame(id = i) %>%
    left_join(df %>%
                mutate(id = as.numeric(id)) %>% #also need to convert to numeric in the original data frame
                rowid_to_column()) %>%
    pull(rowid) 
})  

2 根据折叠索引过滤分组数据中的id,然后按id进行连接

df %>%
  group_by(id) %>%
  summarise(mean = mean(y)) %>%
  ungroup() -> groups1 

folds <- createMultiFolds(groups1$mean, 10, 3)

folds <- lapply(folds, function(i){
  groups1 %>% #start from grouped data
    select(id) %>% #select id's
    slice(i) %>% #filter id's according to fold index
    left_join(df %>% #join by id 
               rowid_to_column()) %>%
    pull(rowid) 
})

它对插入符有用吗?

ctrl.10fold <- trainControl(method = "repeatedcv", number = 10, repeats = 3, index = folds)

rf.ctrl10 <- train(x = df[,-c(1:2)], y = df$y, data = df, method = "rf", tuneLength = 1,
                   ntree = 20, trControl = ctrl.10fold, importance = TRUE)

rf.ctrl10$results
#output
  mtry     RMSE    Rsquared       MAE     RMSESD  RsquaredSD      MAESD
1    3 1.041641 0.007534611 0.8246514 0.06953668 0.009488169 0.05934975

我还建议你去图书馆看看mlr,它有许多不错的功能,包括阻塞 -这是关于SO的一个答案 https://stackoverflow.com/questions/40422377/how-can-a-blocking-factor-be-included-in-makeclassiftask-from-mlr-package。它在很多方面都有非常好的教程things https://mlr-org.github.io/mlr-tutorial/devel/html/index.html。很长一段时间我认为你要么使用caret or mlr但它们非常互补。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

caret:结合分层的 createMultiFolds (repeatedCV) 和 groupKFold 的相关文章

  • 如何在knitr和RStudio中为word和html设置不同的全局选项?

    我正在使用 RStudio 0 98 932 和 knitr 1 6 想要为word和html设置不同的全局knitr选项 例如 想要将word的fig width和fig height设置为6 html的fig width和fig hei
  • 一段 R 代码会影响 foreach 输出中的随机数吗?

    我使用运行模拟foreach and doParallel并与随机数 名为random在代码中 简而言之 我模拟一个足球联赛 随机生成所有比赛的获胜者以及相应的结果 在dt base没有比赛进行 在dt ex1 and dt ex24场比赛
  • 无法更新/编辑从 R 中的包(`gratia`)导出的 ggplot2 对象

    我希望我在这里遗漏了一些令人痛苦的明显的东西 我希望更新 例如 修复标题 实验室等 由 生成的 ggplot 对象gratia draw 不太确定为什么我无法更新该对象 有一个简单的解决方案吗 devtools install github
  • R、Rcpp 与 Armadillo 中矩阵 rowSums() 与 colSums() 的效率

    背景 来自 R 编程 我正在扩展到 C C 形式的编译代码Rcpp 作为循环交换 以及一般的 C C 效果的实践练习 我实现了 R 的等效项rowSums and colSums 矩阵的函数Rcpp 我知道它们以 Rcpp 糖的形式存在 并
  • 如何对数字进行四舍五入并使其显示零?

    R 中将数字四舍五入到小数点后 2 位的常用代码是 gt a 14 1234 gt round a digits 2 gt a gt 14 12 但是 如果该数字的前两位小数位为零 则 R 会在显示中抑制零 gt a 14 0034 gt
  • 将字符串列拆分为多个虚拟变量

    作为 R 中 data table 包的相对缺乏经验的用户 我一直在尝试将一个文本列处理为大量指示符列 虚拟变量 每列中的 1 表示特定的子字符串是在字符串列中找到 例如我想处理这个 ID String 1 a b 2 b c 3 c 进入
  • 列出 R 数据文件的内容而不加载

    我有时用print load myDataFile RData 当我加载数据文件时列出它的内容 有没有办法列出内容而不加载数据文件中包含的对象 我认为如果不加载对象就无法做到这一点 解决方案可能是使用包装器将 R 对象保存到save 该函数
  • R 中两个时间戳之间的左连接

    我的目标是执行左连接intervals哪里的bike id比赛和created at时间戳在records在 之间start and end in the intervals table gt class records 1 data ta
  • 不同编程语言中的浮点数学

    我知道浮点数学充其量可能是丑陋的 但我想知道是否有人可以解释以下怪癖 在大多数编程语言中 我测试了 0 4 到 0 2 的加法会产生轻微的错误 而 0 4 0 1 0 1 则不会产生错误 两者计算不平等的原因是什么 在各自的编程语言中可以采
  • `dplyr::_join` 函数的命名向量“by”参数[重复]

    这个问题在这里已经有答案了 我正在写一个函数dplyr join两个数据框by不同的列 第一个数据帧的列名称动态指定为函数参数 我相信我需要使用rlang准引用 元编程 但未能找到可行的解决方案 我很感激任何建议 library dplyr
  • `as.matrix` 和 `as.data.frame` S3 方法与 S4 方法

    我注意到定义as matrix or as data frame作为 S4 类的 S3 方法 使例如lm formula objS4 and prcomp object 开箱即用 如果它们被定义为 S4 方法 则这不起作用 为什么将方法定义
  • 更新 R6 对象实例中的方法定义

    如何更新 R6 类实例的方法定义 正如我所期望的 S3 使用当前的方法定义 对于 R5 参考类 我可以使用 myInstance myInstance copy 在 R6 中 我尝试了 myInstance myInstance clone
  • 如何声明包含 M 个元素的列表对象

    我想声明一个包含 M 3 x 3 矩阵的列表 如果我事先知道数字 M 那么我可以通过以下方式声明这样的列表 elm lt matrix NA 3 3 Say M 7 myList lt list elm elm elm elm elm el
  • 如何按定义的顺序将图像合并到一个文件中

    我有大约 100 张图像 png 我不想手动执行此操作 而是希望将它们按照定义的顺序 基于文件名 并排放置在一个 pdf 中 每行 12 个图像 有人有什么建议吗 我按照下面托马斯告诉我的方法尝试了 它把它们贴在旁边有一个黑边 我怎样才能去
  • 在包加载之前如何知道 R 中特定函数属于哪个包?

    例如 我知道许多流行的功能 例如tbl df 我通常不记得它属于哪个包 即data table or dplyr 所以我必须始终记住并加载一个包 但我做不到 tbl df除非我加载了正确的包 在 R 控制台本身加载或安装包之前 有没有办法知
  • Keras model.predict 函数给出输入形状错误

    我已经在 Tensorflow 中实现了通用句子编码器 现在我正在尝试预测句子的类概率 我也将字符串转换为数组 Code if model model type universal classifier basic class probs
  • 增加雷达图中长轴标签的空间

    我想创建一个雷达图ggirahExtra ggRadar 问题是我的标签很长并且被剪掉了 我想我可以通过添加在标签和绘图之间创建更多空间margin margin 0 0 2 0 cm to element text in axis tex
  • 要在子集中显示的非数字条目的维恩图

    我有以下数据框 SET1 SET2 SET3 par1 par2 par1 par2 par3 par2 par3 par4 par5 我想制作一个维恩图 其中所有这些 parX 元素都显示在各自的子集中 即作为标签 而不仅仅是重叠元素的数
  • 如何为自定义 S3 类实现提取/取子集 ([ [<-, [[ [[<-)] 函数?

    我有一个自定义的 S3 类foo 它在正常的基础上添加了一些自定义行为data frame foo object lt data frame class foo object lt c foo data frame 对于这个类 还应该有一个
  • 需要在R中跳过不同数量的行

    我正在使用以下代码来处理我的数据 但最近我意识到使用skip 27 在数据开始之前跳过存储在我的文件中的信息 不是一个好的选择 因为每个文件中要跳过的行数不同我的目标是读取存储在多个文件夹中的各种txt文件 并非所有文件都有相同的列数 列的

随机推荐