CUDA系列三:矩阵相乘

2023-11-19

本博文主要讲解下基于cuda的矩阵相乘,cuda特别擅长的就是矩阵乘法,而且也比较容易实现。通过矩阵乘法的实现,可以比较容易理解cuda的核心思想。网上也有很多基于cuda实现的矩阵乘法,但是感觉都不完成,要不就是有错,本文给出的代码都是经过验证可行的,希望能够帮助到大家。

矩阵乘法实现方式一:矩阵乘法的逐点实现方式,具体如下图所示

                           

对应实现代码:

#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>


__global__ void MatMul(int *M,int *N,int *P,int width)
{
	int x = threadIdx.x;
	int y = threadIdx.y;
	
	float Pervalue = 0;
	
	float elem1 = 0.0,elem2 = 0.0,value = 0.0;
	for(int i = 0;i < width;i++)
	{
		elem1 = M[y * width + i];//取M矩阵的一行
		elem2 = N[i * width + x];//取N矩阵的一列
		
		value += elem1 * elem2;//求和
	}
	
	P[y * width + x] = value;
}

int main()
{
	const int ND = 30;
	int a[ND][ND],b[ND][ND],c[ND][ND];
	int *M,*N,*P;
	
	int width = ND;
	int NUM = 900;
	dim3 blockSize(ND,ND);
	
	cudaEvent_t start,stop;
	float elapsedTime = 0;
	cudaEventCreate(&start);
	cudaEventCreate(&stop);
	
	//设备端内存分配
	cudaMalloc((void**)&M,ND * ND * sizeof(int));
	cudaMalloc((void**)&N,ND * ND * sizeof(int));
	cudaMalloc((void**)&P,ND * ND * sizeof(int));
	
	//初始化
	for(int i = 0;i < ND;i++)
	{
		for(int j = 0;j < ND;j++)
		{
			a[i][j] = 2;
			b[i][j] = 3;
		}
	}
	
	int Size = ND * ND;
	//数据拷贝,主机到设备
	cudaMemcpy(M,a,Size * sizeof(int),cudaMemcpyHostToDevice);
	cudaMemcpy(N,b,Size * sizeof(int),cudaMemcpyHostToDevice);
	
	cudaEventRecord(start,0);
	MatMul<<<1,blockSize>>>(M,N,P,width);//调用核函数
	cudaThreadSynchronize();
	cudaEventRecord(stop,0);
	cudaEventSynchronize(stop);
	cudaEventElapsedTime(&elapsedTime,start,stop);
	
	cudaMemcpy(c,P,Size * sizeof(int),cudaMemcpyDeviceToHost);
	
	printf("c0 = %d \n",c[0][0]);
	
	//释放设备内存
	cudaFree(M);
	cudaFree(N);
	cudaFree(P);
	
	return 0;
}

运行结果:

 

矩阵相乘实现方式二:矩阵乘法分块实现,具体如下图所示

                 

具体代码实现:

#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>


#define TILE_WIDTH 10

//核函数的具体实现
__global__ void matmul(int *M,int *N,int *P,int width)
{
	__shared__ float Mds[TILE_WIDTH][TILE_WIDTH];
	__shared__ float Nds[TILE_WIDTH][TILE_WIDTH];
	
	int bx = blockIdx.x;
	int by = blockIdx.y;
	int tx = threadIdx.x;
	int ty = threadIdx.y;
	
	int Col = bx * TILE_WIDTH + tx;
	int Row = by * TILE_WIDTH + ty;
	
	int Pervalue = 0;
	
	for(int i = 0;i < width / TILE_WIDTH;i++)  //有多少个TILE_WIDTH,每个循环计算一个块的大小
	{
		Mds[ty][tx] = M[Row * width + (i * TILE_WIDTH + tx)];
		Nds[ty][tx] = N[Col + (i * TILE_WIDTH + ty) * width];
		__syncthreads();
		
		
		for(int k = 0;k < TILE_WIDTH;k++) //TILE_WIDTH相乘
			Pervalue += Mds[ty][k] * Nds[k][tx];
		__syncthreads();
	}
	
	P[Row * width + Col] = Pervalue;
}


