InvalidArgumentError:loc 处需要可广播形状(未知)

2024-01-11

背景

我对 Python 和机器学习完全陌生。我只是尝试根据在互联网上找到的代码建立一个 UNet,并希望将其适应我正在处理的情况。当试图.fit将UNet训练数据,我收到以下错误:

InvalidArgumentError:  required broadcastable shapes at loc(unknown)
     [[node Equal (defined at <ipython-input-68-f1422c6f17bb>:1) ]] [Op:__inference_train_function_3847]

当我搜索它时,我得到了很多结果,但大多数都是不同的错误。

这是什么意思?更重要的是,我该如何解决这个问题?

导致错误的代码

该错误的上下文如下: 我想分割图像并标记不同的类别。 我为训练、测试和验证数据设置了目录“trn”、“tst”和“val”。这dir_dat()功能适用os.path.join()获取各自的完整路径data set https://www.dropbox.com/sh/g4yr4qjgjtsytmi/AABJXVehoCw5In9wQvL7Llu2a?dl=0。这 3 个文件夹中的每一个都有每个类的子目录,并用整数标记。每个文件夹里都有一些.tif各个类别的图像。

我定义了以下图像数据生成器(训练数据稀疏,因此增强):

classes = np.array([ 0,  2,  4,  6,  8, 11, 16, 21, 29, 30, 38, 39, 51])
bs = 15 # batch size

augGen = ks.preprocessing.image.ImageDataGenerator(rotation_range = 365,
                                                   width_shift_range = 0.05,
                                                   height_shift_range = 0.05,
                                                   horizontal_flip = True,
                                                   vertical_flip = True,
                                                   fill_mode = "nearest") \
    .flow_from_directory(directory = dir_dat("trn"),
                         classes = [str(x) for x in classes.tolist()],
                         class_mode = "categorical",
                         batch_size = bs, seed = 42)
    
tst_batches = ks.preprocessing.image.ImageDataGenerator() \
    .flow_from_directory(directory = dir_dat("tst"),
                         classes = [str(x) for x in classes.tolist()],
                         class_mode = "categorical",
                         batch_size = bs, shuffle = False)

val_batches = ks.preprocessing.image.ImageDataGenerator() \
    .flow_from_directory(directory = dir_dat("val"),
                         classes = [str(x) for x in classes.tolist()],
                         class_mode = "categorical",
                         batch_size = bs)

然后我根据以下内容设置了UNet这个例子 https://github.com/bnsreenu/python_for_microscopists/blob/master/074-Defining%20U-net%20in%20Python%20using%20Keras.py。在这里,我更改了一些参数以使 UNet 适应情况(多个类),即最后一层的激活和损失函数:

layer_in = ks.layers.Input(shape = (imgr, imgc, imgdim))
# convert pixel integer values to float
inVals = ks.layers.Lambda(lambda x: x / 255)(layer_in)

