进行预测时 conv2d_transpose 取决于 batch_size

2024-01-07

我目前有一个在张量流中实现的神经网络,但我在训练后进行预测时遇到问题,因为我有 conv2d_transpose 操作,并且这些操作的形状取决于批量大小。我有一个需要 output_shape 作为参数的层:

def deconvLayer(input, filter_shape, output_shape, strides):
    W1_1 = weight_variable(filter_shape)

    output = tf.nn.conv2d_transpose(input, W1_1, output_shape, strides, padding="SAME")

    return output

这实际上用在我构建的一个更大的模型中,如下所示:

 conv3 = layers.convLayer(conv2['layer_output'], [3, 3, 64, 128], use_pool=False)

 conv4 = layers.deconvLayer(conv3['layer_output'],
                                    filter_shape=[2, 2, 64, 128],
                                    output_shape=[batch_size, 32, 40, 64],
                                    strides=[1, 2, 2, 1])

问题是,如果我使用经过训练的模型进行预测,我的测试数据必须具有相同的批量大小,否则我会收到以下错误。

tensorflow.python.framework.errors.InvalidArgumentError: Conv2DBackpropInput: input and out_backprop must have the same batch size

有什么方法可以预测具有可变批量大小的输入吗?当我查看训练后的权重时,似乎没有任何东西取决于批量大小,所以我不明白为什么这会成为问题。


所以我发现了一个基于张量流问题论坛的解决方案https://github.com/tensorflow/tensorflow/issues/833 https://github.com/tensorflow/tensorflow/issues/833.

在我的代码中

 conv4 = layers.deconvLayer(conv3['layer_output'],
                                    filter_shape=[2, 2, 64, 128],
                                    output_shape=[batch_size, 32, 40, 64],
                                    strides=[1, 2, 2, 1])

我在训练时传递给 deconvLayer 的输出形状是用预定的批量形状进行硬编码的。通过将其更改为以下内容:

def deconvLayer(input, filter_shape, output_shape, strides):
    W1_1 = weight_variable(filter_shape)

    dyn_input_shape = tf.shape(input)
    batch_size = dyn_input_shape[0]

    output_shape = tf.pack([batch_size, output_shape[1], output_shape[2], output_shape[3]])

    output = tf.nn.conv2d_transpose(input, W1_1, output_shape, strides, padding="SAME")

    return output

这允许在运行时动态推断形状,并且可以处理可变的批量大小。

运行代码,在传递任何批量大小的测试数据时,我不再收到此错误。我认为这是必要的,因为转置运算的形状推断目前并不像普通卷积运算那么简单。因此,在正常的卷积运算中,我们通常会使用 None 作为批量大小,但我们必须提供一个形状,并且由于这可能会根据输入而变化,因此我们必须努力动态确定它。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

