PyTorch - 参数不变

2024-03-21

为了了解 pytorch 的工作原理,我尝试对多元正态分布中的一些参数进行最大似然估计。然而,它似乎不适用于任何协方差相关的参数。

所以我的问题是:为什么这段代码不起作用?

import torch


def make_covariance_matrix(sigma, rho):
    return torch.tensor([[sigma[0]**2, rho * torch.prod(sigma)],
                         [rho * torch.prod(sigma), sigma[1]**2]])


mu_true = torch.randn(2)
rho_true = torch.rand(1)
sigma_true = torch.exp(torch.rand(2))

cov_true = make_covariance_matrix(sigma_true, rho_true)
dist_true = torch.distributions.MultivariateNormal(mu_true, cov_true)

samples = dist_true.sample((1_000,))

mu = torch.zeros(2, requires_grad=True)
log_sigma = torch.zeros(2, requires_grad=True)
atanh_rho = torch.zeros(1, requires_grad=True)

lbfgs = torch.optim.LBFGS([mu, log_sigma, atanh_rho])


def closure():
    lbfgs.zero_grad()
    sigma = torch.exp(log_sigma)
    rho = torch.tanh(atanh_rho)
    cov = make_covariance_matrix(sigma, rho)
    dist = torch.distributions.MultivariateNormal(mu, cov)
    loss = -torch.mean(dist.log_prob(samples))
    loss.backward()
    return loss


lbfgs.step(closure)

print("mu: {}, mu_hat: {}".format(mu_true, mu))
print("sigma: {}, sigma_hat: {}".format(sigma_true, torch.exp(log_sigma)))
print("rho: {}, rho_hat: {}".format(rho_true, torch.tanh(atanh_rho)))

output:

mu: tensor([0.4168, 0.1580]), mu_hat: tensor([0.4127, 0.1454], requires_grad=True)
sigma: tensor([1.1917, 1.7290]), sigma_hat: tensor([1., 1.], grad_fn=<ExpBackward>)
rho: tensor([0.3589]), rho_hat: tensor([0.], grad_fn=<TanhBackward>)

>>> torch.__version__
'1.0.0.dev20181127'

换句话说,为什么有这样的估计log_sigma and atanh_rho没有改变它们的初始值?


创建协方差矩阵的方式不是后验概率:

def make_covariance_matrix(sigma, rho):
    return torch.tensor([[sigma[0]**2, rho * torch.prod(sigma)],
                         [rho * torch.prod(sigma), sigma[1]**2]])

从(多个)张量创建新张量时,仅保留输入张量的值。来自输入张量的所有附加信息都被剥离,因此所有图连接您的参数从此时开始被切断,因此反向传播无法通过。

这是一个简短的例子来说明这一点:

import torch

param1 = torch.rand(1, requires_grad=True)
param2 = torch.rand(1, requires_grad=True)
tensor_from_params = torch.tensor([param1, param2])

print('Original parameter 1:')
print(param1, param1.requires_grad)
print('Original parameter 2:')
print(param2, param2.requires_grad)
print('New tensor form params:')
print(tensor_from_params, tensor_from_params.requires_grad)

Output:

Original parameter 1:
tensor([ 0.8913]) True
Original parameter 2:
tensor([ 0.4785]) True
New tensor form params:
tensor([ 0.8913,  0.4785]) False

正如您所看到的,根据参数创建的张量param1 and param2,不跟踪梯度param1 and param2.

因此,您可以使用此代码来保留图形连接 and is 后验概率:

def make_covariance_matrix(sigma, rho):
    conv = torch.cat([(sigma[0]**2).view(-1), rho * torch.prod(sigma), rho * torch.prod(sigma), (sigma[1]**2).view(-1)])
    return conv.view(2, 2)

使用以下方法将这些值连接到一个平面张量torch.cat。然后使用将它们调整为正确的形状view().
这会产生与函数中相同的矩阵输出,但它保持与参数的连接log_sigma and atanh_rho.

