【PyTorch】torch.utils.data.Dataset 介绍与实战

2023-11-05


一、前言

训练模型一般都是先处理 数据的输入问题预处理问题 。Pytorch提供了几个有用的工具:torch.utils.data.Dataset 类和 torch.utils.data.DataLoader 类 。

流程是先把原始数据转变成 torch.utils.data.Dataset 类,随后再把得到的 torch.utils.data.Dataset 类当作一个参数传递给 torch.utils.data.DataLoader 类,得到一个数据加载器,这个数据加载器每次可以返回一个 Batch 的数据供模型训练使用。

在 pytorch 中,提供了一种十分方便的数据读取机制,即使用 torch.utils.data.DatasetDataloader 组合得到数据迭代器。在每次训练时,利用这个迭代器输出每一个 batch 数据,并能在输出时对数据进行相应的预处理或数据增广操作。

本文我们主要介绍对 torch.utils.data.Dataset 的理解,对 Dataloader 的介绍请参考我的另一篇文章:【PyTorch】torch.utils.data.DataLoader 简单介绍与使用

在本文的最后将给出 torch.utils.data.DatasetDataloader 结合使用处理数据的实战代码。


二、torch.utils.data.Dataset 是什么

1. 干什么用的?

  1. pytorch 提供了一个数据读取的方法,其由两个类构成:torch.utils.data.Dataset 和 DataLoader。
  2. 如果我们要自定义自己读取数据的方法,就需要继承类 torch.utils.data.Dataset ,并将其封装到DataLoader 中。
  3. torch.utils.data.Dataset 是一个 Dataset 。通过重写定义在该类上的方法,我们可以实现多种数据读取及数据预处理方式。

2. 长什么样子?

torch.utils.data.Dataset 的源码:

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

注释翻译:

表示一个数据集的抽象类。

所有其他数据集都应该对其进行子类化。 所有子类都应该重写提供数据集大小的 __len____getitem__ ,支持从 0 到 len(self) 独占的整数索引。

理解:

就是说,Dataset 是一个 数据集 抽象类,它是其他所有数据集类的父类(所有其他数据集类都应该继承它),继承时需要重写方法 __len____getitem____len__ 是提供数据集大小的方法, __getitem__ 是可以通过索引号找到数据的方法。


三、通过继承 torch.utils.data.Dataset 定义自己的数据集类

torch.utils.data.Dataset 是代表自定义数据集的抽象类,我们可以定义自己的数据类抽象这个类,只需要重写__len__和__getitem__这两个方法就可以。

要自定义自己的 Dataset 类,至少要重载两个方法:__len__, __getitem__

  1. __len__返回的是数据集的大小
  2. __getitem__实现索引数据集中的某一个数据

下面将简单实现一个返回 torch.Tensor 类型的数据集:

from torch.utils.data import Dataset
import torch

class TensorDataset(Dataset):
    # TensorDataset继承Dataset, 重载了__init__, __getitem__, __len__
    # 实现将一组Tensor数据对封装成Tensor数据集
    # 能够通过index得到数据集的数据,能够通过len,得到数据集大小

    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor

    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

    def __len__(self):
        return self.data_tensor.size(0)    # size(0) 返回当前张量维数的第一维

# 生成数据
data_tensor = torch.randn(4, 3)   # 4 行 3 列,服从正态分布的张量
print(data_tensor)
target_tensor = torch.rand(4)     # 4 个元素,服从均匀分布的张量
print(target_tensor)

# 将数据封装成 Dataset (用 TensorDataset 类)
tensor_dataset = TensorDataset(data_tensor, target_tensor)

# 可使用索引调用数据
print('tensor_data[0]: ', tensor_dataset[0])

# 可返回数据len
print('len os tensor_dataset: ', len(tensor_dataset))

输出结果:

tensor([[ 0.8618,  0.4644, -0.5929],
        [ 0.9566, -0.9067,  1.5781],
        [ 0.3943, -0.7775,  2.0366],
        [-1.2570, -0.3859, -0.3542]])
tensor([0.1363, 0.6545, 0.4345, 0.9928])
tensor_data[0]:  (tensor([ 0.8618,  0.4644, -0.5929]), tensor(0.1363))
len os tensor_dataset:  4

四、为什么要定义自己的数据集类?

因为我们可以通过定义自己的数据集类并重写该类上的方法 实现多种多样的(自定义的)数据读取方式

