所以我正在关注这个文档中的教程 https://pytorch.org/tutorials/beginner/data_loading_tutorial.html在自定义数据集上。我使用的是 MNIST 数据集,而不是教程中的奇特数据集。这是Dataset
我写的课:
class KaggleMNIST(Dataset):
def __init__(self, csv_file, transform=None):
self.pixel_frame = pd.read_csv(csv_file)
self.transform = transform
def __len__(self):
return len(self.pixel_frame)
def __getitem__(self, index):
if torch.is_tensor(index):
index = index.tolist()
image = self.pixel_frame.iloc[index, 1:]
image = np.array([image])
if self.transform:
image = self.transform(image)
return image
它有效,直到我尝试对其使用转换:
tsf = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
trainset = KaggleMNIST('train/train.csv', transform=tsf)
image0 = trainset[0]
我查看了堆栈跟踪,看起来规范化正在这行代码中发生:
c:\program files\python38\lib\site-packages\torchvision\transforms\functional.py in normalize(tensor, mean, std, inplace)
--> 218 tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
所以我不明白为什么除以零std
应该是 0.5,远不是一个小值。
感谢您的帮助!
EDIT:
这并没有回答我的问题,但我发现如果我更改这些代码行:
image = self.pixel_frame.iloc[index, 1:]
image = np.array([image])
to
image = self.pixel_frame.iloc[index, 1:].to_numpy(dtype='float64').reshape(1, -1)
本质上,确保数据类型是float64
解决了问题。我仍然不确定为什么这个问题首先存在,所以我仍然很高兴得到一个解释清楚的答案!