pytorch 中的 autograd 可以处理同一模块中层的重复使用吗?

2024-04-24

我有一层layer in an nn.Module并在一次中使用两次或多次forward步。这个的输出layer稍后输入到相同的layer。 pytorch可以吗autograd正确计算该层权重的梯度?

def forward(x):
    x = self.layer(x)
    x = self.layer(x)
    return x

完整示例:

import torch
import torch.nn as nn
import torch.nn.functional as F

class net(nn.Module):
    def __init__(self,in_dim,out_dim):
        super(net,self).__init__()
        self.layer = nn.Linear(in_dim,out_dim,bias=False)

    def forward(self,x):
        x = self.layer(x)
        x = self.layer(x)
        return x

input_x = torch.tensor([10.])
label = torch.tensor([5.])
n = net(1,1)
loss_fn = nn.MSELoss()

out = n(input_x)
loss = loss_fn(out,label)
n.zero_grad()
loss.backward()

for param in n.parameters():
    w = param.item()
    g = param.grad

print('Input = %.4f; label = %.4f'%(input_x,label))
print('Weight = %.4f; output = %.4f'%(w,out))
print('Gradient w.r.t. the weight is %.4f'%(g))
print('And it should be %.4f'%(4*(w**2*input_x-label)*w*input_x))

Output:

Input = 10.0000; label = 5.0000
Weight = 0.9472; output = 8.9717
Gradient w.r.t. the weight is 150.4767
And it should be 150.4766

在这个例子中,我定义了一个只有一个线性层的模块(in_dim=out_dim=1并且没有偏见)。w是该层的权重;input_x是输入值;label是期望值。由于损失选择为 MSE,因此损失的公式为

((w^2)*input_x-label)^2

手工计算,我们有

dw/dx = 2*((w^2)*input_x-label)*(2*w*input_x)

我上面的示例的输出表明autograd给出了与手工计算相同的结果,这让我有理由相信它可以在这种情况下工作。但在实际应用中,该层可能具有更高维度的输入和输出,后面有一个非线性激活函数,并且神经网络可以有多个层。

我想问的是:我可以信任吗autograd处理这种情况,但比我的例子中复杂得多?当一个层被迭代调用时它是如何工作的?


这会工作得很好。从 autograd 引擎的角度来看,这不是循环应用程序,因为生成的计算图会将重复计算展开为线性序列。为了说明这一点,对于单个层,您可能有:

x -----> layer --------+
           ^           |
           |  2 times  |
           +-----------+

从 autograd 的角度来看,这看起来像:

x ---> layer ---> layer ---> layer

Here layer是同一层在图表上复制 3 次。这意味着在计算层权重的梯度时,它们将从所有三个阶段进行累积。所以使用时backward:

x ---> layer ---> layer ---> layer ---> loss_func
                                            |
       lback <--- lback <--- lback <--------+
         |          |          |
         |          v          |
         +------> weights <----+
                   _grad

Here lback表示的局部导数layer使用上游梯度作为输入的正向变换。每一个都会添加到该层的weights_grad.

循环神经网络在其基础上使用这种层(单元)的重复应用。例如,请参阅本教程使用字符级 RNN 对名称进行分类 https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch 中的 autograd 可以处理同一模块中层的重复使用吗? 的相关文章

