PyTorch 入坑七:模块与nn.Module学习

2023-10-29

模型创建概述

本章开始正式整理深度学习网络相关基础知识。
模型创建分为两个部分:模型结构创建权值初始化

模型结构创建从粒度上讲:有层layer网络Net 两个粒度。前者是指构成CNN的基础结构,如卷积层、池化层、BN层、激活函数层、损失函数层等;后者是指实现某一功能的网络结构如LeNet,AlexNet和ResNet等。

创建好模型后,需要对模型进行权值初始化,pytorch提供了丰富的初始化方法,Xavier,Kaiming,均匀分布,正态分布等。好的权重初始化有以下几点优点:

  • 加速网络收敛
  • 解更优
    在这里插入图片描述

PyTorch中的模块

PyTorch主要包括以下16个模块:

torch模块

torch模块本身包含了PyTorch经常使用的一些激活函数,比如Sigmoid(torch.sigmoid)、ReLU(torch.relu)和Tanh(torch.tanh),以及PyTorch张量的一些操作,比如矩阵的乘法(torch.mm)、张量元素的选择(torch.select)。
需要注意的是,这些操作的对象大多数都是张量,因此,传入的参数需要是PyTorch的张量,否则会报错(一般报类型错误,即TypeError)。另外,还有一类函数能够产生一定形状的张量,比如torch.zeros产生元素全为0的张量,torch.randn产生元素服从标准正态分布的张量等。

torch.Tensor模块

torch.sparse模块

torch.sparse模块定义了稀疏张量,其中构造的稀疏张量采用的是COO格式(Coordinate),主要方法是用一个长整型定义非零元素的位置,用浮点数张量定义对应非零元素的值。稀疏张量之间可以做元素加、减、乘、除运算和矩阵乘法

torch.cuda模块

torch.cuda模块定义了与CUDA运算相关的一系列函数,包括但不限于检查系统的CUDA是否可用,当前进程对应的GPU序号(在多GPU情况下),清除GPU上的缓存,设置GPU的计算流(Stream),同步GPU上执行的所有核函数(Kernel)等。

torch.nn模块

torch.nn是一个非常重要的模块,是PyTorch神经网络模块化的核心。这个模块定义了一系列模块,包括卷积层nn.ConvNd(N=1,2,3)和线性层(全连接层)nn.Linear等。
当构建深度学习模型的时候,可以通过继承nn.Module类并重写forward方法来实现一个新的神经网络。
另外,torch.nn中也定义了一系列的损失函数,包括平方损失函数(torch.nn.MSELoss)、交叉熵损失函数(torch.nn.CrossEntropyLoss)等。一般来说,torch.nn里定义的神经网络模块都含有参数,可以对这些参数使用优化器进行训练。

torch.nn.Parameter

张量子类,表示可学习参数,如weight,bias

torch.nn.Module(本文重点)

所有网络层基类,管理网络属性

torch.nn.functional

定义了一些核神经网络相关的函数,包括卷积函数和池化函数等。

torch.nn.init

torch.nn.init模块定义了神经网络权重的初始化。包括均匀初始化torch.nn.init.uniform_和正态分布归一化torch.nn.init.normal_等。
在PyTorch中函数或者方法如果以下画线结尾,则这个方法会直接改变作用张量的值。因此,这些方法会直接改变传入张量的值,同时会返回改变后的张量。

torch.optim模块

torch.optim模块定义了一系列的优化器,如torch.optim.SGD(随机梯度下降算法)、torch.optim.Adagrad(AdaGrad算法)、torch.optim.RMSprop(RMSProp算法)和torch.optim.Adam(Adam算法)等。

这个模块还包含学习率衰减的算法的子模块,即torch.optim.lr_scheduler,这个子模块中包含了诸如学习率阶梯下降算法torch.optim.lr_scheduler.StepLR和余弦退火算法torch.optim.lr_scheduler.CosineAnnealingLR等学习率衰减算法。

torch.autograd模块

torch.autograd模块是PyTorch的自动微分算法模块,定义了一系列的自动微分函数,包括torch.autograd.backward函数,主要用于在求得损失函数之后进行反向梯度传播,torch.autograd.grad函数用于一个标量张量(即只有一个分量的张量)对另一个张量求导,以及在代码中设置不参与求导的部分。
另外,这个模块还内置了数值梯度功能和检查自动微分引擎是否输出正确结果的功能。

torch.distributed模块

torch.distributed是PyTorch的分布式计算模块,主要功能是提供PyTorch并行运行环境,其主要支持的后端有MPI、Gloo和NCCL三种。

