torch.mm、torch.matmul 和 torch.mul 有什么区别?

2024-04-15

阅读完 pytorch 文档后,我仍然需要帮助来理解之间的区别torch.mm, torch.matmul and torch.mul。由于我不完全理解它们,所以我无法简明地解释这一点。

B = torch.tensor([[ 1.1207],
        [-0.3137],
        [ 0.0700],
        [ 0.8378]])

C = torch.tensor([[ 0.5146,  0.1216, -0.5244,  2.2382]])

print(torch.mul(B,C))

print(torch.matmul(B,C))

print(torch.mm(B,C))

所有三个都会产生以下输出(即它们执行矩阵乘法):

tensor([[ 0.5767,  0.1363, -0.5877,  2.5084],
        [-0.1614, -0.0381,  0.1645, -0.7021],
        [ 0.0360,  0.0085, -0.0367,  0.1567],
        [ 0.4311,  0.1019, -0.4393,  1.8752]])
A = torch.tensor([[1.8351,2.1536], [-0.8320,-1.4578]])
B = torch.tensor([[2.9355, 0.3450], [0.5708, 1.9957]])
print(torch.mul(A,B))
print(torch.matmul(A,B))
print(torch.mm(A,B))

不同的产生输出。 torch.mm 不再执行矩阵乘法(而是广播并执行逐元素乘法,而其他两个仍然执行矩阵乘法。

tensor([[ 5.3869,  0.7430],
        [-0.4749, -2.9093]])
tensor([[ 6.6162,  4.9310],
        [-3.2744, -3.1964]])
tensor([[ 6.6162,  4.9310],
        [-3.2744, -3.1964]])

Inputs

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)

tensor1 = 
tensor([[[-0.2267,  0.6311, -0.5689,  1.2712],
         [-0.0241, -0.5362,  0.5481, -0.4534],
         [-0.9773, -0.6842,  0.6927,  0.3363]],

        [[-2.6759,  0.7817,  2.6821,  0.7037],
         [ 0.1804,  0.3938, -1.2235,  0.8729],
         [-1.9873, -0.5030,  0.0945,  0.2688]],

        [[ 0.4244,  1.7350,  0.0558, -0.1861],
         [-0.9063, -0.4737, -0.4284, -0.3883],
         [ 0.4827, -0.2628,  1.0084,  0.2769]],

        [[ 0.2939,  0.4604,  0.8014, -1.8760],
         [ 1.8807,  0.1623,  0.2344, -0.6221],
         [ 1.3964,  3.1637,  0.7889,  0.1195]],

        [[-0.7202,  1.4250,  2.4302,  1.4811],
         [-0.2301,  0.6280,  0.5379,  0.5178],
         [-2.1073, -1.4399, -0.9451,  0.8534]],

        [[ 2.8178, -0.4451, -0.7871, -0.5198],
         [ 0.2825,  1.0692,  0.1559,  1.2945],
         [-0.5828, -1.6287, -2.0661, -0.4107]],

        [[ 0.5077, -0.6349, -0.0160, -0.4477],
         [-0.8070,  0.3746,  1.1852,  0.0351],
         [-0.6454,  1.5877,  0.8561,  1.1021]],

        [[ 0.1191,  1.0116,  0.5807,  1.2105],
         [-0.5403,  1.2404,  1.1532,  0.6537],
         [ 1.4757, -1.3648, -1.7158, -1.0289]],

        [[-0.1326,  0.3715,  0.2429, -0.0794],
         [ 0.3224, -0.3064,  0.1963,  0.7276],
         [ 0.9098,  1.5984, -1.4953,  0.0420]],

        [[ 0.1511,  0.9691, -0.5204,  0.3858],
         [ 0.4566,  1.5482, -0.3401,  0.5960],
         [-0.9998,  0.7198,  0.9286,  0.4498]]])

tensor2 =
tensor([-1.6350,  1.0335, -0.9023,  0.0696])
print(torch.mul(tensor1,tensor2))
print(torch.matmul(tensor1,tensor2))
print(torch.mm(tensor1,tensor2))

输出是一切都不同。我认为torch.mul广播矩阵的每 4 个元素并将其乘以向量,tensor2,即[-0.2267, 0.6311, -0.5689, 1.2712] x tensor 2 元素方面,[-0.0241, -0.5362, 0.5481, -0.4534] x tensor 2元素方面等等。我不明白什么 torch.matmul是在做。我认为这与文档的第五个要点有关(如果两个参数......),但我无法理解这一点。https://pytorch.org/docs/stable/ generated/torch.matmul.html https://pytorch.org/docs/stable/generated/torch.matmul.html

