基于检测代码库detectron2的蒸馏应用

2023-05-16

基于检测代码库detectron2和蒸馏代码库RepDistiller,完成将蒸馏方法应用在目标检测的代码库,完整代码已开源。

1. 参数添加

config/defaults.py里面添加蒸馏参数的默认值,同时类似于定义了变量

#==================================蒸馏参数=====================================
_C.DISTILL = CN()
_C.DISTILL.DO =  False
_C.DISTILL.PATH_T = None
_C.DISTILL.CFG_T = None
_C.DISTILL.DISTILL =  'kd'
# choices =  ['kd', 'hint', 'attention', 'similarity','correlation', 'vid', 'crd', 'kdsvd', 'fsp',
#             'rkd', 'pkt', 'abound', 'factor', 'nst']
# weight balance for other losses
_C.DISTILL.B =  None 

在实际使用的yaml文件中添加实际使用的蒸馏参数

#==================================蒸馏参数=====================================
DISTILL:
  ENABLE: True
  PATH_T: 'detectron2/teacher_models/FasterRCNN-R101-FPN-lr3x.pkl'
  CFG_T: 'configs/COCO-Detection/teacher_models/faster_rcnn_R_101_FPN_3x.yaml'
  DISTILL: 'hint'
  # choices =  ['hint', 'attention', 'similarity', 'nst', 'rkd', 'pkt']
  # trial id
  TRIAL: 1.0    
  # weight balance for other losses
  B: 1.0

2. 导入教师模型

根据detectron2官方提供的下载地址下载教师模型,并放置在detectron2/teacher_models文件夹下:

在这里插入图片描述
配置教师模型对应的参数,放置在configs/COCO-Detection/teacher_models文件夹下:

在这里插入图片描述

3. 计算蒸馏损失

GeneralizedRCNNfoward函数中添加蒸馏损失,初始化和调用时添加蒸馏参数

# 调用蒸馏损失函数
from detectron2.distiller_helper.distill import Distill

class GeneralizedRCNN(nn.Module):
    @configurable
    def __init__(
        self,
        ...
        distill_cfg: dict = None # 添加蒸馏参数
    ):
        super().__init__()
        ...
        # ==================================蒸馏====================================
        # 添加蒸馏参数
        self.distill_cfg = distill_cfg

        # 构建教师模型
        if distill_cfg.ENABLE:
            self.distill_model = Distill(self.distill_cfg)

    @classmethod
    def from_config(cls, cfg):
        backbone = build_backbone(cfg)
        return {
            ...
            "distill_cfg": cfg.DISTILL # 添加蒸馏参数
        }
        
    def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
        ...
        # distill_losses
        if self.distill_cfg.ENABLE:
            distill_losses = self.distill_model.compute_distill_loss(batched_inputs, features)
            losses.update(distill_losses)
        
        return losses

创建get_features函数,用来输出教师特征图

    def get_features(self, batched_inputs):
        '''
        return: dict
            keys: 'p2'-'p6'
            values: 分别对应FPN第n层的特征层
        '''
        images = self.preprocess_image(batched_inputs)
        features = self.backbone(images.tensor)
        return features

创建distill.py文件,用来构建教师模型和计算蒸馏损失,放置在detectron2/distiller_helper文件夹下。

from detectron2.config import get_cfg
from detectron2.modeling.meta_arch import build_model
from detectron2.distiller_zoo import HintLoss, Attention, Similarity, NSTLoss, RKDLoss, PKT
from detectron2.checkpoint import DetectionCheckpointer
import torch

class Distill():
    def __init__(self, distill_cfg):
        self.opt = distill_cfg
        self.model_t = self.build_teacher_model()

    def build_teacher_model(self):
        '''构建教师模型'''
        # 模型创建
        teacher_cfg = get_cfg() 
        teacher_cfg.merge_from_file(self.opt.CFG_T)
        teacher_cfg['student_identity'] = False
        model_t = build_model(teacher_cfg)
        # 参数加载
        DetectionCheckpointer(model_t).resume_or_load(self.opt.PATH_T, resume=False)
        return model_t
        
    def compute_distill_loss(self, batched_inputs, logit_s): 
        '''计算蒸馏损失'''
        
        # 获取教师特征图
        with torch.no_grad():
            logit_t = self.model_t.get_features(batched_inputs)  

        # ==========================蒸馏损失函数==========================      
        # kd 损失函数
        if self.opt.DISTILL == 'hint':
            criterion_kd = HintLoss()
        elif self.opt.DISTILL == 'attention':
            criterion_kd = Attention()
        elif self.opt.DISTILL == 'nst':
            criterion_kd = NSTLoss()
        elif self.opt.DISTILL == 'similarity':
            criterion_kd = Similarity()
        elif self.opt.DISTILL == 'rkd':
            criterion_kd = RKDLoss()
        elif self.opt.DISTILL == 'pkt':
            criterion_kd = PKT()
        elif self.opt.DISTILL == 'kdsvd':
            criterion_kd = KDSVD()
        else:
            raise NotImplementedError(self.opt.DISTILL)

        # 对五层FPN的损失取平均
        loss_kd = 0
        for t,s in zip(logit_t.values(), logit_s.values()):
            loss_kd += criterion_kd(s, t.detach())
            
        loss_kd = loss_kd/5 * self.opt.B

        distill_losses = {'loss_kd': loss_kd}

        return distill_losses

