模型微调技术

2023-10-26

一、迁移学习中的常见技巧:微调(fine-tuning)

1.1 概念

  1. 将在大数据集上训练得到的weights作为特定任务(小数据集)的初始化权重,重新训练该网络(根据需要,修改全连接层输出);至于训练的方式可以是:
    1.微调所有层;
    2.固定网络前面几层权重,只微调网络的后面几层,这样做有两个原因:A. 避免因数据量小造成过拟合现象;B.CNN前几层的特征中包含更多的一般特征(比如,边缘信息,色彩信息等),这对许多任务来说是非常通用的,但是CNN后面几层的特征学习注重高层特征,也就是语义特征,这是针对于数据集而言的,不同的数据集后面几层学习的语义特征也是完全不同的;

1.2 步骤

  1. 在源数据集上训练神经网络模型或将已经在大数据集上训练好的模型保存的模型,即源模型;
  2. 创建新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计(即模型层数设计)及其参数(输出层除外)。假定模型参数包含从源数据集中学到的知识,这些知识也将适用于目标数据集;
  3. 想目标模型中添加输出层,其输出类别数目是目标数据集中的类别数, 然后随机初始化该层的模型参数;
  4. 在目标数据集上训练目标模型,输出层从头开始训练,其他所有层的参数将根据源模型的参数进行微调。

在这里插入图片描述

1.3 训练

  • 源数据集远复杂于目标数据,通常微调效果更好;
  • 通常使用更小的学习率和更少的数据迭代;

1.4 实现

#热狗识别
#导入所需包
from d2l import torch as d2l
from torch import nn
import torchvision
import torch
import os
%matplotlib inline
#获取数据集
"""
我们使用的热狗数据集来源于网络。 
该数据集包含1400张热狗的“正类”图像,以及包含尽可能多的其他食物的“负类”图像。
含着两个类别的1000张图片用于训练,其余的则用于测试。
"""
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')
print(data_dir)
#输出..\data\hotdog
train_imgs=torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'))
test_imgs=torchvision.datasets.ImageFolder(os.path.join(data_dir,'test'))
hotdogs=[train_imgs[i][0] for i in range(8)]
not_hotdogs=[train_imgs[-i-1][0] for i in range(8)]
d2l.show_images(hotdogs+not_hotdogs,2,8,scale=1.4)

在这里插入图片描述

# 使用RGB通道的均值和标准差,以标准化每个通道
"""
在训练期间,我们首先从图像中裁切随机大小和随机长宽比的区域,然后将该区域缩放为\(224*224\)输入图像。 
在测试过程中,我们将图像的高度和宽度都缩放到256像素,然后裁剪中央\(224*224\)区域作为输入。
此外,对于RGB(红、绿和蓝)颜色通道,我们分别标准化每个通道。 
具体而言,该通道的每个值减去该通道的平均值,然后将结果除以该通道的标准差。
"""
normalize=torchvision.transforms.Normalize([0.485,0.456,0.406],
                                           [0.229,0.224,0.225])
train_augs=torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),#随机裁剪,并resize成224
                                          torchvision.transforms.RandomHorizontalFlip(),
                                          torchvision.transforms.ToTensor(),
                                          normalize])
test_augs=torchvision.transforms.Compose([torchvision.transforms.Resize(256),
                                          torchvision.transforms.CenterCrop(224),#将图片从中心裁剪成224*224
                                          torchvision.transforms.ToTensor(),
                                          normalize])
#我们使用在ImageNet数据集上预训练的ResNet-18作为源模型。 在这里,我们指定pretrained=True以自动下载预训练的模型参数。 
#如果你首次使用此模型,则需要连接互联网才能下载。
pretrained_net=torchvision.models.resnet18(pretrained=True)
"""
预训练的源模型实例包含许多特征层和一个输出层fc(全连接层)。 
此划分的主要目的是促进对除输出层以外所有层的模型参数进行微调。 
下面给出了源模型的成员变量fc。
"""
pretrained_net.fc
#输出
#Linear(in_features=512, out_features=1000, bias=True)
finetune_net=torchvision.models.resnet18(pretrained=True)
finetune_net.fc=nn.Linear(finetune_net.fc.in_features,2)#全连接层的输入神经元数量是特征数量,因为是2分类,所以输出是2
nn.init.xavier_uniform_(finetune_net.fc.weight)#随机初始化全连接层权重
#Parameter containing:
tensor([[ 0.0378,  0.0630, -0.0080,  ..., -0.0220, -0.0511,  0.0959],
        [ 0.0556,  0.0227, -0.0262,  ..., -0.1059, -0.0171,  0.0051]],
       requires_grad=True)
