通过计算雅可比行列式,有效地使用 PyTorch 的 autograd 和张量

2023-12-21

在我之前的question https://stackoverflow.com/questions/67320792/how-to-use-pytorchs-autograd-efficiently-with-tensors/67334809#67334809我找到了如何将 PyTorch 的 autograd 与张量一起使用:

import torch
from torch.autograd import grad
import torch.nn as nn
import torch.optim as optim

class net_x(nn.Module): 
        def __init__(self):
            super(net_x, self).__init__()
            self.fc1=nn.Linear(1, 20) 
            self.fc2=nn.Linear(20, 20)
            self.out=nn.Linear(20, 4) #a,b,c,d

        def forward(self, x):
            x=torch.tanh(self.fc1(x))
            x=torch.tanh(self.fc2(x))
            x=self.out(x)
            return x

nx = net_x()

#input
t = torch.tensor([1.0, 2.0, 3.2], requires_grad = True) #input vector
t = torch.reshape(t, (3,1)) #reshape for batch

#method 
dx = torch.autograd.functional.jacobian(lambda t_: nx(t_), t)
dx = torch.diagonal(torch.diagonal(dx, 0, -1), 0)[0] #first vector
#dx = torch.diagonal(torch.diagonal(dx, 1, -1), 0)[0] #2nd vector
#dx = torch.diagonal(torch.diagonal(dx, 2, -1), 0)[0] #3rd vector
#dx = torch.diagonal(torch.diagonal(dx, 3, -1), 0)[0] #4th vector
dx 
>>> 
tensor([-0.0142, -0.0517, -0.0634])

问题是grad只知道如何从标量张量传播梯度(我的网络的输出不是),这就是为什么我必须计算雅可比行列式。

然而,这不是很有效并且有点慢,因为我的矩阵很大并且计算整个雅可比矩阵需要一段时间(而且我也没有使用整个雅可比矩阵)。

有没有办法只计算雅可比行列式的对角线(以获得本例中的 4 个向量)?

似乎有一个打开功能请求 https://github.com/pytorch/pytorch/issues/41530但它似乎并没有引起太多关注。

更新1:
我尝试了@iacob所说的设置torch.autograd.functional.jacobian(vectorize=True).
不过,这似乎要慢一些。为了测试这个我改变了我的网络输出4 to 400,以及我的输入t to be:

val = 100
t = torch.rand(val, requires_grad = True) #input vector
t = torch.reshape(t, (val,1)) #reshape for batch

Without vectorized = True:

Wall time: 10.4 s

With:

Wall time: 14.6 s

好的,先看结果:

性能(我的笔记本电脑有 RTX-2070,PyTorch 正在使用它):

# Method 1: Use the jacobian function
CPU times: user 34.6 s, sys: 165 ms, total: 34.7 s
Wall time: 5.8 s

# Method 2: Sample with appropriate vectors
CPU times: user 1.11 ms, sys: 0 ns, total: 1.11 ms
Wall time: 191 µs

速度快了大约 30000 倍。


为什么你应该使用backward代替jacobian(就你而言)

我不是 PyTorch 的专业人士。但是,根据我的经验,如果不需要雅可比矩阵中的所有元素,那么计算雅可比矩阵的效率相当低。

如果只需要对角线元素,可以使用backward计算函数vector- 与一些特定向量的雅可比乘法。如果您设置vector如果正确,您可以从雅可比矩阵中采样/提取特定元素。

一点线性代数:

j = np.array([[1,2],[3,4]]) # 2x2 jacobi you want 
sv = np.array([[1],[0]])     # 2x1 sampling vector

first_diagonal_element = sv.T.dot(j).dot(sv)  # it's j[0, 0]

对于这个简单的例子来说,它的功能并不是那么强大。但是如果 PyTorch 需要计算所有雅可比矩阵(j可能是一长串矩阵-矩阵乘法的结果),它会太慢。相反,如果我们计算向量雅可比乘法序列,计算速度将会非常快。


Solution

雅可比的示例元素:

import torch
from torch.autograd import grad
import torch.nn as nn
import torch.optim as optim

class net_x(nn.Module): 
        def __init__(self):
            super(net_x, self).__init__()
            self.fc1=nn.Linear(1, 20) 
            self.fc2=nn.Linear(20, 20)
            self.out=nn.Linear(20, 400) #a,b,c,d

        def forward(self, x):
            x=torch.tanh(self.fc1(x))
            x=torch.tanh(self.fc2(x))
            x=self.out(x)
            return x

nx = net_x()

#input

val = 100
a = torch.rand(val, requires_grad = True) #input vector
t = torch.reshape(a, (val,1)) #reshape for batch


