pytorch 中register_buffer()

2023-11-17

 

今天在看DSSINet代码的ssim.py时,遇到了一个用法

class NORMMSSSIM(torch.nn.Module):

    def __init__(self, sigma=1.0, levels=5, size_average=True, channel=1):
        super(NORMMSSSIM, self).__init__()
        self.sigma = sigma
        self.window_size = 5
        self.levels = levels
        self.size_average = size_average
        self.channel = channel
        self.register_buffer('window', create_window(self.window_size, self.channel, self.sigma))
        self.register_buffer('weights', torch.Tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]))

那么这个register_buffer()是干什么用呢?官方解释如下

nn.modules.module.py
Adds a persistent buffer to the module.向模块添加持久缓冲区。

        This is typically used to register a buffer that should not to be
        considered a model parameter. For example, BatchNorm's ``running_mean``
        is not a parameter, but is part of the persistent state.这通常用于注册不应被视为模型参数的缓冲区。例如,BatchNorm的“running_mean”不是参数,而是持久状态的一部分。
        Buffers can be accessed as attributes using given names.
缓冲区可以使用给定的名称作为属性访问。 
        Args:
            name (string): name of the buffer. The buffer can be accessed
                from this module using the given name 名称(字符串):缓冲区的名称。可以使用给定的名称从该模块访问缓冲区
            tensor (Tensor): buffer to be registered.
        Example::
            >>> self.register_buffer('running_mean', torch.zeros(num_features))        

应该就是在内存中定一个常量,同时,模型保存和加载的时候可以写入和读出。

pytorch一般情况下,是将网络中的参数保存成orderedDict形式的,这里的参数其实包含两种,一种是模型中各种module含的参数,即nn.Parameter,我们当然可以在网络中定义其他的nn.Parameter参数,另一种就是buffer,前者每次optim.step会得到更新,而不会更新后者。

class myModel(nn.Module):
    def __init__(self, kernel_size=3):
        super(Depth_guided1, self).__init__()
        self.kernel_size = kernel_size
        self.back_end = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, 3, padding=1),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(3, 64, 3, padding=1),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(64, 3, 3, padding=1),
            torch.nn.ReLU(True),
        )

        mybuffer = np.arange(1,10,1)
        self.mybuffer_tmp = np.randn((len(mybuffer), 1, 1, 10), dtype='float32')
        self.mybuffer_tmp = torch.from_numpy(self.mybuffer_tmp)
        # register preset variables as buffer
        # So that, in testing , we can use buffer variables.
        self.register_buffer('mybuffer', self.mybuffer_tmp)

        # Learnable weights
        self.conv_weights = nn.Parameter(torch.FloatTensor(64, 10).normal_(mean=0, std=0.01))
        # Other code
        def forward(self):
            ...
            # 这里使用 self.mybuffer!

注记:

1.定义parameter和buffer都只需要传入Tensor即可。也不需要将其转成gpu,这是因为,当网络进行.cuda时候,会自动将里面的层的参数,buffer等转换成相应的GPU上。

2. self.register_buffer可以将tensor注册成buffer,在forward中使用self.mybuffer,而不是self.mybuffer_tmp

3.网络存储时也会将buffer存下,当网络load模型时,会将存储的模型的buffer也进行赋值。

