Tensorflow、Keras:在多类分类中,准确率很高,但大多数类别的精度、召回率和 f1 分数为零

2023-12-29

一般说明:我的代码工作正常,但结果是有线的。我不知道问题出在

  • 网络结构,
  • 或者我向网络提供数据的方式,
  • 或其他任何东西。

我为这个错误苦苦挣扎了几个星期,到目前为止我已经改变了损失函数、优化器、数据生成器等,但我无法解决它。我很感激任何帮助。 如果以下信息还不够,请告诉我。

研究领域:我正在使用张量流、keras 进行多类分类。该数据集有 36 个二元人类属性。我使用了resnet50,然后对于身体的每个部分(头部,上半身,下半身,鞋子,配件),我都在网络中添加了一个单独的分支。该网络有 1 个输入图像,带有 36 个标签和 36 个输出节点(具有 sigmoid 激活的 36 个定义层)。

Problem:问题是 keras 报告的准确性很高,但大多数输出​​的 f1-score 非常低或为零(即使我在编译网络时使用 f1-score 作为指标,用于验证的 f1-socre 是很坏)。

a训练结束后,当我在预测模式下使用网络时,对于某些类,它始终返回一/零。这意味着网络无法学习(即使我使用加权损失函数或焦点损失函数。)

为什么奇怪呢?因为,即使在第一个 epoch 之后,最先进的方法也会报告较高的 f1 分数(例如https://github.com/chufengt/iccv19_attribute https://github.com/chufengt/iccv19_attribute,我已经在我的电脑上运行它并在一个时期后获得了良好的结果)。

部分代码:

        print("setup model ...")
        input_image = KL.Input(args.img_input_shape, name= "input_1")
        C1, C2, C3, C4, C5 = resnet_graph(input_image, architecture="resnet50", stage5=False, train_bn=True)
        output_layers = merged_model (input_features=C4)
        model = Model(inputs=input_image, outputs=output_layers, name='SoftBiometrics_Model')

...

        print("model compiling ...")
        OPTIM = optimizers.Adadelta(lr=args.learning_rate, rho=0.95)
        model.compile(optimizer=OPTIM, loss=binary_focal_loss(alpha=.25, gamma=2), metrics=['acc',get_f1])
        plot_model(model, to_file='model.png')

...

        img_datagen = ImageDataGenerator(rotation_range=6, width_shift_range=0.03, height_shift_range=0.03, brightness_range=[0.85,1.15], shear_range=0.06, zoom_range=0.09, horizontal_flip=True, preprocessing_function=preprocess_input_resnet, rescale=1/255.)
        img_datagen_test = ImageDataGenerator(preprocessing_function=preprocess_input_resnet, rescale=1/255.)

        def multiple_outputs(generator, dataframe, batch_size, x_col):
          Gen = generator.flow_from_dataframe(dataframe=dataframe,
                                               directory=None,
                                               x_col = x_col,
                                               y_col = args.Categories,
                                               target_size = (args.img_input_shape[0],args.img_input_shape[1]),
                                               class_mode = "multi_output",
                                               classes=None,
                                               batch_size = batch_size,
                                               shuffle = True)
          while True:
            gnext = Gen.next()
            # return image batch and 36 sets of lables
            labels = gnext[1]
            output_dict = {"{}_output".format(Category): np.array(labels[index]) for index, Category in enumerate(args.Categories)}
            yield {'input_1':gnext[0]}, output_dict

    trainGen = multiple_outputs (generator = img_datagen, dataframe=Train_df_img, batch_size=args.BATCH_SIZE, x_col="Train_Filenames")
    testGen = multiple_outputs (generator = img_datagen_test, dataframe=Test_df_img, batch_size=args.BATCH_SIZE, x_col="Test_Filenames")

    STEP_SIZE_TRAIN = len(Train_df_img["Train_Filenames"]) // args.BATCH_SIZE
    STEP_SIZE_VALID = len(Test_df_img["Test_Filenames"]) // args.BATCH_SIZE

    ...

    print("Fitting the model to the data ...")
            history = model.fit_generator(generator=trainGen,
                                         epochs=args.Number_of_epochs,
                                         steps_per_epoch=STEP_SIZE_TRAIN,
                                         validation_data=testGen,
                                         validation_steps=STEP_SIZE_VALID,
                                         callbacks= [chekpont],
                                         verbose=1)

您有可能将二进制 f1-score 传递给compile功能。这应该可以解决问题 -

pip install tensorflow-addons

...

import tensorflow_addons as tfa 

f1 = tfa.metrics.F1Score(36,'micro' or 'macro')

model.compile(...,metrics=[f1])

您可以阅读有关如何计算 f1-micro 和 f1-macro 以及哪些内容有用的更多信息here https://towardsdatascience.com/a-tale-of-two-macro-f1s-8811ddcf8f04.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow、Keras:在多类分类中,准确率很高,但大多数类别的精度、召回率和 f1 分数为零 的相关文章

随机推荐

  • android studio onMapReady 未调用

    我想将地图视图集成到我的一个视图中 我已经生成了一个新的地图片段 它以不同的视角出现 并且像魅力一样发挥作用 然后 我尝试将代码集成到正常活动中 带有操作栏等 它有点有效 在屏幕上显示得很好 但 onMapReady 在那种环境中永远不会被
  • django中的自定义用户模型不允许在admin中设置密码

    我创建了一个自定义用户模型 并在我的应用程序中成功使用了该模型 问题是 在管理中 在用户编辑屏幕上 我显示当前密码哈希 而不是用于设置密码的非常有用的界面 我在 Python 2 7 上使用 Django 1 5b1 为了管理用户界面 如何
  • 如何在 Java 8 中从有限流构建无限重复流?

    我怎样才能转动有限的事物流Stream
  • 更改 ionic 2 应用程序中的 iOS 状态栏颜色

    我正在按照 ionic 2 文档设置 iOS 状态栏颜色 但它不起作用 状态栏文本是白色的 这意味着在我的白色背景上它是不可见的 我在应用程序构造函数中放入的代码是 StatusBar overlaysWebView true Status
  • 从 Access DB 发送包含动态名称附件的电子邮件

    我不知道如何让这个东西继续工作 下面的代码发送一封电子邮件 其中包含 MS Access 2010 的附件 问题是 如果它需要固定的文件名 那么当我使用每个文件末尾的日期时 我的文件名会发生变化 示例 green 12 04 2012 cs
  • 使用 AWK 中的第一个字段作为文件名

    该数据集是一个包含三列的大文件 一个部分的 ID 一些不相关的内容和一行文本 示例可能如下所示 A01 001 This is a simple test A01 002 Just for exemplary purpose A01 003
  • 将 NServiceBus 与 Asp.Net MVC 2 结合使用

    有没有办法将 NServiceBus 与 Asp Net MVC 2 一起使用 我想将请求消息从 Asp Net MVC2 应用程序发送到服务 该服务处理该消息并回复响应消息 有没有办法清楚地做到这一点 NServiceBus 仅支持注册状
  • Jquery 冲突导致错误

    从事具有多种功能的项目 例如 谷歌翻译 图像滑块 使用画廊 弹出窗口 使用阴影框 JavaScript 水平菜单栏 Now we are getting jquery conflict in it and error message suc
  • 从 Docker 容器获取 Mac 地址

    是否可以从Docker容器中获取主机的MAC地址并将其写入文本文件中 docker inspect
  • GCS - Python 下载具有目录结构的 blob

    我使用 GCS python SDK 和 google API 客户端的组合来循环启用版本的存储桶并根据元数据下载特定对象 from google cloud import storage from googleapiclient impo
  • 计算负载并避免光标

    给出下面的表结构 它表示乘客通过门磁上下车的公交路线 而且 有一个人坐在那辆公共汽车上 手里拿着一个记着点数的剪贴板 CREATE TABLE BusLoad ROUTE CHAR 4 NOT NULL StopNumber INT NOT
  • 从 Powershell 调用 AppDomain.DoCallback

    这是基于 Stack Overflow 问题 如何在新的 AppDomain 中将程序集加载为仅反射 https stackoverflow com questions 35249342 how to load an assembly as
  • 选择 Plsql 中的第二行

    假设我有下表 SomeTable id price 如何从此表中选择价格第二高的行 注意 这必须在 Pl SQL 中以与数据库无关的方式完成 是否可以在没有任何循环的情况下做到这一点 我知道这是如何使用 Oracle 结构来完成的 例如ro
  • “不要在设计中使用抽象基类;但在建模/分析中”

    虽然我在 OOAD 方面有一些经验 但我是 SOA 的新手 SOA 设计的指导原则之一是 仅使用抽象类进行建模 从设计中省略它们 抽象的使用有助于建模 分析阶段 在分析阶段 我提出了一个 BankAccount 基类 从它派生的专门类是 F
  • 将 Java 7 与官方 Google Appengine Maven 插件结合使用

    我在使用时遇到问题官方 Maven 插件 https developers google com appengine docs java tools maven以及带有 Google Appengine 的 Java 7 配置 我的项目配置
  • 优先级队列数据结构

    假设我有一个优先级队列 它按升序删除元素 并且存储在该队列中的是元素1 1 3 0 1 递增的顺序是0 then 1 then 3 但是有三个元素1s 当我打电话时remove它会首先删除0 但如果我打电话remove它会再次删除所有三个吗
  • 提高功能性能

    我正在编写一个小程序来检查以下问题的解决方案布罗卡的问题 http en wikipedia org wiki Brocard s problem或所谓的棕色数字我首先用 ruby 创建了一个草稿 class Integer def fac
  • 在 Xcode 中创建和编辑 plist 文件的步骤

    我想添加密钥对值plist 我不知道如何在 XCode 中添加 plist 文件 只是我想将这些详细信息添加到名为 的 plist 文件中 Mobile plist Apple iPhone iPod iPad Samsung Galaxy
  • Java 中可以使用 C# 风格的对象初始化吗?

    在 C 中可以这样写 MyClass obj new MyClass field1 hello field2 world field3 new MyOtherClass etc 我可以看到数组初始化可以用类似的方式完成 但是在 Java 中
  • Tensorflow、Keras:在多类分类中,准确率很高,但大多数类别的精度、召回率和 f1 分数为零

    一般说明 我的代码工作正常 但结果是有线的 我不知道问题出在 网络结构 或者我向网络提供数据的方式 或其他任何东西 我为这个错误苦苦挣扎了几个星期 到目前为止我已经改变了损失函数 优化器 数据生成器等 但我无法解决它 我很感激任何帮助 如果