MMDet——EMA更新hook详解

2023-11-14

Hook

首先需要明白mmdet中hook机制,EMA就是建立在Hook机制上的,推荐一个Hook详解:

EMA(指数平均 exponential mean average)

  • 一般来说,在Semi-supervised半监督学习任务中,EMA是指通过Student model学习多个样本后获得的参数,来对Student model进行参数更新的策略,从而使得Student model的权重更新更稳定。
  • 在mmdet中,官方对于EMA就有相关实现mmdetection/mmdet/core/hook/ema.py
class BaseEMAHook(Hook):
    """Exponential Moving Average Hook.

    Use Exponential Moving Average on all parameters of model in training
    process. All parameters have a ema backup, which update by the formula
    as below. EMAHook takes priority over EvalHook and CheckpointHook. Note,
    the original model parameters are actually saved in ema field after train.

    Args:
        momentum (float): The momentum used for updating ema parameter.
            Ema's parameter are updated with the formula:
           `ema_param = (1-momentum) * ema_param + momentum * cur_param`.
            Defaults to 0.0002.
        skip_buffers (bool): Whether to skip the model buffers, such as
            batchnorm running stats (running_mean, running_var), it does not
            perform the ema operation. Default to False.
        interval (int): Update ema parameter every interval iteration.
            Defaults to 1.
        resume_from (str, optional): The checkpoint path. Defaults to None.
        momentum_fun (func, optional): The function to change momentum
            during early iteration (also warmup) to help early training.
            It uses `momentum` as a constant. Defaults to None.
    """

    def __init__(self,
                 momentum=0.0002,
                 interval=1,
                 skip_buffers=False,
                 resume_from=None,
                 momentum_fun=None):
        assert 0 < momentum < 1
        self.momentum = momentum
        self.skip_buffers = skip_buffers
        self.interval = interval
        self.checkpoint = resume_from
        self.momentum_fun = momentum_fun

    def before_run(self, runner):
        """To resume model with it's ema parameters more friendly.

        Register ema parameter as ``named_buffer`` to model.
        """
        model = runner.model
        if is_module_wrapper(model):
            model = model.module
        self.param_ema_buffer = {}
        if self.skip_buffers:
            self.model_parameters = dict(model.named_parameters())
        else:
            self.model_parameters = model.state_dict()
        for name, value in self.model_parameters.items():
            # "." is not allowed in module's buffer name
            buffer_name = f"ema_{name.replace('.', '_')}"
            self.param_ema_buffer[name] = buffer_name
            model.register_buffer(buffer_name, value.data.clone())
        self.model_buffers = dict(model.named_buffers())
        if self.checkpoint is not None:
            runner.resume(self.checkpoint)

    def get_momentum(self, runner):
        return self.momentum_fun(runner.iter) if self.momentum_fun else \
                        self.momentum

    def after_train_iter(self, runner):
        """Update ema parameter every self.interval iterations."""
        if (runner.iter + 1) % self.interval != 0:
            return
        momentum = self.get_momentum(runner)
        for name, parameter in self.model_parameters.items():
            # exclude num_tracking
            if parameter.dtype.is_floating_point:
                buffer_name = self.param_ema_buffer[name]
                buffer_parameter = self.model_buffers[buffer_name]
                buffer_parameter.mul_(1 - momentum).add_(
                    parameter.data, alpha=momentum)

    def after_train_epoch(self, runner):
        """We load parameter values from ema backup to model before the
        EvalHook."""
        self._swap_ema_parameters()

    def before_train_epoch(self, runner):
        """We recover model's parameter from ema backup after last epoch's
        EvalHook."""
        self._swap_ema_parameters()

    def _swap_ema_parameters(self):
        """Swap the parameter of model with parameter in ema_buffer."""
        for name, value in self.model_parameters.items():
            temp = value.data.clone()
            ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
            value.data.copy_(ema_buffer.data)
            ema_buffer.data.copy_(temp)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