PyTorch的分布式工作原理主要是启动多个并行的进程,每个进程都拥有一个模型的备份,然后输入不同的训练数据到多个并行的进程,计算损失函数,每个进程独立地做反向传播,最后对所有进程权重张量的梯度做归约(Reduce)。
用到后端的部分主要是数据的广播(Broadcast)和数据的收集(Gather),其中,前者是把数据从一个节点(进程)传播到另一个节点(进程),比如归约后梯度张量的传播,后者则是把数据从其他节点(进程)转移到当前节点(进程),比如把梯度张量从其他节点转移到某个特定的节点,然后对所有的张量求平均。PyTorch的分布式计算模块不但提供了后端的一个包装,还提供了一些启动方式来启动多个进程,包括但不限于通过网络(TCP)、通过环境变量、通过共享文件等。

torch.distributions模块

torch.distributions模块提供了一系列类,使得PyTorch能够对不同的分布进行采样,并且生成概率采样过程的计算图。

torch.hub模块

提供了一系列预训练的模型供用户使用。比如,可以通过torch.hub.list函数来获取某个模型镜像站点的模型信息。通过torch.hub.load来载入预训练的模型,载入后的模型可以保存到本地,并可以看到这些模型对应类支持的方法。

torch.jit模块

是PyTorch的即时编译器(Just-In-Time Compiler,JIT)模块。这个模块存在的意义是把PyTorch的动态图转换成可以优化和序列化的静态图,其主要工作原理是通过输入预先定义好的张量,追踪整个动态图的构建过程,得到最终构建出来的动态图,然后转换为静态图(通过中间表示,即IntermediateRepresentation,来描述最后得到的图)。通过JIT得到的静态图可以被保存,并且被PyTorch其他的前端(如C++语言的前端)支持。另外,JIT也可以用来生成其他格式的神经网络描述文件,如前文叙述的ONNX。需要注意的一点是,torch.jit支持两种模式,即脚本模式(ScriptModule)和追踪模式(Tracing)。前者和后者都能构建静态图,区别在于前者支持控制流,后者不支持,但是前者支持的神经网络模块比后者少,比如脚本模式不支持torch.nn.GRU

torch.multiprocessing模块

torch.multiprocessing定义了PyTorch中的多进程API。通过使用这个模块,可以启动不同的进程,每个进程运行不同的深度学习模型,并且能够在进程间共享张量(通过共享内存的方式)。共享的张量可以在CPU上,也可以在GPU上,多进程API还提供了与Python原生的多进程API(即multiprocessing库)相同的一系列函数,包括锁(Lock)和队列(Queue)等

torch.random模块

提供了一系列的方法来保存和设置随机数生成器的状态.
神经网络的训练是一个随机的过程,包括数据的输入、权重的初始化都具有一定的随机性。设置一个统一的随机种子可以有效地帮助我们测试不同结构神经网络的表现,有助于调试神经网络的结构
get_rng_state函数获取当前随机数生成器状态,set_rng_state函数设置当前随机数生成器状态,并且可以使用manual_seed函数来设置随机种子,也可使用initial_seed函数来得到程序初始的随机种子

torch.onnx模块

torch.onnx定义了PyTorch导出和载入ONNX格式的深度学习模型描述文件。 ONNX格式的存在是为了方便不同深度学习框架之间交换模型。引入这个模块可以方便PyTorch导出模型给其他深度学习框架使用,或者让PyTorch可以载入其他深度学习框架构建的深度学习模型。

torch.utils

torch.utils.bottleneck模块

用来检查深度学习模型中模块的运行时间,从而可以找到导致性能瓶颈的那些模块,通过优化那些模块的运行时间,从而优化整个深度学习模型的性能

torch.utils.checkpoint模块

用来节约深度学习使用的内存。通过前面的介绍我们知道,因为要进行梯度反向传播,在构建计算图的时候需要保存中间的数据,而这些数据大大增加了深度学习的内存消耗。为了减少内存消耗,让迷你批次的大小得到提高,从而提升深度学习模型的性能和优化时的稳定性,我们可以通过这个模块记录中间数据的计算过程,然后丢弃这些中间数据,等需要用到的时候再重新计算这些数据。这个模块设计的核心思想是以计算时间换内存空间,当使用得当的时候,深度学习模型的性能可以有很大的提升

torch.utils.cpp_extension模块

