pytorch:如何修改加载了预训练权重的模型的输入或输出--(原权重文件修改参数)

2023-11-18

在使用pytorch的过程中,我们往往会使用官方发布的预训练模型,并在此基础上训练自己的模型。为了适配训练数据,有时候需要局部修改这类预训练模型的结构,本文将分别以修改输入的通道数和输出的分类数为例,讲解一种通用的方法来修练模型的结构。

加载模型

在修改之前需要加载预训练模型,这里以mobilenet v2为例

import torchvision.models as models
model = models.mobilenet_v2(pretrained=True)

修改模型结构

修改之前需要查看模型结构

print(model)

则可以看到一长串的模型输入,这里因为篇幅原因只截了开头部分和结尾部分
在这里插入图片描述
仔细观察输出的模型结构,卷积层(特别是括号中的features,classifier,(0) 等标志性词可以得知模型的第一层为:

model.features[0][0] = Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

分类层为

model.classifier[1] = Linear(in_features=1280, out_features=1000, bias=True)

通过这些信息接下来就可以修改输入/输出了

修改模型输入

#输入为单通道

model.features[0][0] = Conv2d(1 ,32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

#修改预训练模型权重的结构,使得模型可以使用修改后的预训练模型权重

#加载预训练模型
pre_trained_model = models.mobilenet_v2(pretrained=True)
#获取预训练权重文件的字典
pretrained_dict = pre_trained_model.state_dict()
#打印权重信息
print(pretrained_dict.items())
'''
打印显示为:
dict_items([('features.0.weight', tensor([[[[-2.8656e-03,  4.1653e-02,  5.7146e-02,  ...,  5.2015e-03,
           -5.7198e-03, -2.3688e-02],...
这里可以看到第一层的key为‘features.0.weight’,接下来就可以通过这个名称访问pretrained_dict中对应的权重
'''
#获取第一层权重
layer1 = pretrained_dict['features.0.0.weight']
#创建一个新的张量,这个张量后面将替代pretrain_dict中的第一层,以适应修改为单通道的模型
new = torch.zeros(32,1,3, 3)
#这里修改第一层
for i,output_channel in enumerate(layer1):
	# Grey = 0.299R + 0.587G + 0.114B, 这个公式参考了RGB图转灰度图的方式
    new[i] = 0.299 * output_channel[0] + 0.587 * output_channel[1] + 0.114 * output_channel[2]
#现在第一层的shape为(32,1,3,3)了
pretrained_dict['features.0.0.weight'] = new 
#修改模型结构
model.features[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
model.load_state_dict(pretrained_dict)

修改模型输出

这里修改输出的方式不像修改输入这么繁琐

fc_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(fc_features, 2)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch:如何修改加载了预训练权重的模型的输入或输出--(原权重文件修改参数) 的相关文章

  • 出现( linker command failed with exit code 1)错误总结

    这种问题 通常出现在添加第三方库文件或者多人开发时 这种问题一般是找不到文件而导致的链接错误 我们可以从如下几个方面着手排查 1 以如下错误为例 如果是多人开发 你同步完成后发现出现如下的错误 Undefined symbols for a
  • ABAP DOI详解

    导语 DOI是SAP与Office集成的一种技术 是早期OLE的升级版本 把Excel嵌套在程序当中进行展示 需要提前上传模板 在Excel模板中 可以事先设计好公式 在SAP将数据写入Excel中之后会自动用公式进行计算 对于习惯于用Ex
  • [Android] 拍照、截图、保存并显示在ImageView控件中

    最近在做Android的项目 其中部分涉及到图像处理的内容 这里先讲述如何调用Camera应用程序进行拍照 并截图和保存显示在ImageView控件中以及遇到的困难和解决方法 PS 作者购买了本 Android第一行代码 著 郭霖 参照里面
  • python程序里一定要有一个主函数吗_Python 为什么没有 main 函数?为什么我不推荐写 main 函数?...

    在开始正题之前 先要来回答这两个问题 所谓的 main 函数 是指什么 为什么有些编程语言需要强制写一个 main 函数 某些编程语言以 main 函数作为程序的执行入口 例如 C C C Java Go 和 Rust 等 它们具有特定的含
  • JS属性defer

    JS属性defer 利用defer属性 让浏览器读js脚本的时候完全不等脚本 就开始读取图片和html代码 给外链JS脚本添加defer true
  • Android源码分析 - Framework层的ContentProvider全解析

    开篇 本篇以android 11 0 0 r25作为基础解析 在四大组件中 可能我们平时用到最少的便是ContentProvider了 ContentProvider是用来帮助应用管理其自身和其他应用所存储数据的访问 并提供与其他应用共享数
  • Rocky9.2 第一次配置virtualbox报错Kernel driver not installed (rc=-1908)

    完整报错信息如下 Kernel driver not installed rc 1908 The VirtualBox Linux kernel driver is either not loaded or not set up corre
  • PDF文件转化成mobi格式,亲测kindle或者iReader可用!

    convertfiles 点击连接 然后选择要转换的文件 比如我的是MySQL的 选择输入文件和输出文件的格式 转换 对了记得输入邮箱号码 转化完毕会发送连接到邮箱提供下载 或者 网络流畅的情况下转化完毕会自动重定向到下载页面
  • Vue3之路--Less教学

    概览 Less Leaner Style Sheets 的缩写 是一门向后兼容的 CSS 扩展语言 这里呈现的是 Less 的官方文档 中文版 包含了 Less 语言以及利用 JavaScript 开发的用于将 Less 样式转换成 CSS
  • 关于table的selectedRowKeys和selectedRows

    项目使用的组件库是antd 页面中有很多table 有的table有行前面的复选框 于是就有了selectedRowkeys和selectedRows的事 他们两个都是数组 selectedRowkeys存的是table的rowKey 也就
  • .Net/C#: 实现支持断点续传多线程下载的 Http Web 客户端工具类 (C# DIY HttpWebClient)

    选择自 playyuer 的 Blog Net C 实现支持断点续传多线程下载的 Http Web 客户端工具类 C DIY HttpWebClient Reflector 了一下 System Net WebClient 重载或增加了若干
  • 论文阅读-NOLANet多模态伪造检测

    一 论文信息 题目 Deepfake Video Detection Based on Spatial Spectral and Temporal Inconsistencies UsingMultimodal Deep Learning
  • SQL学习笔记——limit用法(limit使用一个参数,limit使用两个参数)

    Product表 limit语法 select lt 列名 gt lt 列名 gt from lt 表名 gt limit lt 参数值 gt select from product limit 3 product id product n
  • 第三方平台代微信公众号开发

    第三方平台代微信公众号开发流程 一 准备工作 微信开放平台相关 申请微信开放平台账号后 需前往微信开放平台 创建第三方平台 填写开发相关配置 填写授权流程相关配置 注意事项 授权发起页域名 为项目开发使用域名 调用公众号二维码授权页时 必须
  • test is not a function (js正则表达式匹配问题)

    js中正则表达式匹配时 如果使用test函数 就必须不带引号 并且必须是 定义的规则变量 test 要测试的string 定义变量规则不要带引号 会错误的 如果不使用test 使用match则可以带引号 var re 1 9 d 4 10
  • Android 组件逻辑漏洞漫谈

    前言 随着社会越来越重视安全性 各种防御性编程或者漏洞缓解措施逐渐被加到了操作系统中 比如代码签名 指针签名 地址随机化 隔离堆等等 许多常见的内存破坏漏洞在这些缓解措施之下往往很难进行稳定的利用 因此 攻击者们的目光也逐渐更多地投入到逻辑

随机推荐

  • QT4信号连接与QT5的区别

    QT4信号连接与QT5的区别 QT4信号与槽 1 申明槽函数必须增加public slots 2 SIGNAL SLOT 将函数转为字符串 不进行错误检查 connect中信号和槽需要增加SIGNAL 和SLOT 3 槽函数和信号一致 参数
  • 常用的表格正则验证 + 省份选择 JS JQ

    常用的表格正则验证 轮子 let receiverNameReg u4e00 u9fa5 2 6 reg 收货人姓名 let receiverName receiverName val 收货人姓名 let phoneNumberReg d
  • TCP的几个状态 SYN, FIN, ACK, PSH, RST, URG

    2019独角兽企业重金招聘Python工程师标准 gt gt gt TCP的几个状态对于我们分析所起的作用 在TCP层 有个FLAGS字段 这个字段有以下几个标识 SYN FIN ACK PSH RST URG 其中 对于我们日常的分析有用
  • 数据挖掘技术-绘制散点图

    绘制散点图 前置步骤 准备数据guomin npz 下载数据guomin npz到Linux本地的 course DataAnalyze data目录 绘制散点图 绘制2000 2017年各季度的国民生产总值散点图 如代码 41所示 代码
  • 【华为OD机试真题 JAVA】执行时长

    JS版 华为OD机试真题 JS 执行时长 标题 执行时长 时间限制 1秒 内存限制 262144K 语言限制 不限 为了充分发挥GPU算力 需要尽可能多的将任务交给GPU执行 现在有一个任务数组 数组元素表示在这1秒内新增的任务个数且每秒都
  • Python脚本报错AttributeError: ‘module’ object has no attribute’xxx’解决方法

    最近在编写Python脚本过程中遇到一个问题比较奇怪 Python脚本完全正常没问题 但执行总报错 AttributeError module object has no attribute xxx 这其实是 pyc文件存在问题 问题定位
  • #C++矩阵类的实现

    C 矩阵类的实现 环境 Win10 VS2017 最近老师布置一个简单的C 作业 实现一个矩阵类 并且实现矩阵运算 主要实现运算为矩阵的加 减 乘 除以及求行列式 伴随矩阵 代数余子式和逆矩阵等 在参考网上的一些前辈的代码后 写出了这些运算
  • 信号与系统复习题

    选择题 2分 题 1 频谱与时域的关系 时域压缩 频域展宽 时域有限 频域无限 2 填空题 20分 2分 空 1 冲击信号的性质 抽样性 尺度变换性 奇偶性 2 线性时不变的概念 线性 齐次性 输入夸大多少倍 输出扩大多少倍 可加性 相应的
  • HFP协议

    通话专题HFP协议学习总结 一 配置和角色 二 HFP的连接 2 1服务级连接建立 2 1 1 服务发现和RFCOMM的连接 2 1 2 支持的特性交换 2 1 3 codec协商 2 1 4 HF指示器 2 1 5 AG指示器 2 1 6
  • ctfshow 文件上传 web151~170

    目录 web151 web 152 web 153 web 154 web 155 web 156 web 157 159 web 160 web 161 web 162 163 web 164 web 165 web 166 web 16
  • STM32F030C8T6 多通道ADC采集

    void adc init void ADC InitTypeDef ADC InitStructure GPIO InitTypeDef GPIO InitStructure RCC ADCCLKConfig RCC ADCCLK PCL
  • 动态规划算法解决背包问题(Java实现)

    文章收藏的好句子 你在书本上花的任何时间 都会在某一个时刻给你回报 目录 1 动态规划算法的概述 2 背包问题 3 动态规划算法解决背包问题 3 1 不可重复装入商品 3 2 思路分析 1 动态规划算法的概述 1 动态规划算法的思想是 将大
  • Python psycopg2使用SimpleConnectionPool数据库连接池以及execute_batch批量插入数据

    有关快速插入大量数据到数据库的一个比较好的博文如下 Fastest Way to Load Data Into PostgreSQL Using Python 其中文末还有提到几种不同方式的对比 效率十分的震撼 可以看看 1 连接池和批量插
  • MYSQL 安装

    MySQL8安装Installer 图文教程 编程宝库 Windows10 MySQL Installer 安装 编程宝库
  • shell提取字符串中的数字保存到变量中

    1 提取数字到变量 temp echo helloworld20180719 tr cd 0 9 echo temp 输出 20180719 2 重定向到文件 echo helloworld20180719 tr cd 0 9 gt mid
  • 【数据结构与算法】--排序

    目录 一 排序的概念及其运用 二 常见的排序算法 2 2选择排序 2 3 交换排序 2 3 4 1 快速排序优化 一 排序的概念及其运用 1 1 排序的概念 排序 所谓排序 就是使一串记录 按照其中的某个或某些关键字的大小 递增或递减的排列
  • [OpenAirInterface实战-14] :OAI nFAPI VNF/PNV持续集成测试的xml配置文件详解

    作者主页 文火冰糖的硅基工坊 文火冰糖 王文兵 的博客 文火冰糖的硅基工坊 CSDN博客 本文网址 https blog csdn net HiWangWenBing article details 120850348 目录 1 nFAPI
  • 23种设计模式之装饰模式

    装饰模式 一个简陋的房子 它可以让人在里面居住 为人遮风避雨 但如果给它进行装修 那么它的居住环境就更加宜人了 程序中的对象也与房子十分类似 首先有一个相当于 房子 的对象 然后经过不断装饰 不断对其增加功能 它就变成了使用功能更加强大的对
  • Unity Cinemachine插件学习笔记,实现单目标和多目标之间切换

    Cinemachine在2017版中正式加入 结合Timeline可以轻松的制作出一下相机动画 相比Unity自带的标准相机 这个Cinemachine插件可操作的变量更多 不同虚拟相机 用来控制相机的 可以平滑转换等 具体可以参考上篇 U
  • pytorch:如何修改加载了预训练权重的模型的输入或输出--(原权重文件修改参数)

    在使用pytorch的过程中 我们往往会使用官方发布的预训练模型 并在此基础上训练自己的模型 为了适配训练数据 有时候需要局部修改这类预训练模型的结构 本文将分别以修改输入的通道数和输出的分类数为例 讲解一种通用的方法来修练模型的结构 加载