Pytorch/TensorFlow/Numpy常用函数汇总

2023-11-16

一、Pytorch

1.枚举:enumerate

基本语法:enumerate(iterable,start=0)
返回结果:enumerate is useful for obtaining an indexed list:
(0, seq[0]), (1, seq[1]), (2, seq[2]), …

2.维度交换:permute

基本语法:tensor.permute(arr,dims)
举例:

arr=torch.randint(10,size=(2,3,4))
print(arr)
arr=torch.permute(arr,dims=(0,2,1))
print(arr.size())
print(arr)
tensor([[[7, 1, 7, 9],
         [6, 4, 4, 2],
         [9, 0, 7, 7]],

        [[0, 5, 3, 1],
         [3, 0, 6, 0],
         [1, 3, 4, 5]]])
torch.Size([2, 4, 3])
tensor([[[7, 6, 9],
         [1, 4, 0],
         [7, 4, 7],
         [9, 2, 7]],

        [[0, 3, 1],
         [5, 0, 3],
         [3, 6, 4],
         [1, 0, 5]]])

3.尺寸形状:size/shape

基本语法:
(1)tensor.size()
(2)tensor.shape
举例:

arr=torch.ones(size=(2,2,3,3))
print(arr.size())  # torch.Size([2, 2, 3, 3])
print(arr.size(0))  # 2
print(arr.shape)  # torch.Size([2, 2, 3, 3])

4.数据填充:full/fill_

1)创建指定形状的张量,并使用特定值填充该张量
基本语法:torch.full(size, fill_value)
举例:

arr=torch.full(size=(2,3),fill_value=2.0)

生成用2.0填充,形状为2*3的Tensor:

tensor([[2., 2., 2.],
        [2., 2., 2.]])

2)使用特定值填充指定向量
基本语法:tensor.fill_(value)
举例:

arr=torch.full(size=(2,3),fill_value=2.0)
print(arr)
print('--------------')
arr.fill_(value=3.0)
print(arr)
tensor([[2., 2., 2.],
        [2., 2., 2.]])
--------------
tensor([[3., 3., 3.],
        [3., 3., 3.]])

5.阻断反向梯度传播:detach

基本语法:tensor.detach()
举例:

arr=torch.full(size=(2,3),fill_value=2.0,requires_grad=True)
print(arr.requires_grad)
arr2=arr.detach()
print(arr2.requires_grad)
True
False

6.形状调整:view

基本语法:tensor.view(shape)
举例:

value=torch.arange(6)
print(value)
arr=value.view((-1,2))
print(arr)
tensor([0, 1, 2, 3, 4, 5])
tensor([[0, 1],
        [2, 3],
        [4, 5]])

7.维度扩张/减少:unsqueeze/squeeze

基本语法:
(1)torch.unsqueeze(tensor,dim)
(2)torch.squeeze(tensor,dim)
举例:

arr=torch.ones(size=(5,6,3))
print(arr.size())
arr=torch.unsqueeze(arr,dim=0)
print(arr.size())
arr=torch.squeeze(arr,dim=0)
print(arr.size())
torch.Size([5, 6, 3])
torch.Size([1, 5, 6, 3])
torch.Size([5, 6, 3])

二、Numpy

1.维度交换:transpose

基本语法:arr.transpose(axis)
举例:

arr=np.ones(shape=(2,3,4))
print(arr.transpose(2,1,0).shape) # 其中数字对应坐标轴,如2表示4、1表示3、0表示2
(4, 3, 2)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch/TensorFlow/Numpy常用函数汇总 的相关文章

随机推荐