PyOpenCL 矩阵乘法

2024-03-01

我有使用 pyopenCL 进行矩阵乘法的代码。 我的问题是某些矩阵的结果是错误的,我不明白为什么。 经过一番研究后,我认为它与类似的全球规模有关,但我不明白如何设置该值。

例如:

使用 numpy dtype = float32 的矩阵

矩阵1:

[[ 0.99114645  0.09327769  0.90075564  0.8913309 ]
[ 0.59739089  0.13906649  0.94246316  0.65673178]
[ 0.24535166  0.68942326  0.41361505  0.5789603 ]
[ 0.31962237  0.17714553  0.49025267  0.21861202]]

matrix2:

[[ 0.41509482  0.82779616  0.74143827  0.37681136]
[ 0.88058949  0.01039944  0.4342753   0.45752665]
[ 0.60375261  0.21243185  0.88312167  0.97394323]
[ 0.60855824  0.69482827  0.61627114  0.57155776]]

预期结果:

[[ 1.57981943  1.63210835  2.12016045  1.80288424]
[ 1.3391085   1.15248911  1.7403561   1.58199609]
[ 1.31099532  0.70041376  1.20338154  1.14162762]
[ 0.71769556  0.52246746  0.88158722  0.8039138 ]]

脚本结果:

[[ 1.20828819  0.73175305  1.64546931  1.42526579]
[ 1.13179159  0.46403384  1.20692348  1.14317513]
[ 1.25328159  0.86723316  1.58679342  1.40186214]
[ 1.35214019  0.6795128   1.73811913  1.48048854]]

script:

def openCL_multiplication(matrix1, matrix2, res):

import pyopencl as cl
import numpy as np
import numpy.linalg as la

ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)

mf = cl.mem_flags
a_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=matrix1)
b_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=matrix2)
dest_buf = cl.Buffer(ctx, mf.WRITE_ONLY, matrix1.nbytes )


prg = cl.Program(ctx, """
    __kernel void multiplymatrices(const unsigned int size, __global float * matrix1, __global float * matrix2, __global float * res) {

    int i = get_global_id(1); 
    int j = get_global_id(0);

    res[i + size * j] = 0;

    for (int k = 0; k < size; k++)
    {
        res[i + size * j] += matrix1[i + size * k] * matrix2[k + size * j];
    }

    }
    """).build()

t0 = datetime.datetime.now()

prg.multiplymatrices(queue, matrix1.shape, None,np.int32(len(matrix1)) ,a_buf, b_buf, dest_buf)

final_matrix = np.empty_like(matrix1)
cl.enqueue_copy(queue, final_matrix , dest_buf)

print  final_matrix


delta_t = datetime.datetime.now() - t0
print 'OpenCL Multiplication: ' + str(delta_t)

return final_matrix

谢谢你!


嗯,我认为内核做得很好。 我什至可以认为脚本结果是正确的。这完全取决于你如何对待你的矩阵:-) 如果你想要你预期的结果。我会改变这个:

res[i + size * j] += matrix1[i + size * k] * matrix2[k + size * j];

to this:

res[i + size * j] += matrix1[k + size * i] * matrix2[j + size * k];

