预训练网络的模型微调方法

2023-11-05

  • 是什么
    神经网络需要数据来训练,从数据中获得信息,进而转化成相应的权重。这些权重能够被提取出来,迁移到其他的神经网络中。
    迁移学习:通过使用之前在大数据集上经过训练的预训练模型,我们可以直接使用相应的结构和权重,将他们应用在我们正在面对的问题上。即将预训练的模型“迁移”到我们正在应对的特定问题中。
    在选择预训练模型时需要注意,如果我们的问题与预训练模型训练情景有很大出入,那么模型所得到的的预测结果会非常不准确。举例来说,如果把一个原本用于语音识别的模型用作用户识别,那结果肯定是不理想的。
    ImageNet数据集已经被广泛用作训练集,因为它规模足够大(包括120万张图片),有助于训练普适模型。在迁移学习中,这些预训练的网络对于ImageNet以外的图片表现出了很好的泛化性能。
    微调(fine tuning)可以省去大量的计算资源和计算时间,提高计算效率,甚至提高准确率。

  • 什么时候用
    使用的数据集和预训练模型的数据集相似;
    自己搭建或使用的CNN模型正确率太低;
    数据集相似,但数据集数量少;
    计算资源少。

  • 怎么用
    数据量少,且数据高度相似: - 在这种情况下,我们所做的只是修改最后几层或最终的softmax图层的输出类别。
    数据量少,但数据相似度低: 在这种情况下,我们可以冻结预训练模型的初始层(比如k层),并再次训练剩余的(n-k)层。由于新数据集的相似度较低,因此根据新数据集对较高层进行重新训练具有重要意义。
    数据量大,数据相似度低:此时最好根据我们自己的数据从头开始训练神经网络(Training from scatch)。
    数据量大,数据相似度高: 这是理想情况。在这种情况下,预训练模型应该是最有效的。使用模型的最好方法是保留模型的体系结构和模型的初始权重。然后,我们可以使用在预先训练的模型中的权重来重新训练该模型。

  • 注意事项

  1. 使用较小的学习率来训练网络。由于我们预计预先训练的权重相对于随机初始化的权重已经相当不错,我们不想过快地扭曲它们太多。通常的做法是使初始学习率比用于从头开始训练(Training from scratch)的初始学习率小10倍。

  2. 如果数据集数量过少,我们进来只训练最后一层,如果数据集数量中等,冻结预训练网络的前几层的权重也是一种常见做法。这是因为前几个图层捕捉了与我们的新问题相关的通用特征,如曲线和边。我们希望保持这些权重不变。相反,我们会让网络专注于学习后续深层中特定于数据集的特征。

  • 预训练模型修剪+微调:
    1. 在已经训练好的基网络上添加自定义网络;
    2. 冻结基网络,训练自定义网络;
    3. 解冻部分基网络,联合训练解冻层和自定义网络。

注意,在联合训练解冻层和自定义网络之前,通常要先训练自定义网络,否则,随机初始化的自定义网络权重会将误差信号传到解冻层,破坏解冻层以前学到的表示,使得训练成本增大。

pytorch四种冻结层的方式:
假设模型定义如下:

class Char3SeqModel(nn.Module):
    
    def __init__(self, char_sz, n_fac, n_h):
        super().__init__()
        self.em = nn.Embedding(char_sz, n_fac)
        self.fc1 = nn.Linear(n_fac, n_h)
        self.fc2 = nn.Linear(n_h, n_h)
        self.fc3 = nn.Linear(n_h, char_sz)
        
    def forward(self, ch1, ch2, ch3):
        # do something
        out = #....
        return out

model = Char3SeqModel(10000, 50, 25)

假设需要冻结FC1

  1. 方法1:设置requires_grad为False
# 冻结
model.fc1.weight.requires_grad = False
optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.1)
# 
# compute loss 
# loss.backward()
# optmizer.step()

# 解冻
model.fc1.weight.requires_grad = True
optimizer.add_param_group({'params': model.fc1.parameters()})

  1. 方法2:最简单的方式是在定义optimizer的时候,不要加入你想冻结的那一层的参数。
# 冻结
optimizer = optim.Adam([{'params':[ param for name, param in model.named_parameters() if 'fc1' not in name]}], lr=0.1)
# compute loss
# loss.backward()
# optimizer.step()

# 解冻
optimizer.add_param_group({'params': model.fc1.parameters()})

  1. 方法3:将原来layer的weight缓存下来,每次反向传播之后,再将原来的weight赋值给相应的layer
fc1_old_weights = Variable(model.fc1.weight.data.clone())
# compute loss
# loss.backward()
# optimizer.step()
model.fc1.weight.data = fc1_old_weights.data

  1. 方法4:使用 torch.no_grad()
    这种方式只需要在网络定义中的forward方法中,将需要冻结的层放在使用 torch.no_grad()下。
class xxnet(nn.Module):
    def __init__():
        ....
        self.layer1 = xx
        self.layer2 = xx
        self.fc = xx

    def forward(self.x):
        with torch.no_grad():
            x = self.layer1(x)
            x = self.layer2(x)
        x = self.fc(x)
        return x

