不了解类 UNET 架构中的数据流,并且 Conv2DTranspose 层的输出存在问题

2024-04-06

我对修改后的 U-Net 架构的输入维度有一两个问题。为了节省您的时间并更好地理解/重现我的结果,我将发布代码和输出尺寸。修改后的U-Net架构是来自的MultiResUNet架构https://github.com/nibtehaz/MultiResUNet/blob/master/MultiResUNet.py https://github.com/nibtehaz/MultiResUNet/blob/master/MultiResUNet.py。并基于本文https://arxiv.org/abs/1902.04049 https://arxiv.org/abs/1902.04049请不要因为这段代码的长度而关闭。您只需复制粘贴即可,重现我的结果的时间不会超过 10 秒。此外,您不需要为此提供数据集。使用 TF.v1.9 Keras v.2.20 进行测试。

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, BatchNormalization, Activation, add
from tensorflow.keras.models import Model
from tensorflow.keras.activations import relu 

###{ 2D Convolutional layers

   # Arguments: ######################################################################
   #     x {keras layer} -- input layer                                   #
   #     filters {int} -- number of filters                                        #
   #     num_row {int} -- number of rows in filters                               #
   #     num_col {int} -- number of columns in filters                           #

    # Keyword Arguments:
   #     padding {str} -- mode of padding (default: {'same'})
  #      strides {tuple} -- stride of convolution operation (default: {(1, 1)})
 #       activation {str} -- activation function (default: {'relu'})
#        name {str} -- name of the layer (default: {None})

  #  Returns:
  #          [keras layer] -- [output layer]}

      # # ############################################################################


def conv2d_bn(x, filters ,num_row,num_col, padding = "same", strides = (1,1), activation = 'relu', name = None):

    x = Conv2D(filters,(num_row, num_col), strides=strides, padding=padding, use_bias=False)(x)
    x = BatchNormalization(axis=3, scale=False)(x)
    if(activation == None):
        return x
    x = Activation(activation, name=name)(x)

    return x

# our 2D transposed Convolution with batch normalization

 # 2D Transposed Convolutional layers

 #   Arguments:      #############################################################
 #       x {keras layer} -- input layer                                         #
 #       filters {int} -- number of filters                                    #
 #       num_row {int} -- number of rows in filters                           #
 #       num_col {int} -- number of columns in filters

 #   Keyword Arguments:
 #       padding {str} -- mode of padding (default: {'same'})
 #       strides {tuple} -- stride of convolution operation (default: {(2, 2)}) 
 #       name {str} -- name of the layer (default: {None})

  #  Returns:
  #      [keras layer] -- [output layer] ###################################

def trans_conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(2, 2), name=None): 

    x = Conv2DTranspose(filters, (num_row, num_col), strides=strides, padding=padding)(x)
    x = BatchNormalization(axis=3, scale=False)(x)

    return x

# Our Multi-Res Block 

# Arguments: ############################################################
#        U {int} -- Number of filters in a corrsponding UNet stage     #
#        inp {keras layer} -- input layer                             #

#    Returns:                                                       #
#        [keras layer] -- [output layer]                           #
###################################################################

def MultiResBlock(U, inp, alpha = 1.67):

    W = alpha * U

    shortcut = inp

    shortcut = conv2d_bn(shortcut, int(W*0.167) + int(W*0.333) +
                         int(W*0.5), 1, 1, activation=None, padding='same')

    conv3x3 = conv2d_bn(inp, int(W*0.167), 3, 3,
                        activation='relu', padding='same')

    conv5x5 = conv2d_bn(conv3x3, int(W*0.333), 3, 3,
                        activation='relu', padding='same')

    conv7x7 = conv2d_bn(conv5x5, int(W*0.5), 3, 3,
                        activation='relu', padding='same')

    out = concatenate([conv3x3, conv5x5, conv7x7], axis=3)
    out = BatchNormalization(axis=3)(out)

    out = add([shortcut, out])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    return out

# Our ResPath:
# ResPath

#    Arguments:#######################################
#        filters {int} -- [description]
#        length {int} -- length of ResPath
#        inp {keras layer} -- input layer 

#    Returns:
#        [keras layer] -- [output layer]#############