int main()
{
	const int Nd = 30;
	int Size = Nd * Nd;
	int *M,*N,*P;
	int width = Nd / 3;
	
	int a[Nd][Nd];
	int b[Nd][Nd];
	int c[Nd][Nd];
	
	//线程块以及线程的划分
	dim3 gridSize(Nd / width,Nd / width);
	dim3 blockSize(width,width);
	
	cudaEvent_t start,stop;
	float elapsedTime;
	cudaEventCreate(&start);
	cudaEventCreate(&stop);
	
	//设备内存分配
	cudaMalloc((void**)&M,Size * sizeof(int));
	cudaMalloc((void**)&N,Size * sizeof(int));
	cudaMalloc((void**)&P,Size * sizeof(int));
	
	//初始化
	for(int i = 0;i < Nd;i++)
	{
		for(int j = 0;j < Nd;j++)
		{
			a[i][j] = 2;
			b[i][j] = 3;
		}
	}
	
	//数据拷贝,主机到设备
	cudaMemcpy(M,a,Size * sizeof(int),cudaMemcpyHostToDevice);
	cudaMemcpy(N,b,Size * sizeof(int),cudaMemcpyHostToDevice);
	
	cudaEventRecord(start,0);
	matmul<<<gridSize,blockSize>>>(M,N,P,Nd); //调用核函数
	cudaThreadSynchronize();
	cudaEventRecord(stop,0);
	cudaEventSynchronize(stop);
	cudaEventElapsedTime(&elapsedTime,start,stop);
	
	
	cudaMemcpy(c,P,Size * sizeof(int),cudaMemcpyDeviceToHost);
	printf("c0 = %d\n",c[0][0]);
	
	
	cudaFree(M);
	cudaFree(N);
	cudaFree(P);
	
	return 0;
}

运行结果:

本文也参考了网上的一些资料,主要是做了一定的修改以及程序的完备,图片就直接网上copy的,水平有限,有不当之处,请指教,谢谢!

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

CUDA系列三:矩阵相乘 的相关文章

