Pytorch训练时重要步骤

2023-11-01

目录

optimizer.zero_grad()

代码解读

调用频率和结果

loss.backward()

optimizer.step()


在学习pytorch时,官方文档 有一段示例代码

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

有三个关键步骤:

1. 梯度归零

2. 反向传播,得到每个参数的梯度

3. 通过梯度下降执行一步参数更新

下面详细了解一下这几个步骤

optimizer.zero_grad()

代码解读

d​​代码在这,

def zero_grad(self, set_to_none: bool = False):
        r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.

        Args:
            set_to_none (bool): instead of setting to zero, set the grads to None.
                This will in general have lower memory footprint, and can modestly improve performance.
                However, it changes certain behaviors. For example:
                1. When the user tries to access a gradient and perform manual ops on it,
                a None attribute or a Tensor full of 0s will behave differently.
                2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
                are guaranteed to be None for params that did not receive a gradient.
                3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
                (in one case it does the step with a gradient of 0 and in the other it skips
                the step altogether).
        """
        foreach = self.defaults.get('foreach', False)

        if not hasattr(self, "_zero_grad_profile_name"):
            self._hook_for_profile()
        if foreach:
            per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
        with torch.autograd.profiler.record_function(self._zero_grad_profile_name):
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is not None:
                        if set_to_none:
                            p.grad = None
                        else:
                            if p.grad.grad_fn is not None:
                                p.grad.detach_()
                            else:
                                p.grad.requires_grad_(False)
                            if (not foreach or p.grad.is_sparse):
                                p.grad.zero_()
                            else:
                                per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad)
            if foreach:
                for _, per_dtype_grads in per_device_and_dtype_grads.items():
                    for grads in per_dtype_grads.values():
                        torch._foreach_zero_(grads)

遍历模型所有参数,

如果选择将梯度清空,则通过p.grad.detach_()方法截断反向传播的梯度流,再通过p.grad.zero_()函数将每个参数的梯度值设为0.即不会叠加上次梯度。

如果选择将梯度置None,那么以后也不能改变梯度了,会影响后续的计算。

因为训练的过程通常使用mini-batch方法,所以如果不将梯度清零的话,梯度会与上一个batch的数据相关,因此该函数要写在反向传播和梯度下降之前。

调用频率和结果

一般每个batch需要调用一次gradient.zero_grad()函数,把参数梯度清零;

也可以多个batch调用一次,相当于增大了batch_size

loss.backward()

pyroch的反向传播时通过autograd实现的,官方教程

在创建tensor时,可以设置requires_grad =True,这样后面就可以在执行backward()后拿到参数梯度。

optimizer.step()

step()函数的作用是执行一次优化步骤.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch训练时重要步骤 的相关文章

  • 使用请求验证 SSL 证书

    我正在尝试验证 SSL 但它不起作用 我在浏览器上访问了我想要访问的机密网站 在 Chrome 上 我单击了储物柜 gt 证书 gt 详细信息 gt 复制到文件 gt base64 gt cert cer 我的代码是 test reques
  • 没有任何元数据的 zip 文件

    我想找到一种简单的方法来压缩一堆文件 而无需任何文件元数据 例如时间戳 这zip命令似乎总是保留元数据 我没有找到禁用元数据的方法 我希望解决方案是一个命令或最多一个 python 脚本 谢谢 正如一些帖子已经指出的那样 zip 标头中的大
  • 在Python3.6中调用C#代码

    由于完全不了解 C 编码 我希望在我的 python 代码中调用 C 函数 我知道有很多关于同一问题的问答 但由于一些奇怪的原因 我无法从示例 python 模块导入简单的 c 类库 以下是我所做的事情 C 类库设置 我使用的是 VS 20
  • 如何确定非阻塞套接字是否真正连接?

    这个问题不仅限于Python 这是一个一般的套接字问题 我有一个非阻塞套接字 想要连接到一台可访问的机器 在另一端 该端口不存在 为什么 select 仍然成功 我预计会超时 sock send 因管道损坏而失败 select 之后如何确定
  • 如何使用Python将WebP图像转换为Gif?

    我已经尝试过这个 from PIL import Image im Image open this webp im save that gif gif save all True 这给了我这个错误 类型错误 不支持的操作数类型 tuple
  • 带图像的简单 GUI [关闭]

    很难说出这里问的是什么 这个问题是含糊的 模糊的 不完整的 过于宽泛的或修辞性的 无法以目前的形式得到合理的回答 如需帮助澄清此问题以便重新打开 访问帮助中心 help reopen questions 我试图在简单的 GUI 上显示一些卡
  • Python Kivy - 在本机网络浏览器中打开 url 的应用程序

    我尝试制作一个简单的应用程序 在单击 Screen One 上的按钮后 在 Kivy 中打开一个网页 我使用了这个主题 Python 在应用程序中直接显示网络浏览器 iframe https stackoverflow com questi
  • 如何在python中检索aws批处理参数值?

    流程 Dynamo DB gt Lambda gt 批处理 如果将角色 arn 插入动态数据库 它是从 lambda 事件中检索的 然后使用submit job角色 arn 的 API 被传递为 parameters role arn ar
  • 具有多个元素的数组的真值是二义性错误吗? Python

    from numpy import from pylab import from math import def TentMap a x if x gt 0 and x lt 0 5 return 2 a x elif x gt 0 5 a
  • Pandas Dataframe:将包含列表的行扩展到多行,并为所有列提供所需的索引

    我在 pandas 数据框中有时间序列数据 索引为测量开始时的时间 列中包含以固定采样率记录的值列表 连续索引 列表中元素数量的差异 这是它的样子 Time A B Z 0 1 2 3 4 1 2 3 4 2 5 6 7 8 5 6 7 8
  • 将一个列表的元素除以另一个列表的元素

    我有两个清单 比如说 a 10 20 30 40 50 60 b 30 70 110 正如你所看到的 列表 b 由一个列表的元素总和组成 其中 window 2 b 0 a 0 a 1 10 20 30 etc 如何获得另一个列表 该列表由
  • 如何有效地从 loadmat 函数生成的嵌套 numpy 数组中提取值?

    python中是否有更有效的方法从嵌套的python列表中提取数据 例如A array array 12000000 dtype object 我一直在使用A 0 0 0 0 当你有很多像 A 这样的数据时 这似乎不是一个有效的方法 我也用
  • 在 MacO 和 Linux 上安装 win32com [重复]

    这个问题在这里已经有答案了 我的问题很简单 我可以安装吗win32com蟒蛇API pywin32特别是 在非 Windows 操作系统上 我一直在Mac上尝试多个版本pip install pywin32 都失败了 下面是一个例子 如果你
  • Airflow Python 单元测试?

    我想为我们的 DAG 添加一些单元测试 但找不到任何单元测试 有 DAG 单元测试框架吗 有一个端到端的测试框架存在 但我猜它已经死了 https issues apache org jira browse AIRFLOW 79 https
  • 导入错误:没有名为 google.auth 的模块

    当我尝试导入时firebase admin in python 2 7我收到错误 导入错误 没有名为 google auth 的模块 这是Docker文件 https github com ammaratef45 Attendance bl
  • Scipy 稀疏 Cumsum

    假设我有一个scipy sparse csr matrix代表下面的值 0 0 1 2 0 3 0 4 1 0 0 2 0 3 4 0 我想就地计算非零值的累积和 这会将数组更改为 0 0 1 3 0 6 0 10 1 0 0 3 0 6
  • Python - 如何查询定义方法的类?

    我的问题有点类似于this one https stackoverflow com questions 5520580 how do you get all classes defined in a module but not impor
  • Python组合目录中的所有csv文件并按日期时间排序

    我有 2 年的每日数据分成每月文件 我想将所有这些数据合并到一个按日期和时间排序的文件中 我正在使用的代码组合了所有文件 但不按顺序 我正在使用的代码 import pandas as pd import glob os import cs
  • 全局变量是 None 而不是实例 - Python

    我正在处理Python 中的全局变量 代码应该可以正常工作 但是有一个问题 我必须使用全局变量作为类的实例Back 当我运行应用程序时 它说 back is None 这应该不是真的 因为第二行setup 功能 back Back Back
  • 检查字符串是否只有字母和空格 - Python

    试图让 python 返回一个字符串仅包含字母和空格 string input Enter a string if all x isalpha and x isspace for x in string print Only alphabe

随机推荐

  • 《机器学习》Chapter 5 神经网络笔记

    机器学习 Chapter 5 神经网络 1 神经元模型 神经元接收到来自n个其它神经元传递过来的输入信号 这些输入信号通过带权重的连接进行传递 神经元接收到的总输入值将与神经元的阈值进行比较 然后通过激活函数处理以产生神经元的输出 2 感知
  • 使用css去除input边框

  • 不同CUDA版本对应的最小GPU运算能力和最低兼容驱动

    The minimum compute capability for various CUDA versions CUDA Version Minimum Compute Capability Default Compute Capabil
  • 使用 Python 进行游戏脚本编程 [翻译]

    翻译自 GDC 2002 上 Bruce Dawson 的 Game Scripting in Python 简单介绍 Python 作为游戏脚本语言的优势 原文 Game Scripting in Python 作者 Bruce Daws
  • 机器学习的环境搭建流程

    一 需要 python解释器 pycharm anaconda 机器学习需要的第三方包 二 流程 1 先确定进行机器学习需要的主要包之间的依赖关系及对应的python版本 建议python版本不要太高 3 6或者3 7比较好 因为许多第三方
  • 后端进阶之路——综述Spring Security认证,授权(一)

    前言 作者主页 雪碧有白泡泡 个人网站 雪碧的个人网站 推荐专栏 java一站式服务 前端炫酷代码分享 uniapp 从构建到提升 从0到英雄 vue成神之路 解决算法 一个专栏就够了 架构咱们从0说 数据流通的精妙之道 后端进阶之路 文章
  • 【华为OD机试真题 Python】最左侧冗余覆盖子串

    前言 本专栏将持续更新华为OD机试题目 并进行详细的分析与解答 包含完整的代码实现 希望可以帮助到正在努力的你 关于OD机试流程 面经 面试指导等 如有任何疑问 欢迎联系我 wechat steven moda email nansun09
  • 简单的编写一个通讯录并可进行增删改查功能

    改通讯录分为三个模块 test c contact c contact h 下面依次给我相应的代码 有想问的或者觉得有帮助的帮忙点个赞和关注一下哈 蟹蟹 主要用到了结构体指针来进行对结构体的修改查找之类的算法 test c define C
  • iOS 底层解析weak的实现原理

    参考地址 http www cocoachina com ios 20170328 18962 html weak 实现原理的概括 Runtime维护了一个weak表 用于存储指向某个对象的所有weak指针 weak表其实是一个hash 哈
  • 深度学习之---yolov1,v2,v3详解

    写在前面 如果你想 run 起来 立马想看看效果 那就直接跳转到最后一张 动手实践 看了结果再来往前看吧 开始吧 一 YOLOv1 简介 这里不再赘述 之前的我的一个 GitChat 详尽的讲述了整个代码段的含义 以及如何一步步的去实现它
  • 【数字转换为汉字】

    项目场景 通常这种情况下 后端返回是数组 如果想要把数组这样显示出来 就需要把数组的索引值转换为汉字显示 如 11显示为十一 21显示为二十一 实现代码讲解 NoToChinese num 如果传递过来的值不是数据类型 则直接报错 if d
  • 蓝桥杯C/C++省赛:振兴中华

    目录 题目描述 思路分析 AC代码 题目描述 小明参加了学校的趣味运动会 其中的一个项目是 跳格子 地上画着一些格子 每个格子里写一个字 如下所示 从我做起振 我做起振兴 做起振兴中 起振兴中华 比赛时 先站在左上角的写着 从 字的格子里
  • Vue_shop(Element-UI)优化打包上线

    自己从头到尾打的源码链接 https gitee com bai xianger vue shop 一 项目的打包优化 1 网页顶部添加进度条效果 安装运行依赖nprogresst 链接 https github com rstacruz
  • 1053. 交换一次的先前排列

    转载请声明地址四元君 1053 交换一次的先前排列 题目难度 Medium 给你一个正整数的数组 A 其中的元素不一定完全不同 请你返回可在 一次交换 交换两数字 A i 和 A j 的位置 后得到的 按字典序排列小于 A 的最大可能排列
  • 无向图的广度优先遍历和深度优先遍历

    public class MGraph private char vexs 顶点 private int edge 存储边的二维数组 private int arcNum 边的数目 private boolean visited 访问标志数
  • Java学习笔记 类的特性:拓展知识与案例

    1 this的使用 使用场合 从一个构造方法调用到另一个构造方法 作用 缩短程序代码 减少开发程序时间 规则 通过关键字this来调用 this必须写在构造方法的第一行位置 在圆柱体类Cylinder里 用一个构造方法调用另一个构造方法 c
  • 浮点的加减计算方法

    浮点的计算方法 1 计算步骤 2 基本要素 2 1 浮点数 2 2 规格化浮点数 2 3 偏置指数 2 4 IEEE浮点数 2 5 特点 3 计算实例 4 舍入机制 扩展 乘除计算步骤 1 计算步骤 浮点数格式 单精度 符号位1位 阶码8位
  • MySQL权限管理

    如果声明了 WITH GRANT OPTION 那么权限的接收者也可以将此权限赋予他人 否则 就不能授权他人 MySQL权限管理 mysql gt SHOW GRANTS G 1 row Grants for root localhost
  • 使用SQL Server 获取插入记录后的ID(自动编号)

    要获取此ID 最简单的方法就是在查询之后select indentity SQL语句创建数据库和表 复制代码代码如下 create database dbdemo go use dbdemo go create table tbldemo
  • Pytorch训练时重要步骤

    目录 optimizer zero grad 代码解读 调用频率和结果 loss backward optimizer step 在学习pytorch时 官方文档 有一段示例代码 def train loop dataloader mode