pytorch学习笔记(十二):详解 Module 类

2023-05-16

Modulepytorch 提供的一个基类,每次我们要 搭建 自己的神经网络的时候都要继承这个类,继承这个类会使得我们 搭建网络的过程变得异常简单。

本文主要关注 Module 类的内部是怎么样的。

初始化方法中做了什么

def __init__(self):
    self._backend = thnn_backend
    self._parameters = OrderedDict()
    self._buffers = OrderedDict()
    self._backward_hooks = OrderedDict()
    self._forward_hooks = OrderedDict()
    self._forward_pre_hooks = OrderedDict()
    self._modules = OrderedDict()
    self.training = True

这是 Module 的初始化方法:

  • self._parameters 用来存放注册的 Parameter 对象
  • self._buffers 用来存放注册的 Buffer 对象。(pytorch 中 buffer 的概念就是 不需要反向传导更新的值)
  • self._modules 用来保存注册的 Module 对象。
  • self.training 标志位,用来表示是不是在 training 状态下
  • ...hooks 用来保存 注册的 hook

__setattr____getattr__

__setattr__ 每次给属性赋值的时候,都会调用这个方法。

__setattr__ 的代码比较多,我们一点一点看。

  • remove_from :工具函数, 用来从 self.__dict__, self._buffers, self._modules 中删除对象。

第一种情况: value 的类型是 Paramter

  • 从 三大 字典中将 同名的 对象删掉
  • 然后,注册 paramter

第二种情况: value不是 Parameter对象, nameself._parameter

  • self._parameters[name] = None

已经考虑了 valueParameter对象,剩下的就是考虑 valuebufferModule

第三种情况:value不是 Parameter对象, valueModule 对象

  • 从三大字典里面移除同名 对象
  • 然后直接向 self._modules 字典里添加 value

第四种情况:value不是Parameter对象, value不为 Module对象, 但是 nameself._modules

  • self._modules[name]=None

第五种情况:value不是Parameter对象, value不为 Module对象, name 存在 self._buffers

  • self._buffers[name]=None

最后一种情况: 就是 普通的属性了。

def __setattr__(self, name, value):
    def remove_from(*dicts):
        for d in dicts:
            if name in d:
                del d[name]

    params = self.__dict__.get('_parameters')

    if isinstance(value, Parameter):
        if params is None:
            raise AttributeError(
                "cannot assign parameters before Module.__init__() call")
        remove_from(self.__dict__, self._buffers, self._modules)
        self.register_parameter(name, value)
    elif params is not None and name in params:
        if value is not None:
            raise TypeError("cannot assign '{}' as parameter '{}' "
                            "(torch.nn.Parameter or None expected)"
                            .format(torch.typename(value), name))
        self.register_parameter(name, value)
    else:
        modules = self.__dict__.get('_modules')
        if isinstance(value, Module):
            if modules is None:
                raise AttributeError(
                    "cannot assign module before Module.__init__() call")
            remove_from(self.__dict__, self._parameters, self._buffers)
            modules[name] = value
        elif modules is not None and name in modules:
            if value is not None:
                raise TypeError("cannot assign '{}' as child module '{}' "
                                "(torch.nn.Module or None expected)"
                                .format(torch.typename(value), name))
            modules[name] = value
        else:
            buffers = self.__dict__.get('_buffers')
            if buffers is not None and name in buffers:
                if value is not None and not torch.is_tensor(value):
                    raise TypeError("cannot assign '{}' as buffer '{}' "
                                    "(torch.Tensor or None expected)"
                                    .format(torch.typename(value), name))
                buffers[name] = value
            else:
                object.__setattr__(self, name, value)

__getattr__ : 当获取 self.__dict__ 中没有的键所对应的值的时候,就会调用这个方法

因为 parameter, module, buffer 的键值对存在与 self._parameters, self._modules, self.buffer 中,所以,当想获取这些 值时, 就会调用这个方法。

def __getattr__(self, name):
    if '_parameters' in self.__dict__:
        _parameters = self.__dict__['_parameters']
        if name in _parameters:
            return _parameters[name]
    if '_buffers' in self.__dict__:
        _buffers = self.__dict__['_buffers']
        if name in _buffers:
            return _buffers[name]
    if '_modules' in self.__dict__:
        modules = self.__dict__['_modules']
        if name in modules:
            return modules[name]
    raise AttributeError("'{}' object has no attribute '{}'".format(
        type(self).__name__, name))

