我试图了解如何读取本地图像,将它们用作 TensorFlowDataset https://www.tensorflow.org/api_docs/python/tf/data/Dataset并使用 TF 数据集训练 Keras 模型。我正在关注 TF Keras MNIST TPUtutorial https://github.com/tensorflow/tpu/blob/master/tools/colab/keras_mnist_tpu.ipynb。唯一的区别是我想读取我的图像集并对其进行训练。
假设我有图像列表(文件名)和相应的标签列表。
files = [...] # list of file names
labels = [...] # list of labels (integers)
images = tf.constant(files) # or tf.convert_to_tensor(files)
labels = tf.constant(labels) # or tf.convert_to_tensor(labels)
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.shuffle(len(files))
dataset = dataset.repeat()
dataset = dataset.map(parse_function).batch(batch_size)
The parse_function
是一个简单的函数,它读取输入文件名并生成图像数据和相应的标签,例如
def parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_image(image_string)
image = tf.cast(image_decoded, tf.float32)
return image, label
此时我有一个dataset
这是一个 tf.data.Dataset 类型(更准确地说是 tf.data.BatchDataset),我将它传递给 keras 模型trained_model
from tutorial https://github.com/tensorflow/tpu/blob/master/tools/colab/keras_mnist_tpu.ipynb, e.g.
history = trained_model.fit(dataset, ...)
但此时代码因以下错误而中断:
AttributeError: 'BatchDataset' object has no attribute 'ndim'
该错误来自 keras,它对给定的输入执行检查
from keras import backend as K
K.is_tensor(dataset) # which returns false
Keras 尝试确定输入的类型,并且由于它不是张量,因此它假设它是 numpy 数组并尝试获取其维度。这就是错误发生的原因。
我的问题如下:
- 我正确读取 TF 数据集吗?我在互联网上查找了很多例子,看来我正在按照人们的建议阅读它
- 为什么我的数据集不是张量?可能我需要执行额外的转换,但TF不是这样的tutorial https://github.com/tensorflow/tpu/blob/master/tools/colab/keras_mnist_tpu.ipynb
- 为什么在TFtutorial https://github.com/tensorflow/tpu/blob/master/tools/colab/keras_mnist_tpu.ipynb一切都适用于 tf 数据集,我真的看不出他们读取 MNIST 数据的方式(数据格式不同,但最终他们得到图像)和我在这里所做的有任何区别。
任何建议将不胜感激。
请注意,即使是 TFtutorial https://github.com/tensorflow/tpu/blob/master/tools/colab/keras_mnist_tpu.ipynb是关于 TPU 的,它的结构使其可以在 TPU 和 CPU/GPU 上运行。