在 Pytorch 中估计高斯模型的混合

2024-05-09

我实际上想估计一个以高斯混合作为基本分布的归一化流,所以我有点被火炬困住了。但是,您可以通过估计 torch 中高斯模型的混合来在代码中重现我的错误。我的代码如下:

import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets as datasets

import torch
from torch import nn
from torch import optim
import torch.distributions as D

num_layers = 8
weights = torch.ones(8,requires_grad=True).to(device)
means = torch.tensor(np.random.randn(8,2),requires_grad=True).to(device)#torch.randn(8,2,requires_grad=True).to(device)
stdevs = torch.tensor(np.abs(np.random.randn(8,2)),requires_grad=True).to(device)
mix = D.Categorical(weights)
comp = D.Independent(D.Normal(means,stdevs), 1)
gmm = D.MixtureSameFamily(mix, comp)

num_iter = 10001#30001
num_iter2 = 200001
loss_max1 = 100
for i in range(num_iter):
    x = torch.randn(5000,2)#this can be an arbitrary x samples
    loss2 = -gmm.log_prob(x).mean()#-densityflow.log_prob(inputs=x).mean()
    optimizer1.zero_grad()
    loss2.backward()
    optimizer1.step()

我得到的错误是:

0
8.089411823514835
Traceback (most recent call last):

  File "/home/cameron/AnacondaProjects/gmm.py", line 183, in <module>
    loss2.backward()

  File "/home/cameron/anaconda3/envs/torch/lib/python3.7/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)

  File "/home/cameron/anaconda3/envs/torch/lib/python3.7/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

正如您所看到的,模型运行了 1 次迭代后。


您的代码中存在排序问题,因为您在训练循环之外创建高斯混合模型,那么在计算损失时,高斯混合模型将尝试使用您在定义模型时设置的参数的初始值,但是optimizer1.step()已经修改了该值,所以即使您设置了loss2.backward(retain_graph=True)还是会出现这样的错误:RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

解决这个问题的方法很简单,只要更新参数就创建新的高斯混合模型,示例代码按预期运行:

import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets as datasets

import torch
from torch import nn
from torch import optim
import torch.distributions as D

num_layers = 8
weights = torch.ones(8,requires_grad=True)
means = torch.tensor(np.random.randn(8,2),requires_grad=True)
stdevs = torch.tensor(np.abs(np.random.randn(8,2)),requires_grad=True)

parameters = [weights, means, stdevs]
optimizer1 = optim.SGD(parameters, lr=0.001, momentum=0.9)

num_iter = 10001
for i in range(num_iter):
    mix = D.Categorical(weights)
    comp = D.Independent(D.Normal(means,stdevs), 1)
    gmm = D.MixtureSameFamily(mix, comp)

    optimizer1.zero_grad()
    x = torch.randn(5000,2)#this can be an arbitrary x samples
    loss2 = -gmm.log_prob(x).mean()#-densityflow.log_prob(inputs=x).mean()
    loss2.backward()
    optimizer1.step()

    print(i, loss2)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 Pytorch 中估计高斯模型的混合 的相关文章