#微调模型
# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,
                      param_group=True):
    train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=train_augs),
        batch_size=batch_size, shuffle=True)
    test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'test'), transform=test_augs),
        batch_size=batch_size)
    devices = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction="none")
    if param_group:
        params_1x = [param for name, param in net.named_parameters()
             if name not in ["fc.weight", "fc.bias"]]
        trainer = torch.optim.SGD([{'params': params_1x},
                                   {'params': net.fc.parameters(),
                                    'lr': learning_rate * 10}],
                                lr=learning_rate, weight_decay=0.001)
    else:
        trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
                                  weight_decay=0.001)
    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
                   devices)
train_fine_tuning(finetune_net, 5e-5)                  

在这里插入图片描述

#为了进行比较,我们定义了一个相同的模型,但是将其所有模型参数初始化为随机值。 
#由于整个模型需要从头开始训练,因此我们需要使用更大的学习率。
scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
train_fine_tuning(scratch_net, 5e-4, param_group=False)

在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

模型微调技术 的相关文章

  • 一文读懂 QUIC 协议:更快、更稳、更高效的网络通信

    作者 李龙彦 来源 infoQ 你是否也有这样的困扰 打开 APP 巨耗时 刷剧一直在缓冲 追热搜打不开页面 信号稍微差点就直接加载失败 如果有一个协议能让你的上网速度 在不需要任何修改的情况下就能提升 20 特别是网络差的环境下能够提升
  • 万得Wind量化与东方财富Choice量化接口使用

    接口需要付费 这里接口的付费和配置就不展开了 wind相对容易配置 直接用软件就可以点击并配置 东财请参考 Mac使用Python接入东方财富量化接口Choice 调试与获取数据 但有一点需要注意 wind使用量化接口的时候wind终端需要
  • 王炸功能ChatGPT 联网插件功能放开,视频文章一键变思维导图

    就在上周5月13日 Open AI 发文称 我们将在下周向所有ChatGPT Plus 用户开放联网功能和众多插件 这意味着什么 首先联网功能将使得ChatGPT不再局限于回答2021年9月之前的信息 能直接联网查询最新消息 而插件功能就可

