如何在 CUDA 中执行多个矩阵乘法?

2024-04-27

我有一个方阵数组int *M[10];以便M[i]定位第一个元素i-th 矩阵。我想将所有矩阵相乘M[i]通过另一个矩阵N,这样我就收到了方阵数组int *P[10]作为输出。

我看到有不同的可能性:

  1. 分配不同元素的计算M[i]到不同的线程;例如,我有10矩阵,4x4大小,以便涉及的线程数为160;如何使用CUDA来实现这种方法?
  2. 在上面例子的框架中,创建一个复合矩阵大小40x40(即收集10, 4x4大小矩阵在一起)并使用40x40线程;但这种方法似乎需要更多时间;我正在尝试使用矩阵数组,但我认为我做错了;我怎样才能使用这种方法10矩阵?如何在内核函数中编写它?

这就是我正在尝试的;

void GPU_Multi(int *M[2], int *N, int *P[2], size_t width)
{

    int *devM[2];
    int *devN[2];
    int *devP[2];
    size_t allocasize =sizeof(int) *width*width;

    for(int i = 0 ; i < 10 ; i ++ ) 
    {
        cudaMalloc((void**)&devM[ i ], allocasize );
        cudaMalloc((void**)&devP[ i ], allocasize ); 
    }

    cudaMalloc((void**)&devN, allocasize );

    for(int i = 0 ; i < 10 ; i ++ ) {

        cudaMemcpy(devM[ i ],M[ i ], allocasize , cudaMemcpyHostToDevice);
        cudaMemcpy(devN, N, allocasize , cudaMemcpyHostToDevice);
        dim3 block(width*2, width*2);
        dim3 grid(1,1,1);
        Kernel_Function<<<grid, block>>>  (devM[2], devN, devP[2],width);

        for(int i = 0 ; i < 10 ; i ++ ) 
        {
            cudaMemcpy(P[ i ], P[ i ], allocatesize, cudaMemcpyDeviceToHost);
            cudaFree(devM[ i ]);   
            cudaFree(devP[ i ]);
        }

    }

我认为使用以下方法可能会实现最快的性能CUBLAS批量gemm函数 http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemmbatched它是专门为此目的而设计的(执行大量“相对较小”的矩阵乘法运算)。

Even though you want to multiply your array of matrices (M[]) by a single matrix (N), the batch gemm function will require you to pass also an array of matrices for N (i.e. N[]), which will all be the same in your case.

EDIT:现在我已经完成了一个示例,对我来说很清楚,通过对下面的示例进行修改,我们可以传递一个N矩阵并有GPU_Multi函数只需发送单个N矩阵到设备,同时传递一个指针数组N, i.e. d_Narray在下面的示例中,所有指针都指向同一个N设备上的矩阵。

这是一个完整的批量 GEMM 示例:

#include <stdio.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <assert.h>

#define ROWM 4
#define COLM 3
#define COLN 5

#define cudaCheckErrors(msg) \
    do { \
        cudaError_t __err = cudaGetLastError(); \
        if (__err != cudaSuccess) { \
            fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
                msg, cudaGetErrorString(__err), \
                __FILE__, __LINE__); \
            fprintf(stderr, "*** FAILED - ABORTING\n"); \
            exit(1); \
        } \
    } while (0)


