为了后代,我提出这个问题的最终解决方案。下面的代码是一个复制/粘贴示例,可在该问题所解决的最复杂的条件下工作(请注意,其他两个答案不是可复制/粘贴的代码示例):
代码的目标是:
- 获取(大)文件列表并将其分成块(文件名/索引对)
- 使用映射操作处理每个块(生成器在这里不是一个可行的解决方案,请参阅:https://github.com/tensorflow/tensorflow/issues/16343)
- 从仅采用 1 个文件/块作为输入的映射操作输出多个样本。
- 在整个过程中维护元素命名
复制/粘贴 Tensorflow 1.5 / Python 3.x 的工作示例
import tensorflow as tf
import numpy as np
files = [b'testA', b'testB', b'testC']
def mymap1(x):
result_tensors = tf.py_func(func=mymap2, inp=[x], Tout=[tf.string, tf.int64])
return {'filename': result_tensors[0], 'value': result_tensors[1]}
def mymap2(x):
return np.array([x, x, x]), np.array([10, 20, 30])
def myflatmap(named_elements):
return tf.data.Dataset.zip({
'filename': tf.data.Dataset.from_tensor_slices(named_elements['filename']),
'value': tf.data.Dataset.from_tensor_slices(named_elements['value'])
})
ds = tf.data.Dataset.from_tensor_slices(files)
ds = ds.map(map_func=mymap1)
ds = ds.flat_map(map_func=myflatmap)
element = ds.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for _ in range(9):
print(sess.run(element))
Output:
{'filename': b'testA', 'value': 10}
{'filename': b'testA', 'value': 20}
{'filename': b'testA', 'value': 30}
{'filename': b'testB', 'value': 10}
{'filename': b'testB', 'value': 20}
{'filename': b'testB', 'value': 30}
{'filename': b'testC', 'value': 10}
{'filename': b'testC', 'value': 20}
{'filename': b'testC', 'value': 30}