在tensorflow 2.0-beta之前,要从tf.data.Dataset中检索第一个元素,我们可以使用迭代器,如下所示:
#!/usr/bin/python
import tensorflow as tf
train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
iterator = train_dataset.make_one_shot_iterator()
with tf.Session() as sess:
# 1.0 will be printed.
print (sess.run(iterator.get_next()))
在tensorflow 2.0-beta中,似乎上面的一次性迭代器现已弃用。要打印出整个元素,我们可以使用以下命令for方法。
#!/usr/bin/python
import tensorflow as tf
train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
for data in train_dataset:
# 1.0, 2.0, 3.0, and 4.0 will be printed.
print (data.numpy())
然而,如果我们只想从 tf.data.Dataset 中检索一个元素,那么我们该如何使用 TensorFlow 2.0 beta 呢?看起来next(train_dataset)
不支持。使用旧的一次性迭代器可以轻松完成,如上所示,但对于新的迭代器来说不是很明显for基于的方法。
欢迎任何建议。
启用即时执行(TF 2.0 中默认)的情况是:
elem = next(iter(train_dataset))
说明:数据集有一个__iter__
会员功能支持for elem in dataset:
方法。这将返回一个迭代器。 Python 函数iter
就是这样做的:基本上调用__iter__
功能。next
然后返回迭代器产生的第一个元素。
我还没有找到适用于非急切执行的解决方案,因为目前会引发RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)