dmvnorm MVN 密度 - RcppArmadillo 实现比 R 包慢,包括一些 Fortran

2024-05-13

The solution现已上线RCPP画廊 http://gallery.rcpp.org/articles/dmvnorm_arma/


我从 RcppArmadillo 中的 mvtnorm 包重新实现了 dmvnorm。我有点喜欢犰狳,但我想它也可以在普通的 Rcpp 中工作。 dmvnorm 的方法基于马哈拉诺比斯距离,因此我有一个函数,然后是多元正态密度函数。

让我向你展示我的代码:

#include <RcppArmadillo.h>
#include <Rcpp.h>

// [[Rcpp::depends("RcppArmadillo")]]

// [[Rcpp::export]]
arma::vec mahalanobis_arma( arma::mat x ,  arma::mat mu, arma::mat sigma ){

  int n = x.n_rows;
  arma::vec md(n);
    for (int i=0; i<n; i++){
        arma::mat x_i = x.row(i) - mu;
        arma::mat Y = arma::solve( sigma, arma::trans(x_i) );
        md(i) = arma::as_scalar(x_i * Y);
    }
    return md;

    }



// [[Rcpp::export]]
arma::vec dmvnorm ( arma::mat x,  arma::mat mean,  arma::mat sigma, bool log){ 

arma::vec distval = mahalanobis_arma(x,  mean, sigma);

    double logdet = sum(arma::log(arma::eig_sym(sigma)));
    double log2pi = 1.8378770664093454835606594728112352797227949472755668;
    arma::vec logretval = -( (x.n_cols * log2pi + logdet + distval)/2  ) ;

       if(log){ 
         return(logretval);

       }else { 
       return(exp(logretval));
         }
}

所以,并没有让我大失望:

模拟一些数据

sigma <- matrix(c(4,2,2,3), ncol=2)
x <- rmvnorm(n=5000000, mean=c(1,2), sigma=sigma, method="chol")

和基准

system.time(mvtnorm::dmvnorm(x,t(1:2),.2+diag(2),F))
   user  system elapsed 
   0.05    0.02    0.06 

system.time(dmvnorm(x,t(1:2),.2+diag(2),F))
   user  system elapsed 
   0.12    0.02    0.14 

不!!!!!! :-(

[EDIT]

The 问题是: 1) 为什么 RcppArmadillo 实现比普通 R 实现慢? 2) 如何创建击败 R 实现的 Rcpp/RcppArmadillo 实现?

[EDIT 2]

我将 mahalanobis_arma 放入 mvtnorm::dmvnorm 函数中,它也会减慢速度。


如果您想要更快地实现马哈拉诺比斯距离,您只需重写算法并模仿 R 使用的算法即可。这非常简单

我稍微修改了你的功能mahalanobis_arma转动mu to a rowvec.

基本上我只是将 R 代码翻译为 RcppArmadillo

mahalanobis
function (x, center, cov, inverted = FALSE, ...) 
{
    x <- if (is.vector(x)) 
        matrix(x, ncol = length(x))
    else as.matrix(x)
    x <- sweep(x, 2, center)
    if (!inverted) 
        cov <- solve(cov, ...)
    setNames(rowSums((x %*% cov) * x), rownames(x))
}
<bytecode: 0x6e5b408>
<environment: namespace:stats>

这里是

#include <RcppArmadillo.h>
#include <Rcpp.h>

// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::vec Mahalanobis(arma::mat x, arma::rowvec center, arma::mat cov){
    int n = x.n_rows;
    arma::mat x_cen;
    x_cen.copy_size(x);
    for (int i=0; i < n; i++) {
        x_cen.row(i) = x.row(i) - center;
    }
    return sum((x_cen * cov.i()) % x_cen, 1);    
}


// [[Rcpp::export]]
arma::vec mahalanobis_arma( arma::mat x ,  arma::rowvec mu, arma::mat sigma ){

  int n = x.n_rows;
  arma::vec md(n);
    for (int i=0; i<n; i++){
        arma::mat x_i = x.row(i) - mu;
        arma::mat Y = arma::solve( sigma, arma::trans(x_i) );
        md(i) = arma::as_scalar(x_i * Y);
    }
    return md;

    }

现在,让我们来比较一下这个新的犰狳版本(Mahalanobis),你的第一个版本(mahalanobis_arma)和 R 实现(mahalanobis).

我将此 Cpp 代码保存为mahalanobis.cpp

require(RcppArmadillo)
sourceCpp("mahalanobis.cpp")

set.seed(1)
x <- matrix(rnorm(10000 * 10), ncol = 10)
Sx <- cov(x)


all.equal(c(Mahalanobis(x, colMeans(x), Sx))
          ,mahalanobis(x, colMeans(x), Sx))
## [1] TRUE