随机推荐

  • BIOS启动过程详解

    BIOS 工作原理 最近几天在看 UNIX 操作系统设计 突然想到计算机是如何启动的呢 那就得从 BIOS 说起 其实这个冬冬早已是 n 多人写过的了 今天就以自己的理解来写写 权当一个学习笔记 一 预备知识 很多人将 BIOS 与 CMO
  • 19.3剪裁

    1 在固定管线中 裁剪是在世界坐标系中 2 在可编程管线中 裁剪是在规格化坐标系中 步骤 1 按照法向量和空间点定义裁剪平面 并归一化 2 根据世界观察投影变换矩阵相乘 求逆转置 即为需要的变换矩阵 3 变换矩阵与裁剪平面变换后就是需要的裁
  • numpy模块(2)

    1 利用布尔值来取元素 import numpy as np mask np array 1 0 1 dtype bool 1表示取对应的元素 0表示不取 arr np array 1 2 3 4 5 6 7 8 9 print arr m
  • Hadoop学习心得---二

    大数据运算解决方案MapReduce Hadoop的分布式计算模型MapReduce 最早是Google提出的 主要用于搜索领域 解决海量数据的计算问题 MapReduce有两个阶段组成 Map和Reduce 用户只需实现map 和redu
  • Three.js(学习)

    在vue项目中使用Three js的流程 1 首先利用npm安装Q three js 具体操作代码如下 npm install three 2 接下来利用npm安装轨道控件插件 npm install three orbit control
  • 表、栈和队列

    表 栈和队列 表 增强的for循环 List
  • DM6437 C64X+ EDMA 疑惑总结记录

    总结一下DM6437中的EDMA的使用出现的问题 方便以后再开发定位问题 1 EDMA Link 和 Chain的区别 link实现了DMA的自动重加载 非静态模式 需要两个param chain是不更新param set表 直接event
  • qt界面叠加视频OSD双层显示

    最终代码存放于 http download csdn net detail lzh445096 8849147 本人负责的是UI界面 提供给底层应用程序接口函数 此接口函数功能为向指定路径的文件中写入命令字符 应用程序去到该文件中读取到相应
  • 基于Protobuf协议的Dubbo与SpringBoot的结合

    文章目录 工程概况 父pom dubbo provider 通过proto3定义服务 打包发布服务 dubbo provider service实现服务 dubbo provider web提供服务 dubbo consumer dubbo
  • 依赖注入和控制反转的理解,写的太好了

    学习过Spring框架的人一定都会听过Spring的IoC 控制反转 DI 依赖注入 这两个概念 对于初学Spring的人来说 总觉得IoC DI这两个概念是模糊不清的 是很难理解的 今天和大家分享网上的一些技术大牛们对Spring框架的I
  • 互联网产品运作模式详解

    互联网产品运作模式详解 https www infoq cn article 3EVku39xVhJYs7ba9uk7 本文主要总结下移动互联网产品的市场运作模式 因为本身我是技术出身 对运作模式中的开发体系这 块相对熟悉 但是其他阶段也是
  • js: for in 循环对象

    var peopleObj man 2 2 2 woman 1 1 1 womanDoctor 100 100 100 for const prop in peopleObj if peopleObj hasOwnProperty prop
  • java将图片转为base64后出现的一些问题

    因为需要对接第三方接口 需要将图片转换为base64编码传参 手动转换base64使用postman完全是OK的 结果java中转换出来死活不行 p 将文件转成base64 字符串 p param path 文件路径 return thro
  • Linux下ps命令实现

    include
  • 思科实验-生成树协议STP

    生成树协议 英语 Spanning Tree Protocol STP 是一种工作在OSI网络模型中的第二层 数据链路层 的通信协议 基本应用是防止交换机冗余链路产生的环路 用于确保以太网中无环路的逻辑拓扑结构 从而避免了广播风暴 大量占用
  • vivado2021.1安装

    首先需要在官网注册一个账号 安装软件时需要使用 账号注册连接 xilink账号注册 vivado下载链接 xilink官网下载 使用官网下载需要注册账号 下载免费 vivado阿里云盘下载 vivado licence阿里云盘下载 官网下载
  • QStringLiteral(str)

    在看项目代码的时候 总会看到下面这种情况 QString str QStringLiteral 123rt QString用QStringLiteral str 来初始化 有点好奇 就查了下 记录一下 这是用QStringLiteral初始
  • Java:记录一下第一次面试经历(新希望六和)

    记录一下本菜鸡两个月前第一次面试新希望六合这家公司 那时的我很多都回答不上来 非常尴尬 不过这第一次面试经历也算是给足了我动力继续努力 记录一下这个第一次面试的题目 也算是记录一下那时候的我 做过什么样的项目 简单介绍一下你的项目 项目的整
  • 客户端请求的端口号是什么?

    我们知道服务器端是要指定和开放端口号的 比如 web 服务 http 请求的 80 https 的 443 端口 都要开放 否则无法请求成功 我们知道通信是由两端组成的 既然服务器需要指定端口 那么客户端呢 比方说我用 chrome 浏览器
  • 模型微调技术

    模型微调 一 迁移学习中的常见技巧 微调 fine tuning 1 1 概念 1 2 步骤 1 3 训练 1 4 实现 一 迁移学习中的常见技巧 微调 fine tuning 1 1 概念 将在大数据集上训练得到的weights作为特定任