我正在尝试将 TensorFlow 中的数据集转换为具有多个单值张量。数据集目前如下所示:
[12 43 64 34 45 2 13 54] [34 65 34 67 87 12 23 43] [23 53 23 1 5] ...
转换后应该是这样的:
[12] [43] [64] [34] [45] [2] [13] [54] [34] [65] [34] [67] [87] [12] ...
我最初的想法是使用flat_map
在数据集上,然后使用将每个张量转换为张量列表reshape
and unstack
:
output_labels = self.dataset.flat_map(convert_labels)
...
def convert_labels(tensor):
id_list = tf.unstack(tf.reshape(tensor, [-1, 1]))
return tf.data.Dataset.from_tensors(id_list)
然而,每个张量的形状仅部分已知(即。(?, 1)
) 这就是 unstack 操作失败的原因。有没有办法仍然“连接”不同的张量而不显式地迭代它们?
你的解决方案非常接近,但是Dataset.flat_map()接受一个返回 a 的函数tf.data.Dataset
对象,而不是张量列表。幸运的是,Dataset.from_tensor_slices()方法完全适合您的用例,因为它可以将张量拆分为可变数量的元素:
output_labels = self.dataset.flat_map(tf.data.Dataset.from_tensor_slices)
请注意,tf.contrib.data.unbatch()转换实现了相同的功能,并且在 TensorFlow 当前的 master 分支中具有稍微更高效的实现(将包含在 1.9 版本中):
output_labels = self.dataset.apply(tf.contrib.data.unbatch())
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)