还有一些其它细节,就不赘述了,完整代码见我的Github

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于检测代码库detectron2的蒸馏应用 的相关文章

  • FreeRTOS——流和消息缓冲区

    FreeRTOS 基础系列文章 基本对象 FreeRTOS 任务 FreeRTOS 队列 FreeRTOS 信号量 FreeRTOS 互斥量 FreeRTOS 任务通知 FreeRTOS 流和消息缓冲区 FreeRTOS 软件定时器 Fre
  • FreeRTOS——静态与动态内存分配

    FreeRTOS 基础系列文章 基本对象 FreeRTOS 任务 FreeRTOS 队列 FreeRTOS 信号量 FreeRTOS 互斥量 FreeRTOS 任务通知 FreeRTOS 流和消息缓冲区 FreeRTOS 软件定时器 Fre
  • CAS 6.5.5项目初始化搭建运行

    一 项目背景介绍 公司项目重构 xff0c 决定使用CAS中央认证系统 在GitHub上找到最新的稳定版本6 5 5 CAS项目在5 x版本的运行环境是jdk8 xff0c 使用maven做的项目管理 6 x使用的是jdk11作为运行环境
  • GoogleTest中gMock的使用

    GoogleTest中的gMock是一个库 xff0c 用于创建mock类并使用它们 当你编写原型或测试 prototype or test 时 xff0c 完全依赖真实对象通常是不可行或不明智的 not feasible or wise
  • 基于Autoware制作高精地图(一)

    基于Autoware制作高精地图 xff08 一 xff09 开始进入正题 xff0c 也是最近在忙的一件事 xff0c 制作高精地图 高精地图的制作大概分为以下四个流程 xff08 不一定完全正确 xff09 xff1a 1 构建点云地图
  • Ubuntu sh文件编写,开多终端,自动读取密码

    Ubuntu sh文件编写 xff0c 开多终端 xff0c 自动读取密码 开启多个终端自动读取密码 在最近的项目调试中经常需要开多个终端启动多个launch xff0c 这样的操作多了难免会感到烦躁并且时间一长再回去使用一些功能包的时候就
  • 控制理论——自动控制原理若干概念

    1 对自动控制系统的基本要求 稳定性 被控量因扰动偏离期望值后 xff0c 经过过渡过程可以恢复到原来的期望值状态 快速性 包含两方面 xff1a 过渡过程的时间 最大超调量 xff08 震荡幅度 xff09 准确性 指稳态误差 xff1a
  • Optitrack下通过mavros实现offbord控制

    参考文章 xff1a 树莓派通过MAVROS与Pixhawk PX4通信 PX4使用Optitrack进行室内定位 通过optitrack与妙算连接在同一局域网下 xff0c 关闭防火墙 xff0c 并设置刚体发布 vrpn安装 cd ca
  • 【场景图生成】Unbiased Scene Graph Generation from Biased Training

    文章下载地址 xff1a https arxiv org pdf 2002 11949 pdf 代码地址 xff1a GitHub KaihuaTang Scene Graph Benchmark pytorch 发表地点 xff1a CV
  • 【场景图生成】Graphical Contrastive Losses for Scene Graph Parsing

    文章下载地址 xff1a Graphical Contrastive Losses for Scene Graph Parsing 代码地址 xff1a https github com NVIDIA ContrastiveLosses4V
  • jquery无法获取到textarea中的值详解

    问题描述 xff1a 今天在springboot中jquery读取前端的值通过jquery打包为json传入后端 xff0c 发现其中textarea区域中的内容无法获取 解决办法 xff1a 首先看你的textarea中是否有 name属
  • 阿里云大学——Java语言基础自测考试 - 初级难度

    1 假设有如下程序 xff1a span class token keyword public span span class token keyword class span span class token class name Dem
  • could not transfer artifact org.springframework.boot:spring-boot-starter-parent

    Springboot异常 could not transfer artifact org springframework boot spring boot starter parent pom 2 3 0 RELEASE from to c
  • 阿里云ECS搭建个人简历网站

    能在自己的网站上搭建简历是不是很酷 xff0c 今天我就教大家如何在自己的服务器上搭建一个个人简历网站 因为主流网站的搭站环境是LAMP环境 xff0c 所以第一步就是先去把服务器环境 一 修改为LAMP环境 停止ECS实例运行 点击使用就
  • GitHub加速神器FastGithub的使用

    clone GitHub上的项目时经常超时 pull或push的时候也有类似情况 有时GitHub也打不开 xff0c 这里推荐GitHub上的一个工具FastGithub xff0c 开启它后 xff0c 可大大减少超时情况的发生 这里介
  • 阿里云ECS打造属于自己的WEB——IDE编程环境

    首先感谢 64 1430059860老哥的指导 xff0c 在阿里的官方视频卡着以后就一直进去入不了下一步了 xff0c 特向我的组长老哥带带 xff0c 最终搭建成功 停止实例选择更换操作系统 xff08 如果使用centoS建议更换ub
  • 给阿里云服务器装一个图形化界面——Gnome

    我这里使用的是ubantu系统 第一步 xff1a apt get update更新一下源第二步 下载Gnome图形化界面 apt get install gnome shell ubuntu gnome desktop第三步 下载完成 a
  • 0基础使用阿里云打造自己的私人云盘

    平时我们使用云盘例如有百度云 xff0c 蓝奏云 xff0c 小米云盘 xff0c 虽然给我们带来不少的便利 xff0c 但是也存在私人数据泄露和文件下载速度过慢的风险 xff0c 所以 xff0c 打造一款属于自己的私人云盘是一个很好的选
  • Redis无法加载配置文件中日志文件的解决方法

    Can t open the log file Permission denied logfile usr local redis etc redis6380 log Can t open the log file Permission d
  • Request method ‘PUT‘ not supported

    今天写后端接口出现问题 xff0c 出现Request method PUT not supported 可能是springboot的bug xff0c 在修改无果后 xff0c 关闭程序 xff0c 进行rebuild多次后 xff0c

随机推荐