nn.Module模块

2023-11-15

1. 模块化接口nn


torch.nn是pytorch中专门为神经网络设计的模块化接口。nn构建于autograd之上,可以用来定义和运行神经网络。

2. nn.Module


nn.Module是nn中十分重要的类,包含网络各层的定义及forward方法。

定义自已的网络时,需要继承nn.Module类,并实现forward方法

一般把网络中具有可学习参数的层放在构造函数__init__()中,不具有可学习参数的层(如ReLU)可放在构造函数中,也可不放在构造函数中(而在forward中使用nn.functional来代替)
    
只要在nn.Module的子类中定义了forward函数,backward函数就会被自动实现(利用Autograd)。

在forward函数中可以使用任何Variable支持的函数,毕竟在整个pytorch构建的图中,是Variable在流动。还可以使用if,for,print,log等python语法.
    


注:Pytorch基于nn.Module构建的模型中,只支持mini-batch的Variable输入方式,比如,只有一张输入图片,也需要变成 N x C x H x W 的形式:
 
    input_image = torch.FloatTensor(1, 28, 28)
    input_image = Variable(input_image)
    input_image = input_image.unsqueeze(0)   # 1 x 1 x 28 x 28
 
 

# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
 
class LeNet(nn.Module):
    def __init__(self):
        # nn.Module的子类函数必须在构造函数中执行父类的构造函数
        super(LeNet, self).__init__()   # 等价与nn.Module.__init__()
 
        # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
        # 当调用self.conv1(input)的时候,就会调用该类的forward函数
        self.conv1 = nn.Conv2d(1, 6, (5, 5))   # output (N, C_{out}, H_{out}, W_{out})`
        self.conv2 = nn.Conv2d(6, 16, (5, 5))
        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))  # F.max_pool2d的返回值是一个Variable
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
 
        # 返回值也是一个Variable对象
        return x
 
 
def output_name_and_params(net):
    for name, parameters in net.named_parameters():
        print('name: {}, param: {}'.format(name, parameters))
 
 
if __name__ == '__main__':
    net = LeNet()
    print('net: {}'.format(net))
    params = net.parameters()   # generator object
    print('params: {}'.format(params))
    output_name_and_params(net)
 
    input_image = torch.FloatTensor(10, 1, 28, 28)
 
    # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
    # 这可以从forward中每一步的执行结果可以看出
    input_image = Variable(input_image)
 
    output = net(input_image)
    print('output: {}'.format(output))
    print('output.size: {}'.format(output.size()))

个人学习记录,方便查看,侵删。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

