张量流(使用 Keras)中出现“InvalidArgumentError:形状不兼容:[10,2] vs. [10]”的原因是什么?

2024-02-02

我正在尝试使用 CNN 来使用 Tensorflow 和 Keras 进行对象检测。我对此相当陌生,所以我使用教程作为指南,但有我自己的设置和其他一些东西。我得到的错误是 Tensorflow 的形状与 [x,2] 与 [x] 不兼容,其中 x 是我拥有的任意数量的训练图像,2 是类的数量。我使用少量图像只是为了测试,但我很确定这不是问题?

我尝试了不同倍数的训练图像,但没有成功,并且我查看了 model.summary() 以查看模型是否完全按照我想要的方式布局。另外,我还打印了训练图像的形状及其标签,它们看起来是正确的。

图像大小为 28 x 28 像素,平面大小为 784,完整形状为 (28,28,1),1 是通道数(灰度)。我只有两个类,总共只有 10 张训练图像(如果这被认为是问题所在,我可以得到更多)。

model = Sequential()

model.add(InputLayer(input_shape=(img_size_flat,)))

model.add(Reshape(img_shape_full))

model.add(Conv2D(kernel_size=5, strides=1, filters=16, padding='same',
                 activation='relu', name='layer_conv1'))
model.add(MaxPooling2D(pool_size=2, strides=2))

model.add(Conv2D(kernel_size=5, strides=1, filters=36, padding='same',
                 activation='relu', name='layer_conv2'))
model.add(MaxPooling2D(pool_size=2, strides=2))

model.add(Flatten())

model.add(Dense(128, activation='relu'))

model.add(Dense(num_classes, activation='softmax'))

from tensorflow.python.keras.optimizers import Adam
optimizer = Adam(lr=1e-3)

model.compile(optimizer=optimizer,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

from tensorflow.python.keras.utils import to_categorical
model.fit(x=data.train,
    y=to_categorical(data.train_labels),
    batch_size=128, epochs=1)

我在标签上使用 to_categorical() 只是因为它们以某种方式转换为整数。我检查了它们是否保留了正确的值等。

我打印了模型摘要来检查布局:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
reshape (Reshape)            (None, 28, 28, 1)         0         
_________________________________________________________________
layer_conv1 (Conv2D)         (None, 28, 28, 16)        416       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 16)        0         
_________________________________________________________________
layer_conv2 (Conv2D)         (None, 14, 14, 36)        14436     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 36)          0         
_________________________________________________________________
flatten (Flatten)            (None, 1764)              0         
_________________________________________________________________
dense (Dense)                (None, 128)               225920    
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 258       
=================================================================
Total params: 241,030
Trainable params: 241,030
Non-trainable params: 0
_________________________________________________________________
None

我打印了 numpy 数据的大小:

print(data.train.shape)
print(data.train_labels.shape)

打印

(10, 784) #This is the shape of the images
(10, 2) #This is the shape of the labels

Error:

2019-04-08 10:46:40.239226: I tensorflow/stream_executor/dso_loader.cc:152] successfully opened CUDA library cublas64_100.dll locally
Traceback (most recent call last):
  File "C:/Users/bunja/Dev/testCellDet/project/venv/main.py", line 182, in <module>
    batch_size=128, epochs=1)
  File "C:\Users\bunja\Miniconda3\lib\site-packages\tensorflow\python\keras\engine\training.py", line 880, in fit
    validation_steps=validation_steps)
  File "C:\Users\bunja\Miniconda3\lib\site-packages\tensorflow\python\keras\engine\training_arrays.py", line 329, in model_iteration
    batch_outs = f(ins_batch)
  File "C:\Users\bunja\Miniconda3\lib\site-packages\tensorflow\python\keras\backend.py", line 3076, in __call__
    run_metadata=self.run_metadata)
  File "C:\Users\bunja\Miniconda3\lib\site-packages\tensorflow\python\client\session.py", line 1439, in __call__
    run_metadata_ptr)
  File "C:\Users\bunja\Miniconda3\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 528, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [10,2] vs. [10]
     [[{{node metrics/acc/Equal}}]]
     [[{{node loss/mul}}]]

可以看出,摘要显示dense_1 的输出形状为(None, 2)。这是我遇到问题的地方吗,因为我有不兼容形状的错误:[x,2] vs. [x]?我检查了我最初用来学习这些东西的教程,发现没有重大差异。我对此仍然很陌生,所以可能有点小,而且我可能会丢失一些信息,所以如果您有任何疑问,请询问。谢谢你!!!!!

额外信息:

GPU:GeForce GTX 1080 主要:6 次要:1 内存时钟频率(GHz):1.7335

张量流版本:1.13.1

Python版本:Python 3.7.3

以下是对 to_categorical 形状进行注释的代码:

print(data.train_labels.shape)
print()
print(to_categorical(data.train_labels).shape)

Output:

(10, 2)

(10, 2, 2)

我有一种感觉这可能是我错误的根源?但我不知道如何解决它......


to_categorical通常当你有列表格式的标签并且需要执行时使用one-hot编码以便将其转换为正确的形状,以便在训练期间将其提供给模型。

但就您而言,您的标签已经与您在模型中定义的形状相同,因此one-hot编码不是必需的。

您可以查看None as batch_size这将使您更清楚地了解数据如何从输入转换为输出。

谢谢!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

张量流(使用 Keras)中出现“InvalidArgumentError:形状不兼容:[10,2] vs. [10]”的原因是什么? 的相关文章

