使用 CUBLAS 库给矩阵运算提速

2023-11-08

前言

  编写 CUDA 程序真心不是个简单的事儿,调试也不方便,很费时。那么有没有一些现成的 CUDA 库来调用呢?

  答案是有的,如 CUBLAS 就是 CUDA 专门用来解决线性代数运算的库。

  本文将大致介绍如何使用 CUBLAS 库,同时演示一个使用 CUBLAS 库进行矩阵乘法的例子。

CUBLAS 内容

  CUBLAS 是 CUDA 专门用来解决线性代数运算的库,它分为三个级别:

  Lev1. 向量相乘

  Lev2. 矩阵乘向量

  Lev3. 矩阵乘矩阵

  同时该库还包含状态结构和一些功能函数。

CUBLAS 用法

  大体分成以下几个步骤:

  1. 定义 CUBLAS 库对象

  2. 在显存中为待运算的数据以及需要存放结果的变量开辟显存空间。( cudaMalloc 函数实现 )

  3. 将待运算的数据传输进显存。( cudaMemcpy,cublasSetVector 等函数实现 )

  3. 调用 CUBLAS 库函数 ( 根据 CUBLAS 手册调用需要的函数 )

  4. 从显存中获取结果变量。( cudaMemcpy,cublasGetVector 等函数实现 )

  5. 释放申请的显存空间以及 CUBLAS 库对象。( cudaFree 及 cublasDestroy 函数实现 )

代码示例

  如下程序使用 CUBLAS 库进行矩阵乘法运算,请仔细阅读注释,尤其是 API 的参数说明:

  1 // CUDA runtime 库 + CUBLAS 库 
  2 #include "cuda_runtime.h"
  3 #include "cublas_v2.h"
  4 
  5 #include <time.h>
  6 #include <iostream>
  7 
  8 using namespace std;
  9 
 10 // 定义测试矩阵的维度
 11 int const M = 5;
 12 int const N = 10;
 13 
 14 int main() 
 15 {   
 16     // 定义状态变量
 17     cublasStatus_t status;
 18 
 19     // 在 内存 中为将要计算的矩阵开辟空间
 20     float *h_A = (float*)malloc (N*M*sizeof(float));
 21     float *h_B = (float*)malloc (N*M*sizeof(float));
 22     
 23     // 在 内存 中为将要存放运算结果的矩阵开辟空间
 24     float *h_C = (float*)malloc (M*M*sizeof(float));
 25 
 26     // 为待运算矩阵的元素赋予 0-10 范围内的随机数
 27     for (int i=0; i<N*M; i++) {
 28         h_A[i] = (float)(rand()%10+1);
 29         h_B[i] = (float)(rand()%10+1);
 30     
 31     }
 32     
 33     // 打印待测试的矩阵
 34     cout << "矩阵 A :" << endl;
 35     for (int i=0; i<N*M; i++){
 36         cout << h_A[i] << " ";
 37         if ((i+1)%N == 0) cout << endl;
 38     }
 39     cout << endl;
 40     cout << "矩阵 B :" << endl;
 41     for (int i=0; i<N*M; i++){
 42         cout << h_B[i] << " ";
 43         if ((i+1)%M == 0) cout << endl;
 44     }
 45     cout << endl;
 46     
 47     /*
 48     ** GPU 计算矩阵相乘
 49     */
 50 
 51     // 创建并初始化 CUBLAS 库对象
 52     cublasHandle_t handle;
 53     status = cublasCreate(&handle);
 54     
 55     if (status != CUBLAS_STATUS_SUCCESS)
 56     {
 57         if (status == CUBLAS_STATUS_NOT_INITIALIZED) {
 58             cout << "CUBLAS 对象实例化出错" << endl;
 59         }
 60         getchar ();
 61         return EXIT_FAILURE;
 62     }
 63 
 64     float *d_A, *d_B, *d_C;
 65     // 在 显存 中为将要计算的矩阵开辟空间
 66     cudaMalloc (
 67         (void**)&d_A,    // 指向开辟的空间的指针
 68         N*M * sizeof(float)    // 需要开辟空间的字节数
 69     );
 70     cudaMalloc (
 71         (void**)&d_B,    
 72         N*M * sizeof(float)    
 73     );
 74 
 75     // 在 显存 中为将要存放运算结果的矩阵开辟空间
 76     cudaMalloc (
 77         (void**)&d_C,
 78         M*M * sizeof(float)    
 79     );
 80 
 81     // 将矩阵数据传递进 显存 中已经开辟好了的空间
 82     cublasSetVector (
 83         N*M,    // 要存入显存的元素个数
 84         sizeof(float),    // 每个元素大小
 85         h_A,    // 主机端起始地址
 86         1,    // 连续元素之间的存储间隔
 87         d_A,    // GPU 端起始地址
 88         1    // 连续元素之间的存储间隔
 89     );
 90     cublasSetVector (
 91         N*M, 
 92         sizeof(float), 
 93         h_B, 
 94         1, 
 95         d_B, 
 96         1
 97     );
 98 
 99     // 同步函数
100     cudaThreadSynchronize();
101 
102     // 传递进矩阵相乘函数中的参数,具体含义请参考函数手册。
103     float a=1; float b=0;
104     // 矩阵相乘。该函数必然将数组解析成列优先数组
105     cublasSgemm (
106         handle,    // blas 库对象 
107         CUBLAS_OP_T,    // 矩阵 A 属性参数
108         CUBLAS_OP_T,    // 矩阵 B 属性参数
109         M,    // A, C 的行数 
110         M,    // B, C 的列数
111         N,    // A 的列数和 B 的行数
112         &a,    // 运算式的 α 值
113         d_A,    // A 在显存中的地址
114         N,    // lda
115         d_B,    // B 在显存中的地址
116         M,    // ldb
117         &b,    // 运算式的 β 值
118         d_C,    // C 在显存中的地址(结果矩阵)
119         M    // ldc
120     );
121     
122     // 同步函数
123     cudaThreadSynchronize();
124 
125     // 从 显存 中取出运算结果至 内存中去
126     cublasGetVector (
127         M*M,    //  要取出元素的个数
128         sizeof(float),    // 每个元素大小
129         d_C,    // GPU 端起始地址
130         1,    // 连续元素之间的存储间隔
131         h_C,    // 主机端起始地址
132         1    // 连续元素之间的存储间隔
133     );
134     
135     // 打印运算结果
136     cout << "计算结果的转置 ( (A*B)的转置 ):" << endl;
137 
138     for (int i=0;i<M*M; i++){
139             cout << h_C[i] << " ";
140             if ((i+1)%M == 0) cout << endl;
141     }
142     
143     // 清理掉使用过的内存
144     free (h_A);
145     free (h_B);
146     free (h_C);
147     cudaFree (d_A);
148     cudaFree (d_B);
149     cudaFree (d_C);
150 
151     // 释放 CUBLAS 库对象
152     cublasDestroy (handle);
153 
154     getchar();
155     
156     return 0;
157 }

