模型返回错误 - ValueError:logits 和标签必须具有相同的形状 ((None, 18) vs (None, 1))

2024-06-23

我正在使用基于 keras 的多标签分类器。我创建了一个加载训练和测试数据的函数,然后在函数本身内处理/拆分 X/Y。我在运行模型时遇到错误,但不太确定其含义:

这是我的代码:

def KerasClassifer(df_train, df_test):
  X_train = df_train[columnType].copy()
  y_train = df_train[variableToPredict].copy()
  labels = y_train.unique()
  print(X_train.shape[1])
  #using keras to do classification
  from tensorflow import keras
  from tensorflow.keras.models import Sequential
  from tensorflow.keras.layers import Dense, Dropout, Activation
  from tensorflow.keras.optimizers import SGD

  model = Sequential()
  model.add(Dense(5000, activation='relu', input_dim=X_train.shape[1]))
  model.add(Dropout(0.1))
  model.add(Dense(600, activation='relu'))
  model.add(Dropout(0.1))
  model.add(Dense(len(labels), activation='sigmoid'))

  sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
  model.compile(loss='binary_crossentropy',
                optimizer=sgd)

  model.fit(X_train, y_train, epochs=5, batch_size=2000)

  preds = model.predict(X_test)
  preds[preds>=0.5] = 1
  preds[preds<0.5] = 0

  score = model.evaluate(X_test, y_test, batch_size=2000)
  score

以下是我的数据的属性(如果有帮助的话):

x train shape  (392436, 109)
y train shape  (392436,)
len of y labels 18

如何修复代码以避免此错误?


如果你有 18 个类别,形状为y_train应该(392436, 18)。您可以使用tf.one_hot为了那个原因:

import tensorflow as tf

y_train = tf.one_hot(y_train, depth=len(labels))

如果您从一列中获取值,我怀疑这不是“多标签”,而是多类。一个样本真的可以属于多个类别吗?如果没有,您还需要更改其他一些内容。例如,您需要 softmax 激活:

model.add(Dense(len(labels), activation='softmax'))

还有分类交叉熵损失:

model.compile(loss='categorical_crossentropy', optimizer=sgd)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

模型返回错误 - ValueError:logits 和标签必须具有相同的形状 ((None, 18) vs (None, 1)) 的相关文章