希望这可以帮助。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyOpenCL 矩阵乘法 的相关文章

  • 如何在Python中选择要写入(.csv)的列

    import csv f csv reader open lmt csv r open input file for reading Date Open Hihh mLow Close Volume zip f s plit it into
  • 如何使用 conda 在一行中安装多个包?

    我需要使用 conda 安装以下多个软件包 我不确定 conda forge 是什么 有些使用 conda forge 有些不使用它 是否可以将它们安装成一行而不需要一一安装 谢谢 conda install c conda forge d
  • pandas Wide_to_long 后缀参数

    我对在 pandas 中使用 Wide to long 时的参数有疑问 有一个参数叫suffix我不明白 在文档中它说 后缀 str 默认 d 捕获所需后缀的正则表达式 d 捕获数字后缀 没有数字的后缀可以用否定字符类 D 指定 您还可以进
  • 如何让python优雅地失败?

    我只是想知道如何让 python 在所有可能的错误中以用户定义的方式失败 例如 我正在编写一个处理 大 项目列表的程序 并且某些项目可能不符合我定义的格式 如果 python 检测到错误 它目前只会输出一条丑陋的错误消息并停止整个过程 但是
  • 以矢量化方式在另一个 DataFrame 中查找包含值子集的行

    如何匹配此 DataFrame 中的值source car id lat lon 0 100 10 0 15 0 1 100 12 0 10 0 2 100 09 0 08 0 3 110 23 0 12 0 4 110 18 0 32 0
  • 无法使用 BeautifulSoup 和 Requests 抓取下拉菜单

    我想抓取百年灵网站上的产品页面以获取各种信息 示例页面 https www breitling com gb en watches navitimer b01 chronograph 46 AB0127211C1A1 https www b
  • 将一维数组转换为下三角矩阵

    我想将一维数组转换为较低的零对角矩阵 同时保留所有数字 我知道numpy tril函数 但它用零替换了一些元素 我需要扩展矩阵以包含所有原始数字 例如 10 20 40 46 33 14 12 46 52 30 59 18 11 22 30
  • numpy:大量线段/点的快速规则间隔平均值

    我沿着一维线有许多 约 100 万个 不规则间隔的点 P 这些标记线段 这样 如果点是 0 x a x b x c x d 则线段从 0 gt x a x a gt x b x b gt x c x c gt x d 等 我还有每个段的 y
  • scikit-learn 和tensorflow 有什么区别?可以一起使用它们吗?

    对于这个问题我无法得到满意的答案 据我了解 TensorFlow是一个数值计算库 经常用于深度学习应用 而Scikit learn是一个通用机器学习框架 但它们之间的确切区别是什么 TensorFlow 的目的和功能是什么 我可以一起使用它
  • 在 iPython/pandas 中绘制多条线会生成多个图

    我试图了解 matplotlib 的状态机模型 但在尝试在单个图上绘制多条线时遇到错误 据我了解 以下代码应该生成包含两行的单个图 import pandas as pd import pandas io data as web aapl
  • Jupyter Notebook 中的深色模式绘图 - Python

    我正在使用 Jupyter Notebook 目前正在使用 JupyterThemes 的深色日光主题 我注意到我的绘图不是处于黑暗模式 并且文本仍然是黑色并且在日光照射的背景上无法读取 JupyterThemes 的自述文件建议在 ipy
  • 线性同余生成器 - 如何选择种子和统计检验

    我需要做一个线性同余生成器 它将成功通过所选的统计测试 我的问题是 如何正确选择发电机的数字以及 我应该选择哪些统计检验 我想 均匀性的卡方频率测试 每代收集10 000个号码的方法 将 0 1 细分为10个相等的细分 柯尔莫哥洛夫 斯米尔
  • 计算 pyspark df 列中子字符串列表的出现次数

    我想计算子字符串列表的出现次数 并根据 pyspark df 中包含长字符串的列创建一个列 Input ID History 1 USA UK IND DEN MAL SWE AUS 2 USA UK PAK NOR 3 NOR NZE 4
  • Django Rest Framework POST 更新(如果存在或创建)

    我是 DRF 的新手 我阅读了 API 文档 也许这是显而易见的 但我找不到一个方便的方法来做到这一点 我有一个Answer与 a 具有一对一关系的对象Question 在前端 我曾经使用 POST 方法来创建发送到的答案api answe
  • 更换壳牌管道[重复]

    这个问题在这里已经有答案了 在 subprocess 模块的 Python 2 7 文档中 我找到了以下片段 p1 Popen dmesg stdout PIPE p2 Popen grep hda stdin p1 stdout stdo
  • python dicttoxml 多次使用相同的键

    我正在尝试做如下所示的 xml
  • 在 Python 中访问 argparse 的参数值

    我正在尝试为我的程序设置一些简单的标志参数 但无法弄清楚如何访问它们 我有 argparser parser argparse ArgumentParser description Simple PostScript Interpreter
  • 在 HDF5 (PyTables) 中存储 numpy 稀疏矩阵

    我在使用 PyTables 存储 numpy csr matrix 时遇到问题 我收到此错误 TypeError objects of type csr matrix are not supported in this context so
  • 如何更改matplotlib中双头注释的头大小?

    Below figure shows the plot of which arrow head is very small 我尝试了下面的代码 但它不起作用 它说 引发 AttributeError 未知属性 s k 属性错误 未知属性头宽
  • Python 中的字符串slugification

    我正在寻找 slugify 字符串的最佳方法 蛞蝓 是什么 https stackoverflow com questions 427102 in django what is a slug 我当前的解决方案基于这个食谱 http code

随机推荐