typedef float mytype;
// Pi = Mi x Ni
// pr = P rows = M rows
// pc = P cols = N cols
// mc = M cols = N rows
void GPU_Multi(mytype **M, mytype **N, mytype **P
  , size_t pr, size_t pc, size_t mc
  , size_t num_mat, mytype alpha, mytype beta)
{

    mytype *devM[num_mat];
    mytype *devN[num_mat];
    mytype *devP[num_mat];
    size_t p_size =sizeof(mytype) *pr*pc;
    size_t m_size =sizeof(mytype) *pr*mc;
    size_t n_size =sizeof(mytype) *mc*pc;
    const mytype **d_Marray, **d_Narray;
    mytype **d_Parray;
    cublasHandle_t myhandle;
    cublasStatus_t cublas_result;

    for(int i = 0 ; i < num_mat; i ++ )
    {
        cudaMalloc((void**)&devM[ i ], m_size );
        cudaMalloc((void**)&devN[ i ], n_size );
        cudaMalloc((void**)&devP[ i ], p_size );
    }
    cudaMalloc((void**)&d_Marray, num_mat*sizeof(mytype *));
    cudaMalloc((void**)&d_Narray, num_mat*sizeof(mytype *));
    cudaMalloc((void**)&d_Parray, num_mat*sizeof(mytype *));
    cudaCheckErrors("cudaMalloc fail");
    for(int i = 0 ; i < num_mat; i ++ ) {

        cudaMemcpy(devM[i], M[i], m_size , cudaMemcpyHostToDevice);
        cudaMemcpy(devN[i], N[i], n_size , cudaMemcpyHostToDevice);
        cudaMemcpy(devP[i], P[i], p_size , cudaMemcpyHostToDevice);
    }
    cudaMemcpy(d_Marray, devM, num_mat*sizeof(mytype *), cudaMemcpyHostToDevice);
    cudaMemcpy(d_Narray, devN, num_mat*sizeof(mytype *), cudaMemcpyHostToDevice);
    cudaMemcpy(d_Parray, devP, num_mat*sizeof(mytype *), cudaMemcpyHostToDevice);
    cudaCheckErrors("cudaMemcpy H2D fail");
    cublas_result = cublasCreate(&myhandle);
    assert(cublas_result == CUBLAS_STATUS_SUCCESS);
    // change to    cublasDgemmBatched for double
    cublas_result = cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N
      , pr, pc, mc
      , &alpha, d_Marray, pr, d_Narray, mc
      , &beta, d_Parray, pr
      , num_mat);
    assert(cublas_result == CUBLAS_STATUS_SUCCESS);

    for(int i = 0 ; i < num_mat ; i ++ )
    {
        cudaMemcpy(P[i], devP[i], p_size, cudaMemcpyDeviceToHost);
        cudaFree(devM[i]);
        cudaFree(devN[i]);
        cudaFree(devP[i]);
    }
    cudaFree(d_Marray);
    cudaFree(d_Narray);
    cudaFree(d_Parray);
    cudaCheckErrors("cudaMemcpy D2H fail");

}