all.equal(mahalanobis_arma(x, colMeans(x), Sx)
          ,Mahalanobis(x, colMeans(x), Sx))
## [1] TRUE


require(rbenchmark)
benchmark(Mahalanobis(x, colMeans(x), Sx),
          mahalanobis(x, colMeans(x), Sx),
          mahalanobis_arma(x, colMeans(x), Sx),
          order = "elapsed")


##                                   test replications elapsed
## 1      Mahalanobis(x, colMeans(x), Sx)          100   0.124
## 2      mahalanobis(x, colMeans(x), Sx)          100   0.741
## 3 mahalanobis_arma(x, colMeans(x), Sx)          100   4.509
##   relative user.self sys.self user.child sys.child
## 1    1.000     0.173    0.077          0         0
## 2    5.976     0.804    0.670          0         0
## 3   36.363     4.386    4.626          0         0

正如您所看到的,新的实现比 R 的实现更快。 我非常确定,通过使用乔列斯基分解来求解协方差矩阵或使用其他矩阵分解,我们可以做得更好。

最后,我们可以将其插入Mahalanobis功能进入你的dmvnorm并测试它:

require(mvtnorm)
set.seed(1)
sigma <- matrix(c(4, 2, 2, 3), ncol = 2)
x <- rmvnorm(n = 5000000, mean = c(1, 2), sigma = sigma, method = "chol")


all.equal(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
          c(dmvnorm(x, t(1:2), .2+diag(2), FALSE)))
## [1] TRUE

benchmark(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
          dmvnorm(x, t(1:2), .2+diag(2), FALSE),
          order = "elapsed")

##                                                test replications
## 2          dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE)          100
## 1 mvtnorm::dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE)          100
##   elapsed relative user.self sys.self user.child sys.child
## 2  35.366    1.000    31.117    4.193          0         0
## 1  60.770    1.718    56.666   13.236          0         0

现在几乎快了一倍。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

