迁移学习之resnet50——解决过拟合及验证集上准确率上不去问题

2023-11-15

keras之resnet50迁移学习做分类

问题1描述:迁移学习用resnet50做分类,验证集上的准确率一直是一个大问题,有时候稳定在一个低的准确率,我的一次是一直在75%上下波动。

问题2描述:resnet50迁移学习,训练集上的准确率一直在攀升,但验证集上的准确率一直上不去,一定程度上出现了过拟合现象,但加很多的BN、dropout、l1和l2正则化手段都不能有效的解决问题。

***问题1答案:***这个问题网络设计没有问题的话,一般出现在训练的数据量上,数据量偏少就会出现验证集上准确率一直很低。

问题2答案:

2020/10/12更新

根本原因可参考:https://github.com/keras-team/keras/pull/9965
解决方案:在这里插入图片描述其实就是加一个参数:layers=tf.keras.layers

下面的采坑可不看

--------------------------------------------

-----------------------------------------------

先来看一下一般resnet50迁移学习的网络设计:

 base_model = ResNet50(weights='imagenet', include_top=False,
                              input_shape=(image_size, image_size, 3), )
x = base_model.output
x = GlobalAveragePooling2D(name='average_pool')(x)
x = Flatten(name='flatten')(x)

这是一个典型的残差网络做迁移学习的套路,很多人都是这么做的,但真的有很高的准确率吗?反正我试了很多次,一直出现验证集上的准确率很低上不去的问题。不管怎么用防止过拟合的手段,效果都不是很好。后来研究BN层看了几篇相关的论文,发现包括resnet,inception等模型都包含了Batch Normalization层,如果使用pretrained参数进行finetune,这些BN层一般情况下使用了K.learning_phase的值作为is_training参数的默认值,因此导致训练的时候使用的一直是mini batch的平均值 ,由于trainable在finetune时候一般设置为false了导致整个layer 不会update,因此moving_mean\variance根本没有更新。导致你在test时用的moving_mean\variance全是imagenet数据集上的值。
参考链接:https://github.com/keras-team/keras/pull/9965
修正后的代码:

K.set_learning_phase(0)
base_model = ResNet50(weights='imagenet', include_top=False,
                              input_shape=(image_size, image_size, 3), )
 K.set_learning_phase(1)
x = base_model.output
x = GlobalAveragePooling2D(name='average_pool')(x)
x = Flatten(name='flatten')(x)
x = Dense(2048, activation='relu', kernel_regularizer=regularizers.l2(0.0001), )(x)
x = BatchNormalization()(x)
x = Dense(1024, activation='relu', kernel_regularizer=regularizers.l2(0.0001))(x)
x = BatchNormalization(name='bn_fc_01')(x)
predictions = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)

这样更正后可以使得准确率有一个良好的上升,但优化的不够彻底,再有一些小的技巧可以让你的验证集上的准确率有更好的提升。这是一个试验,没有理论支持,如果上述的方案不能满足你对准确率的要求,不妨试试下面这个方案:

K.set_learning_phase(0)
Inp = Input((224, 224, 3))
base_model = ResNet50(weights='imagenet', include_top=False,
                              input_shape=(image_size, image_size, 3), )
 K.set_learning_phase(1)
 x = base_model(Inp)
x = GlobalAveragePooling2D(name='average_pool')(x)
x = Flatten(name='flatten')(x)
...
predictions = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=Inp, outputs=predictions)

可看出区别了吗?下面代码将输入变更了,不用resnet的输出做输入,直接定义自己的输入,我有测试过,这样做确实对准确率有一定的提升。基本上resnet50迁移学习的坑就踩到这里。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

迁移学习之resnet50——解决过拟合及验证集上准确率上不去问题 的相关文章

