CUDA矩阵乘法优化

2023-11-01

前言

纸上的来终觉浅,绝知此事要躬行。

naive写法

一个矩阵的乘法简单如下:C=A*B, 一般用gemm(A,B,C,M,N,K)来表示,其中的m,n,k代表的位置如下,默认是k表示消失的纬度。
在这里插入图片描述
上图的红色虚线围起来的是一个block要负责的数据区域,具体的代码如下:

__global__ 
void matrixMul(const float *A, const float *B, float *C, int M, int N, int K) 
{
	int col = blockIdx.x * blockDim.x + threadIdx.x;
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    if(row < M && col < N) 
    {
        float c = 0;
        for(int i = 0; i < K; ++i)
        {
            c += A[row * K + i] * B[i * N + col];
        }
        C[row * N + col] = c;
    }
}

第一步先写一个简单的矩阵乘法,不要觉得简单就不写,后面不管怎么耍花招,都是基于当下的核心逻辑和计算流程的。

share memory优化

代码我们可以看到,每一个线程要k次乘法和K次加法,每做计算一次 FMA(乘累加)之前需要读一次 A 和读一次 B,读取 Global Memory 的代价很大,通常都需要几百个 cycle(时钟周期),而计算一次 FMA 通常只需要几个 cycle,大量的时间被花费在了访存上, 那么如何减少访问global memory呢?我们知道share_memory是片上内存,访问速度比较快,我们考虑把A和B中的虚线内的数据放在share_memory上,然后计算的时候从share_memeory上取,这样的话其实会多一次数据从global转移到share上的操作,但是每次做乘加计算取数的时候,省的时间会远远多于这一次操作。
在这里插入图片描述
上面我们假设是把A、B中的虚线部分数据导入到share中,但其实share的大小有限,所以还是继续分片,原理如下:
在这里插入图片描述
具体的伪代码是:

//分为j块
for(int i = 0; i < j; i+1)
{
        load_gmem_to_smem(A, i, smemA);
        load_gmem_to_smem(B, i, smemB);
        __syncthreads();
        //compute
        C_i=gemm(a_i,b_i);
        //累加
        C += C_i;

 
}

可以运行的代码:

template <int BLOCK_SIZE>
__global__
void MatrixMulCUDA(
    const float * __restrict__ A,
    const float * __restrict__ B,
    float * __restrict__ C,
    const int M,
    const int K,
    const int N) {
    // Block index
    int bx = blockIdx.x;
    int by = blockIdx.y;
    // Thread index
    int tx = threadIdx.x;
    int ty = threadIdx.y;
    //对应上图a2的左上角位置
    int aBegin = K * BLOCK_SIZE * by;
    //对应上图aj的右上角位置(注意上图有溢出到A矩阵外,这里代码默认是刚刚在A矩阵里的)
    int aEnd   = aBegin + K - 1;
	//每一个a矩阵的左上角位置距离差
    int aStep  = BLOCK_SIZE;

    // 对应上图b2的左上角位置
    int bBegin = BLOCK_SIZE * bx;
    // 每一个b矩阵的左上角位置距离差
    int bStep  = BLOCK_SIZE * N;

    float Csub = 0;
    // 存储aj的share memory
    __shared__ float As[BLOCK_SIZE][BLOCK_SIZE];
    // 存储bj的share memory
    __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];
     //对应上图就是循环3次
    for (int a = aBegin, b = bBegin; a < aEnd; a += aStep, b += bStep) {
        //每个线程下载一个元素到share memory
        As[ty][tx] = A[a + K * ty + tx];
        Bs[ty][tx] = B[b + N * ty + tx];

        // Synchronize 为了确保share memory需要的数据都下载到位
        __syncthreads();
		// 每个线程计算cj中的一个位置的结果
        #pragma unroll
        for (int k = 0; k < BLOCK_SIZE; ++k) {
            Csub = fma(As[ty][k], Bs[k][tx], Csub);
        }
        // 确保cj中的所有位置的结果都算好了
        __syncthreads();
     }
    //最后把每个线程负责的算的结果从寄存器导入到device memory
    C[N * ( BLOCK_SIZE * by + ty ) + BLOCK_SIZE * bx + tx ] = Csub;
}

实验结果如下,当M=160,K=160, N=160的时候,实现效果显示还不错,比naive好但是比cublas还是有差距。
在这里插入图片描述

