假设您使用的是最新的 Tensorflow(撰写本文时为 1.4),您可以保留生成器并使用tf.data.* https://www.tensorflow.org/api_docs/python/tf/dataAPI如下(我为线程数、预取缓冲区大小、批量大小和输出数据类型选择任意值):
NUM_THREADS = 5
sceneGen = SceneGenerator()
dataset = tf.data.Dataset.from_generator(sceneGen.generate_data, output_types=(tf.float32, tf.int32))
dataset = dataset.map(lambda x,y : (x,y), num_parallel_calls=NUM_THREADS).prefetch(buffer_size=1000)
dataset = dataset.batch(42)
X, y = dataset.make_one_shot_iterator().get_next()
为了表明它实际上是从生成器中提取的多个线程,我修改了您的类,如下所示:
import threading
class SceneGenerator(object):
def __init__(self):
# some inits
pass
def generate_data(self):
"""
Generator. Yield data X and labels y after some preprocessing
"""
while True:
# opening files, selecting data
X,y = threading.get_ident(), 2 #self.preprocess(some_params, filenames, ...)
yield X, y
这样,创建一个 Tensorflow 会话并获取一批即可显示获取数据的线程的线程 ID。在我的电脑上,运行:
sess = tf.Session()
print(sess.run([X, y]))
prints
[array([ 8460., 8460., 8460., 15912., 16200., 16200., 8460.,
15912., 16200., 8460., 15912., 16200., 16200., 8460.,
15912., 15912., 8460., 8460., 6552., 15912., 15912.,
8460., 8460., 15912., 9956., 16200., 9956., 16200.,
15912., 15912., 9956., 16200., 15912., 16200., 16200.,
16200., 6552., 16200., 16200., 9956., 6552., 6552.], dtype=float32),
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])]
Note:您可能想尝试删除map
调用(我们仅用于多线程)并检查是否prefetch
的缓冲区足以消除输入管道中的瓶颈(即使只有一个线程,输入预处理通常比实际图形执行速度更快,因此缓冲区足以使预处理尽可能快地进行)。