我认为原因torch.mm无法产生输出是因为它无法广播(如果我错了,请纠正我)。

tensor([[[ 3.7071e-01,  6.5221e-01,  5.1335e-01,  8.8437e-02],
         [ 3.9400e-02, -5.5417e-01, -4.9460e-01, -3.1539e-02],
         [ 1.5979e+00, -7.0715e-01, -6.2499e-01,  2.3398e-02]],

        [[ 4.3752e+00,  8.0790e-01, -2.4201e+00,  4.8957e-02],
         [-2.9503e-01,  4.0699e-01,  1.1040e+00,  6.0723e-02],
         [ 3.2494e+00, -5.1981e-01, -8.5253e-02,  1.8701e-02]],

        [[-6.9397e-01,  1.7931e+00, -5.0379e-02, -1.2945e-02],
         [ 1.4818e+00, -4.8954e-01,  3.8657e-01, -2.7010e-02],
         [-7.8920e-01, -2.7163e-01, -9.0992e-01,  1.9265e-02]],

        [[-4.8055e-01,  4.7582e-01, -7.2309e-01, -1.3051e-01],
         [-3.0750e+00,  1.6770e-01, -2.1146e-01, -4.3281e-02],
         [-2.2832e+00,  3.2697e+00, -7.1183e-01,  8.3139e-03]],

        [[ 1.1775e+00,  1.4727e+00, -2.1928e+00,  1.0304e-01],
         [ 3.7617e-01,  6.4900e-01, -4.8534e-01,  3.6025e-02],
         [ 3.4455e+00, -1.4882e+00,  8.5277e-01,  5.9369e-02]],

        [[-4.6072e+00, -4.6005e-01,  7.1024e-01, -3.6160e-02],
         [-4.6191e-01,  1.1051e+00, -1.4067e-01,  9.0053e-02],
         [ 9.5283e-01, -1.6833e+00,  1.8643e+00, -2.8571e-02]],

        [[-8.3005e-01, -6.5622e-01,  1.4461e-02, -3.1148e-02],
         [ 1.3195e+00,  3.8716e-01, -1.0694e+00,  2.4421e-03],
         [ 1.0553e+00,  1.6409e+00, -7.7250e-01,  7.6669e-02]],

        [[-1.9477e-01,  1.0455e+00, -5.2398e-01,  8.4209e-02],
         [ 8.8343e-01,  1.2820e+00, -1.0405e+00,  4.5478e-02],
         [-2.4128e+00, -1.4106e+00,  1.5482e+00, -7.1578e-02]],

        [[ 2.1675e-01,  3.8391e-01, -2.1914e-01, -5.5219e-03],
         [-5.2707e-01, -3.1668e-01, -1.7711e-01,  5.0619e-02],
         [-1.4876e+00,  1.6520e+00,  1.3493e+00,  2.9198e-03]],

        [[-2.4706e-01,  1.0015e+00,  4.6955e-01,  2.6842e-02],
         [-7.4663e-01,  1.6001e+00,  3.0685e-01,  4.1462e-02],
         [ 1.6347e+00,  7.4395e-01, -8.3792e-01,  3.1291e-02]]])
tensor([[ 1.6247, -1.0409,  0.2891],
        [ 2.8120,  1.2767,  2.6630],
        [ 1.0358,  1.3518, -1.9515],
        [-0.8583, -3.1620,  0.2830],
        [ 0.5605,  0.5759,  2.8694],
        [-4.3932,  0.5925,  1.1053],
        [-1.5030,  0.6397,  2.0004],
        [ 0.4109,  1.1704, -2.3467],
        [ 0.3760, -0.9702,  1.5165],
        [ 1.2509,  1.2018,  1.5720]])

简而言之:

  • torch.mm- 执行矩阵乘法没有广播-(二维张量)by(二维张量)
  • torch.mul- 执行一个逐元素乘法与广播-(张量)by(张量或数)
  • torch.matmul- 矩阵乘积与广播-(张量)by(张量)根据张量形状(点积、矩阵积、批量矩阵积)具有不同的行为。

一些细节:

  1. torch.mm- 执行矩阵乘法没有广播

它需要两个 2D 张量,所以n×m * m×p = n×p

从文档中https://pytorch.org/docs/stable/ generated/torch.mm.html https://pytorch.org/docs/stable/generated/torch.mm.html:

This function does not broadcast. For broadcasting matrix products, see torch.matmul().
  1. torch.mul- 执行一个逐元素乘法与广播-(张量)by(张量或数)

Docs: https://pytorch.org/docs/stable/ generated/torch.mul.html https://pytorch.org/docs/stable/generated/torch.mul.html