这种方式则是将layer1和layer2定义的层冻结,只训练fc层的参数。
5. 终极方法实现

作者:肥波喇齐
链接:https://www.zhihu.com/question/311095447/answer/589307812
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

from collections.abc import Iterable

def set_freeze_by_names(model, layer_names, freeze=True):
    if not isinstance(layer_names, Iterable):
        layer_names = [layer_names]
    for name, child in model.named_children():
        if name not in layer_names:
            continue
        for param in child.parameters():
            param.requires_grad = not freeze
            
def freeze_by_names(model, layer_names):
    set_freeze_by_names(model, layer_names, True)

def unfreeze_by_names(model, layer_names):
    set_freeze_by_names(model, layer_names, False)

def set_freeze_by_idxs(model, idxs, freeze=True):
    if not isinstance(idxs, Iterable):
        idxs = [idxs]
    num_child = len(list(model.children()))
    idxs = tuple(map(lambda idx: num_child + idx if idx < 0 else idx, idxs))
    for idx, child in enumerate(model.children()):
        if idx not in idxs:
            continue
        for param in child.parameters():
            param.requires_grad = not freeze
            
def freeze_by_idxs(model, idxs):
    set_freeze_by_idxs(model, idxs, True)

def unfreeze_by_idxs(model, idxs):
    set_freeze_by_idxs(model, idxs, False)
# 冻结第一层
freeze_by_idxs(model, 0)
# 冻结第一、二层
freeze_by_idxs(model, [0, 1])
#冻结倒数第一层
freeze_by_idxs(model, -1)
# 解冻第一层
unfreeze_by_idxs(model, 0)
# 解冻倒数第一层
unfreeze_by_idxs(model, -1)


# 冻结 em层
freeze_by_names(model, 'em')
# 冻结 fc1, fc3层
freeze_by_names(model, ('fc1', 'fc3'))
# 解冻em, fc1, fc3层
unfreeze_by_names(model, ('em', 'fc1', 'fc3'))

代码参考地址

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

预训练网络的模型微调方法 的相关文章

  • 十大图像数据标注工具大合集

    给大家推荐十大标注工具 1 常见的标注方法 人工数据标注 的好处是标注结果比较可靠 自动数据标注 一般都需要二次复核 避免程序错误 外包数据标注 很多时候会面临数据泄密与流失风险 2 人工标注工具 可以分为客户端与WEB端标注工具 推荐使用
  • CUDA 7.5 安装及配置(WIN7 64 英伟达G卡 VS2013)

    第一步 下载cuda 7 5最新版本 https developer nvidia com cuda downloads 第二步 运行安装程序 安装过程中选择自定义 第三步 安装完毕 可以看到系统中多了CUDA PATH和CUDA PATH
  • Python算法--求1-100之间所有的偶数和奇数

    i 1 sum1 0 sum2 0 while i lt 100 if i 2 0 sum1 i else sum2 i i 1 print 1 100之间偶数和为 d sum1 print 1 100之间奇数和为 d sum2
  • MySQL递归查询

    在平常开发过程中 递归查询随处可见 话不多说 本人在项目中遇到的是编码和父级编码 需要逐渐查询 1 表结构 2 SQL SELECT id SELECT REPLACE GROUP CONCAT code FROM ddm file dir
  • windows下网卡绑定多个IP地址的方法

    在Windows下 尤其是Windows的服务器 有时候需要在一张网卡上绑定多个IP地址 有两种方法可以完成 一是使用控制面板可视化配置 进入控制面板 网络设置 Network and Internet 网络和共享中心 Network an