dmvnorm MVN 密度 - RcppArmadillo 实现比 R 包慢,包括一些 Fortran 的相关文章

  • 如何很好地注释 ggplot2(手册)

    Using ggplot2我通常使用geom text和类似的东西position jitter注释我的情节 然而 对于一个漂亮的情节 我经常发现手动注释是值得的 像下面这样 data2 lt structure list type str
  • 将公式传递给 R 中的函数?

    对此的任何帮助将不胜感激 我正在使用 Lumley 调查包 并试图简化我的代码 但遇到了一些小障碍 在我的代码中调用包中的 svymean 函数如下 其中第一个参数是指示我想要哪些变量的公式 第二个参数是该数据集 svymean hq eh
  • 在 R 中编写多重积分函数

    为了将以下内容转换为函数 我想知道如何用 R 代码编写以下二重积分 bar x mu 假设pi0 and pi1以向量化方式实现函数 pi 0 和 pi 1 可能的解决方案是 integral lt function n mu s pi0
  • 直接来自数据的马尔可夫模型图(makovchain 或 deemod 包?)

    我想读取一堆因子数据并从中创建一个可以很好地可视化的转换矩阵 我发现了一个非常好的软件包 称为 heemod 它与 diagram 一起工作得不错 对于我的第一个快速而肮脏的方法 我运行了一段 Python 代码来获取矩阵 然后使用这个 R
  • 如何计算满足条件的行数

    假设我有以下数据框 Data1 X1 X2 1 15 1 2 3 1 3 7 0 4 11 1 5 1 0 6 9 0 7 18 0 8 6 1 9 3 1 我想知道如何找到观察的总数X1大于 9 并且X2等于1 我想我需要使用sum 但我
  • 挖泥机子集 (MuMIn) - 如果存在主效应,则必须包括交互作用

    我正在使用 dredge MuMIn 进行一些探索性工作 在此过程中 我想将两个变量设置为仅当它们之间存在相互作用时才允许一起出现 即它们不能仅作为主要效果一起出现 使用样本数据 我想挖掘模型 fm1 尽管它可能没有意义 如果变量 GNP
  • R 中数据帧的条件求和

    我正在努力将在 Excel 中进行的分析迁移到 R 因为我的数据集已达到 Excel 的限制 在 Excel 中 我有一个工作表 状态 它执行 sumifs 函数 对另一个工作表 成员 中 状态 中具有相同状态 周组合的值求和 我想在 R
  • 使用 ggplot 将条形图的列与线图的点对齐

    当线图的点与条形图的条具有相同的 x 轴时 有什么方法可以使用 ggplot 将它们对齐 这是我尝试使用的示例数据 library ggplot2 library gridExtra data data frame x rep 1 27 e
  • 如何解决在Windows中运行R时出现“剪贴板缓冲区已满且输出丢失”错误?

    我正在尝试将一些数据直接从 R 复制到我的 Windows 计算机中的剪贴板 我发现在一些网站上使用 file clipboard 可以工作 确实如此 但对于非常小的数据集 例如 如果我复制一个小数据集 100 个 obs 它会顺利工作 d
  • R 中带有变音符号的字符列表

    我试图将字符串中的电话 字符 出现次数制成表格 但变音符号单独作为字符制成表格 理想情况下 我有一个国际音标的单词列表 其中包含大量变音符号以及它们与基本字符的几种组合 我在这里给出了仅包含一个单词的 MWE 但对于单词列表和更多类型的组合
  • Sweave + RweaveHTML:cat 输出未出现在输出中

    我对 Sweave RweaveHTML 有疑问 我希望 cat 的输出最终出现在正在生成的 html 文件中 我有一个案例 它没有 我不明白为什么 test function bla bla cat Result is 然后在 Rnw 文
  • R中不重复的组合

    我试图获取变量元素长度为 3 的所有可能组合 虽然它部分地与combn 一起工作 但我没有完全得到我正在寻找的输出 这是我的例子 x lt c a b c d e t combn c x x 3 我得到的输出看起来像这样 1 2 3 1 a
  • 将从数据透视表包生成的数据透视表转换为数据帧

    我正在尝试制作一个数据透视表pivottabler包裹 我想将数据透视表对象转换为数据框 以便我可以将其转换为数据表 带有 DT 并在 Shiny 应用程序中渲染它 以便可以下载 library pivottabler pt qpvt mt
  • R 中的金字塔图

    对于示例数据集 我按国家 地区创建了一个金字塔图 显示人口中男性和女性超重的水平 library plotrix xy males overweight lt c 23 2 33 5 43 6 33 6 43 5 43 5 43 9 33
  • 使用示例代码继续在 ggplot2 中遇到错误“loop_apply”未从当前命名空间(plyr)解析”

    我今天一直遇到这个错误 我已经从 github 下载了 plyr 但它仍然不起作用 安装 plyr 后 我重新启动了 R studio 甚至我的电脑 看来问题可能是由于 R 解析对外部 DLL 的引用的方式发生了变化 正如线程中途提到的he
  • 按元素名称组合/合并列表

    我有两个列表 其元素的名称部分重叠 我需要将其逐个元素合并 组合成一个列表 gt lst1 lt list integers c 1 7 letters letters 1 5 words c two strings gt lst2 lt
  • 错误:列索引必须最多为 1,如果... heatmap.2

    我在 heatmap 2 中收到错误 我在这里发现了类似的错误R knnImputation 给出错误 https stackoverflow com questions 45117125 r knnimputation giving er
  • R 版本 4.0.0 上的 ROracle

    当尝试使用 ROracle 时 我收到以下错误消息 gt library ROracle Error package or namespace load failed for ROracle package ROracle was inst
  • 在r中的数据框中循环线性回归输出

    我有一个下面的数据集 我想在其中对每个国家和州进行线性回归 然后绑定数据集中的预测值 添加另外三列后的最终数据框 我已经对一个国家和一个地区进行了此操作 但想对每个国家和地区进行此操作 并将预测值 上限值和下限值放回到cbind的数据集中
  • 如何创建具有特定于每个方面的标题和副标题的分面图?

    生成一个图 该图与每列的单独图相结合 带有标题和副标题 以及每个图的垂直线 我使用直方图创建了带有垂直线的列 library ggplot2 library gridExtra library tidyr actualIris lt dat