torch.mul不执行矩阵乘法。它广播两个张量并执行元素乘法。因此,当您将其与张量 1x4 * 4x1 一起使用时,其工作原理类似于:

import torch

a = torch.FloatTensor([[1], [2], [3]])
b = torch.FloatTensor([[1, 10, 100]])
a, b = torch.broadcast_tensors(a, b)
print(a)
print(b)
print(a * b)
tensor([[1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.]])
tensor([[  1.,  10., 100.],
        [  1.,  10., 100.],
        [  1.,  10., 100.]])
tensor([[  1.,  10., 100.],
        [  2.,  20., 200.],
        [  3.,  30., 300.]])
  1. torch.matmul

最好还是看看官方文档https://pytorch.org/docs/stable/ generated/torch.matmul.html https://pytorch.org/docs/stable/generated/torch.matmul.html因为它根据输入张量使用不同的模式。它可以通过广播执行点积、矩阵-矩阵积或批量矩阵积。

至于您关于以下产品的问题:

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)

它是产品的批量版本。请检查这个简单的例子来理解:

import torch

# 3x1x3
a = torch.FloatTensor([[[1, 2, 3]], [[3, 4, 5]], [[6, 7, 8]]])
# 3
b = torch.FloatTensor([1, 10, 100])
r1 = torch.matmul(a, b)

r2 = torch.stack((
    torch.matmul(a[0], b),
    torch.matmul(a[1], b),
    torch.matmul(a[2], b),
))
assert torch.allclose(r1, r2)

因此它可以看作是跨批次维度堆叠在一起的多个操作。

阅读有关广播的内容也可能有用:

https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