nn.Module模块 的相关文章

  • 如何在 PyTorch 数据加载器中将 RGB 图像转换为灰度图像?

    我已经从 MNIST 数据集中下载了一些示例图像 jpg格式 现在我正在加载这些图像来测试我的预训练模型 transforms to apply to the data trans transforms Compose transforms
  • 当我有另一个具有该版本的 conda 环境时,为什么 pip 不允许我在新的 conda 环境中安装 torch==1.9.1+cu111 ?

    当我在新的 conda 环境中运行 pip install 时 base brando9 pip install torch 1 9 1 cu111 torchvision 0 10 1 cu111 torchaudio 0 9 1 f h
  • Pytorch 说 CUDA 不可用(在 Ubuntu 上)

    我正在尝试在我拥有的笔记本电脑上运行 Pytorch 这是一个较旧的型号 但它确实有 Nvidia 显卡 我意识到这可能不足以实现真正的机器学习 但我正在尝试这样做 以便我可以了解安装 CUDA 的过程 我已按照上面的步骤操作安装指南 ht
  • 检查 PyTorch 张量在 epsilon 内是否相等

    如何检查两个 PyTorch 张量在语义上是否相等 考虑到浮点错误 我想知道元素是否仅相差一个小的 epsilon 值 在撰写本文时 这是最新稳定版本 0 4 1 中的一个未记录的函数 但文档位于master unstable branch
  • PyTorch 中的截断反向传播(代码检查)

    我正在尝试在 PyTorch 中实现随时间截断的反向传播 对于以下简单情况K1 K2 我下面有一个实现可以产生合理的输出 但我只是想确保它是正确的 当我在网上查找 TBTT 的 PyTorch 示例时 它们在分离隐藏状态 将梯度归零以及这些
  • Win10 64位上CUDA 12的PyTorch安装

    我需要在我的 PC 上安装 PyTorch 其 CUDA 版本 12 0 pytorch 2 的表 https i stack imgur com X13oS png in In 火炬网站 https pytorch org get sta
  • Pytorch:了解 nn.Module 类内部如何工作

    一般来说 一个nn Module可以由子类继承 如下所示 def init weights m if type m nn Linear torch nn init xavier uniform m weight class LinearRe
  • 如何避免 PyTorch 中的“CUDA 内存不足”

    我认为对于 GPU 内存较低的 PyTorch 用户来说 这是一个非常常见的消息 RuntimeError CUDA out of memory Tried to allocate X MiB GPU X X GiB total capac
  • torchvision.transforms.Normalize 是如何操作的?

    我不明白如何标准化Pytorch works 我想将平均值设置为0和标准差1跨越张量中的所有列x形状的 2 2 3 一个简单的例子 gt gt gt x torch tensor 1 2 3 4 5 6 7 8 9 10 11 12 gt
  • 在 PyTorch 中原生测量多类分类的 F1 分数

    我正在尝试在 PyTorch 中本地实现宏 F1 分数 F measure 而不是使用已经广泛使用的sklearn metrics f1 score https scikit learn org stable modules generat
  • torch.mm、torch.matmul 和 torch.mul 有什么区别?

    阅读完 pytorch 文档后 我仍然需要帮助来理解之间的区别torch mm torch matmul and torch mul 由于我不完全理解它们 所以我无法简明地解释这一点 B torch tensor 1 1207 0 3137
  • 尝试理解 Pytorch 的 LSTM 实现

    我有一个包含 1000 个示例的数据集 其中每个示例都有5特征 a b c d e 我想喂7LSTM 的示例 以便它预测第 8 天的特征 a 阅读 nn LSTM 的 Pytorchs 文档 我得出以下结论 input size 5 hid
  • PyTorch 中复数矩阵的行列式

    有没有办法在 PyTorch 中计算复矩阵的行列式 torch det未针对 ComplexFloat 实现 不幸的是 目前尚未实施 一种方法是实现您自己的版本或简单地使用np linalg det 这是一个简短的函数 它计算我使用 LU
  • pytorch 中的 autograd 可以处理同一模块中层的重复使用吗?

    我有一层layer in an nn Module并在一次中使用两次或多次forward步 这个的输出layer稍后输入到相同的layer pytorch可以吗autograd正确计算该层权重的梯度 def forward x x self
  • Pytorch Tensor 如何获取元素索引? [复制]

    这个问题在这里已经有答案了 我有 2 个名为x and list它们的定义如下 x torch tensor 3 list torch tensor 1 2 3 4 5 现在我想获取元素的索引x from list 预期输出是一个整数 2
  • Pytorch 损失为 nan

    我正在尝试用 pytorch 编写我的第一个神经网络 不幸的是 当我想要得到损失时遇到了问题 出现以下错误信息 RuntimeError Function LogSoftmaxBackward0 returned nan values in
  • pytorch 的 IDE 自动完成

    我正在使用 Visual Studio 代码 最近尝试了风筝 这两者似乎都没有 pytorch 的自动完成功能 这些工具可以吗 如果没有 有人可以推荐一个可以的编辑器吗 谢谢你 使用Pycharmhttps www jetbrains co
  • 在 Pytorch 中估计高斯模型的混合

    我实际上想估计一个以高斯混合作为基本分布的归一化流 所以我有点被火炬困住了 但是 您可以通过估计 torch 中高斯模型的混合来在代码中重现我的错误 我的代码如下 import numpy as np import matplotlib p
  • Pytorch 与 joblib 的 autograd 问题

    将 pytorch 的 autograd 与 joblib 混合似乎存在问题 我需要并行获取大量样本的梯度 Joblib 与 pytorch 的其他方面配合良好 但是 与 autograd 混合时会出现错误 我做了一个非常小的例子 显示串行
  • 样本()和r样本()有什么区别?

    当我从 PyTorch 中的发行版中采样时 两者sample and rsample似乎给出了类似的结果 import torch seaborn as sns x torch distributions Normal torch tens

