cutlass入门: 调用cutlass做通用矩阵乘法Gemm(附代码)

2023-11-11

cutlass是CUDA C++模板抽象的集合,用于实现CUDA中所有级别和规模的高性能矩阵乘法(GEMM)和相关计算。相较于cuBLAS和cuDNN,cutlass中包含了更多可重用的模块化软件组件,这使得cutlass相较于前两者更为灵活。

cutlass项目官方网站:GitHub - NVIDIA/cutlass: CUDA Templates for Linear Algebra Subroutines

本文将展示如何用cutlass实现最基本的矩阵计算。

cutlass的使用流程与普通kernel大致相同:先在host端分配空间生成数据,再将host端的数据传入device端的buffer中,输入参数调用cutlass模块进行运算,最后将device端的数据传回Host端。

本代码实现的矩阵计算为D = α * A * B + β * C。其中,标量α = β = 1,矩阵A的大小为3840*4096,矩阵B的大小为4096*4096,矩阵C的大小为3840*4096。这三个矩阵均被初始化为全部元素为1的矩阵。那么经过简单计算我们可以知道矩阵D所包含的全部元素应该为4097。

代码及注释如下:

#include <iostream>                                           // 标准输入输出流
#include "cutlass/gemm/device/gemm.h"                         // 引入cutlass头文件

using ColumnMajor = cutlass::layout::ColumnMajor;             // 列主序存储方式
using RowMajor    = cutlass::layout::RowMajor;                // 行主序存储方式

using CutlassGemm = cutlass::gemm::device::Gemm<float,        // A矩阵数据类型
                                                RowMajor,     // A矩阵存储方式
                                                float,        // B矩阵数据类型
                                                RowMajor,     // B矩阵存储方式
                                                float,        // C矩阵数据类型
                                                RowMajor>;    // C矩阵存储方式

void generate_tensor_2D(float *ptr, int i_M, int i_N){        // 二维矩阵填充函数(此处全部填充1)
    for(int i = 0; i < i_M; i++){
        for(int j = 0; j < i_N; j++){
            *(ptr + i*i_N + j ) = 1.0;
        }
    }
}

int main(int argc, const char *arg[]) {

    int M = 3840;           //M
    int N = 4096;           //N
    int K = 4096;           //K

    int lda = K;
    int ldb = K;
    int ldc = N;
    int ldd = N;

    float alpha = 1.0;      //alpha
    float beta = 1.0;       //beta

    float *A;               //申明A矩阵host端指针
    float *B;               //申明B矩阵host端指针
    float *C;               //申明C矩阵host端指针
    float *D;               //申明D矩阵host端指针

    size_t A_mem_size = sizeof(float) * M * K; //memory size of matrix A = M * K * sizeof(float)
    size_t B_mem_size = sizeof(float) * K * N; //memory size of matrix B = K * N * sizeof(float)
    size_t C_mem_size = sizeof(float) * M * N; //memory size of matrix C = M * N * sizeof(float)
    size_t D_mem_size = sizeof(float) * M * N; //memory size of matrix C = M * N * sizeof(float)

    A = (float*)malloc(A_mem_size);  // host端A矩阵分配内存
    B = (float*)malloc(B_mem_size);  // host端B矩阵分配内存
    C = (float*)malloc(C_mem_size);  // host端C矩阵分配内存
    D = (float*)malloc(D_mem_size);  // host端D矩阵分配内存

    generate_tensor_2D(A, M, K);     // 填充A矩阵
    generate_tensor_2D(B, K, N);     // 填充B矩阵  
    generate_tensor_2D(C, M, N);     // 填充C矩阵  

    float *d_A;            // 申明device端A矩阵的指针
    float *d_B;            // 申明device端B矩阵的指针
    float *d_C;            // 申明device端C矩阵的指针
    float *d_D;            // 申明device端D矩阵的指针

    cudaMalloc((void**)&d_A, A_mem_size);  // device端为A矩阵分配内存
    cudaMalloc((void**)&d_B, B_mem_size);  // device端为B矩阵分配内存
    cudaMalloc((void**)&d_C, C_mem_size);  // device端为C矩阵分配内存
    cudaMalloc((void**)&d_D, D_mem_size);  // device端为D矩阵分配内存

    cudaMemcpy(d_A, A, A_mem_size, cudaMemcpyHostToDevice); // 将矩阵A的数据传递到device端
    cudaMemcpy(d_B, B, B_mem_size, cudaMemcpyHostToDevice); // 将矩阵B的数据传递到device端
    cudaMemcpy(d_C, C, C_mem_size, cudaMemcpyHostToDevice); // 将矩阵C的数据传递到device端

    CutlassGemm gemm_operator;                  // 申明cutlassgemm类
    CutlassGemm::Arguments args({M, N, K},      // Gemm Problem dimensions
                                {d_A, lda},     // source matrix A
                                {d_B, ldb},     // source matrix B
                                {d_C, ldc},     // source matrix C
                                {d_D, ldd},     // destination matrix D
                                {alpha, beta}); // alpha & beta
    gemm_operator(args); //运行Gemm

    cudaMemcpy(D, d_D, D_mem_size, cudaMemcpyDeviceToHost);  //将运行结果D矩阵传回host端
    std::cout << D[0] << std::endl;                          //打印D中第一行第一个数据
    std::cout << D[M * N - 1] << std::endl;                  //打印D中最后一行最后一个数据

    return 0;
}   

