你应该能够使用tf.data.Dataset.shuffle https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle实现你想要的。以下是目标的快速摘要:
- 加载非常大的图像,从图像中生成较小的随机裁剪并将它们批处理在一起
- 加载图像后,通过从大图像创建多个补丁来提高管道效率
- 添加足够的随机播放,使得一批补丁是多样化的(所有补丁都来自不同的图像)
- 不要在缓存中加载太多大图像
您可以使用以下方法实现所有这些tf.data
API 通过执行以下步骤:
- 打乱大图像的文件名
- 阅读大图
- 从此图像生成多个补丁
- 再次用足够大的缓冲区大小重新打乱所有这些补丁(请参阅这个答案 https://stackoverflow.com/a/48096625/5098368缓冲区大小)。调整缓冲区大小是良好洗牌和缓存补丁大小之间的权衡
- 批处理它们
- 预取一批
这是相关代码:
filenames = ... # filenames containing the big images
num_samples = len(filenames)
# Parameters
num_patches = 100 # number of patches to extract from each image
patch_size = 32 # size of the patches
buffer_size = 50 * num_patches # shuffle patches from 50 different big images
num_parallel_calls = 4 # number of threads
batch_size = 10 # size of the batch
get_patches_fn = lambda image: get_patches(image, num_patches=num_patches, patch_size=patch_size)
# Create a Dataset serving batches of random patches in our images
dataset = (tf.data.Dataset.from_tensor_slices(filenames)
.shuffle(buffer_size=num_samples) # step 1: all the filenames into the buffer ensures good shuffling
.map(parse_fn, num_parallel_calls=num_parallel_calls) # step 2
.map(get_patches_fn, num_parallel_calls=num_parallel_calls) # step 3
.apply(tf.contrib.data.unbatch()) # unbatch the patches we just produced
.shuffle(buffer_size=buffer_size) # step 4
.batch(batch_size) # step 5
.prefetch(1) # step 6: make sure you always have one batch ready to serve
)
iterator = dataset.make_one_shot_iterator()
patches = iterator.get_next() # shape [None, patch_size, patch_size, 3]
sess = tf.Session()
res = sess.run(patches)
功能parse_fn
and get_patches
定义如下:
def parse_fn(filename):
"""Decode the jpeg image from the filename and convert to [0, 1]."""
image_string = tf.read_file(filename)
# Don't use tf.image.decode_image, or the output shape will be undefined
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
# This will convert to float values in [0, 1]
image = tf.image.convert_image_dtype(image_decoded, tf.float32)
return image
def get_patches(image, num_patches=100, patch_size=16):
"""Get `num_patches` random crops from the image"""
patches = []
for i in range(num_patches):
patch = tf.image.random_crop(image, [patch_size, patch_size, 3])
patches.append(patch)
patches = tf.stack(patches)
assert patches.get_shape().dims == [num_patches, patch_size, patch_size, 3]
return patches