定义了PyTorch的C++扩展,其主要包含两个类:CppExtension定义了使用C++来编写的扩展模块的源代码相关信息,CUDAExtension则定义了C++/CUDA编写的扩展模块的源代码相关信息。
在某些情况下,用户可能需要使用C++实现某些张量运算和神经网络结构(比如PyTorch没有类似功能的模块或者PyTorch类似功能的模块性能比较低),PyTorch的C++扩展模块就提供了一个方法能够让Python来调用使用C++/CUDA编写的深度学习扩展模块。在底层上,这个扩展模块使用了pybind11,保持了接口的轻量性并使得PyTorch易于被扩展。在后续章节会介绍如何使用C++/CUDA来编写PyTorch的扩展

torch.utils.data模块

引入了数据集(Dataset)和数据载入器(DataLoader)的概念,前者代表包含了所有数据的数据集,通过索引能够得到某一条特定的数据,后者通过对数据集的包装,可以对数据集进行随机排列(Shuffle)和采样(Sample),得到一系列打乱数据顺序的迷你批次

torch.utils.dlpacl模块

定义了PyTorch张量和DLPack张量存储格式之间的转换,用于不同框架之间张量数据的交换。

torch.utils.tensorboard模块

是PyTorch对TensorBoard数据可视化工具的支持。TensorBoard原来是TensorFlow自带的数据可视化工具,能够显示深度学习模型在训练过程中损失函数、张量权重的直方图,以及模型训练过程中输出的文本、图像和视频等。TensorBoard的功能非常强大,而且是基于可交互的动态网页设计的,使用者可以通过预先提供的一系列功能来输出特定的训练过程的细节(如某一神经网络层的权重的直方图,以及训练过程中某一段时间的损失函数等)。PyTorch支持TensorBoard可视化之后,在PyTorch的训练过程中,可以很方便地观察中间输出的张量,也可以方便地调试深度学习模型。

torch.nn.Module

所有的网络层都是继承于这个类。
nn.module中有八个重要的属性用于管理整个类,主要关注其中的parameters和modules两个属性。

  • parameters:管理存储属于nn.Parameter类的属性,例如权值或者偏置参数;
  • modules:用来存储管理nn.Module类,例如在LeNet中会构建子模块,modules就会存储创建的卷积层等;
  • buffers:存储管理缓冲的属性,如训练过程中BN的均值,或者是方差都会存储在buffers
  • *_hooks:存储管理钩子函数(5个,暂时不去了解)
    在这里插入图片描述

nn.Module的属性构建会在module类中进行属性赋值的时候会被setattr()函数拦截,在这个函数当中会判断即将要赋值的数据类型是否是nn.parameters类,如果是的话就会存储到parameters字典中;如果是module类就会存储到modul字典中

nn.Module构建参考代码:
LeNet的构建过程

总结

  • 一个module可以包含多个子module;例如LeNet包含很多个子module,例如卷积层,池化层等。
  • 一个module相当于一个运算,必须实现forward()函数;
  • 每个module都有8个字典管理它的属性;
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch 入坑七:模块与nn.Module学习 的相关文章