#method 
%time dx = torch.autograd.functional.jacobian(lambda t_: nx(t_), t)
dx = torch.diagonal(torch.diagonal(dx, 0, -1), 0)[0] #first vector
#dx = torch.diagonal(torch.diagonal(dx, 1, -1), 0)[0] #2nd vector
#dx = torch.diagonal(torch.diagonal(dx, 2, -1), 0)[0] #3rd vector
#dx = torch.diagonal(torch.diagonal(dx, 3, -1), 0)[0] #4th vector
print(dx)


out = nx(t)
m = torch.zeros((val,400))
m[:, 0] = 1
%time out.backward(m)
print(a.grad)

a.grad应等于第一个张量dx. And, m是与代码中所谓的“第一个向量”相对应的采样向量。


  1. 但如果我再次运行它,答案就会改变。

是的,你已经明白了。每次调用时梯度都会累积backward。所以你必须设置a.grad如果您必须多次运行该单元,请先为零。

  1. 你能解释一下背后的想法吗m方法?两者都使用torch.zeros并将该列设置为1。还有,怎么毕业了a而不是在t?
  • 背后的想法m方法是:功能是什么backward计算实际上是一个向量雅可比矩阵乘法,其中向量代表所谓的“上游梯度”,雅可比矩阵是“局部梯度”(这个雅可比矩阵也是你用jacobian函数,因为你的lambda可以被视为单个“本地”操作)。如果您需要来自雅可比的一些元素,您可以伪造(或者更准确地说,构造)一些“上游梯度”,用它您可以从雅可比中提取特定的条目。然而,有时如果涉及复杂的张量运算,这些上游梯度可能很难找到(至少对我来说)。
  • PyTorch 在计算图的叶节点上累积梯度。而且,你原来的代码行t = torch.reshape(t, (3,1))失去叶节点的句柄,并且t现在指的是中间节点而不是叶节点。为了访问叶节点,我创建了张量a.
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

