pytorch 将模型作为特征提取器(提取中间层特征)

2023-11-05

目的

需要加载自己训练好的最好模型作为一个特征提取器,也就是说需要提取最后一层全连接层输出的内容。

解决方法

参考了两个方法(详见文末)

设参数直接提取

准备一个toy model来说明。

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.cl1 = nn.Linear(25, 60)
        self.cl2 = nn.Linear(60, 16)

    def forward(self, x):
        x = F.relu(self.cl1(x))
        x = self.cl2(x)
        ################################
        self.last_feature = x.detach()
        ################################
        x =F.relu(x)
        return x

x = torch.randn(1, 25)
model = MyModel()
output = model(x)
print(model.last_feature)
"""
tensor([[-0.0670,  0.1209,  0.5386, -0.0052, -0.2690, -0.0397, -0.0492,  0.0916,
          0.3837, -0.5325,  0.3419, -0.3190,  0.0589, -0.1058, -0.1944, -0.0929]])
"""

在两个注释条中间,通过设置了一个self.last_feature来保存cl2层的输出结果。

设hook函数提取

同样使用上方的toy model,但是需要额外增加几行代码。

activation = {}
# 告诉模型在哪一层需要detach
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook
    
model.cl2.register_forward_hook(get_activation('cl2'))
print(activation['cl2'])
"""
tensor([[-0.0670,  0.1209,  0.5386, -0.0052, -0.2690, -0.0397, -0.0492,  0.0916,
          0.3837, -0.5325,  0.3419, -0.3190,  0.0589, -0.1058, -0.1944, -0.0929]])
"""

总结

hook函数相对通用,举个例子,还是以上面的model为例,但是稍微修改了一下forward函数的表达形式:

def forward(self, x):
        x = F.relu(self.cl1(x))
        x =F.relu(self.cl2(x))
        self.last_feature = x.detach()
        return x
 # print(model.last_feature)
"""
tensor([[0.0000, 0.1209, 0.5386, 0.0000, 0.0000, 0.0000, 0.0000, 0.0916, 0.3837,
         0.0000, 0.3419, 0.0000, 0.0589, 0.0000, 0.0000, 0.0000]])
"""

如果习惯性将激活函数连写,那么hook函数还是能够提取出正确的特征值,但是人为设置参数的方法则需要放在cl2层直接输出的后面。

Reference:

  1. https://www.zhihu.com/question/68384370
  2. https://discuss.pytorch.org/t/how-can-l-load-my-best-model-as-a-feature-extractor-evaluator/17254/6
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch 将模型作为特征提取器(提取中间层特征) 的相关文章