第二次优化

第一次优化后,访存代价从几百 cycle 降低到几十 cycle,并不改变问题的本质。问题的关键在于主体循环由两条 Load 指令与一条 FMA 指令构成,计算指令只占总体的 1/3,计算访存比过低,最终导致了访存延迟不能被隐藏,从而性能不理想.

这里解释一下,线程的指令会发给调度器,调度器分配对应的执行单元。这里比如有20个线程,机器上一个调度器,一个计算单元和访存单元,此刻线程1告诉调度器要执行加法计算,计算单元需要10s可以得到结果, 等待期间调度器就会问问其他线程有需要访存的吗?毕竟有一个访存单元在闲着,这时候线程8说他需要访存操作,不过一次访存需要200s, 因为计算单元速度很快,所以当20个线程的计算任务都完成时,只有一个线程的访存任务完成,所以后面还要200s*19这么长的时间。这里可以发现一共需要200S*20的时间,我们计算时间10s*20被完成隐藏了用户感知不到,这就是所谓的隐藏延迟,知道了原理,我们的任务说白了就是别让计算单元闲着,如果一个线程的计算时间如果和访存时间一模一样,那么完全就可以隐藏计算或者是访存的时间了,这不是美滋滋?但是但是,这里的前提是只有一个访存单元和计算单元,实际上底层硬件还是差距很大的,不同架构和型号的显卡也不一样,这也是为啥同样的的代码在不同机器上的性能不一样,此外调度器的规则,计算与访存任务的依赖等等都有可能导致性能差异。

float c[4][4] = {{0}};
float a_reg[4];
float b_reg[4];
for(int i = 0; i < K; i += tile_with)
{
        load_gmem_tile_to_smem(A, i, smemA);
        load_gmem_tile_to_smem(B, i, smemB);
        __syncthreads();
		//compute
        for(int j = 0; j < tile_with; ++j) 
        {
            // load tile from shared mem to register 
            load_smem_tile_to_reg(smemA, j, a_reg);
            load_smem_tile_to_reg(smemB, j, b_reg);
            // compute matrix multiply accumulate 4x4
            mma4x4(a_reg, b_reg, c)}
        // 累加
        C += C_i;
}

未完待续。。。
参考:https://zhuanlan.zhihu.com/p/410278370

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

CUDA矩阵乘法优化 的相关文章

