我正在将 TensorFlow 代码从旧的队列接口更改为新的数据集API https://www.tensorflow.org/api_docs/python/tf/data/Dataset。使用旧界面我可以指定num_threads
论证tf.train.shuffle_batch
队列。然而,控制 Dataset API 中线程数量的唯一方法似乎是在map
函数使用num_parallel_calls
争论。但是,我正在使用flat_map
函数代替,它没有这样的参数。
Question: 有没有办法控制线程/进程的数量flat_map
功能?或者有什么办法可以使用map
结合flat_map
并仍然指定并行调用的数量?
请注意,并行运行多个线程至关重要,因为我打算在数据进入队列之前在 CPU 上运行大量预处理。
那里有两个 (here https://github.com/tensorflow/tensorflow/issues/7951#issuecomment-305796971 and here https://github.com/tensorflow/tensorflow/issues/7951#issuecomment-326098305)GitHub 上的相关帖子,但我认为他们没有回答这个问题。
这是我的用例的最小代码示例以供说明:
with tf.Graph().as_default():
data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
input_tensors = (data,)
def pre_processing_func(data_):
# normally I would do data-augmentation here
results = (tf.expand_dims(data_, axis=0),)
return tf.data.Dataset.from_tensor_slices(results)
dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
dataset = dataset_source.flat_map(pre_processing_func)
# do something with 'dataset'
据我所知,目前flat_map
不提供并行选项。
鉴于大部分计算是在pre_processing_func
,您可以使用并行作为解决方法map
调用后进行一些缓冲,然后使用flat_map
使用负责平坦化输出的恒等 lambda 函数进行调用。
In code:
NUM_THREADS = 5
BUFFER_SIZE = 1000
def pre_processing_func(data_):
# data-augmentation here
# generate new samples starting from the sample `data_`
artificial_samples = generate_from_sample(data_)
return atificial_samples
dataset_source = (tf.data.Dataset.from_tensor_slices(input_tensors).
map(pre_processing_func, num_parallel_calls=NUM_THREADS).
prefetch(BUFFER_SIZE).
flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)).
shuffle(BUFFER_SIZE)) # my addition, probably necessary though
注意(对我自己和任何试图理解管道的人):
Since pre_processing_func
从初始样本开始生成任意数量的新样本(以形状矩阵组织)(?, 512)
), the flat_map
需要调用才能将所有生成的矩阵转换为Dataset
s 包含单个样本(因此tf.data.Dataset.from_tensor_slices(x)
在 lambda 中),然后将所有这些数据集扁平化为一个大数据集Dataset
包含单独的样本。
这可能是个好主意.shuffle()
该数据集或生成的样本将打包在一起。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)