MMDet——EMA更新hook详解 的相关文章

  • PyTorch:如何使用 DataLoaders 自定义数据集

    如何利用torch utils data Dataset and torch utils data DataLoader根据您自己的数据 不仅仅是torchvision datasets 有没有办法使用内置的DataLoaders他们使用的
  • PyTorch 中的截断反向传播(代码检查)

    我正在尝试在 PyTorch 中实现随时间截断的反向传播 对于以下简单情况K1 K2 我下面有一个实现可以产生合理的输出 但我只是想确保它是正确的 当我在网上查找 TBTT 的 PyTorch 示例时 它们在分离隐藏状态 将梯度归零以及这些
  • 无法使用 torch.Tensor 创建张量

    我试图创建一个张量 如下所示 import torch t torch tensor 2 3 我收到以下错误 类型错误回溯 最近调用 最后 在 gt 1 a torch tensor 2 3 类型错误 tensor 需要 1 个位置参数 但
  • Pytorch 数据加载器:错误的文件描述符和 EOF > 0

    问题描述 在使用由自定义数据集制作的 Pytorch 数据加载器进行神经网络训练期间 我遇到了奇怪的行为 数据加载器设置为workers 4 pin memory False 大多数时候 训练都顺利完成 有时 训练会随机停止 并出现以下错误
  • PyTorch - 参数不变

    为了了解 pytorch 的工作原理 我尝试对多元正态分布中的一些参数进行最大似然估计 然而 它似乎不适用于任何协方差相关的参数 所以我的问题是 为什么这段代码不起作用 import torch def make covariance ma
  • PoseWarping:如何矢量化此 for 循环(z 缓冲区)

    我正在尝试使用地面真实深度图 姿势信息和相机矩阵将帧从视图 1 扭曲到视图 2 我已经能够删除大部分 for 循环并将其矢量化 除了一个 for 循环 扭曲时 由于遮挡 视图 1 中的多个像素可能会映射到视图 2 中的单个位置 在这种情况下
  • 如何使用 torch.stack?

    我该如何使用torch stack将两个张量与形状堆叠a shape 2 3 4 and b shape 2 3 没有就地操作 堆叠需要相同数量的维度 一种方法是松开并堆叠 例如 a size 2 3 4 b size 2 3 b torc
  • 将 CNN Pytorch 中的预训练权重传递到 Tensorflow 中的 CNN

    我在 Pytorch 中针对 224x224 大小的图像和 4 个类别训练了这个网络 class CustomConvNet nn Module def init self num classes super CustomConvNet s
  • PyTorch 教程错误训练分类器

    我刚刚开始 PyTorch 教程使用 PyTorch 进行深度学习 60 分钟闪电战我应该补充一点 我之前没有编写过任何 python 但其他语言 如 Java 现在 我的代码看起来像 import torch import torchvi
  • 如何平衡 GAN 中生成器和判别器的性能?

    这是我第一次使用 GAN 我面临着判别器多次优于生成器的问题 我正在尝试重现PA模型来自本文 http openaccess thecvf com content ICCV 2017 papers Sajjadi EnhanceNet Si
  • torchvision.transforms.Normalize 是如何操作的?

    我不明白如何标准化Pytorch works 我想将平均值设置为0和标准差1跨越张量中的所有列x形状的 2 2 3 一个简单的例子 gt gt gt x torch tensor 1 2 3 4 5 6 7 8 9 10 11 12 gt
  • 使用 KL 散度时,变分自动编码器为每个输入 mnist 图像提供相同的输出图像

    当不使用 KL 散度项时 VAE 几乎完美地重建 mnist 图像 但在提供随机噪声时无法正确生成新图像 当使用 KL 散度项时 VAE 在重建和生成图像时都会给出相同的奇怪输出 这是损失函数的 pytorch 代码 def loss fu
  • 如何在pytorch中查看DataLoader中的数据

    我在 Github 上的示例中看到类似以下内容 如何查看该数据的类型 形状和其他属性 train data MyDataset int 1e3 length 50 train iterator DataLoader train data b
  • PyTorch:如何检查训练期间某些权重是否没有改变?

    如何检查 PyTorch 训练期间某些权重是否未更改 据我了解 一种选择可以是在某些时期转储模型权重 并检查它们是否通过迭代权重进行更改 但也许有一些更简单的方法 有两种方法可以解决这个问题 First for name param in
  • PyTorch 中复数矩阵的行列式

    有没有办法在 PyTorch 中计算复矩阵的行列式 torch det未针对 ComplexFloat 实现 不幸的是 目前尚未实施 一种方法是实现您自己的版本或简单地使用np linalg det 这是一个简短的函数 它计算我使用 LU
  • Pytorch ValueError:优化器得到一个空参数列表

    当尝试创建神经网络并使用 Pytorch 对其进行优化时 我得到了 ValueError 优化器得到一个空参数列表 这是代码 import torch nn as nn import torch nn functional as F fro
  • 如何使用pytorch构建多任务DNN,例如超过100个任务?

    下面是使用 pytorch 为两个回归任务构建 DNN 的示例代码 这forward函数返回两个输出 x1 x2 用于大量回归 分类任务的网络怎么样 例如 100 或 1000 个输出 对所有输出 例如 x1 x2 x100 进行硬编码绝对
  • pytorch 的 IDE 自动完成

    我正在使用 Visual Studio 代码 最近尝试了风筝 这两者似乎都没有 pytorch 的自动完成功能 这些工具可以吗 如果没有 有人可以推荐一个可以的编辑器吗 谢谢你 使用Pycharmhttps www jetbrains co
  • Pytorch“展开”等价于 Tensorflow [重复]

    这个问题在这里已经有答案了 假设我有大小为 50 50 的灰度图像 在本例中批量大小为 2 并且我使用 Pytorch Unfold 函数 如下所示 import numpy as np from torch import nn from
  • 将 Pytorch LSTM 的状态参数转换为 Keras LSTM

    我试图将现有的经过训练的 PyTorch 模型移植到 Keras 中 在移植过程中 我陷入了LSTM层 LSTM 网络的 Keras 实现似乎具有三种状态类型的状态矩阵 而 Pytorch 实现则具有四种状态矩阵 例如 对于hidden l