随机推荐

  • 如何在嵌入式tomcat中配置valve?

    我需要在嵌入式tomcat中配置valvehttp tomcat apache org tomcat 8 0 doc config valve html Remote IP Valve http tomcat apache org tomc
  • MongoDB 在仅返回 _id 时使用 COLLSCAN

    我想返回 MongoDB 集合中的所有 ID 我使用了以下代码 db coll find id 1 但MongoDB扫描整个集合而不是从默认读取信息index id 1 从日志中 find collection filter project
  • 每第 n 个字符分割一个字符串

    在 JavaScript 中 这就是我们如何在每 3 个字符处分割一个字符串 foobarspam match 1 3 g 我正在尝试弄清楚如何在 Java 中做到这一点 有什么指点吗 你可以这样做 String s 1234567890
  • 在 Delphi 2007 中将具有透明度的位图保存为 PNG

    我有一个包含透明度信息的 Delphi 位图 32 位 我需要将其转换并保存为 PNG 文件 同时保留透明度 我目前拥有的工具是graphics32 Library GR32 PNG 由Christian Budde 提供 和PNGImag
  • 并行启动服务

    我有一个脚本可以检查不同服务器上的某些服务是否已启动 如果没有启动 该脚本应该启动该服务 问题是 它不会并行启动服务 而是等待每个服务启动 Code server list Get Content path D Path list of s
  • Google G-Suite API 控制台未显示启用 G Suite 域范围委派

    我正在与客户合作设置服务帐户凭据 以便通过 API 读取 G Suite 目录信息 我之前已经这样做了十几次 没有任何问题 现在我遇到了一个问题 设置没有向客户端显示 下面的图片显示了我通常会看到的内容 阅读中圈出的区域是启用域范围委派的能
  • 使用 VNext 构建后,TFS tbl_Content 开始快速增长

    直到一个月前我们一直在使用旧样式 XAML 构建 然后开始使用 vNext 构建 之后我注意到 TFS 数据库中的 tbl Content 表开始快速增长 例如 在过去 8 小时内 它增长了 10 GB 但我不明白为什么会这样做 有谁知道它
  • 我可以通过链接分享我的私人 GitHub 存储库吗?

    我在 GitHub 上的私人存储库中有一个 Java 应用程序 我想与没有帐户的人共享它 我在网站上没有找到任何与此相关的选项 有没有办法做到这一点 协作者只能是 GitHub 用户 无法在非 Github 用户之间共享私有存储库 您需要
  • 使用 XPath 3.1 fn:serialize 进行 JSON 序列化

    我在 Saxon HE 9 8 中使用 XSLT 3 0 并且希望将 JSON 文档用作链接数据JSON LD https json ld org 在 JSON LD 中 完整的 HTTP URI 通常显示为值 当我使用 XPath 3 1
  • 插入并发问题-多线程环境

    我有一个问题 即使用完全相同的参数在完全相同的时间调用相同的存储过程 存储过程的目的是获取记录 如果存在 或创建并获取记录 如果不存在 问题是两个线程都在检查记录是否存在并报告错误 然后都插入新记录 在数据库中创建重复记录 我尝试将操作保留
  • 钛金 Android 导航组

    您好 我是钛合金新手 它允许开发人员创建跨平台应用程序 我需要创建一个适用于 Android 和 iOS 的导航组 有没有明确的解决方案 因为 Ti UI iPhone createNavigationGrou 仅适用于 iphone 谢谢
  • Itunes Connect 测试飞行公共链接有效性

    苹果最近为试飞版本启用了公共链接功能 我们可以与任何人共享此链接 他可以使用此公共链接安装应用程序 此公共链接背后的构建有效期为 90 天 我的问题是 与用户共享公共链接后 我们可以增加构建的到期时间吗 这样公共链接的有效性就会增加 我们不
  • 将颜色映射到plotly go.饼图中的标签

    我正在使用 make subplots 和 go Pie 绘制一系列 3 个饼图 我想最终将它们放入破折号应用程序中 用户可以在其中过滤数据并且数字将更新 如何将特定颜色映射到变量 以便男性始终为蓝色 女性始终为粉红色 等等 您可以使用 c
  • 使用 Terraform 管理访问 RDS 数据库的凭据时出现问题

    我通过 Terraform 创建了一个秘密 该秘密用于访问也在 Terraform 中定义的 RDS 数据库 并且在秘密中 我不想包含username and password 因此我创建了一个空密钥 然后在 AWS 控制台中手动添加凭证
  • 在继承的ctypes.Structure类的构造函数中调用from_buffer_copy

    我有以下代码 class MyStruct ctypes Structure fields id ctypes uint perm ctypes uint 定义类后 我可以直接从缓冲区复制数据到我的字段上 例如 ms MyStruct fr
  • 一个新的 JavaScript 数组长度是否无法使用? [复制]

    这个问题在这里已经有答案了 根据MDN 文档new Array length https developer mozilla org en US docs Web JavaScript Reference Global Objects Ar
  • 将数值数据更改为分类数据 - Pandas [重复]

    这个问题在这里已经有答案了 我有一个 pandas 数据框 其中有一个数字列 金额 金额从 0 到 20000 不等 我想将其更改为定义范围的分类变量 因此 分类变量将是 0 1000 之间 1000 2000 美元之间 依此类推 直到 1
  • 多个where条件codeigniter

    如何将此查询转换为活动记录 UPDATE table user SET email email last ip last ip where username username and status status 我尝试将上面的查询转换为 d
  • JavascriptCore:在 JSExport 中将 javascript 函数作为参数传递

    JavascriptCore是iOS7中支持的新框架 我们可以使用 JSExport 协议将 objc 类的部分内容公开给 JavaScript 在javascript中 我尝试将函数作为参数传递 像这样 function getJsonC
  • dmvnorm MVN 密度 - RcppArmadillo 实现比 R 包慢,包括一些 Fortran

    The solution现已上线RCPP画廊 http gallery rcpp org articles dmvnorm arma 我从 RcppArmadillo 中的 mvtnorm 包重新实现了 dmvnorm 我有点喜欢犰狳 但我