torch.mm、torch.matmul 和 torch.mul 有什么区别? 的相关文章

  • 并发.futures问题:为什么只有1个worker?

    我正在尝试使用concurrent futures ProcessPoolExecutor并行化串行任务 串行任务涉及从数字范围中查找给定数字的出现次数 我的代码如下所示 在执行过程中 我从任务管理器 系统监视器 顶部注意到 尽管给定了 m
  • 如何从 Django 中的 POST 获取之前的 URL

    我有一个 Post 模型 在添加到数据库之前需要特定的类别 并且我希望自动生成该类别 单击 addPost 按钮会将您带到不同的页面 因此将通过上一页 URL 的一部分来确定类别 有没有办法以字符串形式获取上一页 URL 我在这里添加了 A
  • 带参数和不带参数的 super() 有什么区别?

    我遇到了一个使用的代码super 方法有两种不同的方式 我不明白逻辑上有什么区别 我现在正在学习pygame模块 我有一个任务来创建一个类Ball它继承自Sprite这是一个来自pygame模块 如果我没记错的话 我遇到了这段代码 impo
  • Python 滚动文本模块

    我想使用scrolledtext模块创建一个ScrolledText小部件 以便在python中创建GUI 我已经成功创建了 ScrolledText 小部件 但是我无法向其添加水平滚动条 e3 ScrolledText window3 w
  • 间歇性的“事件循环在 Future 完成之前停止。”

    我一直在为这个问题抓狂 有问题的代码是此处开源项目的一部分 aiosmtpd https github com pepoluan aiosmtpd 我的实际 FOSS 项目的分支 here https github com aio libs
  • Python:Facebook Graph API - 使用 facebook-sdk 的分页请求

    我正在尝试向 Facebook 查询不同的信息 例如 好友列表 它工作得很好 但当然它只能给出有限数量的结果 如何获取下一批结果 import facebook import json ACCESS TOKEN def pp o with
  • Python 3:资源警告:未关闭的文件 <_io.TextIOWrapper name='PATH_OF_FILE'

    当我在 python 中运行测试用例时 python 规范化器 setup py 测试 我收到以下异常 ResourceWarning unclosed file lt io TextIOWrapper name Users workspa
  • 查找其他列表项中列表项的列表索引

    我有一个长字符串列表 我想获取与另一个列表中的字符串子字符串匹配的列表元素的索引 使用列表理解可以轻松检查列表项是否包含列表中的单个字符串 例如这个问题 https stackoverflow com questions 4843158 c
  • 如何在 Tensorflow 中计算 Spearman 相关性

    Problem 我需要计算 Pearson 和 Spearman 相关性 并将其用作张量流中的指标 对于皮尔逊来说 这是微不足道的 tf contrib metrics streaming pearson correlation y pre
  • 如何让文字显示5秒然后消失并显示按钮?

    我正在努力做到这一点 以便当您在我的问答游戏中得到正确答案时 它会摆脱您看到的大问题并说 干得好 5秒钟 然后返回到主菜单 其中随机有4个选定的问题 问题从 quizfile csv 加载并包含 What colour is elon mu
  • 从另一个文件导入函数,在哪里导入其他库?

    很简单的问题 我搜了一下没有结果 假设我有一个文件 funcs py 其中有一个我想调用当前脚本的函数 该函数使用另一个库 例如 pandas 我在哪里导入该库 约定是什么 我是否将它放在 funcs py 的函数内 funcs py de
  • selenium.common.exceptions.SessionNotCreatedException:消息:未从选项卡创建的会话使用 ChromeDriver Chrome Selenium Python 崩溃

    当我尝试访问脚本请求的没有特定的 url 时 显然出现此错误 我不明白为什么会出现这个错误 但我想对其进行处理 以免在发生错误时中止脚本 这会重复 但不能解决我的问题 如何避免错误 selenium common exceptions Se
  • 如何测试具有多个输入调用的循环?

    我正在尝试测试一个依赖多个用户输入来返回某个值的函数 我已经在这里寻找了多个答案 但没有一个能够解决我的问题 我看到了参数化 模拟和猴子补丁的东西 但没有任何帮助 我认为很大程度上是因为我没有清楚地理解正在做的事情背后的概念 并且我无法适应
  • 在Python中单击按钮时隐藏标签

    在 Python Tkinter 中单击按钮时如何隐藏现有标签 这实际上取决于您使用的几何管理器 如果你使用 lbl Tkinter Label parent 要创建标签 您将使用以下方法之一来隐藏它 lbl grid forget lbl
  • Python 模块导入对一个文件有效,对另一个文件则失败

    我面临着一个非常奇怪的问题 我有三个文件 第一个包含基类 其他两个文件中的类继承自该基类 奇怪的是 昨天一切都工作正常 但今天其中一个文件不再工作了 在此期间我还没有接触过进口 orangecontrib init py prototype
  • Python 3 中的“raw_input()”和“input()”有什么区别? [复制]

    这个问题在这里已经有答案了 有什么区别raw input and input 在 Python 3 中 不同之处在于raw input Python 3 x 中不存在 而input 做 其实 老raw input 已更名为input 和旧的
  • pip install 找不到包,但 pip search 找到

    我想安装hdbcli https pypi org project hdbcli 包 SAP HANA 连接器 当我搜索时pip正在找到该包 但是当我想安装它时 pip找不到包裹 指定当前包也不会产生任何结果 pip install hdb
  • 为 hist2d 子图添加一个颜色条并使它们相邻

    我正在努力调整情节 我一直在努力 我面临两个问题 这些图应该是相邻的并且 wspace 和 hspace 为 0 我将两个值都设置为零 但图之间仍然有一些空格 我想为所有子图使用一个颜色条 它们的范围都相同 现在 代码向最后一个子图添加了一
  • 模拟导入失败

    我该如何制作import pkg失败moduleA py 我可以打补丁pkg如果从中导入某些内容则会失败 否则不会失败 test py import os import moduleA from unittest mock import p
  • 插入失败“OperationalError:没有这样的列”

    我尝试使用我尝试修复的姓名和电话创建一个数据库 但它会随时向我重播 File exm0 py line 14 in