运行测试

  

  PS:矩阵元素是随机生成的

小结

  1. 使用 CUDA 库固然方便,但也要仔细的参阅函数手册,其中每个参数的含义都要很清晰才不容易出错。

  2. 如果程序仅使用 CUDA 库的话,用 .cpp 源码文件即可 (不用 .cu)

转载于:https://www.cnblogs.com/scut-fm/p/3756242.html

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 CUBLAS 库给矩阵运算提速 的相关文章

  • 如何将base64字符串直接解码为二进制音频格式

    音频文件通过 API 发送给我们 该文件是 Base64 编码的 PCM 格式 我需要将其转换为 PCM 然后再转换为 WAV 进行处理 我能够使用以下代码解码 gt 保存到 pcm gt 从 pcm 读取 gt 保存为 wav decod
  • 切片稀疏(scipy)矩阵

    我将不胜感激任何帮助 以理解从 scipy sparse 包中切片 lil matrix A 时的以下行为 实际上 我想根据行和列的任意索引列表提取子矩阵 当我使用这两行代码时 x1 A list 1 x2 x1 list 2 一切都很好
  • Python有条件求解时滞微分方程

    我在用dde23 of pydelay包来求解延迟微分方程 我的问题 如何有条件地编写方程 例如目标方程有两个选项 when x gt 1 dx dt 0 25 x t tau 1 0 pow x t tau 10 0 0 1 x othe
  • 将 numpy 数组写入文本文件的速度

    我需要将一个非常 高 的两列数组写入文本文件 而且速度非常慢 我发现如果我将数组改造成更宽的数组 写入速度会快得多 例如 import time import numpy as np dataMat1 np random rand 1000
  • 更新 Sqlalchemy 中的多个列

    我有一个在 Flask 上运行的应用程序 并使用 sqlalchemy 与数据库交互 我想用用户指定的值更新表的列 我正在使用的查询是 def update table value1 value2 value3 query update T
  • 根据开始列和结束列扩展数据框(速度)

    我有一个pandas DataFrame含有start and end列 加上几个附加列 我想将此数据框扩展为一个时间序列 从start值并结束于end值 但复制我的其他专栏 到目前为止 我想出了以下内容 import pandas as
  • numpy 使用 datetime64 进行数字化

    我似乎无法让 numpy digitize 与 datetime64 一起使用 date bins np array np datetime64 datetime datetime 2014 n 1 s for n in range 1 1
  • 可以用 Django 制作移动应用程序吗?

    我想知道我是否可以在我的网站上使用 Django 代码 并以某种方式在移动应用程序 Flutter 等框架中使用它 那么是否可以使用我现在拥有的 Django 后端并在移动应用程序中使用它 所以就像models views etc 是的 有
  • PySide6.1 与 matplotlib 3.4 不兼容

    当我只安装PySide6时 GUI程序运行良好 但是一旦我安装了matplotlib及其依赖包 包括pyqt5 则GUI程序将无法运行并输出以下错误消息 This application failed to start because no
  • `list()` 被认为是一个函数吗?

    list显然是内置类型 https docs python org 3 library stdtypes html list在Python中 我看到底下有一条评论this https stackoverflow com a 53645813
  • Python多处理错误“ForkAwareLocal”对象没有属性“连接”

    下面是我的代码 我面临着多处理问题 我看到这个问题之前已经被问过 我已经尝试过这些解决方案 但它似乎不起作用 有人可以帮我吗 from multiprocessing import Pool Manager Class X def init
  • django-admin.py makemessages 不起作用

    我正在尝试翻译一个字符串 load i18n trans Well Hello there how are you to Hola amigo que tal 我的 settings py 文件有这样的内容 LOCALE PATHS os
  • Python 惰性迭代器

    我试图了解迭代器表达式如何以及何时被求值 以下似乎是一个懒惰的表达 g i for i in range 1000 if i 3 i 2 然而 这个在构造上失败了 g line strip for line in open xxx r if
  • Pandas style.bar 颜色基于条件?

    如何渲染其中一列的 Pandas dfstyle bar color属性是根据某些条件计算的 Example df style bar subset before after color ff781c vmin 0 0 vmax 1 0 而
  • 解析根元素内元素之间的 XML 文本

    我正在尝试用 Python 解析 XML 以下是 XML 结构的示例 a aaaa1 b bbbb b aaaa2 a
  • 是否可以将 pd.Series 分配给无序 pd.DataFrame 中的列而不映射到索引(即不重新排序值)?

    在 Pandas 中创建或分配新列时 我发现了一些意外的行为 当我对 pd DataFrame 进行过滤或排序 从而混合索引 然后从 pd Series 创建新列时 Pandas 会重新排序该系列以映射到 DataFrame 索引 例如 d
  • 对数据帧的每 2 小时数据进行 Groupby

    我有一个数据框 Time T201FN1ST2010 T201FN1VT2010 1791 2017 12 26 00 00 00 854 69 0 87 1792 2017 12 26 00 20 00 855 76 0 87 1793
  • 处理大文件的最快方法?

    我有多个 3 GB 制表符分隔文件 每个文件中有 2000 万行 所有行都必须独立处理 任何两行之间没有关系 我的问题是 什么会更快 逐行阅读 with open as infile for line in infile 将文件分块读入内存
  • 在 virtualenvwrapper 中激活环境

    我安装了virtualenv and virtualenvwrapper用这个命令我创建了一个环境 mkvirtualenv cv 它有效 创建后我就处于新环境中 现在我重新启动了我的电脑 我想activate又是那个环境 但是怎么样 我使
  • 缓存 Flask-登录 user_loader

    我有这个 login manager user loader def load user id None return User query get id 在我引入 Flask Principal 之前它运行得很好 identity loa

