y = x / sum(x, dim=0) 的反向传播,其中张量 x 的大小为 (H,W)

2024-01-14

Q1.

我正在尝试使用 pytorch 制作自定义 autograd 函数。

但我在使用 y = x / sum(x, dim=0) 进行分析反向传播时遇到了问题

其中张量 x 的大小为(高度,宽度)(x 是二维)。

这是我的代码

class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
  ctx.save_for_backward(input)
  input = input / torch.sum(input, dim=0)

  return input

@staticmethod
def backward(ctx, grad_output):
  input = ctx.saved_tensors[0]
  H, W = input.size()
  sum = torch.sum(input, dim=0)
  grad_input = grad_output * (1/sum - input*1/sum**2)

  return grad_input

我使用 (torch.autograd import) gradcheck 来比较雅可比矩阵,

from torch.autograd import gradcheck
func = MyFunc.apply
input = (torch.randn(3,3,dtype=torch.double,requires_grad=True))
test = gradcheck(func, input)

结果是

请有人帮助我获得正确的反向传播结果

Thanks!


Q2.

感谢您的解答!

由于您的帮助,我可以在 (H,W) 张量的情况下实现反向传播。

然而,当我在 (N,H,W) 张量的情况下实现反向传播时,我遇到了问题。 我认为问题在于初始化新的张量。

这是我的新代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyFunc(torch.autograd.Function):
  @staticmethod
  def forward(ctx, input):
    ctx.save_for_backward(input)
    
    N = input.size(0)
    for n in range(N):
      input[n] /= torch.sum(input[n], dim=0)

    return input

  @staticmethod
  def backward(ctx, grad_output):
    input = ctx.saved_tensors[0]
    N, H, W = input.size()
    I = torch.eye(H).unsqueeze(-1)
    sum = input.sum(1)

    grad_input = torch.zeros((N,H,W), dtype = torch.double, requires_grad=True)
    for n in range(N):
      grad_input[n] = ((sum[n] * I - input[n]) * grad_output[n] / sum[n]**2).sum(1)

    return grad_input

梯度检查代码是

from torch.autograd import gradcheck
func = MyFunc.apply
input = (torch.rand(2,2,2,dtype=torch.double,requires_grad=True))
test = gradcheck(func, input)
print(test)

结果是在此输入图像描述 https://i.stack.imgur.com/KuY6c.png

我不知道为什么会出现错误...

您的帮助将对我实现自己的卷积网络非常有帮助。

谢谢!祝你今天过得愉快。


让我们看一个包含单列的示例,例如:[[x1], [x2], [x3]].

Let sum be x1 + x2 + x3,然后标准化x会给y = [[y1], [y2], [y3]] = [[x1/sum], [x2/sum], [x3/sum]]。您正在寻找dL/dx1, dL/x2, and dL/x3- 我们将它们写为:dx1, dx2, and dx3。对所有人都一样dL/dyi.

So dx1等于dL/dy1*dy1/dx1 + dL/dy2*dy2/dx1 + dL/dy3*dy3/dx1。那是因为x1贡献于相应列上的所有输出元素:y1, y2, and y3.

We have:

  • dy1/dx1 = d(x1/sum)/dx1 = (sum - x1)/sum²

  • dy2/dx1 = d(x2/sum)/dx1 = -x2/sum²

  • 相似地,dy3/dx1 = d(x3/sum)/dx1 = -x3/sum²

所以dx1 = (sum - x1)/sum²*dy1 - x2/sum²*dy2 - x3/sum²*dy3。同样适用于dx2 and dx3。因此,雅可比行列式是[dxi]_i = (sum - xi)/sum² and [dxi]_j = -xj/sum²(对全部j不同于i).

在您的实现中,您似乎缺少所有非对角线组件。

保持相同的一列示例,x1=2, x2=3, and x3=5:

>>> x = torch.tensor([[2.], [3.], [5.]])

>>> sum = input.sum(0)
tensor([10])

雅可比行列式将是:

>>> J = (sum*torch.eye(input.size(0)) - input)/sum**2
tensor([[ 0.0800, -0.0200, -0.0200],
        [-0.0300,  0.0700, -0.0300],
        [-0.0500, -0.0500,  0.0500]])

对于具有多列的实现,这有点棘手,更具体地说对于对角矩阵的形状。更容易保留column轴放在最后,这样我们就不必为广播而烦恼:

>>> x = torch.tensor([[2., 1], [3., 3], [5., 5]])
>>> sum = x.sum(0)
tensor([10.,  9.])

>>> diag = sum*torch.eye(3).unsqueeze(-1).repeat(1, 1, len(sum))
tensor([[[10.,  9.],
         [ 0.,  0.],
         [ 0.,  0.]],

        [[ 0.,  0.],
         [10.,  9.],
         [ 0.,  0.]],

        [[ 0.,  0.],
         [ 0.,  0.],
         [10.,  9.]]])

Above diag形状为(3, 3, 2)哪里两个columns位于最后一个轴上。请注意我们如何不需要广播sum.

What I wouldn't所做的是:torch.eye(3).unsqueeze(0).repeat(len(sum), 1, 1)。既然是这样的形状——(2, 3, 3)- 你将不得不使用sum[:, None, None],并且需要进一步广播......

雅可比行列式很简单:

>>> J = (diag - x)/sum**2
tensor([[[ 0.0800,  0.0988],
         [-0.0300, -0.0370],
         [-0.0500, -0.0617]],

        [[-0.0200, -0.0123],
         [ 0.0700,  0.0741],
         [-0.0500, -0.0617]],

        [[-0.0200, -0.0123],
         [-0.0300, -0.0370],
         [ 0.0500,  0.0494]]])

您可以通过使用任意的操作反向传播来检查结果dy矢量(不与torch.ones不过,你会得到0是因为J!)。经过反向传播后,x.grad应等于torch.einsum('abc,bc->ac', J, dy).

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

y = x / sum(x, dim=0) 的反向传播,其中张量 x 的大小为 (H,W) 的相关文章