代码编译:

$nvcc -I <PATH TO CUTLASS>/include <YOUR SOURCE FILE>

运行结果如下:

 可见,本程序运行成功利用cutlass实现了D = α * A * B + β * C的矩阵计算。

*本文章及程序仅供交流学习,请勿用作商业用途。转载请告知作者。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

cutlass入门: 调用cutlass做通用矩阵乘法Gemm(附代码) 的相关文章

  • Postsharp 不登录跟踪级别

    我喜欢在跟踪级别记录一些 Postsharp 消息 不幸的是 日志到这个级别没有打印任何输出 所有其他级别都在工作 与控制台或 NLog 后端或从其他类登录时的行为相同 如何启用跟踪级别 应用程序 xaml cs Log Attribute
  • 将 ARGB 拆分为字节值

    我有一个 ARGB 值存储为 int 类型 它是通过调用 ToArgb 来存储的 我现在想要来自 int 值的各个颜色通道的字节值 例如 int mycolor 16744448 byte r g b a GetBytesFromColor
  • 如何转发声明要在 unique_ptr 的标准容器中使用的类

    在智能指针的标准容器中使用它时 是否可以避免完整的类定义可见 例如 我无法编译以下内容 include
  • 命名管道客户端无法连接到作为网络服务运行的服务器

    我有一个服务在网络服务帐户下运行 该服务只是设置一个命名管道并侦听连接 NamedPipeServerStream listeningPipe new NamedPipeServerStream ourservicepipe PipeDir
  • 表达式访问者仅为某些 lambda 表达式调用 VisitParameter

    我希望能够使用嵌套扩展方法将 EF 中的实体投影到相应的视图模型 参见我之前的问题使用扩展方法在 EF 中投影单个实体 https stackoverflow com questions 39585427 projection of sin
  • 更改图像颜色与透明背景

    我需要使用 c System Drawings 将透明背景上带有绿色圆圈的图像加载到位图图像中 这是最简单的部分 但是 我需要在将其添加到更大的图像之前更改圆圈的颜色 而不影响周围的透明度 就我而言 我需要将圆圈颜色更改为黄色并将其添加为太
  • C++:字符串流有什么好处?

    谁能告诉我一些在 C 中使用字符串流的实际例子 即使用流插入和流提取运算符输入和输出到字符串流 您可以使用字符串流来转换任何实现operator lt lt 到一个字符串 include
  • 提取单花括号内的值

    我想要一个收藏 value 一个字符串使用正则表达式 例如 lorem ipsum field1 lorem ipsum field2 lorem ipsum field1 lorem ipsum field2 field3 我会得到 fi
  • 带有嵌入 Flash 视频的 PDF 示例?

    有谁知道我在哪里可以查看嵌入 Flash 视频的 PDF 示例 我知道问这个问题很愚蠢 因为你会认为任何面向技术的用户都应该能够使用谷歌找到一个 但我真的找不到 我的另一个问题是 使用 C 中的 API 将 Flash 视频嵌入 PDF 文
  • 将视频上传/保存到数据库或文件系统

    我以前从未尝试过保存视频 所以我对此了解不多 我知道如果视频很小 我可以转换为字节数组并保存到数据库 但是为了提高效率 我想了解如何将任何上传的视频保存到我的服务器文件中 然后只保存该文件的文件路径我的数据库表中的视频 我完全不知道如何开始
  • 用 OpenCL C 编写快速线性系统求解器

    我正在编写一个 OpenCL 内核 它将涉及求解线性系统 目前我的内核太慢了 提高线性系统部分的性能似乎是一个不错的起点 我还应该注意 我并没有尝试使我的线性求解器并行 我正在研究的问题在宏观层面上已经是令人尴尬的并行 以下是我编写的 C
  • 可以通过模板间接访问基类中的私有类型

    我试图在编译时根据类型是否在给定范围内公开可用来选择要使用的类型 最好直接看代码 include
  • _MM_TRANSPOSE4_PS 在 GCC 中导致编译器错误?

    我第一次在 GCC 而不是 MSVC 中编译我的数学库 并经历了所有的小错误 我遇到了一个根本没有意义的错误 Line 284 error lvalue required as left operand of assignment 284号
  • 如何解析多态 JSON 数组?

    我有一个 JSON 格式的文件 其中包含个人用户的记录 一些用户的记录中间有一个评论字段 我只想解析顶级项目 全名 贡献者姓名 电子邮件 使用 Newtonsoft JSON 解析器 但我似乎无法让它识别单个对象 当我将整个字符串解析为一个
  • C中使用JNI从对象获取对象

    public class Student private People people private Result result private int amount 这是 Java 中类的示例 在C中 我试图获取 学生 中的 人 但失败了
  • 如何在realm-dotnet中存储System.Collections.Generic.Dictionary

    我正在尝试将 Realm NET 集成到我的 uwp 项目中 我想知道是否有任何方法可以在 Realm dotnet 库中存储 System Collections Generic Dictionary 我试过这个 public class
  • NSubstitute - 测试特定的 linq 表达式

    我在当前正在开发的 MVC 3 应用程序中使用存储库模式 我的存储库界面如下所示 public interface IRepository
  • Membership.ValidateUser() 的目的是什么

    我一直在学习有关MembershipProvider类 我认为Membership ValidateUser 方法应该用于登录用户 然而我刚刚了解到有一个FormsAuthentication Authenticate 目的是什么Valid
  • 推断“x => { throw .. }”的 Lambda 与重载方法中的 Func 匹配吗?

    我不明白为什么 C 最终在以下 LINQPad 代码中执行不正确的扩展方法 void Main Actual Sync Action Expected Sync Action Run x gt x Dump Actual Async Tas
  • 从其对象获取结构体字段的名称和类型

    例如 我有一个类似这样的结构 struct Test int i float f char ch 10 我有一个该结构的对象 例如 Test obj 现在 我想以编程方式获取字段名称和类型obj 是否可以 顺便说一句 这是 C 你正在要求C