比如,我们重写 __init__ 实现用 pd.read_csv 读取 csv 文件:

from torch.utils.data import Dataset
import pandas as pd  # 这个包用来读取CSV数据

# 继承Dataset,定义自己的数据集类 mydataset
class mydataset(Dataset):
    def __init__(self, csv_file):   # self 参数必须,其他参数及其形式随程序需要而不同,比如(self,*inputs)
        self.csv_data = pd.read_csv(csv_file)
    def __len__(self):
        return len(self.csv_data)
    def __getitem__(self, idx):
        data = self.csv_data.values[idx]
        return data

data = mydataset('spambase.csv')
print(data[3])
print(len(data))

输出结果:

[0.000e+00 0.000e+00 0.000e+00 0.000e+00 6.300e-01 0.000e+00 3.100e-01
 6.300e-01 3.100e-01 6.300e-01 3.100e-01 3.100e-01 3.100e-01 0.000e+00
 0.000e+00 3.100e-01 0.000e+00 0.000e+00 3.180e+00 0.000e+00 3.100e-01
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 1.370e-01 0.000e+00 1.370e-01 0.000e+00 0.000e+00 3.537e+00 4.000e+01
 1.910e+02 1.000e+00]
4601

要点:

  1. 自己定义的 dataset 类需要继承 Dataset。
  2. 需要实现必要的魔法方法:
    __init__ 方法里面进行 读取数据文件
    __getitem__ 方法里支持通过下标访问数据。
    __len__ 方法里返回自定义数据集的大小,方便后期遍历。

五、实战:torch.utils.data.Dataset + Dataloader 实现数据集读取和迭代

实例 1

数据集 spambase.csv 用的是 UCI 机器学习存储库里的垃圾邮件数据集,它一条数据有57个特征和1个标签。

import torch.utils.data as Data
import pandas as pd  # 这个包用来读取CSV数据
import torch


# 继承Dataset,定义自己的数据集类 mydataset
class mydataset(Data.Dataset):
    def __init__(self, csv_file):   # self 参数必须,其他参数及其形式随程序需要而不同,比如(self,*inputs)
        data_csv = pd.DataFrame(pd.read_csv(csv_file))   # 读数据
        self.csv_data = data_csv.drop(axis=1, columns='58', inplace=False)  # 删除最后一列标签
    def __len__(self):
        return len(self.csv_data)
    def __getitem__(self, idx):
        data = self.csv_data.values[idx]
        return data


data = mydataset('spambase.csv')
x = torch.tensor(data[:5])         # 前五个数据
y = torch.tensor([1, 1, 1, 1, 1])  # 标签


torch_dataset = Data.TensorDataset(x, y)  # 对给定的 tensor 数据,将他们包装成 dataset

loader = Data.DataLoader(
    # 从数据库中每次抽出batch size个样本
    dataset = torch_dataset,       # torch TensorDataset format
    batch_size = 2,                # mini batch size
    shuffle=True,                  # 要不要打乱数据 (打乱比较好)
    num_workers=2,                 # 多线程来读数据
)

def show_batch():
    for step, (batch_x, batch_y) in enumerate(loader):
        print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))

show_batch()

输出结果:

steop:0, batch_x:tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.3000e-01, 0.0000e+00,
         3.1000e-01, 6.3000e-01, 3.1000e-01, 6.3000e-01, 3.1000e-01, 3.1000e-01,
         3.1000e-01, 0.0000e+00, 0.0000e+00, 3.1000e-01, 0.0000e+00, 0.0000e+00,
         3.1800e+00, 0.0000e+00, 3.1000e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 1.3500e-01, 0.0000e+00, 1.3500e-01, 0.0000e+00, 0.0000e+00,
         3.5370e+00, 4.0000e+01, 1.9100e+02],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.3000e-01, 0.0000e+00,
         3.1000e-01, 6.3000e-01, 3.1000e-01, 6.3000e-01, 3.1000e-01, 3.1000e-01,
         3.1000e-01, 0.0000e+00, 0.0000e+00, 3.1000e-01, 0.0000e+00, 0.0000e+00,
         3.1800e+00, 0.0000e+00, 3.1000e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 1.3700e-01, 0.0000e+00, 1.3700e-01, 0.0000e+00, 0.0000e+00,
         3.5370e+00, 4.0000e+01, 1.9100e+02]], dtype=torch.float64), batch_y:tensor([1, 1])
