有多种方法可以实现您想要的目标。我将尝试在这里勾勒出一种选择。
系统的总体视图是:你有n
Loader
异步加载数据并送入队列。然后该层读取batch_size
队列中的项目并送入网络forward()
功能。
import caffe, multiprocessing
class Loader(multiprocessing.Process):
def __init__(self, outq, *args, **kwargs):
super(Loader, self).__init__()
self.daemon = True
self.outq = outq
self.start() # start working
def run(self):
while True: # read and never stop at all!
try:
# do your magic here
# assuming you load x,y pairs
self.outq.put((x[None, ...], y[None, ...])) # add singleton "batch" dimension
except Exception as e:
# handle errors?
pass
class MultiProcessInputLayer(caffe.Layer):
def setup(self, bottom, top):
# verify no bottoms, right number of tops etc.
self.dataQ = multiprocessing.Queue()
for _ in xrange(n):
Loader(self.dataQ) # start n Loaders
# some other stuff here...
def reshape(self, bottom, top):
# reshape the inputs to the right sizes
def forward(self, bottom, top):
for i in xrange(batch_size):
item = self.dataQ.get()
top[0].data[i, ...] = item[0]
top[1].data[i, ...] = item[1]
def backward(self, top, propagate_down, bottom):
pass # no backward for data layer
我通过艰难的方式学到的一些提示和技巧:
1. Use multiprocessing
并不是threading
包因为GIL.
2. 有时(例如,如果batch_size
非常大)需要很长时间forward()
从队列中逐项读取以形成每个批次。在这种情况下,您可以添加另一个multiprocessing.Process
这将异步读取batch_size
物品来自self.dataQ
并将整批写入self.batchQ
. Then forward()
只会等待一个single项目来自self.batchQ
每次通话时。
3. 注意不要复制太多数据。使用大图像/标签可能会使所有这些复制成为瓶颈。