24模型微调(finetune)

2023-11-04

一、Transfer Learning & Model Finetune

1.1 Transfer Learning

Transfer Learning:机器学习分支,研究源域(source domain)的知识如何应用到目标域(targetdomain)
在这里插入图片描述

传统的机器学习:
对不同的任务分别训练学习得到不同的learning system,即模型,如上图有三个不同任务,就得到三个不同的模型

迁移学习:
先对源任务进行学习,得到知识,然后在目标任务中,会使用再源任务上学习得到的知识来学习训练模型,也就是说该模型不仅用到了target tasks,也用到了source tasks

1.2 Model Finetune

1.2.1 Model Finetune概念

Model Finetune:模型的迁移学习在这里插入图片描述
模型微调:
模型微调就是一个迁移学习的过程,模型中训练学习得到的权值,就是迁移学习中所谓的知识,而这些知识是可以进行迁移的,把这些知识迁移到新任务中,这就完成了迁移学习

微调的原因:
在新任务中,数据量太小,不足以去训练一个较大的模型,从而选择Model Finetune去辅助训练一个较好的模型,使得训练更快

卷积神经网络的迁移:
在这里插入图片描述
将卷积神经网络分成两部分:features extractor + classifier

  • features extractor:模型的共性部分,通常对其进行保留
  • classifier:根据不同任务要求对输出层进行finetune

1.2.2 Model Finetune步骤

在这里插入图片描述
Model Finetune:
先进行模型微调,加载模型参数,并根据任务要求修改模型,此过程称预训练,然后进行正式训练,此时要注意预训练的参数的保持,具体步骤和方法如下

模型微调步骤:

  1. 获取预训练模型参数
  2. 加载模型( load_state_dict)
  3. 修改输出层

模型微调训练方法:

  • 固定预训练的参数,两种方法:
    • requires_grad =False
    • lr=0
  • Features Extractor部分设置较小学习率( params_group)

说明:
优化器中可以管理不同的参数组,这样就可以为不同的参数组设置不同的超参数,对Features Extractor部分设置较小学习率

二、Pytorch中的Finetune

2.1 Model Finetune实例

在这里插入图片描述
数据: https://download.pytorch.org/tutorial/hymenoptera_data.zip
模型: https://download.pytorch.org/models/resnet18-5c106cde.pth

2.1.1 目录结构

在这里插入图片描述
模型和数据的存放位置如上图所示

2.1.1 代码详解

my_dataset.py

# -*- coding: utf-8 -*-
import os
import random
from PIL import Image
from torch.utils.data import Dataset

random.seed(1)
rmb_label = {
   "1": 0, "100": 1}


class AntsDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.label_name = {
   "ants": 0, "bees": 1}
        self.data_info = self.get_img_info(data_dir)
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img,label

    def __len__(self):
        return len(self.data_info)

    def get_img_info(self, data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = self.label_name[sub_dir]
                    data_info.append((path_img, 
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

24模型微调(finetune) 的相关文章

  • Manifest.json文档说明

    Manifest json文件是5 移动App的配置文件 用于指定应用的显示名称 图标 应用入口文件地址及需要使用的设备权限等信息 是扩展的配置文件 指明了扩展的各种信息 一个manifest json格式如下 必须的字段3个 name M
  • Spring bean生命周期详解

    Spring Bean的完整生命周期从创建Spring容器开始 直到最终Spring容器销毁Bean 这其中包含了一系列关键点 Spring bean生命周期 四个阶段 Bean的实例化阶段 Bean的设置属性阶段 Bean的 初始化阶段
  • token续期

    需求 项目前后端分离 采用token jwt生成 方式作为登录及接口验证 自然而然就会涉及token超时 影响用户体验的问题 要解决的就是如果用户一直点击页面 就不应该出现超时及重新登录 只有用户在设置的超时时间内 一次页面操作都没有 才定
  • 大白话给你说清楚什么是过拟合、欠拟合以及对应措施

    开始我是很难弄懂什么是过拟合 什么是欠拟合以及造成两者的各自原因以及相应的解决办法 学习了一段时间机器学习和深度学习后 分享下自己的观点 方便初学者能很好很形象地理解上面的问题 同时如果有误的地方希望大家在评论区留下你们的砖头 我会进行纠正
  • 计算机故障诊断知识,故障诊断

    利用各种检查和测试方法 发现系统和设备是否存在故障的过程是故障检测 而进一步确定故障所在大致部位的过程是故障定位 故障检测和故障定位同属网络生存性范畴 要求把故障定位到实施修理时可更换的产品层次 可更换单位 的过程称为故障隔离 故障诊断就是
  • UE4Material_材质属性(1)

    材质中的属性 物理材质 Phys Material 物理材质 与该材质关联的物理材质 物理材质 Physical Material 提供了物理属性的定义 例如碰撞 弹力 以及其他基于物理的方面会保留多少能量 物理材质 Physical Ma

随机推荐

  • 一篇教会你,Redis主从、哨兵、 Cluster集群。

    前言 大家好 今天跟小伙伴们一起学习Redis的主从 哨兵 Redis Cluster集群 Redis主从 Redis哨兵 Redis Cluster集群 1 Redis 主从 面试官经常会问到Redis的高可用 Redis高可用回答包括两
  • TCP的三次握手及四次挥手总结(从抓包角度理解)

    目录 TCP报文首部 TCP连接 传输及断开过程图 TCP状态图 三次握手过程理解 四次挥手过程理解 从抓包来理解TCP建立连接 数据传输以及断开连接的过程 建立连接过程 数据传输过程 连接断开过程 为什么连接的时候是三次握手 关闭的时候却
  • Keras查看model weights .h5 文件的内容

    Keras的模型是用hdf5存储的 如果想要查看模型 keras提供了get weights的函数可以查看 for layer in model layers weights layer get weights list of numpy
  • 多进程浏览器框架

    为什么浏览器采用多进程模型 转载于 http www wtoutiao com p s57age html Google Chrome源码剖析 一 多线程模型 转载于 http www ha97 com 2908 html 主流浏览器多进程
  • 【广州华锐互动】无人值守变电站AR虚拟测控平台

    无人值守变电站AR虚拟测控平台是一种基于增强现实技术的电力设备巡检系统 它可以利用增强现实技术将虚拟信息叠加在真实场景中 帮助巡检人员更加高效地完成巡检任务 这种系统的出现 不仅提高了巡检效率和准确性 还降低了巡检成本和风险 传统的变电站巡
  • TPM功能介绍

    文章来源 TPM功能介绍 百度文库 http wenku baidu com link url bQMQyb0A3gto0CCC2CN5ojpUrgHsh8BMXmejpFaqLS52v 013bXPHoRr36r0F0UrgPr8U6rv
  • MATLAB——FFT(快速傅里叶变换)

    基础知识 FFT即快速傅里叶变换 利用周期性和可约性 减少了DFT的运算量 常见的有按时间抽取的基2算法 DIT FFT 按频率抽取的基2算法 DIF FFT 1 利用自带函数fft进行快速傅里叶变换 若已知序列 x 4 3
  • 利用ChatGPT协助编写单元测试

    ChatGPT自从2022年推出以来受到很多人的喜欢 此篇博客重点介绍如何修改Prompt来自动生成较理想的单元测试 如下图所示的一段代码 该class中有一个public方法toLocale 其余都是private方法 toLocale
  • 编写代码的几个tip

    使用的大多是MVC的模式 那么视图就只管视图 逻辑就只管逻辑 一个自定义的cell 上面放了一个button button的点击事件用一个delegate在viewcontroller中来实现 比如先要变化cell的样式 那么代理方法中 不
  • 攻防世界WEB入门

    1 view source X老师让小宁同学查看一个网页的源代码 但小宁同学发现鼠标右键好像不管用了 WP 按F12即可在elements中看到flag 2 robots X老师上课讲了Robots协议 小宁同学却上课打了瞌睡 赶紧来教教小
  • ​LeetCode刷题实战214:最短回文串

    算法的重要性 我就不多说了吧 想去大厂 就必须要经过基础知识和业务逻辑面试 算法面试 所以 为了提高大家的算法能力 这个公众号后续每天带大家做一道算法题 题目就从LeetCode上面选 今天和大家聊的问题叫做 最短回文串 我们先来看题面 h
  • UE4 UMG中使用富文本

    UE4 UMG中使用富文本 一 新建DateTable 二 添加字体样式 注意 第一个RowName必须为Default 字体样式必须赋值 否则会乱码 我们将Default的字体改为白色 Red字体改为红色 字号改小 三 使用 拖一个富文本
  • serverTimezone设置

    在安装完mysql第一次使用IDEA进行数据库连接发现 You must configure either the server or JDBC driver via the serverTimezone configuration pro
  • uview2.0封装网络请求(微信小程序最新登录方式)

    一 网络请求和相应拦截器 此vm参数为页面的实例 可以通过它引用vuex中的变量 module exports vm gt 初始化请求配置 uni u http setConfig config gt config baseURL http
  • 计算机二级C语言三天能过吗,学工干货丨如何三天通过计算机二级

    原标题 学工干货丨如何三天通过计算机二级 不可能的 想都不要想 三天怎么可能过 不信你问考完了的 他们马上就可以 进行计算机二级考试成绩查询啦 计算机二级是什么 计算机二级怎么考 迷惘 彷徨 别怕 今天学工菌来助攻 考生在考后50个工作日
  • 博客园自定义主题代码

    发一下好看的 要开通js权限 皮肤用simple memory 最好禁用模板 侧边栏
  • 基于小脑模型神经网络的轨迹跟踪研究(Matlab代码实现)

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码实现 1 概述 1 在对人类神经学的研究中 得知它由一些神
  • class加载过程

    loading class文件 从硬盘 加载到 内存 linking 1 verification 校验 检查满不满足class文件的格式 2 preparation 将 静态变量 赋默认值 默认值是0 3 resolution 将 常量池
  • Code Embedding研究系列11-ContraFlow

    Path Sensitive Code Embedding via Contrastive Learning for Software Vulnerability Detection 一 引言 1 1 现有方法及其局限 1 2 作者的解决方
  • 24模型微调(finetune)

    一 Transfer Learning Model Finetune 1 1 Transfer Learning Transfer Learning 机器学习分支 研究源域 source domain 的知识如何应用到目标域 targetd