随机推荐

  • 重置 JDBC Kafka 连接器以从头开始提取行?

    Kafka 连接器可以利用主键和时间戳来确定需要处理哪些行 我正在寻找一种重置连接器的方法 以便它从一开始就进行处理 因为要求是在分布式模式下运行 所以最简单的做法是将连接器名称更新为新值 这将提示在 connect offsets 主题中
  • 如何克隆 Angular UI 树中的节点?

    如何克隆 Angular UI 树中所有子节点的节点 现在我使用事件点击 ng click newSubItem this where newSubItem是函数 scope newSubItem function scope var no
  • Xcode 9:如何安装 ios 10 sdk

    鉴于目前 Xcode 9 是测试版 而今天的主要兴趣是了解 iOS 11 这个问题无疑很奇怪 在 Xcode 9 beta 中工作时 有没有办法将 iOS 10 作为基础 sdk Apple 是否需要像在 Xc8 中为以前的操作系统打包 X
  • 如何同时执行 typescript watch 和运行服务器?

    我正在用nodejs开发我的项目 我发现如果我需要编码和测试api 我会运行两个控制台 一个是执行typescript watch 另一个是执行server 我觉得这样太麻烦了我发现 github 上的其他开发人员已经编写了脚本packag
  • 使用神经网络包保存神经网络图时遇到问题 - R

    我正在使用neuralnetR 中的包 但是将绘图保存到磁盘时遇到问题 data iris attach iris library neuralnet nn lt neuralnet as numeric Species Sepal Len
  • flex:1 和 flex-grow:1 的区别

    In mdn https developer mozilla org en US docs Web CSS flex flex 1 意思是一样的 flex grow 1 但实际上它在浏览器中的显示有所不同 你可以在这个尝试一下jsFiddl
  • MVC 2 中不明确的操作方法

    我在 MVC 2 中遇到一些不明确的操作方法问题 我尝试实现此处找到的解决方案 ASP NET MVC 不明确的操作方法 https stackoverflow com questions 1045316 asp net mvc ambig
  • 如何使用 jQuery 获取文本框中输入的文本长度?

    如何使用 jQuery 获取文本框中输入的文本长度 var myLength myTextbox val length
  • 将 __m256 值设置为所有 1 位的最快方法

    如何将值中的所有位设置为 1 m256价值 使用 AVX 或 AVX2 内在函数 要获得全零 您可以使用 mm256 setzero si256 为了获得所有这些 我目前正在使用 mm256 set1 epi64x 1 但我怀疑这比全零情况
  • JSR-310 - 解析可变长度的秒分数

    有没有办法创建 JSR 310 格式化程序 能够解析以下具有可变长度秒分数的日期 时间 2015 05 07 13 20 22 276052 or 2015 05 07 13 20 22 276 示例代码 DateTimeFormatter
  • 生成 k 个成对独立的哈希函数

    我正在尝试实施一个计数最小草图 http en wikipedia org wiki Count Min sketchScala中的算法 所以我需要生成k个成对独立的哈希函数 这是一个比我以前编写过的任何东西都低的级别 除了算法类之外 我对
  • 如何创建关键点来计算 SIFT?

    我正在使用 OpenCV Python 我已经使用确定角点cv2 cornerHarris 输出的类型为dst 我需要计算角点的 SIFT 特征 输入到sift compute 必须是以下类型KeyPoint 我不知道如何使用cv2 Key
  • MY SQL - 错误代码:1010。删除数据库时出错(无法 rmdir;errno:13)

    当尝试删除 MySQL 中的数据库时 DROP DATABASE IF EXISTS temporarydata 我收到以下错误 Error Code 1010 Error dropping database can t rmdir tem
  • 如何通过map[string]interface{}递归迭代

    我遇到了一个问题 如何在附加条件下递归地迭代 map string interface 1 如果一个值是一个映射 递归调用该方法 2 如果一个值是一个数组 调用数组的方法 3 如果一个值不是一个映射 处理它 现在当方法尝试执行时doc th
  • 哪个版本的 SQLite 添加了对 Lead() 和 lag() 函数的支持?

    我正在尝试使用以下查询作为我的 Android SQLite 数据库中更大查询的一部分 但在我看来 我收到的错误表明 Android SQLite 尚不支持 Lead 函数 我尝试查看 sqlite org 上的发布日志 但无法找到何时添加
  • mysqli_real_escape_string - 100% 安全的示例

    我知道已经有人就这个话题提出了很多问题 我也知道要走的路是准备好的陈述 然而 我仍然没有完全理解以下是否或如何可能成为安全问题 mysqli new mysqli localhost root myDatabase mysqli gt se
  • swift 中的延迟函数[重复]

    这个问题在这里已经有答案了 我没有可供采样的代码或任何东西 因为我不知道该怎么做 但是有人可以告诉我如何使用 swift 将函数延迟一定的时间吗 您可以使用 GCD 在示例中延迟 10 秒 Swift 2 let triggerTime I
  • 使用身份验证令牌的 Axios 请求有时会在 Safari 中失败

    我正在开发一个使用 axios 0 19 0 的 React 16 9 0 单页应用程序 axios 请求使用令牌身份验证来访问运行 django rest framework 3 6 4 和 django cors headers 3 1
  • 在 iPhone 5s 或 64 位模拟器上测试 32 位 iOS 应用程序

    我有一个使用第三方库的应用程序 64 位版本的库存在错误 因此我不得不恢复到 32 位版本的框架 我想在 5s 上测试这个版本 但从 XCode 中 它将尝试在 64 位中构建 并且由于这个 32 位框架 构建将失败 我需要发布一个版本 但
  • 模型返回错误 - ValueError:logits 和标签必须具有相同的形状 ((None, 18) vs (None, 1))

    我正在使用基于 keras 的多标签分类器 我创建了一个加载训练和测试数据的函数 然后在函数本身内处理 拆分 X Y 我在运行模型时遇到错误 但不太确定其含义 这是我的代码 def KerasClassifer df train df te