Pytorch中的optimizer

2023-05-16

与优化函数相关的部分在torch.optim模块中,其中包含了大部分现在已有的流行的优化方法。

如何使用Optimizer

要想使用optimizer,需要创建一个optimizer 对象,这个对象会保存当前状态,并根据梯度更新参数。

怎样构造Optimizer

要构造一个Optimizer,需要使用一个用来包含所有参数(Tensor形式)的iterable,把相关参数(如learning rate、weight decay等)装进去。

注意,如果想要使用.cuda()方法来将model移到GPU中,一定要确保这一步在构造Optimizer之前。因为调用.cuda()之后,model里面的参数已经不是之前的参数了。

示例代码如下:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)

常用参数

last_epoch代表上一次的epoch的值,初始值为-1。

单独指定参数

也可以用一个dict的iterable指定参数。这里的每个dict都必须要params这个key,params包含它所属的参数列表。除此之外的key必须它的Optimizer(如SGD)里面有的参数。

You can still pass options as keyword arguments. They will be used as defaults, in the groups that didn’t override them. This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

这在针对特定部分进行操作时很有用。比如只希望给指定的几个层单独设置学习率:

optim.SGD([
			{'params': model.base.parameters()},
			{'params': model.classifier.parameters(), 'lr': 0.001}
			],
			 
			lr = 0.01, momentum = 0.9)

在上面这段代码中model.base将会使用默认学习率0.01,而model.classifier的参数蒋欢使用0.001的学习率。

怎样进行单次优化

所有optimizer都实现了step()方法,调用这个方法可以更新参数,这个方法有以下两种使用方法:

optimizer.step()

多数optimizer里都可以这么做,每次用backward()这类的方法计算出了梯度后,就可以调用一次这个方法来更新参数。

示例程序:

for input, target in dataset:
	optimizer.zero_grad()
	ouput = model(input)
	loss = loss_fn(output, target)
	loss.backward()
	optimizer.step()

optimizer.step(closure)

有些优化算法会多次重新计算函数(比如Conjugate Gradient、LBFGS),这样的话你就要使用一个闭包(closure)来支持多次计算model的操作。这个closure的运行过程是,清除梯度,计算loss,返回loss。
(这个我不太理解,因为这些优化算法不熟悉)

示例程序:

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    optimizer.step(closure)

优化算法

这里就不完整介绍documentation中的内容了,只介绍基类。具体的算法的参数需要理解它们的原理才能明白,这个改天单独来一篇文章介绍。

Optimizer

 class torch.optim.Optimizer(params, defaults)

这是所有optimizer的基类。

注意,各参数的顺序必须保证每次运行都一致。有些数据结构就不满足这个条件,比如dictionary的iterator和set。

参数
  1. params(iterable)是torch.Tensor或者dict的iterable。这个参数指定了需要更新的Tensor。
  2. defaults(dict)是一个dict,它包含了默认的的优化选项。
方法
add_param_group(param_group)

这个方法的作用是增加一个参数组,在fine tuning一个预训练的网络时有用。

load_state_dict(state_dict)

这个方法的作用是加载optimizer的状态。

state_dict()

获取一个optimizer的状态(一个dict)。

zero_grad()方法用于清空梯度。

step(closure)用于进行单次更新。

Adam