随机推荐

  • 利用树莓派搭建简易服务器

    读研以来笔者一直负责实验室的网络维护 可以说是实验室名副其实的首席大网管 整个实验室是从学校网络中心购买了一个教育网的公网IP地址和带宽 公网IP绑定了实验室的主路由器 而主路由器就在笔者卡位的旁边 有一天笔者突发奇想 拿了手头的树莓派3结
  • Micropython——报错解决:TypeError: object with buffer protocol required

    报错 检查报错处代码 仔细检查可以发现 是括号放错位置 导致函数无法正常执行 故报错 一般情况下 Micropython除硬件如定时器中断内存溢出等硬件本身报错外 其他均为语法错误
  • 统计学习方法学习笔记(一)————统计学习方法概论

    1 统计学习 1 统计学习概念 统计学习 statistical learning 是关于计算机基于数据构建概率统计模型并运用模型对数据进行预测与分析的一门学科 统计学习也称为统计机器学习 statistical machine learn
  • MLIR入门系列系列学习笔记

    目录 1 名字解释 这一定义包含3个关键元素 2 代码演示 2 1 环境准备 2 2 编译llvm project 2 3 测试解析 2 3 1 源程序 2 3 2 将源程序生成抽象语法树 AST 3 MLIR三要素 3 1 MLIRGen
  • 为什么在组件内部data是一个函数而不是一个对象?

    为什么在组件内部data是一个函数而不是一个对象 因为在组件复用的时候会重新生成一个对象 而data是一个对象的话 因为对象是引用数据类型 data数据会被复用 而当data是一个函数的时候每次调用的时候就会返回一个新的data对象 vue
  • 安装--centos7上使用kubeadm安装三节点的k8s集群

    安装文档 https kubernetes io zh cn docs setup production environment tools kubeadm install kubeadm 参考 https blog csdn net qq
  • 瑞吉外卖业务开发

    一 软件开发整体介绍 软件开发流程 需求分析 产品原型 需求规格说明书 设计 产品文档 UI界面设计 概要设计 详细设计 数据库设计 编码 项目代码 单元测试 测试 测试用例 测试报告 上线运维 软件环境安装 配置 角色分工 项目经理 对整
  • 2023华为OD机试真题【垃圾短信识别】

    题目描述 大众对垃圾短信深恶痛绝 希望能对垃圾短信发送者进行识别 为此 很多软件增加了垃圾短信的识别机制 经分析 发现正常用户的短信通常具备交互性 而垃圾短信往往都是大量单向的短信 按照如下规则进行垃圾短信识别 本题中 发送者A符合以下条件
  • unity3d FPS 枪的后座力

    实现枪开枪后 向上偏移一段距离 再缓慢下移复位 模仿cs 调小后座力 using UnityEngine using System Collections public class Camera2Follower MonoBehaviour
  • Linux驱动开发--平台总线id和设备树匹配

    目录 一 ID匹配之框架代码 二 ID匹配之led驱动 三 设备树匹配 四 设备树匹配之led驱动 五 一个编写驱动用的宏 一 ID匹配之框架代码 id匹配 可想象成八字匹配 一个驱动可以对应多个设备 优先级次低 注意事项 device模块
  • 学会QT从这里开始——教你快速学会QT

    为了提高提高推文质量 最近又再翻看QT书籍 不知道大家有没有发现 QT书籍大多都是从环境 安装 控件开始讲解 好 现在开始学习吧 1 环境安装 2 新建项目 3 学习控件 QButton QLable QLineEdit QTextEdit
  • nacos2.2.1集成达梦数据库

    nacos2 2 1集成达梦数据库 1 下载源码 https github com alibaba nacos 2 新增达梦驱动依赖 父pom xml
  • openwrt篇修改WiFi热点默认名称和主机名

    在如下图文件中 修改ssid 在如下图文件中修改hostname
  • Linux的用户空间与内核空间

    一 简介 Linux 操作系统和驱动程序运行在内核空间 应用程序运行在用户空间 两者不能简单地使用指针传递数据 因为Linux使用的虚拟内存机制 用户空间的数据可能被换出 当内核空间使用用户空间指针时 对应的数据可能不在内存中 用户空间的内
  • vue3项目引入高德地图详细方法教程

    项目需求需要引入地图 对于目前最新的Vue3 0 无论是百度 高德 腾讯地图目前还没有适配 只有Vue 2 x版本的 目前只有谷歌地图的Vue3 0适配 但是没有适配并不代表不能使用 下面就来教大家如何使用 1 在高德开发平台申请你的key
  • react定义函数,默认函数参数的方式

    参数是 对象 有传入参数用传入参数作为入参数 无传入参数用默认值 getTableData async pageData gt const params Object assign currPage 1 pageSize this stat
  • 网传字节跳动实习生删除GB以下所有机器学习模型,差点没上头条

    作者 陈大鑫 陈彩娴 来源 AI科技评论 昨晚脉脉上有网友爆料 字节跳动一位实习生删除了公司所有轻量级别的机器学习模型 什么是lite模型 该楼主表示 lite模型就是公司内几乎所有GB大小以下的机器学习模型 且全部被删除了 实习生直接删除
  • 公司固定资产怎么明细管理

    固定资产的管理是一个至关重要的环节 它不仅影响到企业的运营效率和经济效益 也直接影响到公司的长期发展 因此 对固定资产进行精细化管理 是每一个负责任的企业都应该做到的 本文将探讨如何通过创新的方式 实现公司固定资产的明细管理 我们需要明确什
  • 设置vscode终端的最大输出行

    使用vscode终端输出的时候 如果输出的行数很多 之前打印的东西就看不到了 因此需要设置一下终端输出的最大行数来保留之前的信息 terminal integrated bell scrollback
  • MMDet——EMA更新hook详解

    Hook 首先需要明白mmdet中hook机制 EMA就是建立在Hook机制上的 推荐一个Hook详解 深度理解目标检测 MMdetection HOOK机制 EMA 指数平均 exponential mean average 一般来说 在