随机推荐

  • 如何根据 Kubernetes / Docker 事件发送警报?

    是否可以根据 Kubernetes 集群内发生的事件以某种方式发送警报 到电子邮件 slack 特别是 如果 Pod 意外重新启动或 Pod 无法启动 那么获取警报将非常有用 同样 了解 Pod 的 CPU 使用率是否超过特定阈值并获取警报
  • 在 r 中的 unique() 函数中使用管道不起作用

    我在使用管道运算符 gt 和 unique 函数时遇到一些麻烦 df data frame a c 1 2 3 1 b a unique df a no problem here df gt unique a not working her
  • Javascript:自动点击按钮?

    我正在学习如何编写 chrome 扩展 而且我对 javascript 还很陌生 这是一些 html div class button data a class button1 whiteColor href http link1 com
  • 在 Visual Studio 2017 中变量模板中的除法返回零

    这大概是一个视觉工作室2017 questions tagged visual studio 2017与此问题相关的错误 Visual Studio 中 Lambda 的模板变量错误 https stackoverflow com q 49
  • p:remoteCommand 无法在异步模式下工作

    如果有人可以在这里给我帮助 我将不胜感激 我在页面上有一个选项卡式布局 通过单击选项卡 p commandLink 我想初始化该选项卡的适当数据并更新显示内容的区域 由于我希望初始化能够延迟发生 当呈现选项卡内容时 因此我使用 Primef
  • Open XML SDK:尝试填充超过 25 列时出现“不可读内容”错误

    我使用 C 中的 Open XML SDK 创建了一个电子表格 并成功填充了两个工作表 当尝试填充第三个时 我得到了 内容不可读 打开已完成的文档时出错 并且当我尝试在第三个文档中连续填充超过 25 个单元格时 似乎会发生此错误 我使用的代
  • Jqueryui:如何在对话框周围制作阴影?

    我正在尝试在 jqueryui 对话框周围放置阴影 就像是 div class ui widget shadow ui corner all Some stuff in the box with a shadow around it div
  • 在Java中将浮点数组写入文件

    我正在读取 NetCDF 文件 我想将每个数组作为浮点数组读取 然后将浮点数组写入新文件 如果我读取浮点数组 然后迭代数组中的每个元素 使用 DataOutputStream 我可以使其工作 但这非常非常慢 我的 NetCDF 文件超过 1
  • 从 pandas udf 记录

    我正在尝试从 python 转换中调用的 pandas udf 进行日志记录 因为在执行器上调用的代码不会显示在驱动程序的日志中 我一直在寻找一些选项 但到目前为止最接近的选项是这个one https stackoverflow com q
  • 如何从 Perl 中的字符串中去除无效的 XML 字符?

    我正在寻找一种标准的 经过批准的 可靠的方法 可以在将字符串写入 XML 文件之前从字符串中删除无效字符 我在这里讨论的是包含退格键 H 和换页符等的文本块 There has成为执行此操作的标准库 模块函数 但我找不到它 我在用着XML
  • PHP7 返回类型为 JSON

    PHP 7 有一个新功能 即返回类型声明 我们可以返回一个 字符串 类型 例如 function myFunction a string 我们还可以返回一个 数组 类型 例如 function myFunction a array 但是我们
  • 带有 React 的 Google Analytics 无法正常工作

    我在我的反应项目中使用谷歌分析 即使我在线 它也不会显示任何活跃用户 我尝试过在网上找到的不同方法 但似乎都不起作用 我只在本地主机上尝试过 而不是在已部署的网站上尝试过 但我认为它应该仍然有效 这是我的代码 我的应用程序 js impor
  • php版本升级到8后,出现此错误

    我正在将 php 7 升级到 php 8 0 在以前的 php 版本 7 中 这段代码工作正常 child parent parent resultData gt parent id gt child Yes 上面的代码在 php 7 中工
  • PostgreSQL 中的逆透视表

    我有下表作为 SUM Case End 的结果 Account Product A Product B Product C 101 1000 2000 3000 102 2000 1000 0 103 2000 1000 0 104 200
  • .Net 将 NULL 值从变量值插入 SQL Server 数据库

    也有类似的问题 但答案不是我想要的 如果引用为 NULL 或尚未分配值 我想将 NULL 值插入 SQL Server 数据库 目前我正在测试 null 它看起来像 String testString null if testString
  • 如何使用 System.out.println 以十六进制打印字节?

    我已经声明了一个字节数组 我使用的是 Java byte test new byte 3 test 0 0x0A test 1 0xFF test 2 0x01 如何打印数组中存储的不同值 如果我使用 System out println
  • struts2 date无法通过jquery datetimepicker获取时间

    我是struts2的新手 创建了一个小型Web应用程序 我想要一个帖子是计时器 我选择jquery datetimpicker 在用户选择时间和日期后 它将显示用户选择的时间和日期 我用这个jquery http www javascrip
  • Maven 配置文件 - 如何为父级运行插件一次,为模块运行多次?

    我对詹金斯的输出有点困惑 Jenkins 上的工作 底部缩短了 pom xml mvn deploy Pprofile1 我的所有插件都会运行 4 次 父 pom xml 父 module1 pom xml 父 module2 pom xm
  • 我应该等待 Flash Player 10.1 还是使用 Flash Lite 3 来为手机和设备开发 Flash 内容

    Adobe 将在 2010 年第一季度推出 Flash Player 10 1 这将在桌面和移动设备上提供一致的运行时 因此我假设如果它是为 Web 构建的 那么它也可以在移动设备上运行 我即将开始为手机开发基于 Flash 的应用程序 我
  • 在 Pytorch 中估计高斯模型的混合

    我实际上想估计一个以高斯混合作为基本分布的归一化流 所以我有点被火炬困住了 但是 您可以通过估计 torch 中高斯模型的混合来在代码中重现我的错误 我的代码如下 import numpy as np import matplotlib p