class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch中的optimizer 的相关文章

  • PyTorch:如何使用 DataLoaders 自定义数据集

    如何利用torch utils data Dataset and torch utils data DataLoader根据您自己的数据 不仅仅是torchvision datasets 有没有办法使用内置的DataLoaders他们使用的
  • 用我自己的值初始化pytorch卷积层

    我想知道是否有办法用我自己的值初始化 pytorch 卷积过滤器 例如 我有一个元组 0 8423 0 3778 3 1070 2 6518 我想用这些值初始化 2X2 过滤器 我该怎么做 我查找了一些答案 但他们大多使用火炬正态分布和其他
  • Pytorch 分析器显示两个不同网络的卷积平均执行时间不同

    我有两个网络 我正在对它们进行分析以查看哪些操作占用了大部分时间 我注意到CUDA time avg为了aten conv2d不同网络的操作有所不同 这也增加了一个数量级 在我的第一个网络中 它是22us 而对于第二个网络则是3ms 我的第
  • PyTorch - 参数不变

    为了了解 pytorch 的工作原理 我尝试对多元正态分布中的一些参数进行最大似然估计 然而 它似乎不适用于任何协方差相关的参数 所以我的问题是 为什么这段代码不起作用 import torch def make covariance ma
  • 如何使用 torch.stack?

    我该如何使用torch stack将两个张量与形状堆叠a shape 2 3 4 and b shape 2 3 没有就地操作 堆叠需要相同数量的维度 一种方法是松开并堆叠 例如 a size 2 3 4 b size 2 3 b torc
  • Pytorch:了解 nn.Module 类内部如何工作

    一般来说 一个nn Module可以由子类继承 如下所示 def init weights m if type m nn Linear torch nn init xavier uniform m weight class LinearRe
  • 使用 pytorch 获取可用 GPU 内存总量

    我正在使用 google colab 免费 Gpu 进行实验 并想知道有多少 GPU 内存可供使用 torch cuda memory allocated 返回当前占用的 GPU 内存 但我们如何使用 PyTorch 确定总可用内存 PyT
  • 在pytorch中使用tensorboard,但得到空白页面?

    我在pytorch 1 3 1中使用tensorboard 并且我在张量板的 pytorch 文档 https pytorch org docs stable tensorboard html 运行后tensorboard logdir r
  • 删除 Torch 张量中的行

    我有一个火炬张量如下 a tensor 0 2215 0 5859 0 4782 0 7411 0 3078 0 3854 0 3981 0 5200 0 1363 0 4060 0 2030 0 4940 0 1640 0 6025 0
  • 为什么我在这里遇到被零除的错误?

    所以我正在关注这个文档中的教程 https pytorch org tutorials beginner data loading tutorial html在自定义数据集上 我使用的是 MNIST 数据集 而不是教程中的奇特数据集 这是D
  • 在 PyTorch 中原生测量多类分类的 F1 分数

    我正在尝试在 PyTorch 中本地实现宏 F1 分数 F measure 而不是使用已经广泛使用的sklearn metrics f1 score https scikit learn org stable modules generat
  • 为什么 RNN 需要两个偏置向量?

    In Pytorch RNN 实现 http pytorch org docs master nn html highlight rnn torch nn RNN 有两个偏差 b ih and b hh 为什么是这样 它与使用一种偏差有什么
  • 尝试理解 Pytorch 的 LSTM 实现

    我有一个包含 1000 个示例的数据集 其中每个示例都有5特征 a b c d e 我想喂7LSTM 的示例 以便它预测第 8 天的特征 a 阅读 nn LSTM 的 Pytorchs 文档 我得出以下结论 input size 5 hid
  • 从打包序列中获取每个序列的最后一项

    我试图通过 GRU 放置打包和填充的序列 并检索每个序列最后一项的输出 当然我的意思不是 1项目 但实际上是最后一个 未填充的项目 我们预先知道序列的长度 因此应该很容易为每个序列提取length 1 item 我尝试了以下方法 impor
  • 下载变压器模型以供离线使用

    我有一个训练有素的 Transformer NER 模型 我想在未连接到互联网的机器上使用它 加载此类模型时 当前会将缓存文件下载到 cache 文件夹 要离线加载并运行模型 需要将 cache 文件夹中的文件复制到离线机器上 然而 这些文
  • PyTorch 中复数矩阵的行列式

    有没有办法在 PyTorch 中计算复矩阵的行列式 torch det未针对 ComplexFloat 实现 不幸的是 目前尚未实施 一种方法是实现您自己的版本或简单地使用np linalg det 这是一个简短的函数 它计算我使用 LU
  • pytorch 中的 autograd 可以处理同一模块中层的重复使用吗?

    我有一层layer in an nn Module并在一次中使用两次或多次forward步 这个的输出layer稍后输入到相同的layer pytorch可以吗autograd正确计算该层权重的梯度 def forward x x self
  • 如何更新 PyTorch 中神经网络的参数?

    假设我想将神经网络的所有参数相乘PyTorch 继承自的类的实例torch nn Module http pytorch org docs master nn html torch nn Module by 0 9 我该怎么做呢 Let n
  • Pytorch ValueError:优化器得到一个空参数列表

    当尝试创建神经网络并使用 Pytorch 对其进行优化时 我得到了 ValueError 优化器得到一个空参数列表 这是代码 import torch nn as nn import torch nn functional as F fro
  • 预期设备类型为 cuda 的对象,但在 Pytorch 中获得了设备类型 cpu

    我有以下计算损失函数的代码 class MSE loss nn Module metric L1 L2 norms or cosine similarity mode training or evaluation mode def init

