torch.autograd.grad求二阶导数

2023-11-05

1 用法介绍

pytorchtorch.autograd.grad函数主要用于计算并返回输出相对于输入的梯度总和,具体的参数作用如下所示:

torch.tril(input, diagonal=0, *, out=None) ⟶ \longrightarrow Tensor

  • outputs(sequence of Tensor):表示微分函数的输出
  • inputs (sequence of Tensor):表示微分函数的输入
  • grad_outputs (sequence of Tensor):表示“向量-雅克比矩阵”的向量
  • retain_graph (bool, optional):表示是否需要将计算图释放掉,当计算二阶导数时需要设置为True
  • create_graph (bool, optional):表示是否需要将梯度将会加入到计算图中,当计算高阶导数或者其他计算时会将其设置为需要设置为True
  • allow_unused (bool, optional):表示是否只返回输入的梯度,而不返回其他叶子节点的梯度

2 实例讲解

以下给出了具体的二阶导数解析解的数学实例

给定一个向量 x = ( x 1 , x 2 ) ⊤ {\bf{x}}=(x_1,x_2)^{\top} x=(x1,x2),可以得到向量 y = ( y 1 , y 2 ) ⊤ = ( x 1 2 , x 2 2 ) ⊤ {\bf{y}}=(y_1,y_2)^{\top}=(x^2_1,x^2_2)^{\top} y=(y1,y2)=(x12,x22)。对向量 y {\bf{y}} y的元素求平均可以得到损失函数 l o s s 1 \mathrm{loss}_1 loss1为: l o s s 1 ( x ) = m e a n ( y ) = x 1 2 + x 2 2 2 \mathrm{loss}_1({\bf{x}})=\mathrm{mean}({\bf{y}})=\frac{x_1^2+x^2_2}{2} loss1(x)=mean(y)=2x12+x22向量 y {\bf{y}} y元素的分量分别对 x {\bf{x}} x求偏导,然后相加求平均得到损失函数 l o s s 2 \mathrm{loss}_2 loss2 { h 1 ( x ) = ∂ y 1 ∂ x = ( 2 x 1 , 0 ) ⊤ h 2 ( x ) = ∂ y 2 ∂ x = ( 0 , 2 x 2 ) ⊤ , l o s s 2 ( x ) = m e a n ( h 1 ( x 1 ) − h 2 ( x 2 ) ) = x 1 − x 2 \left\{\begin{aligned}h_1({\bf{x}})&=\frac{\partial y_1}{\partial {\bf{x}}}=(2x_1,0)^{\top}\\h_2({\bf{x}})&=\frac{\partial y_2}{\partial {\bf{x}}}=(0,2x_2)^{\top}\end{aligned}\right.,\quad \mathrm{loss}_2({\bf{x}})=\mathrm{mean}(h_1({\bf{x}}_1)-h_2({\bf{x}}_2))=x_1-x_2 h1(x)h2(x)=xy1=(2x1,0)=xy2=(0,2x2),loss2(x)=mean(h1(x1)h2(x2))=x1x2将损失函数 l o s s 1 \mathrm{loss}_1 loss1与损失函数 l o s s 2 \mathrm{loss}_2 loss2相加可以得到 l o s s ( x ) = l o s s 1 ( x ) + l o s s 2 ( x ) = x 1 2 + x 2 2 2 + x 1 − x 2 \mathrm{loss}({\bf{x}})=\mathrm{loss}_1({\bf{x}})+\mathrm{loss}_2({\bf{x}})=\frac{x_1^2+x_2^2}{2}+x_1-x_2 loss(x)=loss1(x)+loss2(x)=2x12+x22+x1x2最终损失函数 l o s s \mathrm{loss} loss对向量 x {\bf{x}} x的偏导数为 ∂ l o s s ∂ x = ( x 1 + 1 , x 2 − 1 ) ⊤ \frac{\partial {\mathrm{loss}}}{\partial{{\bf{x}}}}=(x_1+1,x_2-1)^{\top} xloss=(x1+1,x21)

以下为用pytorch实现二阶导数相对应的代码实例:

import torch

x = torch.tensor([5.0, 7.0], requires_grad=True)
y = x**2

loss1 = torch.mean(y)

h1 = torch.autograd.grad(y[0], x, retain_graph = True, create_graph=True)
h2 = torch.autograd.grad(y[1], x, retain_graph = True, create_graph=True)
loss2 = torch.mean(h1[0] - h2[0])

loss = loss1 + loss2

result = torch.autograd.grad(loss, x)
print(result)

当向量 x {\bf{x}} x取值为 ( 5 , 7 ) ⊤ (5,7)^{\top} (5,7)时,根据数学解析解得到的二阶导数为 ( 6 , 6 ) ⊤ (6,6)^{\top} (6,6),对应的代码运行的实验结果也为 ( 6 , 6 ) (6,6) (6,6)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

torch.autograd.grad求二阶导数 的相关文章

随机推荐

  • 常见几种滤波器的比较

    经典的数字滤波器有巴特沃斯滤波器 切比雪夫滤波器 椭圆滤波器和贝塞尔滤波器等 巴特沃斯滤波器的特点是通频带内的频率响应曲线最大限度平坦 没有起伏 而在阻频带则逐渐下降为零 在振幅的对数对角频率的波特图上 从某一边界角频率开始 振幅随着角频率
  • Linux FTP服务(只允许白名单用户访问FTP)

    目录 一 FTP服务器 二 FTP文化传输协议 FTP的传输模式有两种 三 Vsftpd服务程序 四 实验步骤 1 安装vsftpd软件包 2 备份主配置文件 3 去掉 号开头的行 4 创建黑 白名单的目的 约束 允许某些特定用户登录系统
  • 深入学习java源码之ArrayList.addAll()与ArrayList.retainAll()

    深入学习java源码之ArrayList addAll 与ArrayList retainAll 引入多态 List是接口 所以实现类要把接口中的抽象方法全部重写 在重写的时候父类中的方法的时候 操作的数据类型也是要与父类保持一致的 所以父
  • IPX9K IP69K:ISO 20653:2006

    IPX9K IP69K ISO 20653 2006 ISO 20653 2006 已由 ISO 20653 2013 标准代替 道路车辆 防护等级 IP 代码 电气设备对 外来物 水和接触的防护 参考编号 ISO 20653 2006 版
  • 古老的Solidity智能合约错误代码编写

    任何编程语言都有不完善的地方 而使用语言的过程中也可能产生一些逻辑上的Bug 在Solidity0 4 23版本的时候 有人在GitHub上列举了一些使用Solidity编写智能合约时常见的错误用法 虽然现在大家基本上都不会再写同样的问题代
  • Python布雷森汉姆直线算法RViz可视化ROS激光占位网格映射

    使用对数赔率映射已知姿势算法 ROS 包 布雷森汉姆直线算法 布雷森汉姆直线算法是一种线绘制算法 它确定应选择的 n 维栅格的点 以便形成两点之间的直线的近似值 它通常用于在位图图像中 例如在计算机屏幕上 绘制线条图元 因为它仅使用整数加法
  • 安卓系统培训!五年Android开发者小米、阿里面经,一线互联网公司面经总结

    前言 最近有不少人问我这样一个问题 我刚接触编程 准备学习下Android开发 但是担心现在市场饱和了 Android开发的前景怎么样 想着可能有很多人都有这样的担心 于是就赶紧写篇文章 来跟你们谈下Android开发的前景到底怎么样 一
  • 三个美观的个人博客网站源码

    怎么让源码更适合你 改造 名称 二开版UI漂亮的PHP博客论坛网站源码 介绍 可切换皮肤界面 下载 https wwwf lanzout com ihLNM10bfgnc 二 名称 Emlog Pro博客管理系统源码绿色版下载 介绍 源码说
  • MySQL:创建数据库,数据表,主键和外键

    目录 前言 安装MySQL 打开MySQL 创建数据库 查看已建数据库 查看数据库引擎 创建数据表 主键约束 单字段主键 多字段联合主键 外键约束 前言 MySQL数据库安装了很久 一直也没静下心来学习 因为起步太晚 所以什么都想学点 又感
  • Kafka的中的数据清理你知道多少

    Kafka将数据持久化到了硬盘上 为了控制磁盘容量 需要对过去的消息进行清理 那么 删除策略有哪几种呢 日志压缩和日志删除 其中日志压缩一般用的比较少 log cleanup policy compact 启用压缩策略 按照消息key进行整
  • 全局光照技术解析Global Illumination Explained

    解析全局光照Global Illumination Explained 前言 Global Illumination全局光照技术是实时渲染的必然发展方向 我参考了一些研究成果 琢磨了一下 让更多的人可以理解这项 古老 的技术 Front L
  • (Jquery功能篇) Jquery UI 相关组件(手风琴、tab分页、进度条、 滚动条、 时间控件)

    截图 实例代码
  • python dataframe增加数据_Pandas学习笔记(DataFrame基本操作)

    对于生成的dDataFrame 下一步进行的是对他的基本操作 增 减 改 查 一 数据选取 从已有的DataFrame中取出其中一列或几列 并对其进行操作 Pandas取出DataFrame的列有两种方式 两个方式没有好与坏之分 还是看个人
  • 用java做打字训练测试软件,《打字训练测试软件-Java课程设计》.doc

    PAGE PAGE 3 程序设计实践 题目 打字训练测试软件 学校 陕西工业学院 学院 信息学院 班级 信管12 2 学号 201213156619 姓名 刘克豪 2014 年 11 月 09 日 基础类 IO流与异常处理程序设计 一 实践
  • linux如何查看所有的用户、用户组、密码

    linux如何查看所有的用户和组信息 百度经验https jingyan baidu com article a681b0de159b093b184346a7 html linux添加用户 用户组 密码 百度经验https jingyan
  • 【pandas】(六)增删改查

    文章目录 一 增加数据 1 1 增加一行 1 2 增加一列 1 3 pd concat 拼接数据 1 objs Series DataFrame或Panel对象的 序列或映射 2 axis 0 1 默认为0 纵向拼接 3 join inne
  • IOS技术分享

    前言 最近对 WebRTC iOS 端源码进行了下载和编译 网上针对 WebRTC iOS 端的编译文章基本都是几年前的 有些地方已经不适用于最新版的 WebRTC 的编译 简单记录下载 编译的过程 以 M93 版本为例 编译环境 硬件 M
  • Android购物车效果实现(RecyclerView悬浮头部实现)

    刚开始看购物车效果觉得挺复杂 但是把这个功能拆开来一步一步实现会发现并不难 其实就涉及到 ItemDecoration的绘制 recyclerview的滑动监听 贝塞尔曲线和属性动画相关内容 剩下的就是RecyclerView滑动和点击时左
  • Xshell6和Xftp提示“要继续使用此程序,您必须应用最新的更新或使用新版本“

    Xshell6和Xftp提示 要继续使用此程序 您必须应用最新的更新或使用新版本 使用二进制编辑器修改Xshell和Xftp的nslicense dll文件 如sublime Txt编辑器等 1 分别进入Xshell和Xftp的安装路径下
  • torch.autograd.grad求二阶导数

    1 用法介绍 pytorch中torch autograd grad函数主要用于计算并返回输出相对于输入的梯度总和 具体的参数作用如下所示 torch tril input diagonal 0 out None longrightarro