通过计算雅可比行列式,有效地使用 PyTorch 的 autograd 和张量 的相关文章

  • Python中Decimal类型的澄清

    每个人都知道 或者至少 每个程序员都应该知道 http docs oracle com cd E19957 01 806 3568 ncg goldberg html 即使用float类型可能会导致精度错误 然而 在某些情况下 精确的解决方
  • 如何正确地将 MIDI 刻度转换为毫秒?

    我正在尝试将 MIDI 刻度 增量时间转换为毫秒 并且已经找到了一些有用的资源 MIDI Delta 时间刻度到秒 http www lastrayofhope co uk 2009 12 23 midi delta time ticks
  • Python逻辑运算符优先级[重复]

    这个问题在这里已经有答案了 哪个运算符优先4 gt 5 or 3 lt 4 and 9 gt 8 这会被评估为真还是假 我知道该声明3 gt 4 or 2 lt 3 and 9 gt 10 显然应该评估为 false 但我不太确定 pyth
  • if 语句未命中中的 continue 断点

    在下面的代码中 两者a and b是生成器函数的输出 并且可以评估为None或者有一个值 def testBehaviour self a None b 5 while True if not a or not b continue pri
  • Argparse nargs="+" 正在吃位置参数

    这是我的解析器配置的一小部分 parser add argument infile help The file to be imported type argparse FileType r default sys stdin parser
  • 忽略 Mercurial hook 中的某些 Mercurial 命令

    我有一个像这样的善变钩子 hooks pretxncommit myhook python path to file myhook 代码如下所示 def myhook ui repo kwargs do some stuff 但在我的例子中
  • Pandas 数据帧到 numpy 数组 [重复]

    这个问题在这里已经有答案了 我对 Python 很陌生 经验也很少 我已经设法通过复制 粘贴和替换我拥有的数据来使一些代码正常工作 但是我一直在寻找如何从数据框中选择数据 但无法理解这些示例并替换我自己的数据 总体目标 如果有人真的可以帮助
  • 以同步方式使用 FastAPI,如何获取 POST 请求的原始正文?

    在中使用 FastAPIsync not async模式 我希望能够接收 POST 请求的原始 未更改的正文 我能找到的所有例子都显示async代码 当我以正常同步方式尝试时 request body 显示为协程对象 当我通过发布一些内容来
  • 在Python中调整图像大小

    我有一张尺寸为 288 352 的图像 我想将其大小调整为 160 240 我尝试了以下代码 im imread abc png img im resize 160 240 Image ANTIALIAS 但它给出了一个错误TypeErro
  • 为什么在 Python 2.4 中使用 Unicode 数据会出现 ASCII 编码错误,而在 2.7 中却不会?

    我有一个程序 当在 Python 2 7 中运行时 会生成正确的 Unicode 输出到标准输出 当在 Python 2 4 中运行时 我得到UnicodeEncodeError ascii codec can t encode chara
  • Seaborn Pairplot 图例不显示颜色

    我一直在学习如何在Python中使用seaborn和pairplot 这里的一切似乎都工作正常 但由于某种原因 图例不会显示相关的颜色 我无法找到解决方案 因此如果有人有任何建议 请告诉我 x sns pairplot stats2 hue
  • 在 pytube3 中获取 youtube 视频的标题?

    我正在尝试构建一个应用程序来使用 python 下载 YouTube 视频pytube3 但我无法检索视频的标题 这是我的代码 from pytube import YouTube yt YouTube link print yt titl
  • 如何使用列表作为pandas数据框中的值?

    我有一个数据框 需要列的子集包含具有多个值的条目 下面是一个带有 运行时 列的数据框 其中包含程序在各种条件下的运行时 df condition a runtimes 1 1 5 2 condition b runtimes 0 5 0 7
  • Python 将日志滚动到变量

    我有一个使用多线程并在服务器后台运行的应用程序 为了无需登录服务器即可监控应用程序 我决定包括Bottle http bottlepy org为了响应一些HTTP端点并报告状态 执行远程关闭等 我还想添加一种查阅日志文件的方法 我可以使用以
  • 如何为每个屏幕添加自己的 .py 和 .kv 文件?

    我想为每个屏幕都有一个单独的 py 和 kv 文件 应通过 main py main kv 中的 ScreenManager 选择屏幕 设计应从文件 screen X kv 加载 类等应从文件 screen X py 加载 Screens
  • 当鼠标悬停在上面时,intellisense vscode 不显示参数或文档

    我正在尝试将整个工作流程从 Eclipse 和 Jupyter Notebook 迁移到 VS Code 我安装了 python 扩展 它应该带有 Intellisense 但它只是部分更糟糕 我在输入句点后收到建议 但当将鼠标悬停在其上方
  • 如何读取Python字节码?

    我很难理解 Python 的字节码及其dis module import dis def func x 1 dis dis func 上述代码在解释器中输入时会产生以下输出 0 LOAD CONST 1 1 3 STORE FAST 0 x
  • Elastic Beanstalk 中的 enum34 问题

    我正在尝试在 Elastic Beanstalk 中设置 django 环境 当我尝试通过requirements txt 文件安装时 我遇到了python3 6 问题 File opt python run venv bin pip li
  • 迭代 pandas 数据框的最快方法?

    如何运行数据框并仅返回满足特定条件的行 必须在之前的行和列上测试此条件 例如 1 2 3 4 1 1 1999 4 2 4 5 1 2 1999 5 2 3 3 1 3 1999 5 2 3 8 1 4 1999 6 4 2 6 1 5 1
  • Scrapy Spider不存储状态(持久状态)

    您好 有一个基本的蜘蛛 可以运行以获取给定域上的所有链接 我想确保它保持其状态 以便它可以从离开的位置恢复 我已按照给定的网址进行操作http doc scrapy org en latest topics jobs html http d

