TensorFlow 中的高效图像膨胀

2024-01-03

我正在寻找一种有效的实施方式形态学图像膨胀 https://en.wikipedia.org/wiki/Dilation_(morphology)在 TensorFlow 中使用方形内核。正如 OpenCV 所示,与实际效果相比,显而易见的方法似乎效率极低。查看粘贴在底部的运行源代码的结果 - 即使最快的方法也比 OpenCV 慢 30 倍左右。这些来自配备 M1 芯片组的 MacBook Air。

Dilation of 640x480 image with a 25x25 kernel took: 
  0.61ms using opencv
  545.40ms using tf.nn.max_pool2d
  228.66ms using tf.nn.dilation2d naively
  17.63ms using tf.nn.dilation2d with row-col

Question:有谁知道一种使用 TensorFlow 进行图像膨胀的方法,而且效率不是极低?

当前解决方案的源代码:

import numpy as np
import cv2
import tensorflow as tf
import time


def tf_dilate(heatmap, width: int, method: str = 'rowcol'):
    """ Dilate the heatmap with a square kernel """
    if method=='maxpool':
        return tf.nn.max_pool2d(heatmap[None, :, :, None], ksize=width, padding='SAME', strides=(1, 1))[0, :, :, 0]
    elif method == 'naive_dilate':
        return tf.nn.dilation2d(heatmap[None, :, :, None], filters=tf.zeros((width, width, 1), dtype=heatmap.dtype),
                                        strides=(1, 1, 1, 1), padding="SAME", data_format="NHWC", dilations=(1, 1, 1, 1))[0, :, :, 0]
    elif method == 'rowcol_dilate':

        row_dilation = tf.nn.dilation2d(heatmap[None, :, :, None], filters=tf.zeros((1, width, 1), dtype=heatmap.dtype),
                                        strides=(1, 1, 1, 1), padding="SAME", data_format="NHWC", dilations=(1, 1, 1, 1))
        full_dilation = tf.nn.dilation2d(row_dilation, filters=tf.zeros((width, 1, 1), dtype=heatmap.dtype),
                                         strides=(1, 1, 1, 1), padding="SAME", data_format="NHWC", dilations=(1, 1, 1, 1))
        return full_dilation[0, :, :, 0]
    else:
        raise NotImplementedError(f'No method {method}')


def test_dilation_options(img_shape=(480, 640), kernel_size=25):

    img = np.random.randn(*img_shape).astype(np.float32)**2

    def get_result_and_time(version: str):

        tf_image = tf.constant(img, dtype=tf.float32)
        t_start = time.time()
        if version=='opencv':
            result = cv2.dilate(img, kernel=np.ones((kernel_size, kernel_size), dtype=np.float32))
            return time.time()-t_start, result
        else:
            result = tf_dilate(tf_image, width=kernel_size, method=version)
            return time.time()-t_start, result.numpy()

    t_opencv, result_opencv = get_result_and_time('opencv')
    t_maxpool, result_maxpool = get_result_and_time('maxpool')
    t_naive_dilate, result_naive_dilate = get_result_and_time('naive_dilate')
    t_rowcol_dilate, result_rowcol_dilate = get_result_and_time('rowcol_dilate')
    assert np.array_equal(result_opencv, result_maxpool), "Maxpool result did not match opencv result"
    assert np.array_equal(result_opencv, result_naive_dilate), "Naive dilation result did not match opencv result"
    assert np.array_equal(result_opencv, result_rowcol_dilate), "Row-col dilation result did not match opencv result"
    print(f'Dilation of {img_shape[1]}x{img_shape[0]} image with a {kernel_size}x{kernel_size} kernel took: '
          f'\n  {t_opencv*1000:.2f}ms using opencv'
          f'\n  {t_maxpool*1000:.2f}ms using tf.nn.max_pool2d'
          f'\n  {t_naive_dilate*1000:.2f}ms using tf.nn.dilation2d naively'
          f'\n  {t_rowcol_dilate*1000:.2f}ms using tf.nn.dilation2d with row-col'
          )


if __name__ == '__main__':
    test_dilation_options()

好吧,如果你没问题的话近似解决方案中,总是存在“穷人的扩张”,它使用加权局部平均值(盒式滤波器)来近似扩张,其中通过对图像求幂来获取权重。它是O((H+K)*(W+K)) where W,H是图像的宽度、高度和K是内核大小。

它还具有以下优点:梯度不仅流过局部最大值,还流过竞争者直至抛出。

