DataSet
的用法可以参考:pytorch 构造读取数据的工具类 Dataset 与 DataLoader (pytorch Data学习一)
DataLoader
的封装方法可以参考:Pytorch DataLoader一次性封装多种数据集(pytorch Data学习六)
这里博主提供的是一个工具,整个封装流程是:
- 构造
DataSet
用以定义数据集x与y的模板
- 使用
sklearn.datasets.make_regression
生成回归任务的数据
- 使用
pytorch
的Tensor
格式封装产生的数据
- 将
Tensor
格式数据封装如DataSet
- 将
DataSet
封装入DataLoader
示例代码
from torch.utils.data import Dataset
import torch
class DatasetXY(Dataset):
def __init__(self, x, y):
self._x = x
self._y = y
self._len = len(x)
def __getitem__(self, item): # 每次循环的时候返回的值
return self._x[item], self._y[item]
def __len__(self):
return self._len
def load_data(samples=1000, n_features=10, split_train_size: float = 0.3):
"""用来生成回归任务的训练、测试数据"""
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
# 使用sklearn生成回归任务数据
data_x, data_y = make_regression(n_samples=samples, n_features=n_features)
x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, train_size=split_train_size, random_state=0)
# 封装为dataloader
train_loader = DataLoader(DatasetXY(torch.Tensor(x_train).float(), torch.Tensor(y_train).float()),
batch_size=10, shuffle=False, drop_last=True, num_workers=0)
test_loader = DataLoader(DatasetXY(torch.Tensor(x_test).float(), torch.Tensor(y_test).float()),
batch_size=10, shuffle=False, drop_last=True, num_workers=0)
return train_loader, test_loader
def main():
train_loader, test_loader = load_data()
for train_x, train_y in train_loader:
print("打印训练数据:")
print("train_x:", train_x)
print('train_y:', train_y)
break
for test_x, test_y in test_loader:
print("打印训练数据:")
print("test_x:", test_x)
print('test_y:', test_y)
break
if __name__ == '__main__':
main()