问题总结
我正在装修一个brms::brm_multiple()
模型到一个大型数据集,其中缺失的数据已使用mice
包裹。数据集的大小使得并行处理的使用非常可取。但是,我不清楚如何最好地利用计算资源,因为我不清楚如何brms
在核心之间划分估算数据集的采样。
如何选择以下选项以最大限度地有效利用计算资源?
- 插补数 (
m
)
- 链数(
chains
)
- 核心数量(
cores
)
概念示例
假设我天真地(或者为了举例而故意愚蠢地)选择m = 5
, chains = 10
, cores = 24
。因此,需要在 HPC 上保留的 24 个核心之间分配 5 x 10 = 50 个链。如果没有并行处理,这将需要大约 50 个时间单位(不包括编译时间)。
我可以想象三种并行化策略brms_multiple()
,但可能还有其他:
场景一:并行估算数据集,串行关联链
这里,5 个插补中的每一个都分配给它自己的处理器,该处理器串行运行 10 个链。处理时间为 10 个单位(与非并行处理相比,速度提高了 5 倍),但糟糕的规划浪费了 19 个核心 x 10 个时间单位 = 190 个核心时间单位(ctu;= 80% 的预留计算资源)。有效的解决方案是设置cores
= m
.
场景2:串行推算数据集,并行关联链
在这里,采样首先获取第一个估算数据集,并在 10 个不同核心上运行该数据集的一个链。然后对其余四个估算数据集重复此操作。该处理需要 5 个时间单位(比串行处理速度提高 10 倍,比场景 1 提高 2 倍)。然而,这里的计算资源也被浪费了:14 个核心 x 5 个时间单位 = 70 ctu。有效的解决方案是设置cores
= chains
场景3:混战,其中每个核心在可用时承担待定的插补/链组合,直到所有核心都被处理为止。
此处,采样首先分配所有 24 个核心,每个核心分配给 50 个待处理链之一。完成迭代后,将处理第二批 24 条链,使处理的链总数达到 48 条。但现在只有 2 条链待处理,22 个核心闲置 1 个时间单位。总处理时间为3个时间单位,浪费的计算资源为22ctu。有效的解决方案是设置cores
到多个m
x chains
.
最小可重复示例
此代码使用修改自以下示例的示例来比较计算时间brms小插图 https://cran.r-project.org/web/packages/brms/vignettes/brms_missings.html。这里我们要设置m
= 10, chains
= 6,并且cores
= 4。这使得总共需要处理 60 个链。在这些条件下,我预计速度改进(相对于串行处理)如下*:
- 场景 1:60/(6 条链 x 天花板(10 m / 4 芯)) = 3.3x
- 场景2:60/(天花板(6条链/4芯)x 10 m)= 3.0x
- 场景 3:60/天花板((6 条链 x 10 m)/ 4 核)= 4.0x
*(使用上限/向上取整是因为链不能在核心之间细分)
library(brms)
library(mice)
library(tictoc) # convenience functions for timing
# Load data
data("nhanes", package = "mice")
# There are 10 imputations x 6 chains = 60 total chains to be processed
imp <- mice(nhanes, m = 10, print = FALSE, seed = 234023)
# Fit the model first to get compilation out of the way
fit_base <- brm_multiple(bmi ~ age*chl, data = imp, chains = 6,
iter = 10000, warmup = 2000)
# Use update() function to avoid re-compiling time
# Serial processing (127 sec on my machine)
tic() # start timing
fit_serial <- update(fit_base, .~., cores = 1L)
t_serial <- toc() # stop timing
t_serial <- diff(unlist(t_serial)[1:2]) # calculate seconds elapsed
# Parallel processing with 3 cores (82 sec)
tic()
fit_parallel <- update(fit_base, .~., cores = 4L)
t_parallel <- toc()
t_parallel <- diff(unlist(t_parallel)[1:2]) # calculate seconds elapsed
# Calculate speed up ratio
t_serial/t_parallel # 1.5x
显然我错过了一些东西。我无法用这种方法区分场景。