随机推荐

  • std::set 的删元素 c++

    程序演示std set的删除操作 set类模板中 它和vector list不同 set map都是关联式容易 set 内部是基于红黑树实现的 插入和删除操作效率比较高 下面测试一下怎么删除set的里面的元素 include
  • 台资企业管理职的中英文称谓以及级别

    台资企业管理职的中英文称谓以及级别 从低到高 组长 team leader 课长 supervisor 专理 assistant manager 也就是经理助理 经理 manager 资深经理 senior manager 即我们说的高级经
  • UVA-215 电子表格计算器 题解答案代码 算法竞赛入门经典第二版

    GitHub jzplp aoapc UVA Answer 算法竞赛入门经典 例题和习题答案 刘汝佳 第二版 题目并不难 数据量也不大 一次数据最多是20 10是200个 因此即使最长的嵌套引用关系 也只有200层 我们使用暴力 循环200
  • 【ubuntu】ubuntu添加或删除用户

    文章目录 1 创建新用户 2 为新用户填加超级用户权限 方法一 填加新用户到sudo group 方法二 在 etc sudoers中指定用户的权限 3 删除用户 创建新用户的意义不再多述 最直观的就是多个人用同一台机器 要为每个人创建一个
  • tomcat自动加载改变的class文件,且无需重启

    不重启Tomcat有两种方式 热部署 热加载 热部署 容器在运行时重新部署整个项目 这类环境下 一般整个内存会被清空 重新加载 这类方式有可能造成sessin丢失等问题 tomcat 6以上已解决该问题 热加载 最好是在调试过程中使用 以免
  • Caffe源码:math_functions 解析

    目录 目录 主要函数 caffe cpu gemm 函数 caffe cpu gemv 函数 caffe axpy 函数 caffe set 函数 caffe add scalar 函数 caffe copy 函数 caffe scal 函
  • 基于ChatGPT-API实现聊天机器人服务

    1 背景 要基于GPT自己去实现一个聊天机器人服务功能实现上其实特别简单 将上游服务过来的请求转换为GPT接口请求发出去然后直接返回或者回调给上游服务即可 但是其中的一些其他问题不知道大家有没有考虑过 1 搞成一个大同步的实现 当并发真的上
  • 集合方法的代码

    创建一个集合 获取从 某一索引开始到某一索引的前一位结束的代码 class b public static void main String args List
  • GO语言学习-变量2和常量与iota枚举

    变量进阶 1 多重赋值 从左至右依次匹配 如果有不需要的数据用匿名变量处理 2 匿名变量 下划线 丢弃数据不处理 匿名变量主要用于配合函数返回值使用 注 go语言的函数返回值可以有多个 使用匿名变量可以舍去不需要的返回值 package m
  • 下载安装Android Studio教程

    步骤1 下载Android Studio 访问Android Studio官方网站 https developer android com studio 点击 下载Android Studio 按钮 选择适用于您操作系统的版本 然后下载安装
  • Unix Shell 范例精解——awk课后题

    题目数据如下 Mike Harrington 510 548 1278 250 100 175 Christian Dobbins 408 538 2358 155 90 201 Susan Dalsass 206 654 6279 250
  • Three.js 基础- 第 2 章 - 几何体BufferGeometry

    Three js 基础 第 2 章 几何体BufferGeometry Three js教程 几何体BufferGeometry 在Three js中 几何体是3D对象的基本形状 本教程将介绍如何使用缓冲类型几何体BufferGeometr
  • 【独家源码】ssm高校试卷管理系统i0lzr应对计算机毕业设计困难的解决方案

    本项目包含程序 源码 数据库 LW 调试部署环境 文末可获取一份本项目的java源码和数据库参考 系统的选题背景和意义 选题背景 高校试卷管理是教学工作中的重要环节 涉及到试卷的编写 存储 分发和评阅等多个方面 然而 传统的试卷管理方式存在
  • [转]Java 线程池的原理与实现

    最近在学习线程池 内存控制等关于提高程序运行性能方面的编程技术 在网上看到有一哥们写得不错 故和大家一起分享 分享 Java 线程池的原理与实现这几天主要是狂看源程序 在弥补了一些以前知识空白的同时 也学会了不少新的知识 比如 NIO 或者
  • SoftwareSerial库的使用——Arduino软件模拟串口通信

    除HardwareSerial外 Arduino还提供了SoftwareSerial类库 它可以将你的其他数字引脚通过程序模拟成串口通信引脚 通常我们将Arduino UNO上自带的串口称为硬件串口 而使用SoftwareSerial类库模
  • 如何开启计算机cpu虚拟化,电脑开启虚拟化设置的方法 如何开启虚拟化设置

    虚拟化设置的开启其实很简单 因为大家没有接触和操作过 所以一开始会不知所措 虚拟化设置的开启其实很简单 因为大家没有接触和操作过 所以一开始会不知所措 小编在这里为广大玩家深度总计虚拟化开启方法 方便大家在电脑端更流畅的体验手机游戏 虚拟化
  • maven工程依赖的jar包,在本地仓库有,但是pom.xml文件却报错找不到jar包

    例如 Missing artifact com ibm db2 db2jcc license cisuz jar 10 1 但在我本地的仓库中却存在这个jar包 查找了很多的资料发现了两种解决方法 第一种 在eclipse中的window
  • 透彻了解inlining的里里外外——条款30

    Inline函数 多棒的点子 它们看起来像函数 动作像函数 比宏好得多 见条款2 可以调用它们又不需要蒙受函数调用所招致的额外开销 你还能要求更多吗 你实际获得的比想到的还多 因为 免除函数调用成本 只是故事的一部分而已 编译器最优化机制通
  • 2021美赛F题

    2021年 问题E 重新优化食物系统 最近的事件向我们表明 我们的全球粮食系统即使在世界的某些地区也是不稳定的 它通常服务于全世界 这些不稳定的部分原因是我们目前的全球气候变化 庞大的国内和国际食品生产商和经销商体系 这个食物系统 使食物的
  • CUDA矩阵乘法优化

    前言 纸上的来终觉浅 绝知此事要躬行 naive写法 一个矩阵的乘法简单如下 C A B 一般用gemm A B C M N K 来表示 其中的m n k代表的位置如下 默认是k表示消失的纬度 上图的红色虚线围起来的是一个block要负责的