steop:1, batch_x:tensor([[2.1000e-01, 2.8000e-01, 5.0000e-01, 0.0000e+00, 1.4000e-01, 2.8000e-01,
         2.1000e-01, 7.0000e-02, 0.0000e+00, 9.4000e-01, 2.1000e-01, 7.9000e-01,
         6.5000e-01, 2.1000e-01, 1.4000e-01, 1.4000e-01, 7.0000e-02, 2.8000e-01,
         3.4700e+00, 0.0000e+00, 1.5900e+00, 0.0000e+00, 4.3000e-01, 4.3000e-01,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         7.0000e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 1.3200e-01, 0.0000e+00, 3.7200e-01, 1.8000e-01, 4.8000e-02,
         5.1140e+00, 1.0100e+02, 1.0280e+03],
        [6.0000e-02, 0.0000e+00, 7.1000e-01, 0.0000e+00, 1.2300e+00, 1.9000e-01,
         1.9000e-01, 1.2000e-01, 6.4000e-01, 2.5000e-01, 3.8000e-01, 4.5000e-01,
         1.2000e-01, 0.0000e+00, 1.7500e+00, 6.0000e-02, 6.0000e-02, 1.0300e+00,
         1.3600e+00, 3.2000e-01, 5.1000e-01, 0.0000e+00, 1.1600e+00, 6.0000e-02,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 6.0000e-02, 0.0000e+00, 0.0000e+00,
         1.2000e-01, 0.0000e+00, 6.0000e-02, 6.0000e-02, 0.0000e+00, 0.0000e+00,
         1.0000e-02, 1.4300e-01, 0.0000e+00, 2.7600e-01, 1.8400e-01, 1.0000e-02,
         9.8210e+00, 4.8500e+02, 2.2590e+03]], dtype=torch.float64), batch_y:tensor([1, 1])
steop:2, batch_x:tensor([[  0.0000,   0.6400,   0.6400,   0.0000,   0.3200,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.6400,   0.0000,   0.0000,
           0.0000,   0.3200,   0.0000,   1.2900,   1.9300,   0.0000,   0.9600,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.7780,   0.0000,   0.0000,   3.7560,  61.0000,
         278.0000]], dtype=torch.float64), batch_y:tensor([1])

一共 5 条数据,batch_size 设为 2 ,则数据被分为三组,每组的数据量为:2,2,1。

实例 2:进阶

import torch.utils.data as Data
import pandas as pd  # 这个包用来读取CSV数据
import numpy as np

# 继承Dataset,定义自己的数据集类 mydataset
class mydataset(Data.Dataset):
    def __init__(self, csv_file):   # self 参数必须,其他参数及其形式随程序需要而不同,比如(self,*inputs)
        # 读取数据
        frame = pd.DataFrame(pd.read_csv('spambase.csv'))
        spam = frame[frame['58'] == 1]
        ham = frame[frame['58'] == 0]
        SpamNew = spam.drop(axis=1, columns='58', inplace=False)  # 删除第58列,inplace=False不改变原数据,返回一个新dataframe
        HamNew = ham.drop(axis=1, columns='58', inplace=False)
        # 数据
        self.csv_data = np.vstack([np.array(SpamNew), np.array(HamNew)])  # 将两个N维数组进行连接,形成X
        # 标签
        self.Label = np.array([1] * len(spam) + [0] * len(ham))  # 形成标签值列表y
    def __len__(self):
        return len(self.csv_data)
    def __getitem__(self, idx):
        data = self.csv_data[idx]
        label = self.Label[idx]
        return data, label


data = mydataset('spambase.csv')
print(len(data))

loader = Data.DataLoader(
    # 从数据库中每次抽出batch size个样本
    dataset = data,       # torch TensorDataset format
    batch_size = 460,                # mini batch size
    shuffle=True,                  # 要不要打乱数据 (打乱比较好)
    num_workers=2,                 # 多线程来读数据
)

def show_batch():
    for step, (batch_x, batch_y) in enumerate(loader):
        print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))

show_batch()

输出结果:

4601
steop:0, batch_x:tensor([[0.0000e+00, 2.4600e+00, 0.0000e+00,  ..., 2.1420e+00, 1.0000e+01,
         7.5000e+01],
        [0.0000e+00, 0.0000e+00, 1.6000e+00,  ..., 2.0650e+00, 1.2000e+01,
         9.5000e+01],
        [0.0000e+00, 0.0000e+00, 3.6000e-01,  ..., 3.7220e+00, 2.0000e+01,
         2.6800e+02],
        ...,
        [7.7000e-01, 3.8000e-01, 7.7000e-01,  ..., 1.4619e+01, 5.2500e+02,
         9.2100e+02],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.0000e+00, 1.0000e+00,
         5.0000e+00],
        [4.0000e-01, 1.8000e-01, 3.2000e-01,  ..., 3.3050e+00, 1.8100e+02,
         1.6130e+03]], dtype=torch.float64), batch_y:tensor([0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1,
        0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0,
        0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0,
        1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0,
        0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
        1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0,
        0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0,
        1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1,
        0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1,
        1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0,
        0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0,
        0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1,
        0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0,
        1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0,
        0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1,
        1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1,
        0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1,
        0, 1, 0, 1])
steop:1, batch_x:tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.0000e+00, 1.0000e+00,
         2.0000e+00],
        [4.9000e-01, 0.0000e+00, 7.4000e-01,  ..., 3.9750e+00, 4.7000e+01,
         4.8500e+02],
        [0.0000e+00, 0.0000e+00, 7.1000e-01,  ..., 4.0220e+00, 9.7000e+01,
         5.4300e+02],
        ...,
        [0.0000e+00, 1.4000e-01, 1.4000e-01,  ..., 5.3310e+00, 8.0000e+01,
         1.0290e+03],
        [0.0000e+00, 0.0000e+00, 3.6000e-01,  ..., 3.1760e+00, 5.1000e+01,
         2.7000e+02],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.1660e+00, 2.0000e+00,
         7.0000e+00]], dtype=torch.float64), batch_y:tensor([0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
        1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0,
        0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0,
        1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0,
        1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0,
        0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
        1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0,
        0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0,
        1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1,
        1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0,
        0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1,
        1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1,
        1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1,
        0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1,
        1, 0, 0, 0])
steop:2, batch_x:tensor([[0.0000e+00, 0.0000e+00, 1.4700e+00,  ..., 3.0000e+00, 3.3000e+01,
         1.7700e+02],
        [2.6000e-01, 4.6000e-01, 9.9000e-01,  ..., 1.3235e+01, 2.7200e+02,
         1.5750e+03],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.0450e+00, 6.0000e+00,
         4.5000e+01],
        ...,
        [4.0000e-01, 0.0000e+00, 0.0000e+00,  ..., 1.1940e+00, 5.0000e+00,
         1.2900e+02],
        [2.6000e-01, 0.0000e+00, 0.0000e+00,  ..., 1.8370e+00, 1.1000e+01,
         1.5800e+02],
        [5.0000e-02, 0.0000e+00, 1.0000e-01,  ..., 3.7150e+00, 1.0700e+02,
         1.3860e+03]], dtype=torch.float64), batch_y:tensor([1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0,
        1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0,
        0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0,
        0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0,
        0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0,
        0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0,
        0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1,
        0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0,
        1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0,
        0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0,
        1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
        1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1,
        0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1,
        1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0,
        1, 1, 0, 0])
steop:3, batch_x:tensor([[2.6000e-01, 0.0000e+00, 5.3000e-01,  ..., 2.6460e+00, 7.7000e+01,
         1.7200e+02],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.4280e+00, 5.0000e+00,
         1.7000e+01],
        [3.4000e-01, 0.0000e+00, 1.7000e+00,  ..., 6.6700e+02, 1.3330e+03,
         1.3340e+03],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.0000e+00, 1.0000e+00,
         7.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.7010e+00, 2.0000e+01,
         1.8100e+02],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.0000e+00, 1.1000e+01,
         3.6000e+01]], dtype=torch.float64), batch_y:tensor([0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
        1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1,
        0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
        1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0,
        0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0,
        1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0,
        0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0,
        0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1,
        0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1,
        0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1,
        1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0,
        1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0,
        1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0,
        1, 0, 0, 1])