随机推荐

  • idea使用lombok插件不能生效的原因

    要成功的使用lombok插件 需要3个步骤 一 需要先在idea中下载Lombok plugin 点击File gt settings gt plugins gt 然后点击以下图中所示 接着 在输入框输入lombok进行搜索 之后点击安装便
  • 粤嵌GEC6818-学习笔记2-屏幕相关及音频播放

    这里写目录标题 LCD屏幕 简介 操作 打开屏幕 映射 如何让plcd指向屏幕首地址 BMP图片的解析 把一张BMP格式的图片显示在我们的开发板上 触摸板的相关操作 练习 获取屏幕坐标 线程进程 练习 创建广告播放的一个线程 音频播放 播放
  • STM32——GPIO输入——按键检测

    硬件介绍 当按键置空时 IO接地 按键按下之后 IO口接通3 3V高电压 电流比较大 为了避免损坏IO 这里需要加装一个限流电阻 可以看到IO口是默认低电平 按键按下后产生一个上升沿 和平常的电路设计不太一样 这是因为PA0还具有一种自动唤
  • centos7网卡配置参数详细

    CentOS 7 中的网卡配置参数通常位于 etc sysconfig network scripts ifcfg
  • Python爬虫从入门到精通:(1)爬虫基础简介_Python涛哥

    第一章 爬虫基础简介 爬虫概述 前戏 你是否在夜深人静的时候 想看一些会让你更睡不着的图片 你是否在考试或者面试前夕 想看一些具有针对性的题目和面试题 你是否想在杂乱的网络世界获取你想要的数据 爬虫的价值 实际应用 就业 什么是爬虫 通过编
  • TensorFlow学习(4) 学习率调度 & 正则化

    1 学习率调度 恒定高学习率训练可能会发散 低学习率会收敛到最优解但是会花费大量时间 1 1 常用的学习率调度及其概念 幂调度 指数调度 分段调度 性能调度 1 2 实现幂调度 在创建优化器时 设置超参数decay 使用示例 optimiz
  • Python 面向对象程序设计类的使用、继承等

    这个实验主要通过了解对象 类 封装 继承 方法 构造函数和析构函数等面向对象的程序设计的基本概念 掌握 Python 类的定义 类的方法 类的继承等 在做实验时要注意 init 应该是4个下划线 前后各两个 也要注意自己的属性条件 并且也可
  • 对 tcp out-of-window 的安全建议

    TCP 收到一个 out of window 报文后会立即回复一个 ack 这是 RFC793 中 SEGMENT ARRIVES 段的要求 但这是为什么 难道不是默默丢弃才对吗 对 oow 报文回复 ack 岂不是把正确的 ack 号回过
  • L2-041 插松枝

    include
  • 复习1: 深度学习优化算法 SGD -> SGDM -> NAG ->AdaGrad -> AdaDelta -> Adam -> Nadam 详细解释 + 如何选择优化算法

    深度学习优化算法经历了 SGD gt SGDM gt NAG gt AdaGrad gt AdaDelta gt Adam gt Nadam 这样的发展历程 优化器其实就是采用何种方式对损失函数进行迭代优化 也就是有一个卷积参数我们初始化了
  • 无向图染色

    无向图染色 给一个无向图染色 可以填红黑两种颜色 必须保证相邻两个节点不能同时为红色 输出有多少种不同的染色方案 输入描述 第 行输入M 图中节点数 N 边数 后续N行格式为 V1V2表示一个V1到V2的边 数据范围 1 lt M lt 1
  • 研发工具链介绍

    本节课程为 研发工具链介绍 我们将主要学习三个工具 项目管理工具 iCafe 代码管理工具 iCode 交付平台 iPipe 此外我们知道 管理实践具有以下三个特点 用 精益 指引产品规划 用 敏捷 加速迭代开发 用 数据 驱动持续改进 而
  • 那些在一个公司死磕5年以上的测试,最后都怎么样了?

    2023年的测试市场是崩溃的 即使是老员工 也要面对裁员 降薪 外包化 没前途 薪资不过20k 没有面试 找不到工作 确实都客观存在 但与此同时 也有不少卷赢同行拿高薪的案例 因为只要互联网存在 测试就是刚需 只是需要更卷一些了 这里我准备
  • MSRA实习申请经验分享

    MSRA实习申请经验分享 自我介绍 简历投递 面试 成败关键点 自我介绍 博主目前大四 因为大四下没啥事想申请到MSRA实习半年 不久前成功申请到了MSRA的实习 这里简单分享一下经验 首先自我介绍一下 本人本科是国内某top10的985高
  • springboot简单整合logback日志框架

    引入依赖 实际上我们只需要引入springboot的的web依赖就可以了 springboot是默认整合logback的依赖的 编写xml文件 xml文件默认叫做logback xml 放在resource目录下就可以
  • python画桃心表白

    python用turtle画简单图案比较方便 大一学python的turtle模块时 记得要画各种图案 如国旗 桃心等等图案 期末课程设计时有可能还会遇到画54张扑克牌 当初室友就被迫选了这道题 下面是程序 import turtle im
  • 基于FREERTOS系统的LWIP协议移植(STM32F1战舰版)

    文章目录 参考文献 前言 源码链接 FREERTOS系统介绍 FREERTOS系统之API函数 1 创建任务函数xTaskCreate 2 删除任务函数xTaskDelete 3 创建二值信号量函数xSemaphoreCreateBinar
  • 找不到BufferedImage这个Class的解决方法

    找不到BufferedImage这个Class的解决方法 环境 1 RedHat AS5 64位 2 WebSphere6 0 32位版本 正文 发现原来在RedHat AS4 32位系统上跑的程序不能在64位RedHat AS5中运行 系
  • 你还在 Docker 中跑 MySQL?恭喜你,好下岗了!

    上一篇 一个90后员工猝死的全过程 0 2T架构师学习资料干货分享 来源 toutiao com i6675622107390411276 容器的定义 容器是为了解决 在切换运行环境时 如何保证软件能够正常运行 这一问题 目前 容器和 Do
  • PyTorch 入坑七:模块与nn.Module学习

    PyTorch 入坑七 模型创建概述 PyTorch中的模块 torch模块 torch Tensor模块 torch sparse模块 torch cuda模块 torch nn模块 torch nn Parameter torch nn