register_parameter

向模型中注册 Parameter

def register_parameter(self, name, param):
    """Adds a parameter to the module.

    The parameter can be accessed as an attribute using given name.
    """
    if '_parameters' not in self.__dict__:
        raise AttributeError(
            "cannot assign parameter before Module.__init__() call")
    if param is None:
        self._parameters[name] = None
    elif not isinstance(param, Parameter):
        raise TypeError("cannot assign '{}' object to parameter '{}' "
                        "(torch.nn.Parameter or None required)"
                        .format(torch.typename(param), name))
    elif param.grad_fn:
        raise ValueError(
            "Cannot assign non-leaf Variable to parameter '{0}'. Model "
            "parameters must be created explicitly. To express '{0}' "
            "as a function of another variable, compute the value in "
            "the forward() method.".format(name))
    else:
        self._parameters[name] = param

Module.training 标志 如何影响 前向过程

nn.Dropout 来看 Module.training

class Dropout(Module):
    def __init__(self, p=0.5, inplace=False):
        super(Dropout, self).__init__()
        if p < 0 or p > 1:
            raise ValueError("dropout probability has to be between 0 and 1, "
                             "but got {}".format(p))
        self.p = p
        self.inplace = inplace

    def forward(self, input):
        return F.dropout(input, self.p, self.training, self.inplace)

可以看出,在forward 过程中,直接获取,父类的training的值。

我们 通常通过 module.train()module.eval() 来切换模型的 训练测试阶段。

def train(self, mode=True):
    """Sets the module in training mode.
    This has any effect only on modules such as Dropout or BatchNorm.
    """
    self.training = mode

    for module in self.children():
        # 递归调用子模块 train 函数, 来设定所有 module 的 training 值。
        module.train(mode)
        return self

需要注意的是:module.eval() 仅仅设置 moduletraining 属性,如果我们想获得最快的推断速度, 还需要 设置 输入 Variablevolatile 属性为 True