int main(){

  mytype h_M1[ROWM][COLM], h_M2[ROWM][COLM];
  mytype h_N1[COLM][COLN], h_N2[COLM][COLN];
  mytype h_P1[ROWM][COLN], h_P2[ROWM][COLN];
  mytype *h_Marray[2], *h_Narray[2], *h_Parray[2];
  for (int i = 0; i < ROWM; i++)
    for (int j = 0; j < COLM; j++){
      h_M1[i][j] = 1.0f; h_M2[i][j] = 2.0f;}
  for (int i = 0; i < COLM; i++)
    for (int j = 0; j < COLN; j++){
      h_N1[i][j] = 1.0f; h_N2[i][j] = 1.0f;}
  for (int i = 0; i < ROWM; i++)
    for (int j = 0; j < COLN; j++){
      h_P1[i][j] = 0.0f; h_P2[i][j] = 0.0f;}

  h_Marray[0] = &(h_M1[0][0]);
  h_Marray[1] = &(h_M2[0][0]);
  h_Narray[0] = &(h_N1[0][0]);
  h_Narray[1] = &(h_N2[0][0]);
  h_Parray[0] = &(h_P1[0][0]);
  h_Parray[1] = &(h_P2[0][0]);

  GPU_Multi(h_Marray, h_Narray, h_Parray, ROWM, COLN, COLM, 2, 1.0f, 0.0f);
  for (int i = 0; i < ROWM; i++)
    for (int j = 0; j < COLN; j++){
      if (h_P1[i][j] != COLM*1.0f)
      {
        printf("h_P1 mismatch at %d,%d was: %f should be: %f\n"
          , i, j, h_P1[i][j], COLM*1.0f); return 1;
      }
      if (h_P2[i][j] != COLM*2.0f)
      {
        printf("h_P2 mismatch at %d,%d was: %f should be: %f\n"
          , i, j, h_P2[i][j], COLM*2.0f); return 1;
      }
    }
  printf("Success!\n");
  return 0;
}
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 CUDA 中执行多个矩阵乘法? 的相关文章

  • 如何从RichTextBox中获取显示的文本?

    如何获得显示的RichTextBox 中的文本 我的意思是 如果 RichTextBox 滚动到末尾 我只想接收那些对我来说可见的行 P S 获得第一个显示的字符串就足够了 您想使用 RichTextBox GetCharIndexFrom
  • 将图像文件从网址复制到本地文件夹?

    我有该图像的网址 例如 http testsite com web abc jpg http testsite com web abc jpg 我想将该 URL 复制到 c images 中的本地文件夹中 而且当我将该文件复制到文件夹中时
  • 在Application_AquireRequestState事件中用POST数据重写Url

    我有一个在其中注册路线的代码Application AcquireRequestState应用程序的事件 注册路由后 我会在 Http 运行时缓存中设置一个标志 这样我就不会再次执行路由注册代码 在此事件中注册路线有特定原因Applicat
  • 为什么这个函数指针赋值在直接赋值时有效,但在使用条件运算符时无效?

    本示例未使用 include 在 MacOS10 14 Eclipse IDE 上编译 使用 g 选项 O0 g3 Wall c fmessage length 0 假设这个变量声明 int fun int 这无法通过 std touppe
  • 访问“if”语句之外的变量

    我怎样才能使insuranceCost以外可用if陈述 if this comboBox5 Text Third Party Fire and Theft double insuranceCost 1 在 if 语句之外定义它 double
  • 导出类时编译器错误

    我正在使用 Visual Studio 2013 但遇到了一个奇怪的问题 当我导出一个类时 它会抛出 尝试引用已删除的函数 错误 但是 当该类未导出时 它的行为会正确 让我举个例子 class Foo note the export cla
  • 方法“xxx”不能是事件的方法,因为该类派生的类已经定义了该方法

    我有一个代码 public class Layout UserControl protected void DisplayX DisplayClicked object sender DisplayEventArgs e CurrentDi
  • 在 MATLAB 中创建共享库

    一位研究人员在 MATLAB 中创建了一个小型仿真 我们希望其他人也能使用它 我的计划是进行模拟 清理一些东西并将其变成一组函数 然后我打算将其编译成C库并使用SWIG https en wikipedia org wiki SWIG创建一
  • 无法加载文件或程序集“EntityFramework,版本=6.0.0.0”

    我究竟做错了什么 我该如何解决这个问题 我有一个包含多个项目的解决方案 它是一个 MVC NET 4 5 Web 应用程序 在调试模式下启动后调用其中一个项目时 出现此错误 导致此错误的项目具有以下参考 两个都是版本6 0 0 0 应用程序
  • 格式化货币

    在下面的示例中 逗号是小数点分隔符 我有这个 125456 89 我想要这个 125 456 89 其他示例 23456789 89 gt 23 456 789 89 Thanks 看看这个例子 double value 12345 678
  • 在VisualStudio DTE中,如何获取ActiveDocument的内容?

    我正在 VisualStudio 中编写脚本 并尝试获取当前 ActiveDocument 的内容 这是我当前的解决方案 var visualStudio new API VisualStudio 2010 var vsDTE visual
  • 如何用 C 语言练习 Unix 编程?

    经过五年的专业 Java 以及较小程度上的 Python 编程并慢慢感觉到我的计算机科学教育逐渐消失 我决定要拓宽我的视野 对世界的一般用处 并做一些 对我来说 感觉更重要的事情就像我真的对机器有影响一样 我选择学习 C 和 Unix 编程
  • 正确使用“extern”关键字

    有一些来源 书籍 在线材料 解释了extern如下 extern int i declaration has extern int i 1 definition specified by the absence of extern 并且有支
  • 使用(linq to sql)更新错误

    我有两个表 通过外键 CarrierID 绑定 Carrier CarrierID CarrierName CarrierID 1 CarrierName DHL CarrierID 2 CarrierName Fedex Vendor V
  • 如何访问窗口?

    我正在尝试使用其句柄访问特定窗口 即System IntPtr value Getting the process of Visual Studio program var process Process GetProcessesByNam
  • “int i=1,2,3”和“int i=(1,2,3)”之间的区别 - 使用逗号运算符的变量声明[重复]

    这个问题在这里已经有答案了 int i 1 2 3 int i 1 2 3 int i i 1 2 3 这些说法有什么区别 我无法找出任何具体原因 Statement 1 Result Compile error 运算符的优先级高于 运算符
  • 纯虚函数可能没有内联定义。为什么?

    纯虚函数是那些虚函数并且具有纯说明符 0 第 10 4 条第 2 款C 03 的内容告诉我们什么是抽象类 顺便说一句 如下 注意 函数声明不能 同时提供纯说明符和定义 尾注 示例 struct C virtual void f 0 ill
  • 在 C++ 和 Windows 中使用 XmlRpc

    我需要在 Windows 平台上使用 C 中的 XmlRpc 尽管我的朋友向我保证 XmlRpc 是一种 广泛可用的标准技术 但可用的库并不多 事实上 我只找到一个库可以在 Windows 上执行此操作 另外一个库声称 您必须做很多工作才能
  • 如何设置 CMake 与 clang 交叉编译 Windows 上的 ARM 嵌入式系统?

    我正在尝试生成 Ninja makefile 以使用 Clang 为 ARM Cortex A5 CPU 交叉编译 C 项目 我为 CMake 创建了一个工具链文件 但似乎存在错误或缺少一些我无法找到的东西 当使用下面的工具链文件调用 CM
  • 启动画面后主窗口出现在其他窗口后面

    我有一个带有启动屏幕的 Windows 窗体应用程序 当我运行该应用程序时 启动屏幕显示正常 消失并加载应用程序的主窗体 但是 当我加载主窗体时 它出现在包含该应用程序的 Windows 资源管理器目录下 这是运行启动画面然后运行主窗体的代

