TF 数据 API:如何有效地从图像中采样小块

2024-04-23

考虑创建从高分辨率图像目录中采样随机小图像块的数据集的问题。 Tensorflow 数据集 API 提供了一种非常简单的方法来实现此目的,即构建图像名称的数据集,对它们进行排序,将其映射到加载的图像,然后映射到随机裁剪的补丁。

然而,这种幼稚的实现效率非常低,因为将加载并裁剪单独的高分辨率图像以生成每个补丁。理想情况下,图像可以加载一次并重复使用以生成许多补丁。

前面讨论的一种简单方法是从图像生成多个补丁并将它们展平。然而,这会带来数据偏差太大的不幸影响。我们希望每个训练批次都来自不同的图像。

理想情况下,我想要的是一个“随机缓存过滤器”转换,它采用底层数据集并将其 N 个元素缓存到内存中。它的迭代器将从缓存中返回一个随机元素。此外,它还会以预定义的频率将缓存中的随机元素替换为基础数据集中的新元素。该过滤器将允许更快的数据访问,但代价是更少的随机化和更高的内存消耗。

有这样的功能可用吗?

如果不是,是否应该将其实现为新的数据集转换或只是一个新的迭代器?看来一个新的迭代器就足够了。关于如何创建新的数据集迭代器(最好是用 C++)的任何指示?


你应该能够使用tf.data.Dataset.shuffle https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle实现你想要的。以下是目标的快速摘要:

  • 加载非常大的图像,从图像中生成较小的随机裁剪并将它们批处理在一起
  • 加载图像后,通过从大图像创建多个补丁来提高管道效率
  • 添加足够的随机播放,使得一批补丁是多样化的(所有补丁都来自不同的图像)
  • 不要在缓存中加载太多大图像

您可以使用以下方法实现所有这些tf.dataAPI 通过执行以下步骤:

  1. 打乱大图像的文件名
  2. 阅读大图
  3. 从此图像生成多个补丁
  4. 再次用足够大的缓冲区大小重新打乱所有这些补丁(请参阅这个答案 https://stackoverflow.com/a/48096625/5098368缓冲区大小)。调整缓冲区大小是良好洗牌和缓存补丁大小之间的权衡
  5. 批处理它们
  6. 预取一批

这是相关代码:

filenames = ...  # filenames containing the big images
num_samples = len(filenames)

# Parameters
num_patches = 100               # number of patches to extract from each image
patch_size = 32                 # size of the patches
buffer_size = 50 * num_patches  # shuffle patches from 50 different big images
num_parallel_calls = 4          # number of threads
batch_size = 10                 # size of the batch

get_patches_fn = lambda image: get_patches(image, num_patches=num_patches, patch_size=patch_size)

# Create a Dataset serving batches of random patches in our images
dataset = (tf.data.Dataset.from_tensor_slices(filenames)
    .shuffle(buffer_size=num_samples)  # step 1: all the  filenames into the buffer ensures good shuffling
    .map(parse_fn, num_parallel_calls=num_parallel_calls)  # step 2
    .map(get_patches_fn, num_parallel_calls=num_parallel_calls)  # step 3
    .apply(tf.contrib.data.unbatch())  # unbatch the patches we just produced
    .shuffle(buffer_size=buffer_size)  # step 4
    .batch(batch_size)  # step 5
    .prefetch(1)  # step 6: make sure you always have one batch ready to serve
)

iterator = dataset.make_one_shot_iterator()
patches = iterator.get_next()  # shape [None, patch_size, patch_size, 3]


sess = tf.Session()
res = sess.run(patches)

功能parse_fn and get_patches定义如下:

def parse_fn(filename):
    """Decode the jpeg image from the filename and convert to [0, 1]."""
    image_string = tf.read_file(filename)

    # Don't use tf.image.decode_image, or the output shape will be undefined
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)

    # This will convert to float values in [0, 1]
    image = tf.image.convert_image_dtype(image_decoded, tf.float32)

    return image


def get_patches(image, num_patches=100, patch_size=16):
    """Get `num_patches` random crops from the image"""
    patches = []
    for i in range(num_patches):
        patch = tf.image.random_crop(image, [patch_size, patch_size, 3])
        patches.append(patch)

    patches = tf.stack(patches)
    assert patches.get_shape().dims == [num_patches, patch_size, patch_size, 3]

    return patches
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TF 数据 API:如何有效地从图像中采样小块 的相关文章

随机推荐