参见代码:

TensorImage = NewType('TensorImage', tf.Tensor)  # A (height, width, n_colors) uint8 image
TensorFloatImage = NewType('TensorFloatImage', tf.Tensor)
TensorHeatmap = NewType('TensorHeatmap', tf.Tensor)  # A (height, width) heatmap

def tf_box_filter(image: Union[TensorImage, TensorFloatImage, TensorHeatmap], width: int, normalize: bool = True, weights: Optional[TensorHeatmap] = None,
                  weight_eps: float = 1e-6, norm_weights: bool = True):
    image = tf.cast(image, tf.float32) if image.dtype != tf.float64 else image
    if weights is not None:
        if norm_weights:
            weights = weights/(width**2)
        if len(image.shape) == 3:
            weights = weights[:, :, None]  # Lets us broadcast weights against image

        image = image * weights

    lwidth = width // 2 + 1
    rwidth = width - lwidth

    integral_image_container = tf.pad(image,
                                      paddings=[(lwidth, rwidth), (lwidth, rwidth)] + [(0, 0)] * (len(image.shape) - 2))
    integral_image_container = tf.cumsum(tf.cumsum(integral_image_container, axis=0), axis=1)
    box_image = integral_image_container[width:, width:] \
                - integral_image_container[width:, :-width] \
                - integral_image_container[:-width, width:] \
                + integral_image_container[:-width, :-width]

    if not normalize:
        return box_image if (weights is None or not norm_weights) else box_image*(width**2)
    elif weights is None:
        return box_image / (width ** 2)
    else:
        box_weights = tf_box_filter(weights, width=width, normalize=False)
        return (box_image + weight_eps) / (box_weights + weight_eps)


def tf_poor_mans_dilate(heatmap: TensorHeatmap, width: int, power: int = 4, cast_to_64 = False) -> TensorHeatmap:
    """ A 'poor man's' version of dilation, whise runtime is O((image_height+kernel_width), (image_width+kernel_width))"""
    if cast_to_64:
        heatmap = tf.cast(heatmap, tf.float64)
    return tf_box_filter(heatmap, width, weights=heatmap**power, weight_eps=1e-9)


测试表明它比问题中的解决方案快大约 3 倍(当内核很大时速度更快)。


def test_poor_mans_dilate(show=False):
    """ Can be faster for large images and kernels

    Dilating image of shape (1280, 720) with kernel of shape 40x40
        Real Dilate: Elapsed time is 0.09009s
        Poor Man's Dilate: Elapsed time is 0.02953s

    Dilating image of shape (640, 480) with kernel of shape 40x40
        Real Dilate: Elapsed time is 0.03089s
        Poor Man's Dilate: Elapsed time is 0.008736s

    Dilating image of shape (640, 480) with kernel of shape 20x20
        Real Dilate: Elapsed time is 0.01475s
        Poor Man's Dilate: Elapsed time is 0.009809s
    """
    img = tf.random.Generator.from_seed(1234).normal((640, 480))**4
    width = 20
    print(f'Dilating image of shape {img.shape} with kernel of shape {width}x{width}')
    with profile_context('Real Dilate', print_result=True):
        dil_img = tf_dilate(img, width=width)
    with profile_context("Poor Man's Dilate", print_result=True):
        poor_dil_img = tf_poor_mans_dilate(img, width=width)

    assert np.allclose(dil_img.numpy().max(), poor_dil_img.numpy().max(), rtol=0.001)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow 中的高效图像膨胀 的相关文章