参考资料

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch学习笔记(十二):详解 Module 类 的相关文章

  • MinGW-w64安装教程——著名C/C++编译器GCC的Windows版本

    MinGW w64安装教程 著名C C 43 43 编译器GCC的Windows版本 MinGW w64安装教程 著名C C 43 43 编译器GCC的Windows版本 本文主要讲述如何安装 C语言 编译器 MinGW w64 xff0c
  • RT-Thread实时操作系统简介

    目录 一 概述 二 架构 三 版本选择 四 内核启动流程 五 自动初始化机制 六 内核对象模型 七 I O设备模型 1 框架 2 设备驱动使用序列图 3 设备类型 八 FinSH控制台 九 ENV工具 1 menuconfig 2 Scon
  • PCIe RAS

    对于Linux系统针对RAS的AER错误处理机制完成 PCIe RAS简单来讲就是PCIe的错误检测 纠正以及汇报的机制 它可以方便我们准确的定位 xff0c 纠正和分析错误增强系统的健壮性和可靠性 PCIe错误的分类 PCIe错误分为可校
  • Linux下的regulator调试

    先看regulator使用的小demo 如 i2c8 touchscreen 64 28 vddcama supply 61 lt amp xxxxx gt int ret struct regulator power static int
  • 关于添加系统调用遇到 Unable to handle kernel paging request at virtual address 的解决

    Unable to handle kernel paging request at virtual address 是内存访问异常的错误 xff0c 原因通常有三种 xff1a virtual address 为 0x00000000 时
  • vscode安装配置clang-format插件及使用

    vscode安装配置clang format插件及使用 首先安装插件 在vscode扩展里搜索clang format xff0c 安装排名第一的xaver clang format 确认clang format可执行程序路径 window
  • 简历中项目描述怎么写啊

    http wenda tianya cn question 7ade6dc9324bed88
  • 树莓派(Raspberry Pi 3) - 系统烧录及系统使用

    树莓派 xff08 Raspberry pi xff09 是一块集成度极高的ARM开发板 xff0c 不仅包含了HDMI xff0c RCA xff0c CSI xff0c HDMI xff0c GPIO等端口 xff0c 还支持蓝牙以及无
  • flashcache原理

    介绍flashcache的文章很多 xff0c 我就不废话了 使用上 xff0c 有余峰老哥的 文章 xff1b 原理上 xff0c 有ningoo同学的 flashcache系列 但是ningoo同学漏掉了device mapper和fl
  • 无人机算法之PID

    xff08 未完成 xff09 一 PID介绍 xff08 百度百科 xff09 PID 控制器 xff08 比例 积分 微分控制器 xff09 是一个在工业控制应用中常见的反馈回路部件 xff0c 由比例单元 P 积分单元 I 和微分单元
  • java:接口、lambda表达式与内部类

    接口 xff08 interface 接口用来描述类应该做什么 xff0c 而不指定他们具体应该如何做 接口不是类 xff0c 而是对符合这个接口的类的一组需求 接口定义的关键词是interface span class token key
  • 卫星系统算法课程设计 - 第二部分 qt的安装与创建项目

    上一篇文章只讲了基本的东西 xff0c 这一篇要完成qt的安装 xff0c 构建项目 xff0c 并且将上一篇的代码导入进去 某比利比例搜qt安装 xff0c 看到qt5 14 2的下载安装 xff0c 跟着做 1 创建项目 创建新项目 x
  • 无人机-材料准备

    xff08 未完成 xff09 一 使用空心杯电机 xff0c 型号8520 xff0c 1S版本 xff0c 约5G每只 二 空心杯机架 xff0c 型号QX90 xff0c 约8 5g 三 使用55MM桨 四 1S 600MA电池 五
  • CMake中链接库的顺序问题

    原文链接 xff1a https blog csdn net lifemap article details 7586363 cmake中链接库的顺序是a依赖b xff0c 那么b放在a的后面 例如进程test依赖a库 b库 a库又依赖b
  • 鸿蒙wifi Demo运行

    title 鸿蒙Wi Fi Demo运行 date 2021 1 1 22 25 10 categories harmony 本文首发于LHM s notes 欢迎关注我的博客 坑有点多 由于之前没有看过wifi的内核态代码 xff0c 所
  • 将TensorFlow训练好的模型迁移到Android APP上(TensorFlowLite)

    将TensorFlow训练好的模型迁移到Android APP上 xff08 TensorFlowLite xff09 1 写在前面 最近在做一个数字手势识别的APP xff08 关于这个项目 xff0c 我会再写一篇博客仔细介绍 xff0
  • 汉诺塔代码图文详解(递归入门)

    游戏规则 xff1a 已知条件存在A B C三根柱子 xff0c A上套有N片圆盘 如下图 目的将A上的所有圆盘移到C上约束条件每次只能移动一片圆盘 xff0c 且整个过程中只能出现小圆盘在大圆盘之上的情况 首先我们模拟 N 61 2 xf
  • STM32 最小系统电路简析

    文章目录 一 最小系统的组成1 供电电路2 外部晶振3 BOOT选择4 复位电路 二 最小系统实例1 STM32F103C8T6最小系统 三 各部分组成简析1 供电电路设计2 外部晶振原理3 BOOT设计4 复位电路设计 一 最小系统的组成
  • 带参数的宏的问题

    include 34 iostream 34 using namespace std define COMPUTE XX a a a 43 a 2 int main int a 61 2 int test1 61 COMPUTE XX 43
  • python_imbalanced-learn非平衡学习包_02_Over-sampling过采样

    python imbalanced learn非平衡学习包 01 简介 python imbalanced learn非平衡学习包 02 Over sampling过采样 后续章节待定 希望各位认可前面已更 您的认可是我的动力 Over s