def ResPath(filters, length, inp):
    shortcut = inp
    shortcut = conv2d_bn(shortcut, filters, 1, 1,
                         activation=None, padding='same')

    out = conv2d_bn(inp, filters, 3, 3, activation='relu', padding='same')

    out = add([shortcut, out])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    for i in range(length-1):

        shortcut = out
        shortcut = conv2d_bn(shortcut, filters, 1, 1,
                             activation=None, padding='same')

        out = conv2d_bn(out, filters, 3, 3, activation='relu', padding='same')

        out = add([shortcut, out])
        out = Activation('relu')(out)
        out = BatchNormalization(axis=3)(out)

    return out



#    MultiResUNet

#    Arguments: ############################################
#        height {int} -- height of image 
#        width {int} -- width of image 
#        n_channels {int} -- number of channels in image

#    Returns:
#        [keras model] -- MultiResUNet model###############




def MultiResUnet(height, width, n_channels):



    inputs = Input((height, width, n_channels))

    # downsampling part begins here 

    mresblock1 = MultiResBlock(32, inputs)
    pool1 = MaxPooling2D(pool_size=(2, 2))(mresblock1)
    mresblock1 = ResPath(32, 4, mresblock1)

    mresblock2 = MultiResBlock(32*2, pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(mresblock2)
    mresblock2 = ResPath(32*2, 3, mresblock2)

    mresblock3 = MultiResBlock(32*4, pool2)
    pool3 = MaxPooling2D(pool_size=(2, 2))(mresblock3)
    mresblock3 = ResPath(32*4, 2, mresblock3)

    mresblock4 = MultiResBlock(32*8, pool3)


    # Upsampling part 

    up5 = concatenate([Conv2DTranspose(
        32*4, (2, 2), strides=(2, 2), padding='same')(mresblock4), mresblock3], axis=3)
    mresblock5 = MultiResBlock(32*8, up5)

    up6 = concatenate([Conv2DTranspose(
        32*4, (2, 2), strides=(2, 2), padding='same')(mresblock5), mresblock2], axis=3)
    mresblock6 = MultiResBlock(32*4, up6)

    up7 = concatenate([Conv2DTranspose(
        32*2, (2, 2), strides=(2, 2), padding='same')(mresblock6), mresblock1], axis=3)
    mresblock7 = MultiResBlock(32*2, up7)


    conv8 = conv2d_bn(mresblock7, 1, 1, 1, activation='sigmoid')

    model = Model(inputs=[inputs], outputs=[conv8])

    return model

现在回到 UNet 架构中输入/输出维度不匹配的问题。

如果我选择过滤器高度/宽度 (128,128) 或 (256,256) 或 (512,512) 并执行以下操作:

 model = MultiResUnet(128, 128,3)
 display(model.summary()) 

Tensorflow 为我提供了整个架构的完美结果。现在如果我这样做

     model = MultiResUnet(36, 36,3)
     display(model.summary()) 

我收到此错误:

-------------------------------------------------- -------------------------- ValueError Traceback(最近调用 最后)在 ----> 1 模型 = MultiResUnet(36, 36,3) 2 显示(model.summary())

在 MultiResUnet 中(高度、宽度、 n_通道) 25 26 up5 = 连接([Conv2DTranspose( ---> 27 32*4, (2, 2), 步长=(2, 2), 填充='相同')(mresblock4), mresblock3], 轴=3) 28 mresblock5 = MultiResBlock(32*8, up5) 29

〜/miniconda3/envs/MastersTheenv/lib/python3.6/site-packages/tensorflow/python/keras/layers/merge.py 连接(输入,轴,**kwargs) 第682章 一个张量,输入沿轴的串联axis。 第683章 --> 684 返回连接(轴=轴,**kwargs)(输入) 第685章 第686章

〜/miniconda3/envs/MastersTheenv/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py 在call(自身、输入、*args、**kwargs) 第694章 第695章 --> 696 self.build(input_shapes) 第697章 第698章输入形状。

〜/miniconda3/envs/MastersTheenv/lib/python3.6/site-packages/tensorflow/python/keras/utils/tf_utils.py 在包装器中(实例,input_shape) 146 其他: [第 147 章] --> 148 输出形状 = fn(实例, 输入形状) 149如果output_shape不是None: 150 if isinstance(输出形状,列表):

〜/miniconda3/envs/MastersTheenv/lib/python3.6/site-packages/tensorflow/python/keras/layers/merge.py 在构建(自身,input_shape)中 第388章 [第 389 章] ' --> 390 '获得输入形状:%s' % (input_shape)) 第391章 第392章

值错误:AConcatenate层需要具有匹配形状的输入 除了连续轴之外。获得输入形状:[(None, 8, 8, 128), (无、9、9、128)]

为什么 Conv2DTranspose 给我错误的尺寸

(无、8、8、128)

代替

(无、9、9、128)

为什么当我选择 (128,128)、(256,256) 等过滤器大小(32 的倍数)时 Concat 函数不会抱怨因此,为了概括这个问题,我怎样才能使这个 UNet 架构适用于任何过滤器大小我该如何处理 Conv2DTranspose 层,生成比实际需要的尺寸少一维(宽度/高度)的输出(当过滤器大小不是 32 的倍数或不对称时),为什么不这样做如果其他过滤器尺寸是 32 的倍数,则会发生这种情况。如果我有的话会怎样可变输入尺寸 ??

任何帮助将不胜感激。

干杯, H


U-Net 系列模型(例如上面的 MultiResUNet 模型)遵循编码器-解码器架构。Encoder是具有特征提取的下采样路径,而decoder一个上采样的。来自编码器的特征图是串联的在解码器处通过跳过连接。这些特征图在最后一个轴上连接起来,即'渠道'轴(考虑特征的尺寸 [batch_size、高度、宽度、通道])。现在,对于要在任何轴(在我们的例子中为“通道”轴)连接的特征,所有其他轴的尺寸must match.

在上述模型架构中,有3 下采样/最大池化正在执行的操作(通过MaxPooling2D)在编码器路径中。在解码器路径上3 上采样/转置转换执行操作,旨在将图像恢复到完整尺寸。然而,为了发生串联(通过跳过连接),下采样和上采样的特征维度高度、宽度和批量大小在模型的每个“级别”都应该保持相同。我将用您在问题中提到的示例来说明这一点:

1st case:输入尺寸(128,128,3):128 -> 64 -> 32 -> 16 -> 32 -> 64 -> 128

2nd case:输入尺寸(36,36,3): 36 -> 18 ->9 -> 4 -> 8-> 16 -> 32

在第二种情况下,当height and width特征图的数量达到9在编码器路径中,进一步下采样会导致尺寸变化(损失)无法在解码器中恢复上采样时。因此,由于无法连接维度的特征图,它会抛出错误[(无、8、8、128)] & [(无、9、9、128)].

一般来说,对于一个简单的编码器-解码器模型(具有跳过连接)具有'n' 下采样(MaxPooling2D) 层,输入维度必须是2^n 的倍数能够在解码器处连接模型的编码器特征。在这种情况下n=3,因此输入必须是8以免遇到这些尺寸不匹配错误。

希望这可以帮助! :)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