随机推荐

  • 无需 jQuery 即可获得 innerWidth() 等效项

    我目前正在努力从我编写的一些代码中消除 jQuery 并且我有一部分代码 其中我正在计算某些代码的内部和外部宽度span元素 这好像是 getBoundingClientRect 对于获取元素的外部宽度效果很好 但我在获取内部宽度方面有点困
  • 如果参数有成员变量则特殊化函数

    我有一个模板化的错误报告函数 因为它可以报告许多不同消息类别的错误 template
  • 将希腊字符和星号 (*) 添加到轴标题

    我想在 R 中的直方图的 x 轴上添加一个希腊字符 我可以单独写希腊字符或与帽子一起写 但问题是我需要这个字符带有帽子和星号 一起 更具体地说 我想要像 hat phi 这是我所做的 x rnorm 1000 hist x nclass 1
  • sublime text-“列出包含“查找”字符串的行

    如何列出包含 find 命令中的匹配项的行 即 我想在单独的窗口中列出所有匹配的行 目前只能转到下一个 上一个 查找 Try Find in Files Cmd Shift F on a Mac presumably Ctrl Shift
  • 多个线程同时使用同一个 JDBC 连接

    我试图更好地理解如果多个线程尝试使用相同的 JDBC 连接同时执行不同的 sql 查询会发生什么 结果在功能上是否正确 对性能有何影响 会穿线A必须等待线程B完全完成其查询 或者会穿线A能够在线程之后立即发送其查询B已发送查询 之后数据库将
  • 如何使用 Hibernate Criteria 选择嵌套属性

    我有三个实体 例如注册 用户和国家 地区 基本上 一个注册属于一个用户 一个用户属于一个国家 现在我尝试使用以下内容从注册中选择国家 地区名称 Criteria criteria getSession createCriteria Regi
  • 调试和发布版本之间可能会出现差异? [关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心 help reopen questi
  • 为什么R包lubridate无法解析多种格式的向量?

    我正在使用包lubridate解析异构格式日期的向量并将它们转换为字符串 如下所示 parse date time c 12 17 1996 04 00 00 PM 4 18 1950 0130 c m d Y I M S p m d Y
  • 异步图像下载器

    我编写了一个小类来使用 NSURLConnection 执行图像的下载 其想法是将下载委托给此类以避免阻止执行 因此 我将目标 UIImageView 通过引用 和 url 传递给函数并开始下载 void getImage UIImageV
  • 如何将图像转换为灰度?

    我想通过代码 android 做与上图完全相同的事情 但我对执行此操作的算法感到困惑 我所知道的是 对于每个像素 将 RGB 转换为 HSL 将 HSL 转换回 RGB 谁能帮我解释一下第 2 步要做什么 非常感谢 ps 我可以通过 Col
  • org.openqa.selenium.WebDriverException:未知错误:无法确定加载状态

    我是 Selenium 的新手 需要一些线索来查找以下错误的根本原因 硒版本 3 5 3 ChromeDriver 版本 2 29 4 Chrome 版本 63 org openqa selenium WebDriverException
  • SAP HANA 交叉应用替代方案

    HANA sql 中是否有 MsSql 中可用的 交叉应用 运算符的替代方案 或者有没有办法对表中的值应用函数 就像是 select T F from T cross join someFunction T Value F 您可以在 SAP
  • 为什么我的 ExtJS 网格中的按钮显示为“[object Object]”?

    In an ExtJS网格我有一列 其中当单元格的内容为某个值时 button应该显示 我定义了将包含这样的按钮的列 该按钮调用渲染函数 header Payment Type width 120 sortable true rendere
  • 在nodejs中将文件附加到zip

    我正在制作一个应用程序 您可以在其中编辑文件 它应该将编辑后的文件附加到 zip 存档中并使其可下载 它应该是跨平台的 Windows 和 Linux 我的目标是以编程方式生成编辑后的文件并将其附加到静态存档 始终相同 大约 3 4MB 但
  • pandas 中的新列 - 通过应用列表 groupby 将系列添加到数据框

    给出以下内容df Id other concat 0 A z 1 1 A y 2 2 B x 3 3 B w 4 4 B v 5 5 B u 6 我想要的结果是new包含分组值作为列表的列 Id other concat new 0 A z
  • 从 ASP.NET 3.5 应用程序在 IFRAME 内运行 GWT 应用程序(包括 Applet)?

    我们正在考虑将成熟的 GWT Google Web Toolkit 2 0 应用程序与现有的 ASP NET 3 5 应用程序集成 我的第一直觉反应是这是一个可怕的弗兰肯斯坦想法 然而 客户坚持要求我们使用第三方开发的应用程序 我几乎无法控
  • 使用 HTML5 输入的 Ionic 文件上传

    我正在使用 ionic 构建一个包含文件上传的移动应用程序 Versions 离子 离子 CLI 4 10 3 离子框架 离子角3 9 3 ionic app scripts 3 2 3 支持的文件有 images Word 文档 doc
  • log4j:错误尝试附加到名为的已关闭附加程序

    我们的 weblogic 8 1 服务器中有大约 19 个应用程序 每个应用程序都是一个带有一些 ejb mdb 等的 Ear 应用程序 每个应用程序都有一个在文件系统中某处的 properties 文件中定义的 log4j 属性 我们不断
  • 如何在oracle中执行包含两个insert语句的存储过程?

    我正在尝试执行一个包含两个插入语句的存储过程 但是我不确定如何执行此操作或是否有更好的方法 该存储过程有两个插入语句 将数据插入两个不同的表中 我努力了execute new order 请参阅下面 然后传入值 但出现此错误ERROR at
  • y = x / sum(x, dim=0) 的反向传播,其中张量 x 的大小为 (H,W)

    Q1 我正在尝试使用 pytorch 制作自定义 autograd 函数 但我在使用 y x sum x dim 0 进行分析反向传播时遇到了问题 其中张量 x 的大小为 高度 宽度 x 是二维 这是我的代码 class MyFunc to