随机推荐

  • QT实现弹窗

    第一行申请的栈空间 函数运行结束后内存释放 弹窗会闪退 换用第二行申请堆空间可解决 但是窗口弹出后可以对其他窗口进行操作 不符合要求 将第四行换用dialog gt exec 即可解决 QDialog exec 模态 应用程序级 窗口显示
  • C语言中字符数组的初始化问题

    1 参考博客 https blog csdn net cherrydreamsover article details 81741459 1 char a Hello 按字符串初始化 大小为6 2 char b H e l l 按字符初始化
  • 网络编程_bind函数返回值

    define WINSOCK DEPRECATED NO WARNINGS include
  • Shell 脚本中 '$' 符号的多种用法

    来源 JackTian 杰哥的IT之旅 https mp weixin qq com s XBu7G UxPs2dv6fsPXGq4w 通常情况下 在工作中用的最多的有如下几项 0 Shell 的命令本身 1 到 9 表示 Shell 的第
  • http请求头中的Accept的用处和常用的值

    1 Accept属于http请求头 描述客户端希望接收的响应body 数据类型 就是希望服务器返回什么类型的数据 2 常见的媒体格式类型如下 text html HTML格式 text plain 纯文本格式 text xml XML格式
  • Linux Ubuntu下各种TensorFlow版本所对应(匹配)的Python、GCC编译器、Build tools、cuDNN、CUDA版本

    参考TensorFlow官网 https www tensorflow org install source common installation problems
  • element ui 多张图片上传、回显、删除

    element ui 多张图片上传 回显 删除 前端文件上传 1 展示部分
  • 计算机为什么负数不用减一,计算机的加减乘除(原码反码补码)

    计算机对数的操作 以二进制为基 因为电子原件只能表达0 1 开或关这两种状态 如果学过模电和数电 对此的理解会更深 比如说十进制9 在计算机里不可能单独记个9 而是记录成0000 1001 第一位符号位 0表示正数 但是 9 在计算机里记得
  • TensorRT部署神经网络

    TensorRT部署神经网络 大佬的讲解记录一下 基础知识 TensorRT使用例子 TensorRT加速模型 示例代码 这个脚本向你展示了如何使用 torch2trt 加速 pytorch 推理 截止目前为止 torch2trt 的适配能
  • swagger2 注解说明

    Api 用在请求的类上 表示对类的说明 tags 说明该类的作用 可以在UI界面上看到的注解 value 该参数没什么意义 在UI界面上也看到 所以不需要配置 ApiOperation 用在请求的方法上 说明方法的用途 作用 value 说
  • 如何用硬币模拟1/3的概率,以及任意概率?

    突然想起一个挺有意思的事 如何用硬币模拟1 3的概率 甚至任意概率 之前和朋友偶然间谈到如何用硬币模拟任何概率 当时以为是不可能的 因为硬币有两面 模拟的结果底数一定是2 n 今天又回顾了某个经典的条件概率问题 突然想到用硬币模拟任意概率是
  • IT职业发展路线

    网上找的
  • 第九课移动与相机

    讲的是shift 物体的移动轴 则摄像机与物体一起运动 设置了个聚光灯 本来要把聚光灯和摄像机锁定 但是不知为何 视频教程上的lock选项 在UE4编辑器没有 应该是版本不同的缘故
  • JS中document.createElement()用法及注意事项

    今天处理了一个日期选择器的ie和ff的兼容问题 本来这种情况就很难找错误 找了好久才把错误定位到js中创建元素的方法document createElement 这个方法在ie下支持这样创建元素 var inputObj document
  • Windows下开启Astra 摄像头的三种方式

    Windows下开启Astra摄像头有三种方式 第一种 使用官方提供的Orbbec Viewer软件 在此可以修改设备分辨率并且支持多台设备同时使用 非常方便 具体效果如下 该程序直接去奥比中光官网下载即可 官网也有具体的使用的手册 答主在
  • gcc compiler error messages

    Summarizing the gcc errors I encountered to be continued 1 dereferencing pointer to incomplete type You have written som
  • IP包流量分析程序

    使用套接字编程实现捕获一段时间内以本机为源地址或目的地址的IP数据包 不包括以广播形式发出的数据包 统计IP数据包的信息 列出本机与其他主机之间不同协议类型IP数据包的数量 及流量 以源地址 目的地址 协议类型 数据包数量 流量的格式输出统
  • failed to load response data出现的问题

    分片上传的时候 状态码请求是200的状态 但是 出现了 failed to load response data 没有response的返回 原因是 我分片的 每片大小太大了 分成了10M 所以出现了这个问题 const chunkSize
  • 【2-3】《Java基础语法》——二进制、变量、数据类型、标识符、数据类型转换、特殊变量定义、方法、运算符、变量作用域、编程规范、转义字符

    文章目录 基础语法 一 二进制 1 补码 2 二进制与十进制的转换 二 变量概述 三 数据类型 1 分类 2 范围 四 标识符 1 命名规则 2 Java中的关键字 3 定义变量 4 变量练习 五 数据类型转换 六 特殊变量定义 1 flo
  • nn.Module模块

    1 模块化接口nn torch nn是pytorch中专门为神经网络设计的模块化接口 nn构建于autograd之上 可以用来定义和运行神经网络 2 nn Module nn Module是nn中十分重要的类 包含网络各层的定义及forwa