随机推荐

  • TX2+JetPack3.2.1+opencv3.3.1+caffe+realsense2.0环境配置教程

    TX2 开箱 一共6样 xff0c 开机之后自带ubuntu16 04LTS的系统 xff0c ARMv8的处理器 xff0c 所以有些指令 xff0c 安装包必须与arm结构保持一致 开机之后 xff0c 按照指示进入图形界面 xff1a
  • 初视openwrt

    openwrt是一个微型的嵌入式操作系统 在编译的时候需要安装许多的工具和库 预置环境 xff1a sudo apt get install g 43 43 libncurses5 dev zlib1g dev bison flex unz
  • 滑动窗口详解

    前言 滑动窗口是双指针的一种特例 xff0c 可以称为左右指针 xff0c 在任意时刻 xff0c 只有一个指针运动 xff0c 而另一个保持静止 滑动窗口路一般用于解决特定的序列中符合条件的连续的子序列的问题 滑动窗口的时间复杂度是线性的
  • RT-Thread入门教程,环境配置和第一个代码

    1 前言 RT Thread这一个操作系统获得很多工程师的好评 xff0c 使用简单 xff0c 支持多 xff0c 有软件包可以下载 xff0c 甚至未来会有更多MicroPython的支持 xff0c 能够兼容主流的一些MCU xff0
  • DHT12温湿度传感器IIC,I2C接口调试心得和代码说明

    来源 xff1a http www fuhome net bbs forum php mod 61 viewthread amp tid 61 2141 DHT11那个单总线的温湿度传感器用的很多了 xff0c aosong推出了DHT12
  • 升级windows11如何在电脑上启用TPM2.0

    本文适用于无法升级到 Windows 11 xff0c 因为他们的电脑当前未启用 TPM 2 0 或其电脑能够运行 TPM 2 0 xff0c 但并未设置为运行 TPM 2 0 1 下载微软电脑健康状况检查 下载地址为 xff1a Wind
  • python调用谷歌翻译

    from GoogleFreeTrans import Translator if name 61 61 39 main 39 translator 61 Translator translator src 61 39 en 39 dest
  • C++(4) 运算符重载

    C 43 43 学习心得 xff08 1 xff09 运算符重载 from 谭浩强 C 43 43 面向对象程序设计 第一版 2014 10 6 4 1什么是运算符重载 用户根据C 43 43 提供的运算符进行重载 xff0c 赋予它们新的
  • C++学习心得(3)多态性与虚函数

    C 43 43 学习心得 xff08 3 xff09 多态性与虚函数 from 谭浩强 C 43 43 面向对象程序设计 第一版 2014 10 13 6 1 多态性的概念 在C 43 43 中 xff0c 多态性是指具有不同功能的函数可以
  • C发送http请求

    C语言发送http请求和普通的socket通讯 原理是一样的 无非就三步connect 连上服务器 send 发送数据 recv 接收数据 只不过发送的数据有特定的格式 下面的是简单发送一个http请求的例子 span class hljs
  • tensorflow(四十七):tensorflow模型持久化

    模型保存 span class token keyword from span tensorflow span class token keyword import span graph util graph def span class
  • git subtree使用

    在一个git项目下引用另一个项目的时 xff0c 我们可以使用 git subtree 使用 git subtree 时 xff0c 主项目下包含子项目的所有代码 使用 git subtree 主要关注以下几个功能 一个项目下如何引入另一个
  • tensorflow(四十八): 使用tensorboard可视化训练出的文本embedding

    对应 tensorflow 1 15版本 log dir span class token operator 61 span span class token string 34 logdir 34 span metadata path s
  • java中数组之间的相互赋值

    前言 本文考虑的研究对象是数组 xff0c 需要明确的是在java中 xff0c 数组是一种对象 xff0c java的所有对象的定义都是放在堆当中的 xff0c 对象变量之间的直接赋值会导致引用地址的一致 在java中声明一个数组 spa
  • tensorflow学习笔记(十):sess.run()

    session run fetch1 fetch2 关于 session run fetch1 fetch2 xff0c 请看http stackoverflow com questions 42407611 how tensorflow
  • tensorflow学习笔记(二十三):variable与get_variable

    Variable tensorflow中有两个关于variable的op xff0c tf Variable 与tf get variable 下面介绍这两个的区别 tf Variable与tf get variable tf Variab
  • pytorch 学习笔记(一)

    pytorch是一个动态的建图的工具 不像Tensorflow那样 xff0c 先建图 xff0c 然后通过feed和run重复执行建好的图 相对来说 xff0c pytorch具有更好的灵活性 编写一个深度网络需要关注的地方是 xff1a
  • pytorch学习笔记(五):保存和加载模型

    span class hljs comment 保存和加载整个模型 span torch save model object span class hljs string 39 model pkl 39 span model 61 torc
  • tensorflow:自定义op简单介绍

    本文只是简单的翻译了 https www tensorflow org extend adding an op 的简单部分 xff0c 高级部分请移步官网 可能需要新定义 c 43 43 operation 的几种情况 xff1a 现有的
  • pytorch学习笔记(十二):详解 Module 类

    Module 是 pytorch 提供的一个基类 xff0c 每次我们要 搭建 自己的神经网络的时候都要继承这个类 xff0c 继承这个类会使得我们 搭建网络的过程变得异常简单 本文主要关注 Module 类的内部是怎么样的 初始化方法中做