进行预测时 conv2d_transpose 取决于 batch_size 的相关文章

  • 如何指定聚类的距离函数?

    我想对给定距离的点进行聚类 奇怪的是 似乎 scipy 和 sklearn 聚类方法都不允许指定距离函数 例如 在sklearn cluster AgglomerativeClustering 我唯一可以做的就是输入一个亲和力矩阵 这将非常
  • Flask中使用的路由装饰器是如何工作的

    我熟悉 Python 装饰器的基础知识 但是我不明白这个用于 Flask 路由的特定装饰器是如何工作的 以下是 Flask 网站上的代码片段 from flask import Flask escape request app Flask
  • Tensorflow conv2d_transpose 大小错误“out_backprop 的行数与计算的不匹配”

    我正在张量流中创建一个卷积自动编码器 我得到了这个确切的错误 tensorflow python framework errors InvalidArgumentError Conv2DBackpropInput Number of row
  • Tkinter 菜单删除项

    如何删除任何菜单项 例如我想删除 播放 self menubar Menu self root self root config menu self menubar self filemenu2 Menu self menubar self
  • 从字典的元素创建 Pandas 数据框

    我正在尝试从字典创建一个 pandas 数据框 字典设置为 nvalues y1 1 2 3 4 y2 5 6 7 8 y3 a b c d 我希望数据框仅包含 y1 and y2 到目前为止我可以使用 df pd DataFrame fr
  • 数据框 - 平均列

    我在 pandas 中有以下数据框 Column 1 Column 2 Column3 Column 4 2 2 2 4 1 2 2 3 我正在创建一个数据框 其中包含第 1 列和第 2 列 第 3 列和第 4 列等的平均值 ColumnA
  • 从 pyspark.sql 中的列表创建数据框

    我完全陷入了有线的境地 现在我有一个清单li li example data map lambda x get labeled prediction w x collect print li type li 输出就像 0 0 59 0 0
  • 使用 Pytest 的参数化添加测试功能的描述

    当其中一个测试失败时 可以在测试正在测试的内容的参数化中添加描述 快速了解测试失败的原因 有时您不知道测试失败的原因 您必须查看代码 通过每个测试的描述 您就可以知道 例如 pytest mark parametrize num1 num2
  • 是否有一个包可以维护所有带有符号的货币列表?

    是否有一个 python 包提供所有 或相当完整 货币的列表与符号 如美元的 有优秀的pycountry 贪财的 https github com limist py moneyed and ccy http code google com
  • 如何使用 Homebrew 在 Mac 上安装 Python 2 和 3?

    我需要能够在 Python 2 和 3 之间来回切换 我如何使用 Homebrew 来做到这一点 因为我不想弄乱路径并陷入麻烦 现在我已经通过 Homebrew 安装了 2 7 我会用pyenv https github com yyuu
  • python 中的 h2o 框架子集

    如何在 python 中对 h2o 框架进行子集化 如果 x 是一个 df 并且 Origin 是一个变量 那么在 pandas 中我们通常可以通过以下方式进行子集化 x x Origin AAF 但使用 h2o 框架会出现以下错误 H2O
  • Python“非规范化”unicode 组合字符

    我正在寻找标准化 python 中的一些 unicode 文本 我想知道是否有一种简单的方法可以在 python 中获得组合 unicode 字符的 非规范化 形式 例如如果我有序列u o xaf i e latin small lette
  • 使用标签或 href 传递 Django 数据

    我有一个包含链接的表 当单击该链接进行更多操作时 我想将一些数据传递给我的函数 my html table tbody for query in queries tr td value a href internal my func que
  • 如何从列表类别中对 pandas 数据框进行排序?

    所以我在下面有这个数据集 我想根据我的列表从 名称 列进行排序 以及按 A 升序和按 B 降序排序 import pandas as pd import numpy as np df1 pd DataFrame from items A 1
  • 基于值而不是类型的单次调度

    我在 Django 上构建 SPA 并且有一个庞大的功能 其中包含许多功能if用于检查我的对象字段的状态名称的语句 像这样 if self state new do some logic if self state archive do s
  • 在 numpy 中连接维度

    我有x 1 2 3 4 5 6 7 8 9 10 11 12 shape 2 2 3 I want 1 2 3 4 5 6 7 8 9 10 11 12 shape 2 6 也就是说 我想连接中间维度的所有项目 在这种特殊情况下我可以得到这
  • 高效创建抗锯齿圆形蒙版

    我正在尝试创建抗锯齿 加权而不是布尔 圆形掩模 以制作用于卷积的圆形内核 radius 3 no of pixels to be 1 on either side of the center pixel shall be decimal a
  • Jupyter Notebook:带有小部件的交互式绘图

    我正在尝试生成一个依赖于小部件的交互式绘图 我遇到的问题是 当我使用滑块更改参数时 会在前一个绘图之后完成一个新绘图 而我预计只有一个绘图会根据参数发生变化 Example from ipywidgets import interact i
  • Python 中的 Unix cat 函数 (cat * > merged.txt)? [复制]

    这个问题在这里已经有答案了 一旦建立了目录 有没有办法在Python中使用Unix中的cat函数或类似的函数 我想将 files 1 3 合并到 merged txt 我通常会在 Unix 中找到该目录 然后运行 cat gt merged
  • 如何通过点击复制 folium 地图上的标记位置?

    I am able to print the location of a given marker on the map using folium plugins MousePosition class GeoMap def update

随机推荐