随机推荐

  • 微软服务器的主要功能,数据库服务器主要功能

    数据库服务器主要功能 内容精选 换一换 HANA全称High performanceAnalyticAppliance是由SAP开发的基于内存的面向行 列存储的关系型数据库管理系统 其作为数据库服务器的主要功能是根据应用程序的要求存储和检索
  • jdk17下载

    官网下载 https download oracle com java 17 latest jdk 17 windows x64 bin zip
  • 也想做一个绝地求生版的汽车控制移动,进来瞧瞧?(干货满满)

    控制车子移动 效果图附上 1 首先4个车轮复制一遍为车轮2备用 2 给车轮2全部添加wheel collider 只剩下车轮碰撞器和transform组件 3 给原版4个车轮添加脚本wheel 变量共有 面板赋值 依次添加车轮2里面的车轮c
  • c#图解教程和c#高级编程电子书链接

    链接 https pan baidu com s 1y TM08JvyBh8kQ0v7uT5hg 提取码 b0cq
  • Python的多维空数组赋值

    Python里面的list tuple默认都是一维的 创建二维数组或者多维数组也是比较简单 可以这样 list1 1 2 list1 append 3 4 可以这样 list2 1 2 3 4 还可以这样 list3 1 2 list3 i
  • android界面监控,防劫持

    1 首先要对自己应用的activity建立一个白名单 2 权限
  • http协议从客户端提交数据给服务器并返回数据

    老罗视频学习 本例从客户端提交数据给服务器 服务器接收到数据之后 看是否匹配 匹配返回字符串 login is success 失败返回 login is error 一 客户端 初始化url地址 private static String
  • Git如何比较不同分支的差异

    前两天 良许在做集成的时候碰到了一件闹心事 事情是这样的 良许的一位同事不小心把一个错误的 dev 分支 merge 到了 master 分支上 导致了良许编译不通过 于是 我们需要将版本回退到 merge 之前的状态 如果是下面这个状态
  • 电子设计竞赛(三)-SPWM与PID

    1 SPWM波调制技术 逆变电路的控制方式主要是采用SPWM 正弦脉宽调制技术 IR2104控制开关管的通断来实现正弦调制 SPWM的基本思路是将一个正弦波按等宽间距分成N等份 对于每一个波形以一个等面积的脉冲来对应 使脉冲的中点与相应正弦
  • python3 hashlib库sha256、pbkdf2_hmac、blake2b基本用法

    hashlib sha256 import hashlib x hashlib sha256 x update b asd print x 1 x hexdigest x hashlib sha256 x update asd encode
  • 数据下载网站整理

    数据十分重要 如何找到理想的数据显得更重要了 这里记录自己经过网上查询到的数据 进行整理 如果侵权 请联系我删除 再次感谢网友大佬们提供的资料 1 中国气象站点数据 下载地址 https www resdc cn data aspx DAT
  • 递归算法中的时间复杂度分析

    对于一种算法的时间复杂度分析还是特别重要的 在一些非递归算法中 我们仅仅看运算次数最多的那一行代码可能执行多少次就可以 实际就是看在循环中变量的变化 但是对于递归算法中该怎么分析呢 下面介绍几种递归函数中的算法时间复杂度分析的方法 0 递推
  • 使用paramiko跨服务器传输文件/文件夹

    一些概念 SSH Secure Shell 安全外壳协议 是建立在应用层基础上的安全协议 专为远程登录和其他网络服务提供安全性的协议 SFTP SSH 文件传输协议 Secret File Transfer Protocol SFTP 安全
  • window.location.href的用法

    window location href的用法 一 前言 二 常见用例 一 前言 window location href 是一个用于获取当前页面 URL 或让浏览器跳转到新 URL 的重要方法 是 window location 对象的属
  • 【gis系列】等高线创建dem,以及高程分析,坡度分析,坡向分析

    绝对原创 首先 我们要整理一份cad的文件格式 这里我不说那么多 就是在某某地图下载后 方法很多 可以通过qgis globalmapper来操作数据 以及一些普通的地图软件直接生成 这里呢 然后进入cad 把里面的高程标注信息给删除掉 图
  • 机器学习资源大全

    C 计算机视觉 CCV 基于C语言 提供缓存 核心的机器视觉库 新颖的机器视觉库 OpenCV 它提供C C Python Java 以及 MATLAB接口 并支持Windows Linux Android and Mac OS操作系统 通
  • SD卡初始化以及命令详解

    SD卡是嵌入式设备中很常用的一种存储设备 体积小 容量大 通讯简单 电路简单所以受到很多设备厂商的欢迎 主要用来记录设备运行过程中的各种信息 以及程序的各种配置信息 很是方便 有这样几点是需要知道的 SD 卡是基于 flash 的存储卡 S
  • Visual Studio 创建DLL 、LIB及调用

    一 前言 在工程中 经常会根据不同的场景需求将类封装成库文件 以供他人使用 那么如何利用VS进行库 动态库 的生成呢 以下简要演示实现过程 开发环境 VS2019 二 生成DLL动态库 1 创建控制台工程 添加类库函数 2 添加函数代码 d
  • vue打包及运行白屏,Android低版本适配

    版本支持 对于Android 4 X无法打开的问题 具体表现 1 运行后低版本谷歌浏览器打开后白屏 2 打包后低版本Android系统打不开 白屏 打包前npm run build后低版本浏览器打开白屏 如果低版本打开白屏那么打包后低版本A
  • CUDA系列三:矩阵相乘

    本博文主要讲解下基于cuda的矩阵相乘 cuda特别擅长的就是矩阵乘法 而且也比较容易实现 通过矩阵乘法的实现 可以比较容易理解cuda的核心思想 网上也有很多基于cuda实现的矩阵乘法 但是感觉都不完成 要不就是有错 本文给出的代码都是经