随机推荐

  • Rails 2 到 Rails 3,控制器中的方法验证消失了吗?

    来自 Rails 2 的我的大多数控制器都会有这些行 verify method gt post only gt create render gt text gt 405 HTTP POST required status gt 405 a
  • 错误的泛型转换没有 ClassCastException [Java]

    看一下下面这个类的main方法 public class Outer static class A
  • Google 地图,设置最小和最大滑块控件

    我正在开发一个谷歌地图 除了我似乎无法设置最大和最小缩放这一事实之外 它一切正常 我想将默认缩放视图的级别限制为几个级别 我尝试过使用 map getMimimumResolution 但这似乎不起作用 有什么想法吗 function in
  • 如何在 gdb 中打印长字符串的完整值?

    我想在 GDB 中打印 C 字符串的完整长度 默认情况下它是缩写的 如何强制 GDB 打印整个字符串 set print elements 0 来自GDB手册 https sourceware org gdb onlinedocs gdb
  • 将 Jersey 日志输出到文件?

    我们已将这些添加到 web xml 中
  • 如何在 Ruby 中计算字符串的宽度?

    String length只会告诉我字符串中有多少个字符 事实上 在Ruby 1 9之前 它只会告诉我有多少字节 这更没有什么用处 我真的很想知道一个字符串有多少 en 宽 例如 foo width gt 3 moo width gt 3
  • git 将上游设置为原点

    我一直在阅读和使用 git 但仍然对 起源 一词感到困惑 我有一个本地存储库 并在其上创建了一个新分支 这个新分支是我主人的副本 我的主控是原始主控的分叉 是其他人制作的另一个存储库 并且在某些提交方面领先于原始主控 而在其他方面则落后于原
  • 在 64 位计算机上使用 Redemption dll (Outlook)

    我在 32 位机器上安装了一个 exe 它循环访问登录的用户收件箱并且工作正常 注意我仍然没有让它为另一个用户工作 see here https stackoverflow com questions 589254 using redemp
  • 防止隐藏状态栏重新布局(伪造 SYSTEM_UI_FLAG_LAYOUT_STABLE)

    我正在开发具有列表视图和详细信息视图的应用程序 并且我从列表视图到详细视图进行动画处理 在执行此操作时 我想在某个阶段隐藏状态栏 最好仍然在后台显示列表视图 问题是使用隐藏状态栏 getWindow setFlags WindowManag
  • 如何使用 java.nio.ByteBuffer 从 C++ 返回到 Java

    这几乎是问题的重复如何使 Swig 正确包装在 C 中修改为 Java Something or other 的 char 缓冲区 https stackoverflow com questions 2740068 how can i ma
  • 更新的标题:为什么 ICommand.CanExecute 一直被调用,而不是像事件一样工作?

    我在 WPF 中采用 MVVM 模式并学习了使用Command 但在我的实现中 我分配来实现的代表CanExecute总是被调用 我的意思是 如果我在委托函数内放置一个断点 它表明该函数不断被调用 根据我的理解 也是一种自然的思维方式 但我
  • 如何使用 CSS 实现这种视觉效果

    我需要仅使用 css 和一个高度和宽度为 300px 的 div 创建上述视觉效果 我尝试了渐变但无法得到任何相同的东西 有人可以帮忙吗 渐变是一个好主意 您甚至可以添加内容 无论渐变的大小如何 只要将其大小设置为正方形即可 div bac
  • 使用带有几个字符串的 Ionic Storage 的 QuotaExceededError

    我在离子存储方面遇到了这个问题 这是完整的堆栈跟踪 core es5 js 1084 ERROR Error Uncaught in promise QuotaExceededError at c polyfills js 3 at c p
  • 如何使用词袋进行训练和预测?

    我有一个文件夹 里面有汽车各个角度的图像 我想使用词袋方法来训练系统识别汽车 训练完成后 我希望如果给出那辆车的图像 它应该能够识别它 我一直在尝试学习 opencv 中的 BOW 函数 以便完成这项工作 并且已经达到了我现在不知道该怎么做
  • vue.js 可以绑定内联样式吗?

    我很好奇 Vue js 中是否可以绑定内联样式 我熟悉类绑定 但是如果有时由于某种原因您想要内联绑定样式语句 是否可以像对待类一样绑定它 例如
  • 枚举和字典<枚举,操作>

    我希望我能以每个人都清楚的方式解释我的问题 我们需要您对此的建议 我们有一个枚举类型 它定义了超过 15 个常量 我们收到来自 Web 服务的报告 并将其一列转换为此枚举类型 根据我们从该网络服务收到的信息 我们使用以下命令运行特定功能 字
  • 有什么例子可以说明了解 C 语言可以让我用任何其他语言编写更好的代码?

    在 Stack Overflow 播客中 Joel Spolsky 不断地抱怨 Jeff Atwood 不知道如何用 C 语言编写代码 他的说法是 了解 C 可以帮助你编写更好的代码 他还总是使用某种涉及字符串操作的故事 以及了解 C 如何
  • Azure 数据工厂 v2:活动执行管道输出

    有没有办法在活动 执行管道 中引用已执行管道的输出 即 主管道按顺序执行2个管道 第一个管道生成一个自己创建的 run id 需要将其作为参数转发到第二个管道 我已阅读文档并检查主管道是否记录了第一个管道的输出 但看起来这不可能直接实现 到
  • 断言:exportArchive:“Test.app”需要配置文件

    当我尝试在 Xcode9 中使用 Xcode 服务器集成持续集成时 我可以成功创建 BOT 并尝试集成 然后我总是收到类似的错误 断言 exportArchive Test app 需要配置文件 如何解决这个问题 我遇到了同样的问题 并按照
  • torch.mm、torch.matmul 和 torch.mul 有什么区别?

    阅读完 pytorch 文档后 我仍然需要帮助来理解之间的区别torch mm torch matmul and torch mul 由于我不完全理解它们 所以我无法简明地解释这一点 B torch tensor 1 1207 0 3137