torch.mul()
函数功能:逐个对 input 和 other 中对应的元素相乘。
本操作支持广播,因此 input 和 other 均可以是张量或者数字
import torch
a = torch.randn((1,2))
b = torch.randn((2,1))
print(a,b)
torch.mul(a,b)
torch.multiply()
torch.mul() 的别称
torch.matmul()
matmul可以进行张量乘法, 输入可以是高维.
torch.dot()
函数功能:计算 input 和 output 的点乘,此函数要求 input 和 output 都必须是一维的张量(其 shape 属性中只有一个值)!并且要求两者元素个数相同!
import torch
x = torch.Tensor([1,2])
y = torch.Tensor([3,4])
z = torch.dot(x,y)
z
torch.mm()
函数功能:实现线性代数中的矩阵乘法(matrix multiplication):(n×m) × (m×p) = (n×p) 。
本函数不允许广播!
import torch
x = torch.randn((3,4))
y = torch.randn((4,5))
z = torch.mm(x,y)
z
torch.mv()
函数功能:实现矩阵和向量(matrix × vector)的乘法,要求 input 的形状为 n×m,output 为 torch.Size([m])的一维 tensor.
import torch
x = torch.randn((3,4))
y = torch.randn(4)
z = torch.mv(x,y)
z
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)