4.buffer的更新在forward中,optim.step只能更新nn.parameter类型的参数。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch 中register_buffer() 的相关文章

  • 基于Python实现 传感器的随机布置 传感网覆盖仿真

    代码演示 import tkinter as tk import random import win32gui import cv2 import time import math from PIL import Image ImageGr
  • 黑客游戏Hacknet下载(游戏分享一)

    OK Shall we begin Hacknet中文版下载 百度网盘 添加链接描述 夸克网盘 添加链接描述 注 解压后直接点击Hacknet exe进行游戏 英文版下载 百度网盘 添加链接描述 夸克网盘 添加链接描述 难关过不了自行上b站
  • AndroidStudio链接手机的步骤

    1 设置手机为开发者模式 设置 gt 关于手机 gt 连续点击MIUI版本 开启成功 2 在更多设置中选择开发者选项 在开发者选项中同时勾选USB调试和USB安装的开关 3 数据线与电脑连接 4 打开AndroidStudio 等待程序加载
  • centos7关闭防火墙

    出现物理机ping不通虚拟机 但虚拟机可以ping通物理机 排查的方向 一个是虚拟机的防火墙问题 1 查看防火墙的状态 systemctl status firewalld 2 关闭防火墙 如果还是不通 第二个排查方向是虚拟机的链接模式 桥
  • Springboot集成activiti的配置文件ActivitiConfig

    Configuration public class ActivitiConfig Bean public ProcessEngineConfiguration processEngineConfiguration DataSource d
  • Stable Diffusion教程

    什么是Stable Diffusion Stable Diffusion是一种潜在扩散模型 Latent Diffusion Model 能够从文本描述中生成详细的图像 它还可以用于图像修复 图像绘制 文本到图像和图像到图像等任务 简单地说
  • radius认证服务

    radius认证服务 RADIUS是一种分布的 客户端 服务器系统 实现安全网络 反对未经验证的访问 在cisco实施中 RADIUS客户端运行在cisco路由器上上 发送认证请求到中心RADIUS服务器 服务器上包含了所有用户认证和网络服
  • cuda测试集编译linux,linux下使cmake编译cuda(附列子,亲测可用)

    在网上百度 并没有找到什么合适的教程 让我等小白着急不已 借助于GOOGLE的强大能力 发现原来cmake已经支持了cuda 于是乎 赶紧 http www cmake org 下载了最新的cmake 调用了里面的一个FindCUDA cm
  • ApiPost 开源接口调试工具使用大全

    ApiPost使用 简介 接口调试 API请求参数 Header 参数 Query 参数 Body 参数 API 请求响应 返回Headers 响应结果分屏展示 生成调试代码 参数 全局参数 目录参数 参数的优先级 变量 环境变量 环境变量
  • 运行 AppImage软件:Running AppImages (***)

    How to run an AppImage Running AppImages 使用 AppImage appImagetool 进行 Linux 软件包管理 带笔记 要点 1 需要运行权限 通常 linux软件的运行 都需要运行权限 B
  • 基于OpenCASCADE自制三维建模软件(一)介绍

    一 制作背景 目前工作的项目中 需要三维建模作为其中一个模块 而本人刚接触三维建模 因而借助制作一个简单的三维建模软件学习相关的知识 并在此作笔记 在调研过程中 我了解到开源的Open CASCADE软件平台 Open CASCADE简称O
  • ChatGPT 类 AI 软件供应链的安全及合规风险

    AIGC将成为重要的软件供应链 近日 OpenAI推出的ChatGPT通过强大的AIGC 人工智能生产内容 能力让不少人认为AI的颠覆性拐点即将到来 基于AI将带来全新的软件产品体验 而AI也将会成为未来软件供应链中非常重要的一环 在Ope
  • colab 导出csv文件

    生成之后download即可 from google colab import files files download train csv
  • 使用nodejs接入封装第三方短信验证码工具类

    简介 使用nodejs接入封装第三方短信验证码工具类 第三方短信运营商接入 安装axiosyarn add axios 1 1 3 配置 在项目的config文件夹新建一个文件 命名为接入短信验证码平台的名字 例如aliyunMessage
  • js实现5秒后跳转页面

    提示 文章写完后 目录可以自动生成 如何生成可参考右边的帮助文档 文章目录 一 使用js代码实现延时跳转 二 使用步骤 1 定时器 setInterval 2 location跳转 3 整体实现 总结 提示 以下是本篇文章正文内容 下面案例
  • 数据结构之 栈(C语言实现)

    数据结构之 栈 C语言实现 1 栈的模型 栈 stack 是限制插入和删除只能在一个位置上进行的表 该位置是表的末端 叫做栈的顶 top 对栈的基本操作有push 进栈 和pop 出栈 前者相当于插入 后者则是删除最后插入的元素 最后插入的
  • Shell脚本攻略:Linux防火墙(一)

    目录 一 理论 1 安全技术 2 防火墙 3 通信五元素和四元素 4 总结 二 实验 1 iptables基本操作 2 扩展匹配 3 自定义链接 一 理论 1 安全技术 1 安全技术 入侵检测系统 Intrusion Detection S
  • 一分钟秒懂公有云、私有云、混合云......

    近几年随着云计算技术的逐渐普及 越来越多的企业开始选择了部署云计算方案 当运营赖于数据结构和网络管理业务时 云计算的灵活性 易用性 定制性给企业带来的优势是毋庸置疑的 但是公有云 私有云 混合云等等到底都是什么呢 公有云 私有云 混合云 这
  • NPOI 单元格设置边框

    很多表格中都要使用边框 本节将为你重点讲解NPOI中边框的设置和使用 边框和其他单元格设置一样也是调用ICellStyle接口 ICellStyle有2种和边框相关的属性 分别是 边框相关属性 说明 范例 Border 方向 边框类型 Bo
  • SourceInsight保存文件时自动去除多余的空格

    在用source insight 写代码后提交git 如果有一些多余的空格不删除就提交会出现标红的界面 在source insight 中可以设置保存时自动去除多余的空格 Options gt gt Files gt gt Remove e