steop:4, batch_x:tensor([[  0.0000,   0.0000,   0.3100,  ...,   5.7080, 138.0000, 274.0000],
        [  0.0000,   0.0000,   0.3400,  ...,   2.2570,  17.0000, 158.0000],
        [  1.0400,   0.0000,   0.0000,  ...,   1.0000,   1.0000,  17.0000],
        ...,
        [  0.0000,   0.0000,   0.0000,  ...,   4.0000,  12.0000,  28.0000],
        [  0.3300,   0.0000,   0.0000,  ...,   1.7880,   6.0000,  93.0000],
        [  0.0000,  14.2800,   0.0000,  ...,   1.8000,   5.0000,   9.0000]],
       dtype=torch.float64), batch_y:tensor([1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1,
        0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1,
        0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0,
        1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
        0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0,
        0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1,
        0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0,
        1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1,
        1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0,
        0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0,
        1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1,
        0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
        0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
        1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0,
        0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1,
        1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0,
        1, 1, 0, 0])
steop:5, batch_x:tensor([[7.0000e-01, 0.0000e+00, 1.0500e+00,  ..., 1.1660e+00, 1.3000e+01,
         1.8900e+02],
        [0.0000e+00, 3.3600e+00, 1.9200e+00,  ..., 6.1370e+00, 1.0700e+02,
         1.7800e+02],
        [5.4000e-01, 0.0000e+00, 1.0800e+00,  ..., 5.4540e+00, 6.8000e+01,
         1.8000e+02],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.8330e+00, 9.0000e+00,
         2.3000e+01],
        [6.0000e-02, 6.5000e-01, 7.1000e-01,  ..., 4.7420e+00, 1.1700e+02,
         1.3420e+03],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.6110e+00, 1.2000e+01,
         4.7000e+01]], dtype=torch.float64), batch_y:tensor([1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1,
        1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0,
        0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,
        0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1,
        0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0,
        0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1,
        1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1,
        0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1,
        1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1,
        0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1,
        0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0,
        0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1,
        0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1,
        0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0,
        1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0,
        0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
        0, 1, 1, 1])
steop:6, batch_x:tensor([[0.0000e+00, 1.4280e+01, 0.0000e+00,  ..., 1.8000e+00, 5.0000e+00,
         9.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.9280e+00, 1.5000e+01,
         5.4000e+01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.0692e+01, 6.5000e+01,
         1.3900e+02],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.5000e+00, 5.0000e+00,
         2.4000e+01],
        [7.6000e-01, 1.9000e-01, 3.8000e-01,  ..., 3.7020e+00, 4.5000e+01,
         1.0700e+03],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.0000e+00, 1.2000e+01,
         8.8000e+01]], dtype=torch.float64), batch_y:tensor([0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1,
        0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1,
        0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
        1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1,
        1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0,
        0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1,
        0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0,
        0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
        0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0,
        0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
        1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0,
        0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1,
        1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,
        1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,
        1, 0, 1, 0])
steop:7, batch_x:tensor([[0.0000e+00, 2.7000e-01, 0.0000e+00,  ..., 5.8020e+00, 4.3000e+01,
         4.1200e+02],
        [0.0000e+00, 3.5000e-01, 7.0000e-01,  ..., 3.6390e+00, 6.1000e+01,
         3.1300e+02],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.5920e+00, 7.0000e+00,
         1.2900e+02],
        ...,
        [8.0000e-02, 1.6000e-01, 8.0000e-02,  ..., 2.7470e+00, 8.6000e+01,
         1.9950e+03],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.6130e+00, 1.1000e+01,
         7.1000e+01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.9110e+00, 1.5000e+01,
         6.5000e+01]], dtype=torch.float64), batch_y:tensor([0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0,
        0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
        1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1,
        0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1,
        0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0,
        1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1,
        1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0,
        0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0,
        0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0,
        0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0,
        1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1,
        0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1,
        0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1,
        1, 0, 0, 0])
steop:8, batch_x:tensor([[1.7000e-01, 0.0000e+00, 1.7000e-01,  ..., 1.7960e+00, 1.2000e+01,
         4.5800e+02],
        [3.7000e-01, 0.0000e+00, 6.3000e-01,  ..., 1.1810e+00, 4.0000e+00,
         1.0400e+02],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.0000e+00, 1.0000e+00,
         7.0000e+00],
        ...,
        [2.3000e-01, 0.0000e+00, 4.7000e-01,  ..., 2.4200e+00, 1.2000e+01,
         3.3400e+02],
        [0.0000e+00, 0.0000e+00, 1.2900e+00,  ..., 1.3500e+00, 4.0000e+00,
         2.7000e+01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.3730e+00, 1.1000e+01,
         1.6900e+02]], dtype=torch.float64), batch_y:tensor([1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1,
        0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0,
        1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0,
        0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1,
        1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0,
        0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0,
        0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1,
        0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
        1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1,
        0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0,
        1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
        0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0,
        1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0,
        0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1,
        1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0,
        1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0,
        1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1,
        0, 0, 0, 0])