随机推荐

  • # 解析bt文件_PC端BT资源搜索及下载,诸位请节制!

    Hello大家好 这里是TopOne软件管家 毕竟要求的人太多了 今天将我测试最好的搭配给大家分享一下 当然 这个是站在我的角度 大家可以根据自己的使用情况进行调整 今天分享的是PC端 由于Mac限制 苹果电脑现只提供BT搜索软件 BT搜索
  • Windows下基于WSL2的Ubuntu开发环境搭建

    1 背景介绍 Windows是市场占有率最高的桌面操作系统 嵌入式开发领域一般需要搭建ubuntu虚拟机环境以实现linux下的交叉编译等工作 传统的Vmvare Ubuntu虚拟机安装过程繁琐且资源消耗巨大 自从Windows提供WSL2
  • 数据分析08——Pandas中对数据进行数据清洗

    0 前言 使用pandas修改数据是否会改变源数据 Pandas 对 DataFrame 的操作通常是针对原始数据本身而不是其副本的 例如 当我们使用 loc 或 iloc 方法选择 DataFrame 中的某行或某列并进行修改时 实际上是
  • python实现手势识别

    python实现手势识别 入门 使用open cv实现简单的手势识别 刚刚接触python不久 看到了很多有意思的项目 尤其时关于计算机视觉的 网上搜到了一些关于手势处理的实验 我在这儿简单的实现一下 PS 和那些大佬比起来真的是差远了 毕
  • Flink Sql使用mysql-cdc捕获多个表失败的问题

    问题描述 要捕获同一个库里的多个表的binlog 程序不报错 但是修改某个表后没有结果没有任何改变 fllinkSql的with语句 WITH connector mysql cdc hostname s port s username s
  • Linux安装anaconda3是否初始化的区别

    Linux安装anaconda3提示是否希望安装程序通过运行conda init来初始化Anaconda3 Do you wish the installer to initialize Anaconda3 by running conda
  • 数据结构1.1.1单链表的实现

    1 初始化链表节点内容 typedef struct char isbn 20 char name 10 double price Book typedef struct list Book date struct list next Li
  • GIT——! [rejected] master -> master (non-fast-forward)

    问题 rejected master gt master non fast forward error failed to push some refs to ssh 192 168 137 64 29418 51Selling git h
  • Maven的安装与使用

    一 简介 1 什么是Maven Maven翻译为 专家 内行 的意思 是著名Apache公司下基于Java开发的开源项目 Maven项目对象模型 POM 是一个项目管理工具软件 可以通过简短的中央信息描述来管理项目的搭建 报告和文档等步骤
  • JS+AES解密(CBC模式、输出HEX)

    if tokenMsgs const response await getMqttMsgService let mqttMsg response data msg state mqttconfigs mqttMsg const aesKey
  • 【工具类】发送邮件表格html生成类

    发送邮件的时候 有时候要自己拼html画一个表格 嫌麻烦就写了个工具类 核心类MailTableBuilder import java util MailTableCell author zgd date 2022 8 25 17 43 p
  • 【JAVA】垃圾回收详解

    文章目录 垃圾回收 调用垃圾回收器的方法 finalize 方法 判断对象是否可回收 引用计数算法 根搜索算法 引用的分类 垃圾回收算法 标记 清除算法 标记 整理算法 复制算法 分代收集算法 分配内存与回收策略 Minor GC 和 Fu
  • 使用UDP实现下载上传

    include
  • python基础知识点汇总

    本文包括python基本知识 简单数据结构 数据结构类型 可变 列表 字典 集合 不可变 数值类型 字符串 元组 分支循环和控制流程 类和函数 文件处理和异常等等 python控制语句 if语句 当条件成立时运行语句块 经常与else el
  • 纯js原生实现图片批量下载

    前端纯js实现图片批量下载到本地 图片转base64 getImageBase64 image const canvas document createElement canvas canvas width image width canv
  • 4行Python代码打败美图秀秀

    我们平时使用一些图像处理软件时 经常会看到其对图像的亮度 对比度 色度或者锐度进行调整 你是不是觉得这种技术的底 层实现很高大上 其实最基础的实现原理 用 Python 实现只需要几行代码 学会后你也可以进行简单的图像增强处理了 图像增强哪
  • wsl 内突然不能上网了

    现象 1 一开始是间歇性无法联网 无法连接外网 表现为 apt get update 时请求失败 国内源 2 尝试 ping www baidu com等外网域名 超时 nslookup能够正常解析域名 IP 在 host 主机上也能够正常
  • 进程和线程:进程的开销比线程大在了哪里?

    进程和线程 进程 Process 顾名思义就是正在执行的应用程序 是软件的执行副本 而线程是轻量级的进程 进程是分配资源的基础单位 线程很长一段时间被称作轻量级进程 Light Weighted Process 是程序执行的基本单位 在计算
  • pyltp安装教程windows11

    我是用anaconda创建一个环境 这个比较容易管理 第一步 anaconda创建环境 网上很多教程 第二步 安装pyltp 第一种方法 pip install pyltp 用这个多半失败 第二种方法 用wheel安装 下载wheel 参考
  • 使用 CUBLAS 库给矩阵运算提速

    前言 编写 CUDA 程序真心不是个简单的事儿 调试也不方便 很费时 那么有没有一些现成的 CUDA 库来调用呢 答案是有的 如 CUBLAS 就是 CUDA 专门用来解决线性代数运算的库 本文将大致介绍如何使用 CUBLAS 库 同时演示