随机推荐

  • Deep Learning for Anomaly Detection: A Review

    本文是对 Deep Learning for Anomaly Detection A Review 的翻译 深度学习进行异常检测 综述 摘要 1 引言 2 异常检测 问题复杂性和挑战 2 1 主要问题复杂性 2 2 深度异常检测面临的主要挑
  • 产品经理实战--抖音

    目录 抖音 短视频发展历程 短视频概念 短视频在马斯洛理论上的应用 短视频行业发展历程 规划设计一款产品的思路和流程 报名学习 后台产品 前端 包含app 顶部导航 侧边导航 流程图 设计一款APP的思路和流程 抖音 我规划设计APP的思路
  • cvc-complex-type.2.4.a: 发现了以元素 ‘base-extension‘ 开头的无效内容。(解决方案的最全整理)

    记录一下 新电脑安装新版的Android Studio 小蜜蜂版本 导入那些gradle还是5 1 1 distributionUrl https services gradle org distributions gradle 5 1 1
  • Java字符串不相同但HashCode相同的例子(算法)

    相关文章 为什么重写equals方法时必须重写hashcode方法 Java字符串不相同但HashCode相同的例子 算法 Java字符串不相同但HashCode相同的例子 public static void main String ar
  • 符号 ?. 是什么

    在修改问题的时候 看到别人的代码是这样的 如图 不太懂 是干嘛的 于是去查关键字 js 发现这是ES2020 ES11 新增的 可选链操作符 可选链操作符 允许读取位于连接对象链深处的属性的值 而不必明确验证链中的每个引用是否有效 操作符的
  • 遗传算法求解TSP及其变式

    刚刚接触遗传算法 主要学习的是以下几位老师的文章 抱拳 链接附上 https blog csdn net u010451580 article details 51178225 https blog csdn net wangqiuyun
  • 【案例教学】华为云API图像搜索ImageSearch的快捷性—AI帮助您快速归类图片

    云服务 API SDK 调试 查看 我都行 阅读短文您可以学习到 人工智能AI同类型的相片合并归类 1 IntelliJ IDEA 之API插件介绍 API插件支持 VS Code IDE IntelliJ IDEA等平台 以及华为云自研
  • 数据科学家:在实际工作后,我深刻认识到的五点

    我从事数据科学工作了已经将近半年了 我一路上成长了很多 也犯了很多错误 并在这一过程中从学习了很多 不存在没有失败 只有反馈 而现实世界就是一种反馈机制 是的 学习之旅并不容易 我们该做的就是继续努力 不断学习和改进 通过这段时间的学习历程
  • xshell session配置文件转移

    问题描述 笔者一直用的xshell5有天无聊升级了xshell6 结果发现只能打开4个会话而且xshell5的session配置也都没了 用了一段时间后想解决此问题 结果xshell6搞成收费的版本了 气死人果断装回xshell5合并ses
  • 查看npm模块的版本列表以及版本发布日期,解决模块版本不兼容问题。

    最近想用 vue2 0 less 写一个demo 加载 less loader 时 因为版本太高 项目报错了 也不知道less loaser用什么版本 于是就有了以下操作 查看vue3发布前 vue2的最后一个版本的发布日期 再找到这个离这
  • 那么多的数据可视化图表,你选对了吗?

    作者 邻川 程序员懂画图 一宝变三宝 本期 菜鸟国际物流技术部高级开发工程师邻川将分享他在数据可视化图标方面的积累 常听到一句话 能用图描述的就不用表 能用表就不用文字 这句话也直接的表明了 在认知上 大家对于图形的敏感度远比文字高 但同时
  • 自定义ViewGroup--浮动标签的实现

    前面在学习鸿洋大神的一些自定义的View文章 看到了自定义ViewGroup实现浮动标签 初步看了下他的思路以及结合自己的思路完成了自己的浮动标签的自定义ViewGroup 目前实现的可以动态添加标签 可点击 效果图如下 1 思路 首先在o
  • Spring AOP报错之通配符的匹配很全面, 但无法找到元素 'aop:config' 的声明

    问题 配置完aop后 运行时报错 如下 Library Java JavaVirtualMachines jdk1 8 0 111 jdk Contents Home bin java ea Didea launcher port 7534
  • 2021年网络空间安全学院预推免面试经验总结

    2021年网络空间安全学院预推免面试经验总结 建议 结合学科评估 跟着自己的判断走 https www dxsbb com news 1797 html https www zhihu com question 19825429 answe
  • 精品微信小程序ssm校友录网站+后台管理系统

    博主介绍 在职Java研发工程师 专注于程序设计 源码分享 技术交流 专注于Java技术领域和毕业设计 温馨提示 文末有 CSDN 平台官方提供的老师 Wechat QQ 名片 项目名称 精品微信小程序ssm校友录网站 后台管理系统 演示视
  • 【网络编程】传输层协议——UDP协议

    文章目录 一 传输层的意义 二 端口号 2 1 五元组标识一个通信 2 2 端口号范围划分 2 3 知名端口号 2 4 绑定端口号数目问题 2 5 pidof netstat命令 三 UDP协议 3 1 UDP协议格式 3 2 如何理解报头
  • TortoiseGit如何迁移项目地址

    大家工作中可能会遇到项目迁服务器 那么在以前老服务器上的git项目也需要迁到新服务器 如果大家使用TortoiseGit 那么该如何迁移呢 很简单 一 首先在新服务器git上建个项目 然后把项目地址复制下来 二 在本地项目里找到 git文件
  • 你需要知道的企业网页制作流程

    企业网页制作是企业建立线上形象和宣传的重要手段之一 它不仅可以提高企业的品牌知名度 还可以扩大企业的影响力和拓展客户群 下面 我们将介绍一些企业网页制作的基本流程和技巧 并结合一个案例来详细解析 企业网页制作的基本流程可以分为以下几个步骤
  • Android自定义View(三) Scroller与平滑滚动

    目录 一 什么是Scroller 二 认识scrollTo和scrollBy方法 2 1 scrollTo scrollBy对View内容的影响 2 2 思考为什么移动负数距离会向坐标正方向移动 2 3 scrollTo scrollBy对
  • 预训练网络的模型微调方法

    是什么 神经网络需要数据来训练 从数据中获得信息 进而转化成相应的权重 这些权重能够被提取出来 迁移到其他的神经网络中 迁移学习 通过使用之前在大数据集上经过训练的预训练模型 我们可以直接使用相应的结构和权重 将他们应用在我们正在面对的问题