不了解类 UNET 架构中的数据流,并且 Conv2DTranspose 层的输出存在问题 的相关文章

随机推荐

  • 如何将unix时间戳转换为日期时间

    我正在尝试转换这个unix时间戳1415115303410在日期时间中 这样 private static DateTime UnixTimeStampToDateTime long unixTimeStamp System DateTim
  • 无法在詹金斯奴隶上运行 gradle

    我已经配置了一个 jenkins ubuntu 从机 我想在它上面运行我的 gradle 构建 使用 gradle 插件 问题是 当运行 jenkins 构建作业时 我得到 gradle no daemon info clean build
  • 如何在 VB.NET 中覆盖文本

    我曾经被教导如何使用以下代码附加文本文件 但是每次按下按钮一时如何覆盖该文件 没有人教我 Private Sub Button1 Click ByVal sender As System Object ByVal e As System E
  • 保存和恢复片段状态

    我有一系列的片段 我使用 上一个 和 下一个 按钮在该片段中进行导航 该片段中有许多编辑文本和单选按钮 当通过单击 上一个 按钮加载上一个片段时 我想保存和恢复这些编辑文本和单选按钮中的用户输入 截图 片段1 https i stack i
  • 用 Java 8 Streams 替换传统的 newForLoop

    因此 最终从 Java 6 到 Java 8 有了相对较大的跳跃 我阅读了大量的 Java 8 Streams API 不幸的是 几乎所有被问到的例子都几乎接近我试图弄清楚如何做的事情 但还不够接近 我拥有的是 final List
  • TextView textColor 中的数据绑定选择器

    我正在尝试根据频道中未读消息的数量从文本视图设置颜色 就像这样 android textColor channel unreadCount gt 0 color selector conversation row title unread
  • MediaEncodingProfile.CreateWmv 给出“未找到合适的转换来编码或解码内容。”错误

    我正在创建一个 Windows Phone 应用程序 XAML C 用于将音频和视频上传到服务器 在 Windows Phone 8 0 上使用 VideoCaptureDevice 效果很好 但它只允许设备支持的分辨率 在诺基亚 625
  • 使用 py2exe 隐藏 Python GUI 应用程序的控制台窗口

    我有一个使用 Qt 实际上是 PyQt4 的 Python 程序 当我从 main py 启动它时 我会得到一个控制台窗口和 GUI 窗口 当然 在 Windows 上 然后我用 py2exe 编译我的程序并成功创建 main exe 但是
  • 如何获得批号的可用数量

    如何获取多个仓库中批号的可用数量 假设我有3个仓库A B和C 批号 LOT0001 我想要所有三个位置的 LOT 0001 目前可用的总数量 在 odoo 中 您可以在上下文中传递过滤器 ex context lot id owner id
  • 导入错误:您必须是 root

    我尝试在 python 3 中使用键盘库 但仍然出现导入错误 我在 Thonny 的 Windows 中运行了该程序 它工作正常 但我无法在 pi 中运行它 我尝试以 root 身份运行它并使用 sudo 命令运行它 得到相同的结果 下面是
  • nhibernate 交替批量大小

    当使用 NHibernate 执行查询时 如果批处理大小设置为大于实际返回的结果 则似乎不考虑批处理大小 我正在使用最新版本的 NHibernate 2 1 0 4000 和 Linq to NHibernate 的 GA 我有一个类似于
  • 为什么在使用 Microsoft.Bcl - 无法等待'System.Threading.Tasks.Task 时,我不能在 Windows Phone 7.1 MvvmCross 项目中使用 wait 关键字?

    使用 Microsoft Bcl Microsoft BCL Portability Pack 时 我无法在 MvvmCross Windows Phone 7 1 项目中使用 wait 关键字 我已经发布了下面描述的示例项目的代码GitH
  • 只有创建视图层次结构的原始线程才能触摸其视图错误

    一切正常 除非到达代码的最后部分 注册成功 然后标题中提到的错误出现在registerDialog消息部分中 我做错了什么吗 谁能帮我检查我的代码 非常感谢 该应用程序没有崩溃 尽管它只是退出回到应用程序主页 如果我再次按下注册按钮 它将返
  • 如何根据周对 pandas 数据框进行分区并保存为 CSV?

    我有一个熊猫数据框 如下所示 这个数据框大约一个月的时间段 如何根据周对该数据框进行分区 我需要每 4 周保存为 4 个单独的 CSV 文件 Time Stamp Id Latitude Longitude 01 10 2016 15 22
  • 使用 AngularJS ngTable 自定义过滤器

    我正在尝试使用 ngTable 构建一个表 但使用与中描述的不同的自定义过滤ngTable 页面的示例 http bazalt cms com ng table example 11 我希望进行适当的过滤 但我不希望 ngTable 呈现过
  • Cypress:在第一次失败时中断所有测试

    如何在第一次测试失败时中断所有赛普拉斯测试 我们使用信号量为每个 PR 与 Cypress 启动完整的 e2e 测试 但这需要太多时间 我想在第一次测试失败时中断所有测试 获取完整的错误是每个开发人员在开发时的职责 如果在部署之前出现任何问
  • 如何使用grep提取子字符串? [复制]

    这个问题在这里已经有答案了 可能的重复 从字符串中提取正则表达式结果并将其写入变量 https stackoverflow com questions 3148558 extract regexp result from string an
  • 生成字符串列表的所有组合

    我想生成一个字符串列表的所有可能组合的列表 它实际上是一个对象列表 但为了简单起见 我们将使用字符串 我需要这个列表 以便我可以在单元测试中测试每种可能的组合 例如 如果我有一个列表 var allValues new List
  • 在 C# 中使用派生返回类型覆盖抽象属性

    我有四节课 请求 派生请求 处理程序 派生处理程序 Handler 类有一个带有以下声明的属性 public abstract Request request get set DerivedHandler 需要重写此属性 以便它返回 Der
  • 不了解类 UNET 架构中的数据流,并且 Conv2DTranspose 层的输出存在问题

    我对修改后的 U Net 架构的输入维度有一两个问题 为了节省您的时间并更好地理解 重现我的结果 我将发布代码和输出尺寸 修改后的U Net架构是来自的MultiResUNet架构https github com nibtehaz Mult