这是更改后的步骤之前和之后的输出make_covariance_matrix。如您所见,现在您可以优化参数,并且值确实会发生变化:

Before:
mu: tensor([ 0.1191,  0.7215]), mu_hat: tensor([ 0.,  0.])
sigma: tensor([ 1.4222,  1.0949]), sigma_hat: tensor([ 1.,  1.])
rho: tensor([ 0.2558]), rho_hat: tensor([ 0.])

After:
mu: tensor([ 0.1191,  0.7215]), mu_hat: tensor([ 0.0712,  0.7781])
sigma: tensor([ 1.4222,  1.0949]), sigma_hat: tensor([ 1.4410,  1.0807])
rho: tensor([ 0.2558]), rho_hat: tensor([ 0.2235])

希望这可以帮助!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch - 参数不变 的相关文章

  • 使用Python的工业视觉相机[关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心 help reopen questi
  • 在 python 2 和 3 的spyder之间切换

    根据我在文档中了解到的内容 它指出您只需使用命令提示符创建一个新变量即可轻松在 2 个 python 环境之间切换 如果我已经安装了 python 2 7 则 conda create n python34 python 3 4 anaco
  • OpenCV 错误:使用 COLOR_BGR2GRAY 函数时断言失败

    我在使用 opencv 时遇到了一个奇怪的问题 我在 jupyter 笔记本中工作时没有任何问题 但在尝试运行此 Sublime 时却出现问题 错误是 OpenCV错误 cvtColor中断言失败 深度 CV 8U 深度 CV 16U 深度
  • 根据 pandas 中的条件交换列值

    我想按条件重新定位列 如果国家 地区是 日本 我需要将姓氏和名字反向重新定位 df pd DataFrame France Kylian Mbappe Japan Hiroyuki Tajima Japan Shiji Kagawa Eng
  • 无法将 datetime.datetime 与 datetime.date 进行比较

    我有以下代码并收到上述错误 由于我是 python 新手 我无法理解这里的语法以及如何修复错误 if not start or date lt start start date 有一个datetime date 从日期时间转换为日期的方法
  • DataFrame 中的字符串,但 dtype 是对象

    为什么 Pandas 告诉我我有对象 尽管所选列中的每个项目都是一个字符串 即使在显式转换之后也是如此 这是我的数据框
  • 对打开文件的脚本进行单元测试

    我编写了一个脚本 它打开一个文件 读取内容并进行一些操作和计算 并将它们存储在集合和字典中 我该如何为这样的事情编写单元测试 我的问题具体是 我会测试文件是否打开 文件很大 这是unix字典文件 我如何对计算进行单元测试 我真的必须手动计算
  • 如何在“python setup.py test”中运行 py.test 和 linter

    我有一个项目setup py文件 我用pytest作为测试框架 我还在我的代码上运行各种 linter pep8 pylint pydocstyle pyflakes ETC 我用tox在多个 Python 版本中运行它们 并使用以下命令构
  • 远程控制或脚本打开 Office 从 Python 编辑 Word 文档

    我想 最好在 Windows 上 在特定文档上启动 Open Office 搜索固定字符串并将其替换为我的程序选择的另一个字符串 我该如何从外部 Python 程序中做到这一点 OLE 什么 原生 Python 脚本解决方案 The doc
  • 如何在 openpyxl 中设置或更改表格的默认高度

    我想通过openpyxl更改表格高度 并且我希望首先默认一个更大的高度值 然后我可以设置自动换行以使我的表格更漂亮 但我不知道如何更改默认高度 唯一的到目前为止 我知道更改表格高度的方法是设置 row dimension idx heigh
  • 使用 Python 抓取维基百科数据

    我正在尝试从以下内容中检索 3 列 NFL 球队 球员姓名 大学球队 维基百科页面 http en wikipedia org wiki 2008 NFL draft 我是 python 新手 一直在尝试使用 beautifulsoup 来
  • 使用reduce方法的斐波那契数列

    于是 我看到有人用reduce方法来计算斐波那契数列 这是他的想法 1 0 1 1 2 1 3 2 5 3 对应于 1 1 2 3 5 8 13 21 代码如下所示 def fib reduce n initial 1 0 dummy ra
  • 在 Windows 上将 Word2vec 与 Tensorflow 结合使用

    In 本教程文件 https github com tensorflow models blob master tutorials embedding word2vec py L45通过 Tensorflow 找到以下行 第 45 行 来加
  • 将具有不同大小的行的数据加载到 Numpy 数组中

    假设我有一个包含如下数据的文本文件 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 如何将它加载到 numpy 数组中 使其看起来像这样 1 2 3 4 5 0 6 7 8 0 0 0 9 1
  • 从 Apache 运行 python 脚本的最简单方法

    我花了很长时间试图弄清楚这一点 我基本上正在尝试开发一个网站 当用户单击特定按钮时 我必须在其中执行 python 脚本 在研究了 Stack Overflow 和 Google 之后 我需要配置 Apache 以便能够运行 CGI 脚本
  • PermanentTaskFailure:“模块”对象没有属性“迁移”

    我在 google appengine 上使用 Nick Johnson 的批量更新库 http blog notdot net 2010 03 Announcing a robust datastore bulk update utili
  • python 中的基本矩阵转置

    我尝试了 python 中矩阵转置的最基本方法 但是 我没有得到所需的结果 接下来是代码 A 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 print A def TS A B A for i in range len A
  • Scikit Learn - K-Means - 肘部 - 标准

    今天我想学习一些关于 K means 的知识 我已经了解该算法并且知道它是如何工作的 现在我正在寻找正确的 k 我发现肘部准则作为检测正确的 k 的方法 但我不明白如何将它与 scikit learn 一起使用 在 scikit learn
  • Django - 缺少 1 个必需的位置参数:'request'

    我收到错误 get indiceComercioVarejista 缺少 1 个必需的位置参数 要求 当尝试访问 get indiceComercioVarejista 方法时 我不知道这是怎么回事 views from django ht
  • bool() 和operator.truth() 有什么区别?

    bool https docs python org 3 library functions html bool and operator truth https docs python org 3 library operator htm

随机推荐

  • Entity Framework 4.1 - Code First:多对多关系

    我想建立这样的关系 一个区域位于 x 个其他区域的附近 public class Zone public string Id get set public string Name get set public virtual ICollec
  • 在 Java 中使用 ENUMS 验证值组合的最佳方法是什么?

    我通过如下定义 ENUM 来验证从数据库检索的记录的状态 public enum RecordStatusEnum CREATED CREATED INSERTED INSERTED FAILED FAILED private String
  • 在Linux中使用自定义规则在多个端口上运行的SSH服务[关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 我正在努力设置一台在多个端口上运行 SSH 服务的服务器 例如端口 22 和 5522 这些端口应该具有一组不同的规则 即 我们为端口 2
  • 在 C# 中如何将字符串转换为 ascii 二进制?

    不久前 高中一年级 我请一位非常优秀的大三 C 程序员制作一个简单的应用程序 将字符串转换为二进制 他给了我以下代码示例 void ToBinary char str char tempstr int k 0 tempstr new cha
  • 列表未添加 C# 中的所有值

    我尝试了下面的代码来创建 json 代码 代码工作正常 我从数据库加载值 但只有最后一个值我得到了输出 剩余值未添加 DataTable dt new DataTable var objectToSerialize new RootObje
  • 解除PDF密码保护,知道密码[关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 我有一堆 pdf 文件 我想从中删除密码 请注意 我知道密码 因此无需暴力破解 我正在 Mac 上工作 所以我想制作一个应用程序来删除这些
  • Git合并分支到master

    我有一个主分支和一个工作分支branch 1 我想 动 一下branch 1正是如此master 所以我想要这样的东西 git checkout master git merge branch 1 I don t know what is
  • symfony2 - twig - 如何从树枝模板内部渲染树枝模板

    我有一个 xxx html twig 文件 它显示一个页面 但是当我想用不同的数据刷新页面并用新数据更新它时 我有一个选择和一个提交按钮 问题是我不知道如何在控制器中调用一个动作 我从我的树枝传递参数并调用新数据 然后我用新参数再次渲染相同
  • Python:单击按钮[重复]

    这个问题在这里已经有答案了 我在单击此按钮时遇到问题 该按钮的 HTML 代码如下所示
  • Eventbug 的实际工作原理

    Eventbug http getfirebug com wiki index php Firebug Extensions Eventbug是 Firebug 的一个附加组件 是的 附加组件的附加组件 其目的是跟踪分配给 DOM 元素的所
  • ld:架构armv7的871个重复符号,clang:错误:链接器命令失败,退出代码1(使用-v查看调用)

    我在 iPhone 应用程序中使用 FastPDFKit 来显示 PDF 当我在模拟器上运行该项目时 它工作正常 但是 当我在 iPhone 上运行该项目时 出现以下错误 duplicate symbol value map in User
  • 如何多次查询并最后关闭连接?

    我想打开与 mysql 数据库的连接并使用不同的查询检索数据 我是否需要在每次获取数据时关闭连接 或者是否有更好的方法可以多次查询并仅在最后关闭连接 目前我这样做 db dbConnect MySQL user root password
  • 我们可以导出 Kibana 中的所有搜索结果数据吗?

    我正在尝试导出 Kibana 5 中的所有搜索结果数据 但它仅导出结果的计数 有没有办法将所有数据导出为 CSV 格式 在基巴纳 到目前为止尝试过 单击搜索结果底部的符号 可视化 尝试使用 原始 和 格式化 选项 数据以 CSV 格式导出
  • symfony:如何设置不同环境的配置参数文件?

    如何为每个环境设置不同的配置参数文件 目前参数在parameters yml两者都使用dev and prod环境 但我需要不同的参数才能在产品中部署我的应用程序 您可以将所有使用的参数放入dev环境在一个app config parame
  • Postgresql计数+排序性能

    我使用 postgresql 和 psycopg2 构建了一个小型库存系统 一切都很好 除了当我想创建内容的聚合摘要 报告时 由于 count 和排序 我的性能非常糟糕 数据库架构如下 CREATE TABLE hosts id SERIA
  • 如何更新 Kubernetes 中的 api 版本列表

    我尝试在我的配置中使用 autoscaling v2beta2 apiVersion 如下本教程 https kubernetes io docs tasks run application horizontal pod autoscale
  • Perl 中的简单并行处理

    我在某个对象的函数内有一些代码块 它们可以并行运行并加快速度 我尝试使用subs parallel通过以下方式 所有这些都在函数体内 my is a done parallelize block a do some work return
  • 意外的 T_ENCAPSED_AND_WHITESPACE,期待 T_STRING 或 T_VARIABLE 或 T_NUM_STRING 错误 [重复]

    这个问题在这里已经有答案了 我对这个错误一直茫然 似乎不知道问题是什么 当我运行查询时 我收到此错误 意外的 T ENCAPSED AND WHITESPACE 需要 T STRING 或 T VARIABLE 或 T NUM STRING
  • 带 Bootstrap 的 Google 地图没有响应

    我正在使用 bootstrap 并嵌入了 Google Maps API 3 map canvas没有反应 它是固定宽度 另外 如果我使用height auto and width auto地图未显示在页面中 Why div class c
  • PyTorch - 参数不变

    为了了解 pytorch 的工作原理 我尝试对多元正态分布中的一些参数进行最大似然估计 然而 它似乎不适用于任何协方差相关的参数 所以我的问题是 为什么这段代码不起作用 import torch def make covariance ma