在自己写pytorch的数据集加载函数时,会遇到一个问题,如何将多张图片张量合并到一起,提升迷你批次batch的纬度,但是不影响图片的大小和RGB通道数
解决方法:
函数torch.cat(inputs,dim)
这里的inputs是你要合并的图像,dim=多少代表你打算让他们在哪个纬度上融合
import cv2
import torch
image_path = 'F:\\Pythontest\\cnn\\data\\train\\n01440764\\n01440764_13316.JPEG'
image_path1 = 'F:\\Pythontest\\cnn\\data\\train\\n01440764\\n01440764_13360.JPEG'
image_path2 = 'F:\\Pythontest\\cnn\\data\\train\\n01440764\\n01440764_13375.JPEG'
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
img = cv2.resize(img, (6, 6))
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).float()
img = img.unsqueeze(0)
print(img.size())
print(img)
img1 = cv2.imread(image_path1, cv2.IMREAD_COLOR)
img1 = cv2.resize(img1, (6, 6))
img1 = img1.transpose(2, 0, 1)
img1 = torch.from_numpy(img1).float()
img1 = img1.unsqueeze(0)
print(img1.size())
print(img1)
result1 = torch.cat((img,img1), dim=0)
print(result1.size())
print(result1)
img2 = cv2.imread(image_path2, cv2.IMREAD_COLOR)
# 对图像的预处理可以放到这个位置
img2 = cv2.resize(img2, (6, 6))
img2 = img2.transpose(2, 0, 1)
img2 = torch.from_numpy(img2).float()
img2 = img2.unsqueeze(0)
print(img2.size())
print(img2)
result0 = torch.cat((result1, img2), dim=0)
print(result0.size())
print(result0)
下面这是结果,可以看到,只有batch的维数增加了,其他位置没有变
在这里插入代码片
torch.Size([1, 3, 6, 6])
tensor([[[[112., 122., 126., 162., 76., 80.],
[ 87., 31., 36., 17., 84., 138.],
[ 21., 85., 186., 209., 228., 179.],
[ 32., 158., 199., 9., 78., 169.],
[ 27., 36., 25., 43., 38., 128.],
[ 28., 25., 53., 71., 105., 153.]],
[[162., 168., 158., 150., 67., 77.],
[145., 22., 21., 14., 75., 135.],
[ 18., 117., 210., 231., 237., 172.],
[ 23., 197., 226., 32., 92., 205.],
[ 20., 38., 20., 43., 65., 169.],
[ 54., 13., 52., 90., 145., 178.]],
[[151., 185., 157., 140., 57., 67.],
[135., 12., 15., 6., 66., 130.],
[ 10., 112., 216., 233., 241., 162.],
[ 19., 219., 237., 50., 88., 208.],
[ 11., 32., 11., 38., 55., 172.],
[ 34., 7., 42., 91., 164., 196.]]]])
torch.Size([1, 3, 6, 6])
tensor([[[[ 71., 63., 127., 75., 96., 62.],
[ 57., 232., 155., 185., 164., 117.],
[ 48., 51., 40., 43., 53., 37.],
[ 43., 78., 165., 208., 204., 37.],
[ 33., 76., 126., 114., 125., 100.],
[ 54., 58., 23., 15., 10., 19.]],
[[ 80., 64., 147., 97., 151., 101.],
[ 76., 238., 144., 169., 148., 173.],
[ 49., 43., 39., 50., 99., 41.],
[ 60., 128., 204., 216., 200., 67.],
[ 47., 88., 115., 100., 126., 93.],
[ 84., 111., 43., 12., 21., 58.]],
[[ 53., 58., 144., 102., 156., 102.],
[ 57., 228., 136., 156., 140., 170.],
[ 40., 44., 33., 57., 112., 59.],
[ 80., 128., 181., 194., 177., 76.],
[ 79., 76., 118., 102., 114., 86.],
[103., 119., 48., 18., 19., 56.]]]])
torch.Size([2, 3, 6, 6])
tensor([[[[112., 122., 126., 162., 76., 80.],
[ 87., 31., 36., 17., 84., 138.],
[ 21., 85., 186., 209., 228., 179.],
[ 32., 158., 199., 9., 78., 169.],
[ 27., 36., 25., 43., 38., 128.],
[ 28., 25., 53., 71., 105., 153.]],
[[162., 168., 158., 150., 67., 77.],
[145., 22., 21., 14., 75., 135.],
[ 18., 117., 210., 231., 237., 172.],
[ 23., 197., 226., 32., 92., 205.],
[ 20., 38., 20., 43., 65., 169.],
[ 54., 13., 52., 90., 145., 178.]],
[[151., 185., 157., 140., 57., 67.],
[135., 12., 15., 6., 66., 130.],
[ 10., 112., 216., 233., 241., 162.],
[ 19., 219., 237., 50., 88., 208.],
[ 11., 32., 11., 38., 55., 172.],
[ 34., 7., 42., 91., 164., 196.]]],
[[[ 71., 63., 127., 75., 96., 62.],
[ 57., 232., 155., 185., 164., 117.],
[ 48., 51., 40., 43., 53., 37.],
[ 43., 78., 165., 208., 204., 37.],
[ 33., 76., 126., 114., 125., 100.],
[ 54., 58., 23., 15., 10., 19.]],
[[ 80., 64., 147., 97., 151., 101.],
[ 76., 238., 144., 169., 148., 173.],
[ 49., 43., 39., 50., 99., 41.],
[ 60., 128., 204., 216., 200., 67.],
[ 47., 88., 115., 100., 126., 93.],
[ 84., 111., 43., 12., 21., 58.]],
[[ 53., 58., 144., 102., 156., 102.],
[ 57., 228., 136., 156., 140., 170.],
[ 40., 44., 33., 57., 112., 59.],
[ 80., 128., 181., 194., 177., 76.],
[ 79., 76., 118., 102., 114., 86.],
[103., 119., 48., 18., 19., 56.]]]])
torch.Size([1, 3, 6, 6])
tensor([[[[ 54., 88., 19., 106., 79., 164.],
[ 17., 39., 95., 20., 129., 70.],
[ 29., 150., 76., 121., 91., 78.],
[ 58., 50., 21., 33., 89., 99.],
[ 50., 17., 110., 53., 22., 24.],
[ 52., 84., 55., 40., 83., 19.]],
[[137., 130., 34., 153., 133., 201.],
[ 87., 65., 110., 62., 217., 168.],
[ 69., 153., 111., 141., 123., 95.],
[ 95., 91., 102., 106., 121., 174.],
[ 85., 68., 131., 100., 73., 95.],
[ 94., 114., 66., 92., 118., 94.]],
[[ 92., 105., 37., 120., 87., 184.],
[ 46., 40., 149., 31., 182., 134.],
[ 44., 177., 116., 142., 133., 89.],
[ 87., 90., 109., 114., 129., 144.],
[ 85., 81., 174., 105., 42., 54.],
[ 74., 141., 74., 62., 96., 46.]]]])
torch.Size([3, 3, 6, 6])
tensor([[[[112., 122., 126., 162., 76., 80.],
[ 87., 31., 36., 17., 84., 138.],
[ 21., 85., 186., 209., 228., 179.],
[ 32., 158., 199., 9., 78., 169.],
[ 27., 36., 25., 43., 38., 128.],
[ 28., 25., 53., 71., 105., 153.]],
[[162., 168., 158., 150., 67., 77.],
[145., 22., 21., 14., 75., 135.],
[ 18., 117., 210., 231., 237., 172.],
[ 23., 197., 226., 32., 92., 205.],
[ 20., 38., 20., 43., 65., 169.],
[ 54., 13., 52., 90., 145., 178.]],
[[151., 185., 157., 140., 57., 67.],
[135., 12., 15., 6., 66., 130.],
[ 10., 112., 216., 233., 241., 162.],
[ 19., 219., 237., 50., 88., 208.],
[ 11., 32., 11., 38., 55., 172.],
[ 34., 7., 42., 91., 164., 196.]]],
[[[ 71., 63., 127., 75., 96., 62.],
[ 57., 232., 155., 185., 164., 117.],
[ 48., 51., 40., 43., 53., 37.],
[ 43., 78., 165., 208., 204., 37.],
[ 33., 76., 126., 114., 125., 100.],
[ 54., 58., 23., 15., 10., 19.]],
[[ 80., 64., 147., 97., 151., 101.],
[ 76., 238., 144., 169., 148., 173.],
[ 49., 43., 39., 50., 99., 41.],
[ 60., 128., 204., 216., 200., 67.],
[ 47., 88., 115., 100., 126., 93.],
[ 84., 111., 43., 12., 21., 58.]],
[[ 53., 58., 144., 102., 156., 102.],
[ 57., 228., 136., 156., 140., 170.],
[ 40., 44., 33., 57., 112., 59.],
[ 80., 128., 181., 194., 177., 76.],
[ 79., 76., 118., 102., 114., 86.],
[103., 119., 48., 18., 19., 56.]]],
[[[ 54., 88., 19., 106., 79., 164.],
[ 17., 39., 95., 20., 129., 70.],
[ 29., 150., 76., 121., 91., 78.],
[ 58., 50., 21., 33., 89., 99.],
[ 50., 17., 110., 53., 22., 24.],
[ 52., 84., 55., 40., 83., 19.]],
[[137., 130., 34., 153., 133., 201.],
[ 87., 65., 110., 62., 217., 168.],
[ 69., 153., 111., 141., 123., 95.],
[ 95., 91., 102., 106., 121., 174.],
[ 85., 68., 131., 100., 73., 95.],
[ 94., 114., 66., 92., 118., 94.]],
[[ 92., 105., 37., 120., 87., 184.],
[ 46., 40., 149., 31., 182., 134.],
[ 44., 177., 116., 142., 133., 89.],
[ 87., 90., 109., 114., 129., 144.],
[ 85., 81., 174., 105., 42., 54.],
[ 74., 141., 74., 62., 96., 46.]]]])
进程已结束,退出代码0