随机推荐

  • 从矩阵中删除零行(优雅的方式)

    我有一个包含一些零行的矩阵 我想删除零行 矩阵是Nx3 我所做的很简单 我创造std vector其中每三个元素代表一行 然后我将其转换为Eigen MatrixXd 有没有一种优雅的方法来删除零行 include
  • 在 ncurses 中的指定位置添加相同符号的快捷方式是什么?

    我想添加str in ncurse屏幕 带坐标x 5 to 24 y 23 to 42 这是一个正方形 但我想不出一个简单的方法来做到这一点 我试过了 stdscr addstr range 23 42 range 5 24 但这行不通 它
  • 使用敏捷方法建造飞机? [关闭]

    Closed 这个问题是基于意见的 help closed questions 目前不接受答案 开发者可以从其他行业学到很多东西 作为一个思维练习 是否有可能使用敏捷技术建造一架客机 暂时忘记成本 对硬件 机身 机翼等 和软件进行迭代和增量
  • DYMOLA:opc 服务器如何使用 MATLAB 使用 dsin.txt 或 mat 文件进行初始化

    我在 DYMOLA 中创建了一个 OPC 服务器 现在我在 DYMOSIM 中有这个可以单击并初始化 使用 dsin txt 的 MAT 文件 现在我在 MATLAB 中创建了一个 GUI 文件 并获取变量的输入并创建了一个 mat 文件
  • 无法构建轮子 - 错误:无效命令“bdist_wheel”

    我已经尝试了这个非常相关的问题中的所有内容 为什么我无法在 python 中创建轮子 https stackoverflow com questions 26664102 why can i not create a wheel in py
  • postgresql自连接

    假设我有一张这样的桌子 id device cmd value id unique row ID device device identifier mac address cmd some arbitrary command value v
  • Rails:创建删除表级联迁移

    如何在 Rails 3 2 迁移中强制执行 DROP TABLE CASCADE 是否有一个选项可以传递给 drop table table name 在 Rails 4 中 您可以执行以下操作 drop table accounts fo
  • 如何使用在单击按钮上创建的用户触发图表中的放大和缩小?

    我正在构建一个角度应用程序 其中我们需要创建用于放大和缩小图表的单击按钮 我们可以使用可悬停模式栏上的按钮放大缩小图表 但这对于我们的应用程序来说不是必需的 我们希望使用通过单击按钮创建的用户来放大和缩小图表 有没有办法使用单击按钮触发可悬
  • Electron如何拦截http响应体

    有什么办法可以拦截BrowserWindow主进程中的http响应主体没有调试器 是否无法使用WebRequest类和onCompleted method 我可以使用调试器做到这一点 但由于某种原因我不能使用它 await w webCon
  • 在 Eclipse (Spring Source) 中,Grails 始终以生产模式构建

    当在 Grails 项目中使用 Eclipse 时 战争的构建似乎陷入了生产模式 如果您想部署到附加的 tcServer 您只需右键单击您的项目 然后选择 运行方式 gt 在服务器上运行 如果您将 grails 项目设置为 dev 右键单击
  • 气流:Dag 每隔几秒安排两次

    我尝试每天仅运行一次 DAG00 15 00 午夜 15 分钟 然而 它被安排了两次 间隔几秒钟 dag DAG my dag default args default args start date airflow utils dates
  • 显式语义分析

    我遇到了这个术语 显式语义分析 它使用维基百科作为参考 找到文档中的相似性并将它们分类 如果我错了 请纠正我 我遇到的链接是here http www cs technion ac il gabr resources code esa es
  • 十进制铸造

    我有一个这样的十进制数 62 000 0000000 我需要将该小数转换为 int 它的小数总是为零 所以我不会失去任何精度 我想要的是这样的 62 000 存储在 C 中的 int 变量中 我尝试了很多方法 但它总是给我一个错误 字符串的
  • Python列表来存储类实例?

    给定一个 python 类class Student 和一个清单names 然后我想创建几个实例Student 并将它们添加到列表中names names For storing the student instances class St
  • 如何将html页面的动态内容转换为pdf

    在 html 页面中 一些标签是使用 jquery 动态创建的 内容是使用 jquery 和 php 从 msql 数据库加载的 我想将这个动态页面转换为pdf 我尝试过以下代码 但它生成 html 页面静态部分的 pdf html cod
  • 在 C++ 中正确地将 `void*` 转换为整数

    我正在处理一些使用外部库的代码 您可以在其中通过void value 不幸的是 前一个处理此代码的人决定通过将整数转换为 void 指针来将整数传递给这些回调 void val 我现在正在努力清理这个混乱 并且我正在尝试确定将整数转换为整数
  • 估计命令如何查找 R 公式中的变量名称?

    我想使用 R 来估计大量模型nls 函数作用于用户定义的函数 由于许多变量在我的规范中是固定的 我想要一种在我的函数中预先设置它们的方法 但我没有正确理解 R 如何在公式中包含的函数中查找变量 我看过 Hadley Wickham 的高级
  • 我无法获取 servlet 页面中的 POST 值?

    我无法在 servlet 页面中获取 POST 值 我之前的问题与这个问题相关 如何从servlet页面中的ajax请求获取数据 https stackoverflow com questions 6042177 how to get th
  • 如何找到已安装的pandas版本

    我在使用 Pandas 的某些功能时遇到问题 如何查看我的安装版本是什么 Check pandas version In 76 import pandas as pd In 77 pd version Out 77 0 12 0 933 g
  • pytorch 中的 autograd 可以处理同一模块中层的重复使用吗?

    我有一层layer in an nn Module并在一次中使用两次或多次forward步 这个的输出layer稍后输入到相同的layer pytorch可以吗autograd正确计算该层权重的梯度 def forward x x self