对数据集进行批处理后,最后一批的形状可能与其余批次的形状不同。例如,如果数据集中共有 100 个元素,并且批处理的大小为 6,则最后一批的大小仅为 4。(100 = 6 * 16 + 4)。
因此,在这种情况下,您将无法直接将数据集转换为 numpy。因此,您将不得不使用drop_remainder https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset#batch批处理方法中的参数为 True。如果最后一批尺寸不正确,它将丢弃它。
之后,我附上了有关如何将数据集转换为 Numpy 的代码。
import tensorflow as tf
import numpy as np
(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()
TRAIN_BUF=1000
BATCH_SIZE=64
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).
shuffle(TRAIN_BUF).batch(BATCH_SIZE, drop_remainder=True)
test_dataset = tf.data.Dataset.from_tensor_slices(test_images).
shuffle(TRAIN_BUF).batch(BATCH_SIZE, drop_remainder=True)
# print(train_dataset, type(train_dataset), test_dataset, type(test_dataset))
train_np = np.stack(list(train_dataset))
test_np = np.stack(list(test_dataset))
print(type(train_np), train_np.shape)
print(type(test_np), test_np.shape)
Output:
<class 'numpy.ndarray'> (937, 64, 28, 28)
<class 'numpy.ndarray'> (156, 64, 28, 28)