随机推荐

  • accept() 创建一个新套接字是什么意思?

    我的问题基于以下理解 套接字由 ip port 定义 服务器和客户端都有自己的套接字 Socket连接由五组server ip server port client ip client port protocol定义 套接字描述符是标识套接
  • 如何将带有嵌套节点(父/子关系)的 XML 导入 Access?

    我正在尝试将 XML 文件导入 Access 但它创建了 3 个不相关的表 也就是说 子记录被导入到子表中 但无法知道哪些子记录属于哪个父记录 如何导入数据来维护父子节点 记录 之间的关系 以下是 XML 数据的示例
  • 将目录从 Assets 复制到本地目录

    我正在尝试使用资产文件夹中的目录并将其作为File 是否可以访问 Assets 目录中的某些内容File 如果没有 如何将 Assets 文件夹中的目录复制到应用程序的本地目录 我会像这样复制一个文件 try InputStream str
  • Tkinter 嵌套主循环

    我正在写一个视频播放器tkinter python 所以基本上我有一个可以播放视频的 GUI 现在 我想实现一个停止按钮 这意味着我将有一个mainloop 对于 GUI 还有另一个嵌套mainloop 播放 停止视频并返回 GUI 启动窗
  • JyNI Eclipse 设置

    我在 Eclipse 中有以下 Java 文件 package java python tutorial import org python core PyInstance import org python util PythonInte
  • 仅使用 NumPy einsum 处理上三角元素

    我使用 numpy einsum 来计算形状为 3 N 的列向量 pts 数组与其自身的点积 从而得到形状为 N N 的矩阵 dotps 与所有点积 这是我使用的代码 dotps np einsum ij ik gt jk pts pts
  • 为什么 Ruby 解析文件时常量不像局部变量那样被初始化?

    在 Ruby 中 我知道我可以做这样的事情 if false var Hello end puts var 应用程序不会崩溃 并且var只需设置为nil 我读到 这种情况的发生是由于 Ruby 解析器的工作方式造成的 为什么同样的方法不适用
  • 在 MVC 5 中,如何在单个 Ajax POST 请求中发送 ViewModel 和文件?

    我有一个 ASP NET MVC 5 应用程序 我正在尝试发送带有模型数据的 POST 请求 并且还包括用户选择的文件 这是我的 ViewModel 为了清晰起见进行了简化 public class Model public string
  • 给GAC,还是不给GAC?

    我有一个用 ASP NET 3 5 编写的数据访问层 DAL 并使用 Microsoft 模式和实践库 以下简称 P P 来完成其数据访问 我安装了 P P 它驻留在我的 GAC 中 因此 从逻辑上讲 我的 DAL 在 GAC 中引用它 因
  • `checkout` = `reset` + `symbolic ref`?

    Suppose a branch是一个现有分支 指向与之前不同的提交HEAD指着 HEAD可能直接或通过某些方式指向提交branch 以下命令等效吗 git checkout a branch and git symbolic ref HE
  • 分布式张量流中的并行进程

    我有带有训练参数的张量流神经网络 它是代理的 策略 网络正在核心程序的主张量流会话的训练循环中进行更新 在每个训练周期结束时 我需要将该网络传递给几个并行进程 工作人员 这些进程将使用它来从代理策略与环境的交互中收集样本 我需要并行执行 因
  • 没有传输安全性的 WCF 可靠会话不会按时发生故障事件

    我遇到了可靠会话的一个非常有趣的行为 我使用的是netTcp绑定 双工通道 可靠会话 当我尝试侦听channel faulted时 如果安全模式设置为transport 则当客户端断开连接时 故障事件将立即触发 但是 当我将绑定的安全模式设
  • 在实体框架中附加集合

    使用实体框架 我可以使用附加单个对象 entity Attach 但是 我没有看到任何方法允许我将多个对象的集合 数组添加到实体 我必须循环遍历集合中的每个项目并调用entity Attach 每一次 是的 您必须循环遍历子集合并Attac
  • 在 MySQL 中存储 IPv6 地址

    正如 需要支持 ipv6 的 inet aton 和 inet ntoa 函数 http bugs mysql com bug php id 34037 目前没有用于存储 IPv6 地址的 MySQL 函数 用于存储 插入的推荐数据类型 函
  • 如何在 CSS 中用 SVG 图标替换 Web 字体(Font Awesome)?

    我注意到在我的 CSS 文件中 有一些使用 Font Awesome Web 字体的规则 如下所示 ul fancy li before category page ul li before display none font style
  • 删除URL参数而不刷新页面

    我试图删除 之后的所有内容在文档准备好的浏览器 URL 中 这是我正在尝试的 jQuery document ready function var url window location href url url split 0 我可以做到
  • toLocaleLowerCase() 和 toLowerCase() 之间的区别[重复]

    这个问题在这里已经有答案了 我试图fiddle http jsfiddle net xameeramir kr33b0aL with toLocaleLowerCase http www w3schools com jsref jsref
  • 如何退出 Instagram API?

    Instagram API 身份验证页面没有任何有关如何注销用户的信息 在使用 API 的 iOS 应用程序上 我该如何允许用户注销 要注销用户 您只需删除令牌即可 如果用户不希望您的应用访问他们的数据 他们将取消您的应用访问权限 如果您想
  • 编写无 BOM 的 UTF-8

    这段代码 OutputStream out new FileOutputStream new File C file test txt out write A getBytes 和这个 OutputStream out new FileOu
  • 如何在 CUDA 中执行多个矩阵乘法?

    我有一个方阵数组int M 10 以便M i 定位第一个元素i th 矩阵 我想将所有矩阵相乘M i 通过另一个矩阵N 这样我就收到了方阵数组int P 10 作为输出 我看到有不同的可能性 分配不同元素的计算M i 到不同的线程 例如 我