随机推荐

  • 后台进程的 cy.exec 超时

    我正在尝试使用启动服务器cy exec并像这样后台处理 cy exec nohup python m my module arg 1 failOnNonZeroExit false then result gt if result code
  • 如何防止密码和其他敏感信息出现在 ASP.NET 转储中?

    如何防止在 IIS ASP NET 转储文件中向 ASP NET 网页提交和接收密码和其他敏感数据 重现步骤 使用 Visual Studio 2010 创建 ASP NET MVC 3 Intranet 应用程序 将其配置为使用 IIS
  • Spring嵌套事务

    在我的 Spring Boot 项目中 我实现了以下服务方法 Transactional public boolean validateBoard Board board boolean result false if inProgress
  • 更新更改 svn 时出错

    我安装了 PHPStorm 并使用 SVN 打开包含 PHP 项目的目录 在 更改 的 SVN 选项卡下 我遇到以下错误 Error updating changes svn E155021 The client is too old to
  • Spring JPA Repository - 在服务器重启时保留数据

    我目前正在尝试学习如何使用 Spring Boot 但遇到一个问题 我不确定如何解决 我已经按照使用 JPA 访问数据 http spring io guides gs accessing data jpa 指导 一切正常 但是 如果我重新
  • Pandas 和 Matplotlib - fill_ Between() 与 datetime64

    有一个 Pandas 数据框
  • ggplot 中的热图,每组不同的颜色

    我正在尝试在 ggplot 中生成热图 我希望每个组都有不同的颜色渐变 但不知道该怎么做 我当前的代码如下所示 dummy data data lt data frame group sample c Direct Patient Care
  • OL3:强制重绘图层

    我目前正在将 OpenLayers 客户端版本 2 13 1 升级为新版本的 OpenLayers OL3 我的设置包括作为 WMS 映射服务器的 Mapserver 和前面提到的 OpenLayers 客户端 在旧系统中 我支持用户交互
  • R 中百分比格式表

    我想获取一个百分比表 将值格式化为百分比并以良好的格式显示它们 如果重要的话 我正在使用 RStudio 并编织为 PDF 我看过其他关于此的帖子 但它们看起来都不干净 而且效果不佳 例如 下面的 apply 语句确实采用百分比格式 但是
  • 检索两个字符之间的子字符串

    我有这样的字符串 var str it itA itB et etA etB etC etD 如何检索 和 之间的元素 截至目前 我正在用新行分割文本 但无法解决这个问题 请帮我解决这个问题 请使用这个小提琴http jsfiddle ne
  • IronPython - JSON 选择

    在 IronPython 2 0 1 中处理 JSON 的最佳方法是什么 原生 Python 标准库 json 看起来尚未实现 如果我想使用 Newtonsoft Json NET 库 我该怎么做 我可以将程序集添加到 GAC 但我还有其他
  • 如何使用 php 渲染远程图像?

    这是一个 jpg https i stack imgur com PIFN0 jpg 假设我希望这个渲染自 img php file name PIFN0 jpg 以下是我尝试完成这项工作的方法 样本 php p Here s my ima
  • UICollectionView 启用取消选择单元格,同时禁用 allowedMultipleSelection

    When collectionView allowsMultipleSelection YES 我可以取消选择已选择的单元格 when collectionView allowsMultipleSelection NO 我无法取消选择已选择
  • Fortran 中不提升数组的标量参数

    为什么 Fortran 会将标量表达式提升为数组表达 但不作为过程的参数 特别是 为什么标准机构做出这样的设计决定 仅仅是因为含糊不清 程序就应该超载吗 在这种情况下 错误消息是否可以作为替代方法 例如 在下面的代码中 最后一条语句 x f
  • Jsoup,在执行表单POST之前获取值

    这是我用来提交表单的代码 Connection Response res Jsoup connect http example com data id myID data username myUsername data code MyAu
  • iPhone:cocos2d 中相机跟随玩家

    我正在用 cocos2d 制作 iPhone 游戏 我想知道如何使相机 视图遵循特定的精灵 我会使用 CCCamera 类吗 是的 CCCamera 可以工作 然而 它有一些缺点 使其不适合某些用途 相对于该精灵移动图层以及所有其他对象可能
  • 在 StructureMap 中注册一个默认实例

    我有一堂课 MyService 具有静态属性 MyService Context 代表当前上下文 特定于当前登录的用户 因此它会发生变化 我想要实现的目标 ObjectFactory Initialize x gt x For
  • 在 WPF 中,我们如何将 Duration 定义为资源?

    我在许多动画中使用了一个持续时间 0 0 0 5 并且我想仅在一个位置定义该数字 我可以将双精度定义为
  • 在 Win32 API 中绘制格式化文本的最快方法是什么?

    我正在使用普通 Win32 API 在 C 中实现一个文本编辑器 并且我正在尝试找到实现语法突出显示的最佳方法 我知道有像 scintilla 这样的现有控件 但我这样做是为了好玩 所以我想自己完成大部分工作 我还希望它又快又轻 从我到目前
  • TensorFlow 中的高效图像膨胀

    我正在寻找一种有效的实施方式形态学图像膨胀 https en wikipedia org wiki Dilation morphology 在 TensorFlow 中使用方形内核 正如 OpenCV 所示 与实际效果相比 显而易见的方法似