steop:9, batch_x:tensor([[0.0000e+00, 6.3000e-01, 0.0000e+00,  ..., 2.2150e+00, 2.2000e+01,
         1.1300e+02],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.0000e+00, 1.0000e+00,
         5.0000e+00],
        [0.0000e+00, 0.0000e+00, 2.0000e-01,  ..., 1.1870e+00, 1.1000e+01,
         1.1400e+02],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.3070e+00, 1.6000e+01,
         3.0000e+01],
        [5.1000e-01, 4.3000e-01, 2.9000e-01,  ..., 6.5900e+00, 7.3900e+02,
         2.3330e+03],
        [6.8000e-01, 6.8000e-01, 6.8000e-01,  ..., 2.4720e+00, 9.0000e+00,
         8.9000e+01]], dtype=torch.float64), batch_y:tensor([0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1,
        1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0,
        0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
        0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
        1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1,
        0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1,
        0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1,
        1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0,
        1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,
        1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0,
        0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
        1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0,
        1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
        1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0,
        1, 1, 1, 1])
steop:10, batch_x:tensor([[0.0000e+00, 2.5000e-01, 7.5000e-01, 0.0000e+00, 1.0000e+00, 2.5000e-01,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.5000e-01, 2.5000e-01,
         1.2500e+00, 0.0000e+00, 0.0000e+00, 2.5000e-01, 0.0000e+00, 1.2500e+00,
         2.5100e+00, 0.0000e+00, 1.7500e+00, 0.0000e+00, 2.5000e-01, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 2.5000e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 4.2000e-02, 0.0000e+00, 0.0000e+00,
         1.2040e+00, 7.0000e+00, 1.1800e+02]], dtype=torch.float64), batch_y:tensor([0])

一共 4601 条数据,按 batch_size = 460 来分:能划分为 11 组,前 10 组的数据量为 460,最后一组的数据量为 1 。


参考链接

  1. torch.Tensor.size()方法的使用举例
  2. Pytorch笔记05-自定义数据读取方式orch.utils.data.Dataset与Dataloader
  3. pytorch 可训练数据集创建(torch.utils.data)
  4. Pytorch的第一步:(1) Dataset类的使用
  5. pytorch中的torch.utils.data.Dataset和torch.utils.data.DataLoader
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【PyTorch】torch.utils.data.Dataset 介绍与实战 的相关文章

