【Pytorch】常用函数功能介绍和注意事项

2023-11-19

【持续更新中…】

数据预处理

Variable

from torch.autograd import Variable 

作用:自动微分变量,用于构建计算图

网络层定义

torch.nn.BatchNorm2d()

设尺寸为N*C*H*W,其中N代表batchsize,C表示通道数(例如RGB三通道),H,W分别表示feature map的宽高。

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
  • num_features:通道数,例如RGB为3
  • eps:一个加至分母的参数,为提高计算稳定性
  • momentum:运行中调整均值、方差的估计参数
  • affine:当设为true时,给定可以学习的系数矩阵\gammaγ和 \beta

torch.nn.Linear()

torch.nn.Linear(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None)
  • in_features:输入的二维张量的大小
  • out_features:输出的二维张量的大小

torch.nn.Sequential()

 self.conv4 = torch.nn.Sequential(
            torch.nn.Conv2d(64, 64, 2, 2, 0),  
            torch.nn.BatchNorm2d(64), 
            torch.nn.ReLU()#ReLU激活函数) 
        self.mlp1 = torch.nn.Linear(2 * 2 * 64, 100)#torch.nn.Linear定义全连接层,conv4为2*2*64
        self.mlp2 = torch.nn.Linear(100, 10) 

使用损失函数和优化器的步骤

  1. 获取损失:loss=loss_fuction(out,batch_y)
  2. 清空上一步残余更新参数:opt.zero_grad()
  3. 误差反向传播:loss.backward()
  4. 将参数更新值施加到net的parmeter上:opt.step()

模型相关参数配置(使用argparse.ArgumentParser)

简介

argparse是一个Python模块:命令行选项、参数和子命令解析器。

使用方法

  1. 创建解析器
parser = argparse.ArgumentParser(description='Process some integers.')
  1. 添加参数
parser.add_argument('integers', metavar='N', type=int, nargs='+', help='an integer for the accumulator')

例如在神经网络训练过程中,我们需要定义训练的初始学习率,并设置默认值为0.00001,可通过以下代码实现:

parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
                        help='Learning rate', dest='lr')
  1. 解析参数
args = parser.parse_args()

模型输出

torch.max(out,1)[1]
output = torch.max(input, dim)

  • input是softmax函数输出的一个tensor
  • dim是max函数索引的维度0/1,0是每列最大值,1是每行最大值