随机推荐

  • 如何确定一个 3D 对象是否适合另一个 3D 对象(容器)?

    给定两个 3D 对象 我如何找到一个是否适合第二个对象 并找到该对象在容器中的位置 应平移和旋转对象以适合容器 但不得进行其他修改 其他并发症 相同的情况 但寻找最适合的解决方案 即使它不是正确的匹配 最小化不适合容器的物体的体积 支持弹性
  • Puppeteer:如何监听对象事件

    是否可以监听页内对象调度的事件 假设我访问的页面中有以下代码 var event new CustomEvent status detail ok window addEventListener status function e cons
  • 强制 git push + pull 超时

    我发现的所有问题都想避免 git 推 拉超时 就我而言 我想强迫他们 我的推 拉都是通过 ssh 传输到在某个时间点可能不可用的远程计算机 例如 我有一个脚本可以推送到两个远程公共存储库 我不希望这个脚本在推送到第一个存储库并且该机器不可用
  • Flexslider 和从右到左的语言支持

    我在 WordPress 上安装了一个包含 Flexslider 的模板 我的语言是从右到左 RTL 书写的 当页面为 RTL 时 Flexslider 停止并且不显示图像 我该如何解决这个问题 Flex 滑块不支持 RTL 语言 解决这个
  • C++:崩溃时不显示 glibc 的回溯和内存映射

    我正在使用 Python 进行自动 C 代码测试 所以我有一个编译和执行 C 代码的 Python 脚本 当 C 代码崩溃时 即使我重定向 libc 输出也可以从我的 Python 脚本输出中看到cout and cerr正在执行的 C 程
  • 从数据库更新模型时出现实体框架错误,反之亦然

    当我尝试使用 VS Express 2013 for web EF6 1 1 和 NET Framework 4 5 从数据库更新模型时 会发生以下情况 在本例中 我只是在表定义中向表中添加了一个字段并更新了数据库 之后 我在 EDMX 模
  • 使用GridSearchCV时出现值错误

    我正在使用 GridSearchCV 进行分类 我的代码是 parameter grid SVM dual True False loss squared hinge hinge penalty l1 l2 clf GridSearchCV
  • Autofac PropertiesAutowired - 是否可以忽略一个或多个属性?

    尽管建议通过构造函数传递依赖项 但我发现使用无参数构造函数然后自动装配所有属性的开发成本显着减少 并使应用程序更易于开发和维护 然而有时 例如在视图模型上 我有一个在容器中注册的属性 但我不想在构造时填充该属性 例如绑定到容器的所选项目 有
  • 谷歌移动视觉库无法下载

    我正在尝试将 Google Mobile Vision TextRecogniser API 实现到我的应用程序中 以读取给定图像中的文本 当我尝试使用该功能时 出现以下错误 W DynamiteModule Local module de
  • 为什么 Julia 不鼓励对 UTF8 字符串建立索引?

    Julia 的入门指南 在 Y 分钟内学习 Julia https learnxinyminutes com docs julia 阻止用户索引 UTF8 字符串 Some strings can be indexed like an ar
  • 如何调整表单大小以自动适应其内容?

    我正在尝试实现以下行为 表单上有一个选项卡控件 在该选项卡控件上有一个树视图 为了防止出现滚动条 我希望表单在第一次显示时根据树视图的内容更改其大小 如果树视图有太多节点无法在窗体的默认大小上显示 则窗体应更改其大小 以便树视图上没有垂直滚
  • 无法在列表框中绑定命令

    我的 WPF 使用 MVVM 方法 我正在尝试在列表控件中绑定 2 个控件
  • 我自己的 R 中的 K 均值算法

    我是 R 编程的初学者 我正在 R 中进行此练习作为编程入门 我已经在 R 中实现了自己的 K 均值实现 但在某一点上卡住了一段时间 我需要达成共识 算法迭代直到找到每个簇的最佳中心 这是没有迭代的原始算法 它只是从整个数据中随机选取一个数
  • 在ColdFusion中,有没有办法确定代码在哪个服务器上运行?

    ColdFusion 代码中是否有任何方法可以确定代码在哪个服务器上执行 我有一些负载平衡的 ColdFusion 服务器 当我捕获异常时 我希望能够知道代码正在哪个服务器上运行 因此我可以将该信息包含在日志记录 报告代码中 服务器是 Wi
  • 当您无法提供色彩美感时手动创建图例

    在试图回答时这个问题 https stackoverflow com questions 34066131 can data points be labeled in stripcharts 34068263 创建所需绘图的一种方法是使用g
  • 为什么在JPA Hibernate中更新查询;所有属性都在 SQL 中更新

    我将 JPA 与 Hibernate 一起使用 当我修改对象的一个 属性并更新它时 生成的 SQL 显示所有列都已更新 为什么它不只更新修改的列 有没有办法实现这一点 因为我觉得这样会更加优化 默认情况下 hibernate 包含更新查询中
  • 在 NetBeans 中找不到主类

    我一直在为我的编程课做作业 我正在使用 NetBeans 我完成了我的项目并且运行良好 当我尝试运行它时 收到一条消息 未找到主类 这是主要的一些代码 package luisrp3 import java io FileNotFoundE
  • 如何使用 Seaborn 在 hexbins 上绘制回归线?

    我终于成功地将我的 hexbin 分布图整理成几乎漂亮的东西 import seaborn as sns x req apply clicks y req reqs wordcount sns jointplot x y kind hex
  • 将 PySpark DenseVector 转换为数组

    我正在尝试将 DenseVector 的 pyspark 数据帧列转换为数组 但总是出现错误 data Vectors dense 8 0 1 0 3 0 2 0 5 0 Vectors dense 2 0 0 0 3 0 4 0 5 0
  • 通过计算雅可比行列式,有效地使用 PyTorch 的 autograd 和张量

    在我之前的question https stackoverflow com questions 67320792 how to use pytorchs autograd efficiently with tensors 67334809