我正在使用新的tf.data
API为 CIFAR10 数据集创建迭代器。我正在读取两个数据.tfrecord文件。一个保存训练数据 (train.tfrecords),另一个保存测试数据 (test.tfrecords)。这一切都很好。然而,在某些时候,我需要两个数据集(训练数据和测试数据)作为numpy 数组.
是否可以从 numpy 数组中检索数据集tf.data.TFRecordDataset
目的?
您可以使用tf.data.Dataset.batch() https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch转变和tf.contrib.data.get_single_element() https://www.tensorflow.org/api_docs/python/tf/contrib/data/get_single_element去做这个。
作为复习,dataset.batch(n)
将需要长达n
的连续元素dataset
并通过连接每个组件将它们转换为一个元素。这要求所有元素的每个组件都具有固定的形状。如果n
大于中的元素数量dataset
(or if n
没有精确地划分元素数量),那么最后一批可以更小。因此,您可以选择较大的值n
并执行以下操作:
import numpy as np
import tensorflow as tf
# Insert your own code for building `dataset`. For example:
dataset = tf.data.TFRecordDataset(...) # A dataset of tf.string records.
dataset = dataset.map(...) # Extract components from each tf.string record.
# Choose a value of `max_elems` that is at least as large as the dataset.
max_elems = np.iinfo(np.int64).max
dataset = dataset.batch(max_elems)
# Extracts the single element of a dataset as one or more `tf.Tensor` objects.
# No iterator needed in this case!
whole_dataset_tensors = tf.contrib.data.get_single_element(dataset)
# Create a session and evaluate `whole_dataset_tensors` to get arrays.
with tf.Session() as sess:
whole_dataset_arrays = sess.run(whole_dataset_tensors)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)