函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引。
(PS:第一个tensor value是不需要的,我们仅提取第二个tensor并将数据转换为array格式
torch.max(out,1)[1].numpy()

网络参数查看

model_dict=torch.load('D:\JetBrains\pythonProject\CMRI\CNN\CNN_Log')
parameters=list(model_dict.named_parameters())

补充知识点

downsampling(向下采样) & upsampling

down-sampling通过舍弃一些元素,实现图像的缩放

在CNN中,汇合层(Pooling layer)通过max poolingaverage pooling等操作,使汇合后结果中一个元素对应于原输入数据的一个子区域,因此汇合操作实际上就是一种”降采样“操作

up-sampling 可实现图像的放大或分辨率的优化等

常用方法:

  • Bilinear(双线性插值法):只需要设置好固定的参数值即可,设置的参数就是中心值需要乘以的系数。
  • Deconvolution(反卷积):参考https://github.com/vdumoulin/conv_arithmetic
  • Unpooling(反池化):在反池化过程中,将一个元素根据kernel进行放大,根据之前的坐标将元素填写进去,其他位置补0
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【Pytorch】常用函数功能介绍和注意事项 的相关文章

随机推荐

  • 为什么要选择云原生数据库

    为什么要选择云原生数据库 前言 1 传统数据库 1 1 传统数据库概念 1 2 传统数据库优缺点 1 2 1 优点 1 2 2 缺点 2 云原生数据库 2 1 云原生数据库概念 2 2 云化代表未来 2 3 云原生数据库的优势 2 3 1
  • 【MATLAB】字符串的处理及矩阵的初步学习

    欢迎访问我的个人网站 reality2ideal xyz 内容在CSDN和个人网站上同步更新 字符串处理 字符串矩阵 gt gt ch 123456 qwerty ch 2 6 char 数组 123456 qwerty 字符串矩阵的列数要
  • 转载:R语言绘图—图形标题、坐标轴设置

    R语言绘图是通过函数命令及相应参数设置实现的 如plot x y plot为绘图函数命令 x y则是绘图参数 指定了绘图的数据向量 但这种最基本的绘图设置很难满足个性化绘图的要求 我们需要根据需要对图形元素进行设置 图形元素是各类图形的基本
  • 生成带干扰线的验证码

    import java awt Color import java awt Font import java awt Graphics2D import java awt Transparency import java awt image
  • vue吸顶导航栏_vue2组件系列第四十二节:NavBar 导航栏

    NavBar就是程序顶部的内容 相当于网站里的面包屑的作用 准备工作 创建一个页面 NavBar vue 在router js里配置NavBar页面的路由 path navbar name navbar component gt impor
  • 左右手坐标系区别和联系

    本文是分析 所谓的右手坐标系转换为左手坐标系需要的 z轴取反 x轴取反 或者改变摄像机位置 渲染绕序改变 其中的进一步的原因 参考文章 https msdn microsoft com en us library bb204853 28VS
  • 真伪定时器

    首先观察一下下面两组代码区别在哪里 第一组代码 setInterval gt 1 5s 的同步逻辑 1000 第二组代码 function fn setTimeout gt 1 5s 的同步逻辑 fn 1000 fn 两组代码都有定时功能
  • Java实体类详解及使用方法

    在Java编程中 实体类 Entity Class 是一种经常使用的类类型 实体类用于表示真实世界中的对象 通常与数据库中的表格相对应 本文将详细介绍Java实体类的概念 特点以及使用方法 什么是实体类 实体类是指用于表示和存储真实世界中的
  • 【论文精读】A view-free image stitching network based on global homography-基于全局单应的无视图图像拼接网络

    论文链接地址 代码链接地址 关于本文的代码 我已经调试过了 在调试过程中遇到的错误 我也做了一些总结 有需要的可以参考这篇博文 A view free image stitching network based on global homo
  • Spring Boot集成控制反转

    Most of the time dependency injection is the first thing that comes to mind whenever the concept of inversion of control
  • idea 2021.1安装 与 常用配置

    前置说明 该文档是基于idea 2021 1版本编写的 一 下载安装 官方下载地址 https www jetbrains com idea download other html 二 常用的设置 显示工具栏 设置tab选项卡换行 设置代码
  • Unity 打开时一直busy怎么办

    查看网络连接 比如360流量球或者任务管理器内的网络 如果能看到unity在下载东西或网络占用高 则表明可能是unity在下载在线资源 查看 工程目录 Package manifest json 文件是否存在国外地址 可能是由于网络原因连不
  • RabbitMq——发布确认高级和消息回退

    发布确认高级 消息在传递过程中 我们需要确定消息状态信息 开启发布确认高级模式 消息传递结束后会返回传递结果信息 若发送失败的消息 该消息会被存入缓存中 定时任务发送失败消息 交换机收到消息后 缓存会删除该信息 如果只开启发布确认模式的话
  • java多线程的意义

    https www zhihu com question 332042250
  • 前缀和与差分(分析与模板)

    前缀和 处理数组公式 s i s i 1 num i 输出区间和公式 s r s l 1 模板 include
  • kMeans算法(K均值聚类算法)

    机器学习中有两类的大问题 一个是分类 一个是聚类 分类是根据一些给定的已知类别标号的样本 训练某种学习机器 使它能够对未知类别的样本进行分类 这属于supervised learning 监督学习 而聚类指事先并不知道任何样本的类别标号 希
  • 【100%通过率 】【华为OD机试真题 c++ 】最大数字【 2023 Q1 A卷

    华为OD机试 题目列表 2023Q1 点这里 2023华为OD机试 刷题指南 点这里 题目描述 给定一个由纯数字组成以字符串表示的数值 现要求字符串中的每个数字最多只能出现2次 超过的需要进行删除 删除某个重复的数字后 其它数字相对位置保持
  • Android 模拟器 Genymotion 安装配置与 ARM 支持

    简介 Genymotion是一款基于x86架构的Android模拟器 由于系统启动速度 应用运行速度远远快于Android SDK自带模拟器而受到广泛应用 优缺点 优点 1 模拟器启动速度快 比AVD快很多 2 应用运行速度快 3 跨平台
  • Python面向对象类继承中发生的私有属性访问错误问题

    按照Python100days项目中的该方法来访问私有属性 可正常访问到 class Test def init self foo self foo foo def bar self print self foo print bar def
  • 【Pytorch】常用函数功能介绍和注意事项

    持续更新中 数据预处理 Variable from torch autograd import Variable 作用 自动微分变量 用于构建计算图 网络层定义 torch nn BatchNorm2d 设尺寸为N C H W 其中N代表b