随机推荐

  • Yahoo(雅虎)宣布停止开发YUI

    转载至 http www infoq com cn news 2014 09 yahoo drop axe YUI utm campaign infoq content utm source infoq utm medium feed ut
  • DoTween使用

    using System Collections using UnityEngine using DG Tweening using UnityEngine UI DOTween真的比iTween好很多 1 编写方面更加人性化 2 效率高很
  • 供应链金融三大类模式

    供应链金融三类模式的最全对比分析 2017 08 25 15 56 供应链金融可以解决中小企业供应链中资金分配的不平衡问题 打通上下游物流链 资金链 商流 信息流 提升整个供应链的群体竞争力 因此 供应链金融 备受中小企业青睐 在 供应链金
  • V4l2框架基础知识(三)

    V4L2框架概述 V4L2框架主要部分组成 V4L2 device 管理所有设备 media device media device框架管理运行时的pipeline V4L2 device 这个是整个输入设备的总结构体 可以认为他是整个V4
  • ROS节点运行管理launch文件

    launch 文件是一个 XML 格式的文件 可以启动本地和远程的多个节点 还可以在参数服务器中设置参数 作用 可以简化节点的配置与启动 提高ROS程序的启动效率 一 新建 1 新建launch文件 如 turtlesim 在功能包下添加
  • gdb

    100个gdb技巧 Debugging with GDB gdb调试基础 g选项 在编译时要加上 g选项 生成的可执行文件才能用gdb进行源码级调试 g选项的作用是在可执行文件中加入源代码的信息 比如可执行文件中第几条机器指令对应源代码的第
  • Windows10+ubuntu 双系统安装(针对联想小新air14)

    联想小新air14 Windows10 ubuntu 双系统安装 一 准备工作 1 查看电脑配置 1 查看BIOS模式 2 搞清楚硬盘单双 2 制作系统盘 1 资源准备 2 写盘 3 磁盘分区 二 安装过程 1 用做好的系统盘安装系统 2
  • LaTeX的基本使用

    看前说明 说明 这篇文章介绍了latex的基本使用 基本覆盖了latex入门的知识点 由本人自己学习研究整理出来 不可被他人拿来进行不当的商用等等 违者必究 大家利用下面完整的latex文档 在编译器中编译 对比latex文档和生成文件之间
  • MyEclipse中关闭项目的作用及操作方法

    1 关闭项目的操作方式 选中项目 右键 点击Close Project 便可关闭当前项目 如图 关闭后的项目状态如图所示 2 开启项目的操作方式 双击项目或右键项目 点击Open Project 弹出如下窗口 点击 No 按钮 开启当前项目
  • JUC 之 线程局部变量 ThreadLocal

    ThreadLocal 基本概念 ThreadLocal 提供线程局部变量 这些变量与正常的变量不同 因为每一个线程在访问 ThreadLocal 实例的时候 通过其get 或者 set 方法 都有自己的 独立初始化的变副本 ThreadL
  • react、umi、dva

    React 一 React的简介 1 介绍 React 是一个用于构建用户界面的 JAVASCRIPT 库 React主要用于构建UI 很多人认为 React 是 MVC 中的 V 视图 React 起源于 Facebook 的内部项目 用
  • Mac升级Catalina(10.15)后 clion不能运行,提示「xcrun: error: invalid active developer path ...」

    Mac升级Catalina 10 15 后 使用clion 运行失败 提示内容如下 xcrun error invalid active developer path Library Developer CommandLineTools m
  • MQTT-保留消息和遗嘱消息

    遗嘱消息 为什么需要遗嘱消息 MQTT的订阅发布机制 解耦了消息的发送方和接收方 这使我们没有办法获取对端的状态 为了解决该问题 MQTT提供了遗嘱消息 为意外断线的客户端提供了对外发出通知的能力 如何使用遗嘱消息 使用遗嘱消息 客户端需要
  • 【笔记】关于win导入外部动态磁盘时“包名称无效”的解决办法

    网易博客搬家 原贴时间 2015 02 22 一 问题背景 硬盘闲置 电脑重装系统以后装上硬盘 计算机 中无盘符 磁盘管理中显示硬盘为 外部动态磁盘 右键 导入外部磁盘 提示 包名称错误 二 尝试过程 1 重启电脑 无效 2 换sata接口
  • DBeaver教程:连接达梦数据库DM8

    本文介绍如何通过dbeaver连接达梦数据库进行管理 DBeaver 是一个基于 Java 开发 免费开源的通用数据库管理和开发 DBeaver 采用 Eclipse 框架开发 支持插件扩展 并且提供了许多数据库管理工具 ER 图 数据导入
  • linux查看所有文件

    这本阿里P8撰写的算法笔记 再次推荐给大家 身边不少朋友学完这本书最后加入大厂 Github 疯传 史上最强悍 阿里大佬 LeetCode刷题手册 开放下载了 1 linux文件结构 linux文件结构是树形的 根目录是 其它所有文件都是在
  • OCR加持白描App,让AI成为视障者的眼睛

    现实中 你可以轻松无障碍地阅读各类平面印刷文字以及身边的一切 或许你未曾想过 视障人群该怎么办呢 统计数据显示 中国大约有1700万的视障群体 相当于每100个人中就有超过1位是视障人士 但我们在日常生活中却很少见到他们 那是因为视障群体在
  • 3、ARIMA序列预测Matlab代码、可视化(可做算法对比)

    1 文件包中程序均收集 整理 汇总自网络 2 文件包完整内容 1 ARIMA 功能函数 仅包含一个ARIMA算法函数 需要调用到自己的程序中使用 函数部分代码及预览图 function result ARIMA algorithm data
  • 应用程序本地化

    一 简介 使用本地化功能 可以轻松地将应用程序翻译成多种语言 甚至可以翻译成同一语言的多种方言 如果要添加本地化功能 需要为每种支持的语言创建一个子目录 称为 本地化文件夹 通常使用 lproj作为拓展名 当本地化的应用程序需要载入某一资源
  • pytorch 中register_buffer()

    今天在看DSSINet代码的ssim py时 遇到了一个用法 class NORMMSSSIM torch nn Module def init self sigma 1 0 levels 5 size average True chann