随机推荐

  • 1-11 实验9 网络管理实验1 获取自身的和父节点网络地址、MAC地址

    p p p style color rgb 51 51 51 font family Arial font size 14px line height 26px 获取自身的和父节点网络地址 MAC地址 p p style color rgb
  • 1-14 实验11 获取网络拓扑

    获取网络拓扑 1 实验内容 xff1a PC端串口调试助手向协调器发送命名 topology 协调器接受到命令后 xff0c 将网络拓扑信息发送到PC机串口调试助手上 2 知识点 xff1a 在1 11 实验9 网络管理实验1 获取自身和父
  • S 串口编程 详解3 串口的初始化、打开/关闭

    串口编程 详解3 串口的初始化 程序打开串口 xff0c 采用两种方法 xff1a 1 程序启动 xff0c 调用OnInitDialog 函数 xff0c 打开串口 xff0c 缺省串口号为COM1 xff0c 如果COM1不存在或被占用
  • 求关键路径(包含邻接表的建立、拓扑排序)

    include lt stdio h gt include lt stdlib h gt typedef struct node int adjvex 邻接点域 int info 边上的信息 struct node next 指向下一个邻接
  • FPGA串口回环实验

    本文将从个人理解的角度 xff0c 解释FPGA串口通信的原理 xff0c 并进行实战演示 1 写在前面的话 串口通信是初学FPGA必过的一道坎 xff0c 如果能够在不参考任何资料的情况下自己手搓一套串口回环的代码 xff0c Verio
  • Debug Assertion Failed!解决方法详解

    1 野指针 2 内存泄露 解决方法 1 看一看你的程序里是不是有 ASSERT xff08 xff09 或 VERIFY xff08 xff09 语句 这两个宏是用来测试它的参数是否为真的 出现你说的 xff0c 这说明你的指针或表达试有问
  • 用tftp的方式在u_boot下 烧写uImage内核

    用 u boot 进行下载 uImage 一种 kernel 镜像文件 首先 把编译好的 uImage 文件放在 tftpboot 目录下 用网线把开发板和电脑连上 但PC上的网卡显示是没连接的 xff0c 这一点是没有关系的 xff0c
  • 利用NFS服务挂载NFS根文件系统

    嵌入式Linux根文件系统 xff0c 简单地说 xff0c 根文件系统就是一种目录结构 注意根文件系统和普通的文件系统的区别 常见的Linux根文件系统有 xff1a xff08 1 xff09 NFS xff08 网络根文件系统 xff
  • 数据校验之Checksum算法

    校验和 xff08 Checksum xff09 是网络协议使用的数据错误检测方法 xff0c 并且被认为比LRC xff08 纵向冗余校验 xff0c Longitudinal Redundancy Check xff0c LRC xff
  • 位序转字符串的一种高效方法

    include lt stdio h gt include lt stdlib h gt include lt malloc h gt include lt string h gt include lt arpa inet h gt def
  • OpenSIPS实战(一):OpenSIPS使用简介

    1 OpenSIPS是什么 OpenSIPS xff08 Open SIP Server xff09 是一个成熟的开源SIP服务器实现 可以作为SIP代理 路由器 但OpenSIPS不仅仅是一个SIP代理 路由器 xff0c 因为它包含了应
  • Floyd判圈算法(龟兔赛跑算法, Floyd's cycle detection)及其证明

    问题 xff1a 如何检测一个链表是否有环 xff08 循环节 xff09 xff0c 如果有 xff0c 那么如何确定环的起点以及环的长度 空间要求 xff1a 不能存储所经过的的每一个点 举例 xff1a x 0 61 1 x 0 61
  • Ubuntu配置GPU版本pytorch环境(含NVIDIA驱动+Cuda+Cudnn)

    本文更新于2018年8月底 概述 步骤如下 xff1a 1 安装Ubuntu 2 安装NVIDIA 显卡驱动 2 安装NVIDIA Cuda 3 安装NVIDIA CuDNN 4 安装GPU版本的PyTorch 安装Ubuntu 系统版本选
  • PyTorch中的Dataset、Dataloader和_DataloaderIter

    Dataset Pytorch中数据集被抽象为一个抽象类torch utils data Dataset xff0c 所有的数据集都应该继承这个类 xff0c 并override以下两项 xff1a len xff1a 代表样本数量 len
  • 彻底搞懂Lab 颜色空间

    本文参考wikipedia xff0c 并加入了自己的理解 xff0c 有不对的地方多多指教 名称 在开始之前 xff0c 先明确一下Lab颜色空间 xff08 Lab color space xff09 的名字 xff1a Lab的全称是
  • MiniFly微型四轴学习与开发日志(五)——遥控器任务详解

    文章目录 radiolinkTask无线连接任务usblinkTxTask usb发送任务usblinkRxTask usb接收任务commanderTask飞控指令发送任务keyTask按键扫描任务displayTask显示任务confi
  • .与::的使用区别

    今天尝试编写了一个小的Windows应用程序 xff0c 在编写的过程中用到MessageBox函数 但是一直不正确 我当时尝试MessageBox 34 NULL 34 34 Alert 34 34 ERROR 34 MB OK xff0
  • Pytorch中的contiguous理解

    最近遇到这个函数 xff0c 但查的中文博客里的解释貌似不是很到位 xff0c 这里翻译一下stackoverflow上的回答并加上自己的理解 在pytorch中 xff0c 只有很少几个操作是不改变tensor的内容本身 xff0c 而只
  • 一文读懂GAN, pix2pix, CycleGAN和pix2pixHD

    本文翻译 总结自朱俊彦的线上报告 xff0c 主要讲了如何用机器学习生成图片 来源 xff1a Games2018 Webinar 64期 xff1a Siggraph 2018优秀博士论文报告 人员信息 主讲嘉宾 姓名 xff1a 朱俊彦
  • Pytorch中的optimizer

    与优化函数相关的部分在torch optim模块中 xff0c 其中包含了大部分现在已有的流行的优化方法 如何使用Optimizer 要想使用optimizer xff0c 需要创建一个optimizer 对象 xff0c 这个对象会保存当