随机推荐

  • Laravel / Eloquent 模型属性可见性

    以前我使用的 ORM 将数据库列直接映射到类属性 这允许您特定的属性可见性 就像您通常限制对某些属性的访问一样 密码 使用 Eloquent 我似乎无法复制这一点 因为数据库列映射到不包含可见性的内部属性数组 我的愿望是将用户密码的访问范围
  • 在 Pandas DataFrame 中的字符串内漂亮地打印换行符

    我有一个 Pandas DataFrame 其中一列包含字符串元素 而这些字符串元素包含我想按字面打印的新行 但它们只是表现为 n在输出中 也就是说 我想打印这个 pos bidder 0 1 1 2 2 3 lt alice lt bob
  • Google API/获取目录联系人

    我需要从谷歌企业目录列表中获取联系人 电话列表 我尝试过 Google Contacts api 它对 我的联系人 下的所有联系人都可以正常工作 但不允许显示 目录 联系人 我有什么用途 如何访问这些联系人 公司联系人 要将用户添加到全局地
  • 跨活动访问领域数据库

    我有 3 项不同的活动 1 扩展了此活动中配置的应用程序和领域 2 数据从第二个活动添加到领域 3 数据将在第三个活动中显示 我无法完成第三部分 我无法在第三个活动中获取 Realm 实例 以下是应用程序 我提到的第一个活动 Overrid
  • Azure 物联网中心反馈接收器 ReceiveAsync 非常慢(15 秒)高延迟

    如果我通过 IoT 中心发送消息 Cloud 2 设备 var serviceMessage new Message Encoding ASCII GetBytes Hello Device serviceMessage Ack Deliv
  • 帮助创建带有弯曲标题部分的 HTML 页面

    我想知道 创建一个顶部标题部分看起来是斜角而不是直角的网页的最佳方法是什么 使用 html css 和图形 请参阅下图作为示例 我不确定如何使用图像 以便它们根据不同的浏览器大小 分辨率扩展 收缩 有人能给我一些帮助吗 或者也许给我指出一个
  • 如何知道我的 Android 设备上是否存在传感器?

    我想知道我的 Android 设备上是否存在传感器 例如加速度计 我正在处理 SensorManager 类 这是我正在使用的代码 sensorMgr SensorManager getSystemService SENSOR SERVIC
  • 由于命名空间为空,Python XPath lxml 无法读取 SVG 路径元素?

    我有一个 SVG Xml 文件 我想从中选择一些元素 为了 MCRE 我已将文件缩减为以下内容
  • pip install 语法允许不安全

    我试着跑 pip install upgrade allow insecure setuptools 但似乎不起作用 我的语法错误吗 这是在 ubuntu 13 10 上 我需要 allow insecure 因为我无法获得 公司代理 SS
  • 移动 Highstock 导航器位置

    是否可以将 Highstock 图表导航器从图表底部移至顶部 是的 这是可能的 请看示例 http jsfiddle net jBUGN http jsfiddle net jBUGN navigator top 40
  • php.ini 不允许我禁用_functions

    我把它放在 php ini 文件中 disable functions popen exec system passthru proc open shell exec show source phpinfo 但我仍然可以调用它们 测试了 e
  • 从 Excel 导入到数据表,跳过最后一列值

    我正在尝试将数据从 Excel 文件导入到数据表 但问题是最后一列值被跳过 其余列的值是完美的 我的 Excel 文件包含以下内容 导入后数据表中的数据如下 我的代码如下 Dim connExcel As New OleDbConnecti
  • 如何加快从栅格中提取缓冲区中土地覆盖类型比例的速度?

    我想提取 10 公里缓冲区中大约 30 000 个 SpatialLines 类对象的空间数据 并计算缓冲线周围每种土地覆盖类型的比例 我第一次使用这个功能crop裁剪我的光栅 然后 我使用了该功能extract 包栅格 计算 10 种土地
  • 我想要为我的第一个可可应用程序提供一个漂亮的自定义窗口[关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 我是一名图形设计师 转型为网页设计师 转型为网页开发人员 目前正在尝试转型为 Mac 开发人员 我的第一个应用程序即将完成 这是一个非常非
  • 用范围替换连续数字

    如何在单元格中查找连续数字 并将其替换为范围 例如 改变 1 3 5 15 16 17 25 28 29 31 to 1 3 5 15 17 25 28 29 31 这些数字已经排序 即按升序排列 Thanks 我想看看一个有趣的问题 无需
  • 如何在 XAML 中使用枚举类型?

    我在学习WPF时遇到了以下问题 我在 XAML 之外的另一个命名空间中有一个枚举类型 public enum NodeType Type SYSTEM 1 System Type DB 2 Database Type ROOT 512 Ro
  • mockito - 模拟接口 - 抛出 NullPointerException

    我在模拟后也收到空指针异常 请找到我的项目结构 this is the pet interface public interface Pet An implementation of Pet public class Dog extends
  • 在java中读取Doc或Docx文件的问题[关闭]

    这个问题不太可能对任何未来的访客有帮助 它只与一个较小的地理区域 一个特定的时间点或一个非常狭窄的情况相关 通常不适用于全世界的互联网受众 为了帮助使这个问题更广泛地适用 访问帮助中心 help reopen questions 我在阅读时
  • 相关模型的Build Model.query

    我需要构建一个查询 列出所有用户 最好的朋友和朋友总数 该列表必须按用户拥有的好友总数排序 我希望生成的查询具有以下结构 users id users userName users userEmail users userPhone tot
  • 张量流(使用 Keras)中出现“InvalidArgumentError:形状不兼容:[10,2] vs. [10]”的原因是什么?

    我正在尝试使用 CNN 来使用 Tensorflow 和 Keras 进行对象检测 我对此相当陌生 所以我使用教程作为指南 但有我自己的设置和其他一些东西 我得到的错误是 Tensorflow 的形状与 x 2 与 x 不兼容 其中 x 是