随机推荐

  • activiti7的网关

    工作流 activiti7网关 1 排他网关 排他网关 也叫异或 XOR 网关 或叫基于数据的排他网关 用来在流程中实现决策 当流程执行到这个网关 所有分支都会判断条件是否为true 如果为 true 则执行该分支 注意 排他网关只会选择一
  • Hbase基础入门

    HBase 1 HBase是什么 1 1 HBase的概念 1 2 HBase的特点 2 HBase集群安装部署 2 1 准备安装包 2 2 修改HBase配置文件 2 2 1 hbase env sh 2 2 2 hbase site x
  • RocketMQ入门

    1 认识MQ 1 1 什么是MQ MQ全称为Message Queue 即消息队列 是一种提供消息队列服务的中间件 也称为消息中间件 是一套提供了消息生 产 存储 消费全过程的软件系统 遵循FIFO原则 1 2 为什么用MQ 并发量高时 当
  • 微信小程序 组件生命周期

    完整微信小程序 Java后端 技术贴目录清单页面 必看 组件的生命周期 指的是组件自身的一些函数 这些函数在特殊的时间点或遇到一些特殊的框架事件时被自动触发 其中 最重要的生命周期是 created attached detached 包含
  • yolov5详解与改进

    https github com z1069614715 objectdetection script YOLOV5改进 Optimal Transport Assignment Optimal Transport Assignment O
  • 49天精通Java,第43天,缓冲区数据结构bytebuffer

    目录 专栏导读 一 缓冲区 二 常用方法 三 通道获取 1 从 FileInputStream FileOutputStream 中获取 2 从 RandomAccessFile 中获取 3 通过 FileChannel open 获取 四
  • 如何创建A/B Test谷歌广告实验(3种类型)

    为了更精细化的测试广告 我们需要做一些测试 谷歌广告实验 我们也经常会叫A B Test 目前谷歌支持搜索广告 展示广告和视频广告三种广告系列类型的A B Test 在谷歌广告实验中分为广告变体 自定义实验和视频实验三种类型 广告变体主要用
  • Hadoop集群完全分布式搭建

    本人也只是hadoop学习的一个萌新 在这段时间内因为课程的需要 安装了一下hadoop集群 里面遇到了一些问题 找到了一些解决办法 如果文章内有什么错误 欢迎大家与我交流 下面就开始搭建hadoop集群吧 搭建环境为win10 虚拟机为V
  • 在Linux环境搭建Java版Minecraft(我的世界)服务器

    文章目录 前言 一 帮助轻松开服的工具 1 Xshell 2 XFTP 二 开服步骤 1 准备一个可以满足你需要的Linux服务器 2 安装工具 3 连接服务器 4 配置服务器 确保你已经完成第三步 成功连接上了服务器 1 安装Java 如
  • HDFS入门和应用开发场景案例:如何模拟实现分布式存储?

    如何解决海量数据存的下问题 1 传统式存储方式 应对文件存储服务 传统做法是在服务器上部署文件服务比如FTP 但是随着数据变多 会遇到存储瓶颈 此时 本能的操作反应是 内存不够加内存 磁盘不够加磁盘 单机纵向扩展 但是单机能够扩展的内存磁盘
  • 使用Python进行名片OCR(识别姓名,职务,电话,Email邮箱)

    上一篇博客介绍了如何通过以下方式自动OCR和扫描收据 检测输入图像中的接收 应用透视变换以获得收据的自顶向下视图 利用Tesseract对收据上的文本进行OCR 使用正则表达式提取价格数据 这篇博客将介绍如何使用Python对名片进行OCR
  • JAVA面试题 整合版

    1 List Set和Map 的区别 List 以索引来存取元素 有序的 元素是允许重复的 可以插入多个null Set 不能存放重复元素 无序的 只允许一个null Map 保存键值对映射 List 底层实现有数组 链表两种方式 Set
  • 无需MS Office!使用Aspose在C ++中以编程方式将 DOCX 转换为 DOC

    Microsoft Word 文档有两种格式 DOC 和 DOCX DOC 是一种较旧的格式 而 DOCX 是它的继任者 可以将 DOCX 文件转换为 DOC 格式 反之亦然 在本文中 将学习如何将 DOCX 文件转换为 DOC 格式以及如
  • Guava学习之Multisets

    今天谈谈Guava类库中的Multisets数据结构 虽然它不怎么经常用 但是还是有必要对它进行探讨 我们知道Java类库中的Set不能存放相同的元素 且里面的元素是无顺序的 而List是能存放相同的元素 而且是有顺序的 而今天要谈的Mul
  • 【软考初级指南】软考网络管理员如何备考通过,一个月足矣。

    文章目录 写在前面 涉及知识 1 前期准备 A 教材重难点 小于一个星期 B 远古真题测试 一周时间 2 温习阶段 一周时间 A 常错知识点文章整理 B 错题集的整理 3 冲刺刷题 一周时间 A 近几年真题 B 近几年调整 4 考场复习 写
  • Qt制作一个简单的电子时钟

    电子时钟 新建桌面应用程序 项目名LCDClock 类名Clock 基类QDialog 取消产生界面文件 当前项目添加C 类DigitalClock 基类QLCDNumber 编辑digitalclock h文件 clock h ifnde
  • 动手学深度学习_个人笔记01_李沐(更新中......)

    序言 神经网络 本书中关注的DL模型的前身 被认为是过时的工具 深度学习在近几年推动了CV NLP和ASR等领域的快速发展 关于本书 让DL平易近人 教会概念 背景和代码 一种结合了代码 数学和HTML的媒介 测试深度学习 DL 的潜力带来
  • 桌面点击:右键点击-显示设置,提示“该文件没有与之关联的程序来执行该操作“解决方法总结

    解决方法1 解决了我的问题的方案 1 WIN R组合键 运行 输入regedit 2 打开注册表 regedit 3 定位到 HKEY CURRENT USER software classes 4 找到ms settings 5 删除或重
  • Java 接口

    目录 1 接口的概念 2 接口的格式 3 接口的使用 4 接口的特性 4 1接口是一种引用类型 但是不能通过直接new接口的对象 4 2接口中的方法只能public修饰 4 3接口中的方法不能在接口实现 4 4重写接口方法时 不能使用默认的
  • pytorch 将模型作为特征提取器(提取中间层特征)

    目的 需要加载自己训练好的最好模型作为一个特征提取器 也就是说需要提取最后一层全连接层输出的内容 解决方法 参考了两个方法 详见文末 设参数直接提取 准备一个toy model来说明 class MyModel nn Module def