随机推荐

  • js 实现多重罗盘转动

    引子 这几天一直在忙一个可滑动的转盘的demo 网上也有类似的例子 但是根据老板的需求来改他们的代码 还不如重新写个完全符合需求的插件 想法很美好 但是新手上路 效果链接文末 需求 image 这个demo给的非常简单 能转动的地方有三处
  • 方差、标准差、均方差、均方误差区别总结

    转载 http blog csdn net Leyvi Hsing article details 54022612 一 百度百科上方差是这样定义的 variance 是在概率论和统计方差衡量随机变量或一组数据时离散程度的度量 概率论中方差
  • Flutter 重写原生App -- 02 基础知识 一路踩坑

    Pubspec Assist 插件 快速添加 pubspec yaml 的依赖 device info 0 4 0 3 可查看当前 链接的设备是 Ios Android 并且获得设备信息 Dart 语法 https dart dev gui
  • SVN配置

    1 SVN插件下载地址 http subclipse tigris org update 1 4 x http subclipse tigris org servlets ProjectDocumentList expandFolder 2
  • 【特征工程】特征选择与特征学习

    特征选择与特征学习 在机器学习的具体实践任务中 选择一组具有代表性的特征用于构建模型是非常重要的问题 特征选择通常选择与类别相关性强 且特征彼此间相关性弱的特征子集 具体特征选择算法通过定义合适的子集评价函数来体现 在现实世界中 数据通常是
  • Python实现ACO蚁群优化算法优化LightGBM分类模型(LGBMClassifier算法)项目实战

    说明 这是一个机器学习实战项目 附带数据 代码 文档 视频讲解 如需数据 代码 文档 视频讲解可以直接到文章最后获取 1 项目背景 蚁群优化算法 Ant Colony Optimization ACO 是一种源于大自然生物世界的新的仿生进化
  • 背景差分法《python图像处理篇》

    引言 背景差分常用于运动目标检测 是一种动态检测的方法 即观察两帧图像间的差距 哪个物体存在相对运动 其基本原理就是将两幅图像做减法 只不过这里的两幅图像分为输入图像和背景图像 此方法对于动态常见特别敏感 例如监控环境下的下雪 刮风时的树叶
  • ❤ jeecgboot 使用

    jeecgboot 使用 JDictSelectTag 字典下拉去掉请选择 JDictSelectTag
  • 最近大火的ChatGPT和RPA机器人相结合会带来什么前景?

    ChatGPT是由人工智能技术驱动的自然语言处理工具 它可以通过理解和学习人类语言进行对话 并根据聊天的上下文进行互动 真正像人类一样进行聊天和交流 甚至完成撰写电子邮件 视频脚本 文案 翻译 代码 写论文等任务 ChatGPT和RPA都是
  • line-height行高的解析

  • golang 框架_Go Web 框架 Gin 实践9—将Golang应用部署到Docker

    Go语言中文网 致力于每日分享编码知识 欢迎关注我 每天一起进步 项目地址 https github com EDDYCJY go gin example 注 开始前你需要安装好 docker 配好镜像源 本章节源码在 f 20180324
  • 同花顺某v参数详解

    声明 本文章中所有内容仅供学习交流 抓包内容 敏感网址 数据接口均已做脱敏处理 严禁用于商业用途和非法用途 否则由此产生的一切后果均与作者无关 若有侵权 请联系我立即删除 目标站点 aHR0cDovL3EuMTBqcWthLmNvbS5jb
  • 自定义控件中 wrap_content 属性无效的分析解决

    问题 在自定义一个类似锁屏页面时间日期样式的控件 继承 View 的时候 发现在 xml 中使用 wrap content 属性相当于使用了 match parent 属性 原因分析 进入View的源码 可以看到 onMeasure 的方法
  • jdbc连接数据库(MySQL 8.0.19)url设置

    本文只针对下述版本的url设置问题 我的JDK版本是11 0 1 MySQL版本8 0 19 MySQL的8系列版本应该都可以 一般连接失败的原因是url没设置好 这里我所设置的url亲测有效 String urlString jdbc m
  • Kali之渗透攻击

    渗透攻击是指黑客为了获得非法利益 通过各种手段进入网络系统 计算机系统中 在未经授权的情况下获取信息 利用漏洞控制系统和执行越权操作的一种行为 其目的在于获取非法利益 破坏或者窃取关键数据 以及对网络系统进行控制 在学习渗透攻击这一知识点过
  • world特殊符号

    world特殊符号 论文需要 和圆里面一个乘号 论文需要 和圆里面一个乘号 1 首先打开word文档 找到要 插入符号 的地方 2 选择插入功能下面的 符号按钮 3 选择符号下面的 其他符号 4 将字体选为 symbol 这个字体 5 在这
  • Hive 一文读懂

    Hive 简介 1 1 什么是Hive 1 hive简介 Hive 由Facebook开源用于解决海量结构化日志的数据统计 Hive是基于Hadoop的一个数据仓库工具 可以将结构化的数据文件映射为一张表 并提供类SQL查询功能 2 Hiv
  • opencv边缘检测-拉普拉斯算子

    sobel算子一文说了 索贝尔算子是模拟一阶求导 导数越大的地方说明变换越剧烈 越有可能是边缘 那如果继续对f t 求导呢 可以发现 边缘处 的二阶导数 0 我们可以利用这一特性去寻找图像的边缘 注意有一个问题 二阶求导为0的位置也可能是无
  • python报错AttributeError module ‘scipy.misc‘ has no attribute ‘imresize‘和 ‘imread‘

    python报错AttributeError module scipy misc has no attribute imresize 和 imread 报错原因 scipy版本过高 解决方案 降低scipy版本 如下 pip install
  • 迁移学习之resnet50——解决过拟合及验证集上准确率上不去问题

    keras之resnet50迁移学习做分类 问题1描述 迁移学习用resnet50做分类 验证集上的准确率一直是一个大问题 有时候稳定在一个低的准确率 我的一次是一直在75 上下波动 问题2描述 resnet50迁移学习 训练集上的准确率一