# Contraction path
c1 = ks.layers.Conv2D(16, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(inVals)
c1 = ks.layers.Dropout(0.1)(c1)
c1 = ks.layers.Conv2D(16, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(c1)
p1 = ks.layers.MaxPooling2D((2, 2))(c1)

c2 = ks.layers.Conv2D(32, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(p1)
c2 = ks.layers.Dropout(0.1)(c2)
c2 = ks.layers.Conv2D(32, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(c2)
p2 = ks.layers.MaxPooling2D((2, 2))(c2)
 
c3 = ks.layers.Conv2D(64, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(p2)
c3 = ks.layers.Dropout(0.2)(c3)
c3 = ks.layers.Conv2D(64, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(c3)
p3 = ks.layers.MaxPooling2D((2, 2))(c3)
 
c4 = ks.layers.Conv2D(128, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(p3)
c4 = ks.layers.Dropout(0.2)(c4)
c4 = ks.layers.Conv2D(128, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(c4)
p4 = ks.layers.MaxPooling2D(pool_size = (2, 2))(c4)
 
c5 = ks.layers.Conv2D(256, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(p4)
c5 = ks.layers.Dropout(0.3)(c5)
c5 = ks.layers.Conv2D(256, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(c5)

# Expansive path 
u6 = ks.layers.Conv2DTranspose(128, (2, 2), strides = (2, 2), padding = "same")(c5)
u6 = ks.layers.concatenate([u6, c4])
c6 = ks.layers.Conv2D(128, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(u6)
c6 = ks.layers.Dropout(0.2)(c6)
c6 = ks.layers.Conv2D(128, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(c6)
 
u7 = ks.layers.Conv2DTranspose(64, (2, 2), strides = (2, 2), padding = "same")(c6)
u7 = ks.layers.concatenate([u7, c3])
c7 = ks.layers.Conv2D(64, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(u7)
c7 = ks.layers.Dropout(0.2)(c7)
c7 = ks.layers.Conv2D(64, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(c7)
 
u8 = ks.layers.Conv2DTranspose(32, (2, 2), strides = (2, 2), padding = "same")(c7)
u8 = ks.layers.concatenate([u8, c2])
c8 = ks.layers.Conv2D(32, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(u8)
c8 = ks.layers.Dropout(0.1)(c8)
c8 = ks.layers.Conv2D(32, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(c8)
 
u9 = ks.layers.Conv2DTranspose(16, (2, 2), strides = (2, 2), padding = "same")(c8)
u9 = ks.layers.concatenate([u9, c1], axis = 3)
c9 = ks.layers.Conv2D(16, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(u9)
c9 = ks.layers.Dropout(0.1)(c9)
c9 = ks.layers.Conv2D(16, (3, 3), activation = "relu",
                            kernel_initializer = "he_normal", padding = "same")(c9)
 
out = ks.layers.Conv2D(1, (1, 1), activation = "softmax")(c9)
 
model = ks.Model(inputs = layer_in, outputs = out)
model.compile(optimizer = "adam", loss = "sparse_categorical_crossentropy", metrics = ["accuracy"])
model.summary()

最后,我定义了回调并运行了训练,这产生了错误:

cllbs = [
    ks.callbacks.EarlyStopping(patience = 4),
    ks.callbacks.ModelCheckpoint(dir_out("Checkpoint.h5"), save_best_only = True),
    ks.callbacks.TensorBoard(log_dir = './logs'),# log events for TensorBoard
    ]

model.fit(augGen, epochs = 5, validation_data = val_batches, callbacks = cllbs)

完整控制台输出

这是运行最后一行时的完整输出(如果它有助于解决问题):

trained = model.fit(augGen, epochs = 5, validation_data = val_batches, callbacks = cllbs)
Epoch 1/5
Traceback (most recent call last):

  File "<ipython-input-68-f1422c6f17bb>", line 1, in <module>
    trained = model.fit(augGen, epochs = 5, validation_data = val_batches, callbacks = cllbs)

  File "c:\users\manuel\python\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1183, in fit
    tmp_logs = self.train_function(iterator)

  File "c:\users\manuel\python\lib\site-packages\tensorflow\python\eager\def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)

  File "c:\users\manuel\python\lib\site-packages\tensorflow\python\eager\def_function.py", line 950, in _call
    return self._stateless_fn(*args, **kwds)

  File "c:\users\manuel\python\lib\site-packages\tensorflow\python\eager\function.py", line 3023, in __call__
    return graph_function._call_flat(

  File "c:\users\manuel\python\lib\site-packages\tensorflow\python\eager\function.py", line 1960, in _call_flat
    return self._build_call_outputs(self._inference_function.call(

  File "c:\users\manuel\python\lib\site-packages\tensorflow\python\eager\function.py", line 591, in call
    outputs = execute.execute(

  File "c:\users\manuel\python\lib\site-packages\tensorflow\python\eager\execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,

InvalidArgumentError:  required broadcastable shapes at loc(unknown)
     [[node Equal (defined at <ipython-input-68-f1422c6f17bb>:1) ]] [Op:__inference_train_function_3847]

Function call stack:
train_function

当类标签的数量与输出层的输出形状不匹配时,我遇到了这个问题。

例如,如果有 10 个类标签,并且我们将输出层定义为:

output = tf.keras.layers.Conv2D(5, (1, 1), activation = "softmax")(c9)

由于类标签的数量 (10) 不等于输出形状 (5)。 然后,我们就会得到这个错误。

确保类标签的数量与输出层的输出形状匹配。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

InvalidArgumentError:loc 处需要可广播形状(未知) 的相关文章

  • sudo 和 pip 不在同一路径上

    pip and sudo不在我的计算机上的同一路径上 因此当 基本上一直 我需要运行这两个命令时 如下所示 sudo pip install xxx I get sudo pip command not found pip下载软件包 但由于
  • 从 asyncio 子进程获取实时输出

    我正在尝试使用 Python asyncio 子进程来启动交互式 SSH 会话并自动输入密码 实际用例并不重要 但它有助于说明我的问题 这是我的代码 proc await asyncio create subprocess exec ssh
  • Pandas ParserError:标记数据时出错。 C 错误:字符串内有 EOF

    我的数据超过 400 000 行 运行此代码时 f pd read csv filename error bad lines False 我收到以下错误 pandas errors ParserError Error tokenizing
  • 使用 keras 澄清 Yolo v3 模型输出

    我将 yolo v3 模型与 keras 一起使用 该网络为我提供了形状如下的输出容器 1 13 13 255 1 26 26 255 1 52 52 255 所以我找到了这个link https www cyberailab com ho
  • 如何进行重定向并保留查询字符串?

    我想进行重定向并保留查询字符串 就像是self redirect加上发送的查询参数 那可能吗 newurl my new route urllib urlencode self request params self redirect ne
  • 如何用pygame画一条虚线?

    我需要在坐标系上绘制正弦波和余弦波 就像在this https i stack imgur com DGI8g png图片 除了没能代表以外 我所有的工作都做得很好虚线和曲线与 pygame 一致 我有与我需要的类似的东西 但我怎样才能让它
  • 清理 MongoDB 的输入

    我正在为 MongoDB 数据库程序编写 REST 接口 并尝试实现搜索功能 我想公开整个 MongoDB 接口 我确实有两个问题 但它们是相关的 所以我将它们放在一篇文章中 使用 Python json 模块解码不受信任的 JSON 是否
  • Python正则表达式替换引号中的文本(引号本身除外)

    例如 我有一个测试字符串 content I opened my mouth Good morning I said cheerfully 我想使用正则表达式删除双语音标记之间的文本 但不删除语音标记本身 所以它会返回 I opened m
  • Flask 和 Reactjs 抛出 JSX 转换错误

    我已经开始将 ReactJS 与 Python Flask 后端结合使用 通过 Flask 渲染模板时 我在 Chrome 控制台中收到以下客户端错误 错误 找不到模块 jstransform visitors es6 templates
  • Plotly:如何设置文本格式(下划线、粗体、斜体)

    使用注释时 我尝试在绘图中为文本添加下划线 我使用添加注释 import plotly graph objects as go g go FigureWidget make subplots rows 1 cols 1 g update l
  • 如何停止 PythonShell

    如何终止 停止 Node js 中 PythonShell 执行的 Python 脚本的执行 我在交互模式下运行 输出通过 socket io 发送到给定的房间 如果没有更多的客户端连接到这个房间 我想停止 python 脚本的执行 这是我
  • python os.fork 使用相同的 python 解释器吗?

    据我所知 Python 中的线程使用相同的 Python 解释器实例 我的问题是与创建的流程相同os fork 或者每个进程创建的os fork有自己的翻译吗 每当你 fork 时 整个 Python 进程都会在内存中复制 包括Python
  • 使用最新值进行采样

    考虑以下系列 created at 2014 01 27 21 50 05 040961 80000 00 2014 03 12 18 46 45 517968 79900 00 2014 09 05 20 54 17 991260 636
  • 如何在 tkinter 后台运行函数[重复]

    这个问题在这里已经有答案了 我是 GUI 编程新手 我想用 tkinter 编写一个 Python 程序 我想要它做的就是在后台运行一个可以通过 GUI 影响的简单函数 该函数从 0 计数到无穷大 直到按下按钮为止 至少这是我想要它做的 但
  • 使用 Popen 打开进程并获取 PID

    我正在开发一个漂亮的小功能 def startProcess name path Starts a process in the background and writes a PID file returns integer pid Ch
  • Python:如何“杀死”类实例/对象?

    我希望 Roach 类在达到一定量的 饥饿 时 死亡 但我不知道如何删除该实例 我的术语可能有误 但我的意思是 窗户上有大量 蟑螂 我希望特定的蟑螂完全消失 我会向您展示代码 但它很长 我将蟑螂类添加到策划者类蟑螂种群列表中 一般来说 每个
  • 如何使用 QAbstractTableModel(模型/视图)将数据设置到 QComboBox?

    我希望能够设置itemData of a combobox当使用填充时QAbstractTableModel 但是 我只能从模型返回一个字符串data method 通常 当不使用模型时 可以像这样执行 Set text and data
  • 如何在Python中不使用库函数将字符串转换为整数?

    我正在尝试转换 a 546 to a 546 不使用任何库函数 我能想到的 最纯粹 gt gt gt a 546 gt gt gt result 0 gt gt gt for digit in a result 10 for d in 01
  • 如何将另一整列作为参数传递给 pandas fillna()

    我想用另一列中的值填充一列中的缺失值 使用fillna方法 我读到循环遍历每一行将是非常糟糕的做法 最好一次完成所有事情 但我不知道如何使用fillna 之前的数据 Day Cat1 Cat2 1 cat mouse 2 dog eleph
  • nltk 标记化和缩写

    我用 nltk 对文本进行标记 只是将句子输入到 wordpunct tokenizer 中 这会拆分缩写 例如 don t 到 don t 但我想将它们保留为一个单词 我正在改进我的方法 以实现更精确的文本标记化 因此我需要更深入地研究

随机推荐

  • 离子运行什么也不做

    执行时ionic run android device verbose 没有任何反应 它记录以下消息 但没有任何反应 离子编译虽然有效 ConfigXml setConfigXml path to my project resetConte
  • 在vue.js组件中,如何在css中使用props?

    我是 vue js 新手 这是我的问题 在 vue 文件中 如下所示
  • UIVisualEffectView 不会模糊它的 SKView 超级视图

    我正在编写一个 SpriteKit 游戏 并遇到了 SKView 上的视图模糊的问题 当游戏暂停时 它应该从右侧滑动 并且应该模糊其父视图 SKView 的内容 就像 iOS 7 中的控制中心面板一样 这是所需的外观 我实际得到的是 事实上
  • 动画后移动 Android 视图点击边界

    我在RelativeLayout中有两个视图 它们都填满了屏幕 所以视图B位于视图A的顶部 我还定义了一个动画 可以将视图B部分移出屏幕以在下面显示视图A 动画工作正常 但我遇到了视图边界不随视图移动的经典问题 因此我用来触发动画的按钮 位
  • 为什么 ko.mapping.fromJS 一个可以工作,而另一个却不能?

    我有一个名为 Foo 的类 并且 Foo 包含模型 我正在 Foo 上执行 ko applyBinding Foo 类有一个从服务器检索 JSON 的函数 然后我这样做 self Model ko mapping fromJS result
  • 使用 groovy 解析 Jenkin 的 shell 脚本中的 JSON 对象

    假设我有一个 JSON 如下 id 1 0 0 6 version 1 0 0 build 6 tag android v1 0 0 6 commitHash 5a78c4665xxxxxxxxxxe1b62c682f84 dateCrea
  • 在具有两个(或多个)片段的单个活动上实现 MVP

    我正在开发一个显示列表的小型应用程序 当单击某个项目时 它会打开一个包含项目详细信息的辅助屏幕 我想实现 MVP 作为我这个应用程序的架构 并且当我有能力时我一直在努力弄清楚如何做到这一点具有 2 个片段的单个 Activity 出现了一些
  • Qt - 几个 QTextBlock 内联

    是否可以将 QTextDocument 中的多个 QTextBlock 排列在一根水平线上 我需要知道单击了哪个文本块 并且 QTextBlock 很好用 因为它的方法 setUserState int 可以用来保存特定块的 id 有更好的
  • Google 端点 - Android GoogleAuthIOException Tic Tac Toe - 删除了 clientIds

    我下载了 Google Endpoints Tic Tac Toe 示例 Java 中的服务器代码 为了快速运行它 我从 API 定义中删除了 clientId 这样我就可以快速看到它在 API Explorer 中运行 Api name
  • ASP.Net MVC 3 中的远程验证:如何在操作方法中使用AdditionalFields

    我一直在使用新的 ASP Net MVC 3 RemoteAttribute 将远程调用发送到具有单个参数的操作方法 现在我想使用AdditionalFields 属性传入第二个参数 Remote IsEmailAvailable User
  • 在 Objective C 中将十六进制转换为 base64?

    我使用以下函数创建了字符串的 SHA256 编码 const char s 123456 cStringUsingEncoding NSASCIIStringEncoding NSData keyData NSData dataWithBy
  • 如何同时导入同一个python模块的两个版本?

    假设我有两个版本的 python 包 比如 lib 一个在文件夹里 version1 lib另一个是在 version2 lib 我试图通过这样做在一个会话中加载这两个包 sys path insert 0 version1 import
  • $model->get() 上 Eloquent 模型时间戳的时区错误,但使用 print_r() 正确

    对于任何使用 Laravel 7 的 Eloquent 模型来说 我都遇到了一些非常奇怪的事情 P S 我已经运行了我所做的每一个测试 php artisan optimize clear 我不知道我在这里错过了什么 我不会发布任何代码 因
  • Python相当于node.js的npm链接使用本地开发版本的要求?

    在 Node js 中 我习惯使用npm link让项目使用依赖项的自定义版本 来自节点文档 First npm link在包文件夹中将创建一个全局安装的符号链接prefix package name到当前文件夹 接下来 在其他一些地方 n
  • 获取代数项的系数

    给定代数项的输入 我试图获取变量的系数 输入中唯一的运算符是 并且只有一个变量 例子 2x 2 3x 4 gt 2 3 4 3 x gt 1 3 x 2 x gt 1 1 0 x x 3 gt 1 0 1 0 输入无效 2x 2 2x 2
  • 在 Windows 上的 python 2.7.8 上安装 pip

    我正在尝试安装 python 2 7 8 的模块 pip 即 arcGIS 为您安装的模块 我正在使用安装 pip 的引导方法 当我运行时遇到错误get pip py使用命令提示符 我收到以下错误 Warning from warnings
  • 如何复制特征矩阵

    我有两个Eigen MatrixXd他们总是有一排 输入矩阵是A我想将这个矩阵复制到另一个矩阵中B 但矩阵之间的列数可以不同 下面是一个例子 A 0 5 我需要创建一个B1行4列的矩阵 因此它是 B 0 5 0 5 0 5 0 5 But
  • Angular 2:视图未在数组推送时更新

    我有两个子组件 他们正在共享我使用 http get subscribe 方法加载的 json 文件中的数据 由于某种原因 当我将数据推入数组时 它不会在视图中更新 但它在控制台中显示了更新后的数组 应用程序组件从服务加载数据 this d
  • 海龟图形颜色检测

    有什么方法可以检测python中乌龟站立的颜色吗 例如 如果乌龟在黑色空间上 他就会向前移动 快速扫了一眼turtle文档 不 没有办法检测颜色 您可能应该记录迄今为止绘制的空间 每当绘制新空间时将其添加到集合中 那么 当你想知道一个旧空间
  • InvalidArgumentError:loc 处需要可广播形状(未知)

    背景 我对 Python 和机器学习完全陌生 我只是尝试根据在互联网上找到的代码建立一个 UNet 并希望将其适应我正在处理的情况 当试图 fit将UNet训练数据 我收到以下错误 InvalidArgumentError required