随机推荐

  • 去除自定义AlertDialog黑边

    http blog csdn net mwj 88 article details 45482421 1 现象描述 html view plain copy View view LayoutInflater from getActivity
  • java学习笔记——day1

    java笔记 字面量 变量 数据类型 命名规则 类型转换 运算符operator API 程序的流程控制 数组 字面量 变量 字面量 计算机用来处理数据的 字面量就是告诉程序员 数据在程序中的书写格式 字符 单引号 一个字符 字符串 双引号
  • python+selenium自动化软件测试(第3章):unittes

    3 1 unittest简介 前言 python基础比较弱的 建议大家多花点时间把基础语法学好 这里有套视频 可以照着练习下 http pan baidu com s 1i44jZdb 密码 92fs 熟悉java的应该都清楚常见的单元测试
  • 分层测试(一):什么是分层测试?

    什么是分层测试 分层测试是通过对质量问题分类 分层来保证整体系统质量的测试体系 模块内通过接口测试保证模块质量 多模块之间通过集成测试保证通信路径和模块间交互质量 整体系统通过端到端用例对核心业务场景进行验证 用户体验通过手工测试确保无妨碍
  • Unity开发(2)建片草地

    文章目录 1 简述 2 创建 2 1 创建项目 2 2 进入开发窗体 3 建个地面 3 1 新建地面 3 2 调整地面大小 3 3 添加草地 3 3 1 初识Unity图片资源 3 3 2 添加图片资源 3 3 3 修改图片在场景中大小 1
  • C语言入门知识1(零基础新手适用)

    C语言入门知识1 零基础新手适用 程序语言 1 机器语言 机器语言是低级语言 是用01码来编写的二进制代码语言 2 汇编语言 汇编语言也是低级语言 是用英文字母和符号串编写的 3 高级语言 由于汇编语言依赖于硬件体系且符合较多 为了方便高级
  • Go中 defer的使用

    文章目录 简介 示例 使用场景 捕获异常 文件操作 简介 defer 是 Golang 中的一个非常有用的关键字 它用于注册延迟调用 也就是一个函数的执行被延迟到调用它的函数返回之后 常用于资源清理 异常处理等场景 示例 defer 是注册
  • python实现电子邮件编程

    一 几个专业名词 MUA MTA MDA 假设我们自己的电子邮件地址是me 163 com 对方的电子邮件地址是friend sina com 注意地址都是虚构的哈 现在我们用Outlook或者Foxmail之类的软件写好邮件 填上对方的E
  • C++提高8: 类模板成员函数类外实现和类模板分文件编写

    1 类模板成员函数类外实现 类外实现主要有三个关键点 作用域 识别T的数据类型 告诉编译器这是一个类模板 剩下的 就还是基础的类内声明类外定义实现了 直接上代码观察一下 include
  • redis后台实现投票功能

    原创文章 转载请注明出处https blog csdn net qq 41969845 article details 108406059 一 前言 本文以投票功能为例 从实际例子中熟练掌握redis的应用 阅读本文需要有一定的Java基础
  • SparkStreaming与Kafka010之05之01 Consumer

    package Kafka010 import Kafka010 Utils MyKafkaUtils import org apache kafka clients consumer ConsumerRecord import org a
  • 常用网络数据帧格式

    常用网络数据帧格式 1 ARP帧格式 2 ICMP帧格式 3 UDP帧格式 4 TCP帧格式 本文主要介绍ARP ICMP UDP TCP等常用网络数据帧格式 1 ARP帧格式 当一个应用层的数据在网络中传输时 会被逐步封装成链路层的帧 而
  • ffplay源码解析-main入口函数

    main入口函数 初始化 变量 缓存区 SDL窗口初始化等 int main int argc char argv int flags VideoState is av log set level AV LOG TRACE init dyn
  • L1-086 斯德哥尔摩火车上的题(15分) Python

    上图是新浪微博上的一则趣闻 是瑞典斯德哥尔摩火车上的一道题 看上去是段伪代码 s a 1112031584 for i 1 i lt length a i if a i 2 a i 1 2 s max a i a i 1 goto url
  • 2020-11-24-ElasticSearch7.x学习笔记

    笔记记录 B站狂神说Java的ElasticSearch课程 https www bilibili com video BV17a4y1x7zq 在学习ElasticSearch之前 先简单了解一下Lucene Doug Cutting开发
  • 根据PV或者QPS来计算需要多少台机器

    QPS 单个进程每秒请求服务器成功的次数 req sec 总请求数 进程总数 请求时间 一般使用http load进行统计 每台服务器每天的PV QPS x 3600 x 6 或者乘以8小时 一天按照6或者8小时计算 晚上可能没人访问 服务
  • Conda环境 下载Jupyter Lab并使用

    1 下载Jupyter Lab conda 安装方式 conda install jupyterlab conda install c conda forge jupyterlab python 安装方式 pip install jupyt
  • python waitress_python 角度理解web服务器

    概述 web服务器实际上就是一个运行在物理机上的网络服务器 它等待客户端给他发送请求 成功接收后将客户端请求的资源响应给它 客户端与服务端的通信通过http协议实现 客户端可以是浏览器或者可以发送请求的一段程序 一 一个简单的web服务器
  • Android11 热点设置永不关闭

    Android11 热点设置永不关闭 文章目录 Android11 热点设置永不关闭 一 前言 二 framework设置热点永不超时关闭 三 基于 SoftApManager java 研究超时逻辑 三 总结 1 设置热点不关闭的方法 1
  • cutlass入门: 调用cutlass做通用矩阵乘法Gemm(附代码)

    cutlass是CUDA C 模板抽象的集合 用于实现CUDA中所有级别和规模的高性能矩阵乘法 GEMM 和相关计算 相较于cuBLAS和cuDNN cutlass中包含了更多可重用的模块化软件组件 这使得cutlass相较于前两者更为灵活