随机推荐

  • drools规则引擎的在项目中的使用手记

    需求 按照登录用户的会员等级 和签到周期 根据一定的计算规则送积分 由于之前都是通过if else去做的控制 规则变更的时候可能需要重新调整代码甚至发布服务 由于不想再每次规则变更后需要调整代码 于是最近在确认方案 于是最好找到了规则引擎
  • python3 条件语句

    条件语句 if 判断条件 执行语句 else 执行语句 if 判断条件1 执行语句1 elif 判断条件2 执行语句2 elif 判断条件3 执行语句3 else 执行语句4 python 并不支持 switch 语句 如果多个条件需同时判
  • uniapp打包app,对接华为厂商,实现unipush离线消息推送

    今天终于可以抽出点时间 来记录一下这几天心塞的心情 上周公司派过来一个活 说是使用uniapp制作一个app 同时要实现在线消息推送和离线消息推送 啥话没说就揽了下来 不过说实在的 从来没有开发过app 好歹会点vue 可想而知 接下来的几
  • arduino基础25个实验代码

    arduino基础25个实验代码 双色LED灯项目源码 int redPin 11 红色LED引脚 int greenPin 10 绿色LED引脚 int val 0 PWM输出值 void setup pinMode redPin OUT
  • 华为打造狼性团队的22条军规

    打造狼性团队的22条军规 领导者要读3遍 打出来 贴到桌子上 做老板的 无不对华为公司的狼性团队推崇有加 华为的狼性文化之所以如此成功 三大因素缺一不可 一是具有诱惑力的薪酬 这是自驱力 二是内部竞争机制 这是推动力 三是执行力文化 这是牵
  • 类的静态成员变量为什么不能再h文件类外初始化

    h文件 class Image public static void AddProtoType Image iamge Prototype nsize iamge private static Image Prototype 10 stat
  • pandas练习题

    按要求创建Dataframe df 并通过分组得到以下结果 以A分组 求出C D的分组平均值 以A B分组 求出D E的分组求和 以A分组 得到所有分组 以字典显示 按照数值类型分组 求和 将C D作为一组分出来 并计算求和 以B分组 求出
  • git撤销一次代码提交方法

    以下方法亲测有效 但是根据需求选择哦 友情提示 注意备份 方法一 1 删除上一次提交 或者撤销上一次合并 reset方式是将HEAD指针指到指定提交 历史记录则不会出现你删除的上步commit记录等 合并时间线等都会删除彻底 并删除 mer
  • 分享是个好习惯

    无止境的求索 把脚印记下来 累了 迷茫了 回头望望 记住来时的路 收拾收拾行囊 云淡风轻
  • C++字符指针的特殊

    如果我们对一个非字符的指针进行操作 方法是这样的 注意 int p 则p i 等价于 p i 定义 1 int a 7 int p a 或者 2 int a 7 int p p a 或者 3 int a 7 int p p a 1 这样定义
  • matlab将一个数组中的元素转换为整型_科学计算

    最近期末考试结束了 自己立下了一个flag 自学MATLAB 写这篇文章的目的就在于将自己所学的知识输出 希望能够帮到你 大家一起相互学习吧 话不多说 下面直接进入主题 01matlab系统环境 1 matlab操作界面的组成 matlab
  • Go语言 之 变量声明

    声明变量 能声明的数据类型 https www runoob com go go data types html 变量可以声明为全局变量和局部变量 在函数外声明为全局变量 函数内声明为局部变量 声明变量方式为 var 变量名 数据类型 pa
  • 这个 Chrome 插件,让你的 ChatGPT 不再报错

    ChatGPT的官网最近几天报错越来越频繁了 相信大家都发现了 一旦你离开页面时间比较久 再度返回跟它进行对话 就会出现如下报错 虽然这个报错信息以前也出现过 但现在的频率确实过高 对于每天需要使用 ChatGPT 处理大量任务的用户来说
  • [ kvm ] 进程的处理器亲和性和vCPU的绑定

    cpu调用进程或线程的方式 Linux内核的进程调度器根据自有的调度策略将系统中的一个进程调度到某个CPU上执行 一个进程在前一个执行时间是在cpuM上运行 而在后一个执行时间则是在cpuN上运行 这样的情况在cpu中是很可能发生的 因为l
  • 有关CSS3 3D盒子模型的一些总结

    以前就想学CSS3动画 觉得挺高级的 但后来因为一些原因 没能理解好 也没有时间 最近重新学了一波 为了帮助那些像我一样理解能力不太好的人 同时也使自己更好的理解知识点 这里做一下总结 主要是我在学习过程中遇到的一些问题 如果有写的不清晰的
  • 安防天下5、6——视频编码器技术DVS、网络录像机(NVR)技术

    视频编码器技术DVS DVS Digital vedio server 的出现 标志着视频监控系统进入了网络时代 编码器的主要功能是编码压缩及网络传输 适合应用再监控点比较分散的应用环境中 但从本质上讲 DVS还不是纯粹的网络监控设备 因为
  • Linux常用命令大全(详细版)

    目录 1 Linux管理文件和目录的命令 2 有关磁盘空间的命令 3 文件备份和压缩命令 4 有关关机和查看系统信息的命令 5 管理使用者和设立权限的命令 6 线上查询的命令 7 文件阅读的命令 8 网络操作命令 9 其他命令 详细版本 1
  • Spring常用注解(绝对经典),全靠这份Java知识点PDF大全

    Bean public ColorFactoryBean colorFactoryBean return new ColorFactoryBean 创建一个spring定义的FactoryBean public class ColorFac
  • 原生js实现导航栏拖拽滑动(适用于pc端和手机端)

    先贴一张动图看看效果吧 下面把代码贴上注释都在代码边上
  • 【PyTorch】torch.utils.data.Dataset 介绍与实战

    文章目录 一 前言 二 torch utils data Dataset 是什么 1 干什么用的 2 长什么样子 三 通过继承 torch utils data Dataset 定义自己的数据集类 四 为什么要定义自己的数据集类 五 实战