在 pytorch 上使用 MC Dropout 测量不确定性

2023-11-25

我正在尝试在 Pytorch 上使用 Mc Dropout 实现贝叶斯 CNN, 主要思想是,通过在测试时应用 dropout 并运行多次前向传递,您可以从各种不同的模型中获得预测。 我发现了 Mc Dropout 的应用,但我真的不明白他们是如何应用这种方法的,以及他们到底是如何从预测列表中选择正确的预测的

这是代码


 def mcdropout_test(model):
    model.train()
    test_loss = 0
    correct = 0
    T = 100
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output_list = []
        for i in xrange(T):
            output_list.append(torch.unsqueeze(model(data), 0))
        output_mean = torch.cat(output_list, 0).mean(0)
        test_loss += F.nll_loss(F.log_softmax(output_mean), target, size_average=False).data[0]  # sum up batch loss
        pred = output_mean.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nMC Dropout Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


    train()
    mcdropout_test()

我已经更换了

data, target = Variable(data, volatile=True), Variable(target)

通过增加

with torch.no_grad():一开始

这就是我定义 CNN 的方式

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 192, 5, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(192, 192, 5, padding=2)
        self.fc1 = nn.Linear(192 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)
        self.dropout = nn.Dropout(p=0.3)
        
        nn.init.xavier_uniform_(self.conv1.weight)
        nn.init.constant_(self.conv1.bias, 0.0)
        nn.init.xavier_uniform_(self.conv2.weight)
        nn.init.constant_(self.conv2.bias, 0.0)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.constant_(self.fc1.bias, 0.0)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.constant_(self.fc2.bias, 0.0)
        nn.init.xavier_uniform_(self.fc3.weight)
        nn.init.constant_(self.fc3.bias, 0.0)


    def forward(self, x):
        x = self.pool(F.relu(self.dropout(self.conv1(x))))  # recommended to add the relu
        x = self.pool(F.relu(self.dropout(self.conv2(x))))  # recommended to add the relu
        x = x.view(-1, 192 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(self.dropout(x)))
        x = self.fc3(self.dropout(x))  # no activation function needed for the last layer
        return x

谁能帮助我在 CNN 上正确实现 Monte Carlo Dropout 方法?


在 Pytorch 中实现 MC Dropout 很容易。所需要做的就是将模型的 dropout 层设置为训练模式。这允许在不同的前向传递期间使用不同的漏失掩模。下面是 Pytorch 中 MC Dropout 的实现,说明了如何将来自各种前向传递的多个预测堆叠在一起并用于计算不同的不确定性度量。

import sys

import numpy as np

import torch
import torch.nn as nn

def enable_dropout(model):
    """ Function to enable the dropout layers during test-time """
    for m in model.modules():
        if m.__class__.__name__.startswith('Dropout'):
            m.train()

def get_monte_carlo_predictions(data_loader,
                                forward_passes,
                                model,
                                n_classes,
                                n_samples):
    """ Function to get the monte-carlo samples and uncertainty estimates
    through multiple forward passes

    Parameters
    ----------
    data_loader : object
        data loader object from the data loader module
    forward_passes : int
        number of monte-carlo samples/forward passes
    model : object
        keras model
    n_classes : int
        number of classes in the dataset
    n_samples : int
        number of samples in the test set
    """

    dropout_predictions = np.empty((0, n_samples, n_classes))
    softmax = nn.Softmax(dim=1)
    for i in range(forward_passes):
        predictions = np.empty((0, n_classes))
        model.eval()
        enable_dropout(model)
        for i, (image, label) in enumerate(data_loader):
            image = image.to(torch.device('cuda'))
            with torch.no_grad():
                output = model(image)
                output = softmax(output)  # shape (n_samples, n_classes)
            predictions = np.vstack((predictions, output.cpu().numpy()))

        dropout_predictions = np.vstack((dropout_predictions,
                                         predictions[np.newaxis, :, :]))
        # dropout predictions - shape (forward_passes, n_samples, n_classes)

    # Calculating mean across multiple MCD forward passes 
    mean = np.mean(dropout_predictions, axis=0)  # shape (n_samples, n_classes)

    # Calculating variance across multiple MCD forward passes 
    variance = np.var(dropout_predictions, axis=0)  # shape (n_samples, n_classes)

    epsilon = sys.float_info.min
    # Calculating entropy across multiple MCD forward passes 
    entropy = -np.sum(mean * np.log(mean + epsilon), axis=-1)  # shape (n_samples,)

    # Calculating mutual information across multiple MCD forward passes 
    mutual_info = entropy - np.mean(np.sum(-dropout_predictions * np.log(dropout_predictions + epsilon),
                                           axis=-1), axis=0)  # shape (n_samples,)

继续讨论上面问题中发布的实现,通过首先将模型设置为训练模式(model.train())。请注意,这是不可取的,因为如果模型中存在除 dropout 之外的层(例如批量标准化),预测中将会引入不需要的随机性。因此,最好的方法是将 dropout 层设置为训练模式,如上面的代码片段所示。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 pytorch 上使用 MC Dropout 测量不确定性 的相关文章

随机推荐

  • SceneKit - SCNText 居中不正确

    我尝试在下面的代码中将文本字符串 SCNText 放入框 SCNBox 中 框的大小看起来正确 但文本不在框的右中心 有什么想法或解决方案吗 谢谢 let geoText SCNText string Hello extrusionDept
  • 如何继承泛型虚方法?

    我有以下代码 我想重写基础列表的 Notify 方法 以便能够在列表的修改上创建事件 TDescendantList class TObjectList
  • 为 OAuth 创建签名和随机数 (Ruby)

    我希望从我的应用程序访问 SmugMug 的 API 以获取用户的相册和图像 用户已通过 ruby 的 OmniAuth 进行身份验证 根据SmugMug 的 OAuth API OAuth需要六个参数 我可以使用 OmniAuth 获取令
  • 添加最少数量的字符来形成回文

    问题 给定任何字符串 添加尽可能少的字符 使其在线性时间内成为回文 I m only able to come up with a O N2 solution 有人可以帮我解决 O N 问题吗 恢复字符串 使用修改过的高德莫里斯普拉特查找最
  • 零长度正则表达式和无限匹配?

    在试图详细阐述答案时this问题 我现在正在尝试了解零长度正则表达式的行为 含义 我经常使用 www regexr com 作为测试 调试 了解正则表达式中发生的情况的游乐场 所以我们有这个最平庸的场景 正则表达式是a 输入字符串是dgwa
  • 如何将动态对象序列化为xml C#

    我有一个object System Collections Generic List 其中包含 1000object DynamicData 在它的内部 每个都有 4 个键和值 还有一个List里面有 2 个键和值 我需要将此对象序列化为
  • 为什么使用匿名类型可以工作,而使用显式类型却不能在 GroupBy 中使用?

    我有一个问题 我希望组类型是强类型的 但如果我这样做 它就不能正确分组 请参阅下面的代码 using System using System Collections Generic using System Linq namespace C
  • Rails & Devise:设计未显示在 Rails 控制台中的特定列

    我正在尝试在我的用户模型上使用 Devise 但是当我进入 Rails 控制台并尝试时User new我只得到 irb main 002 0 gt User new gt
  • 反思获取代表信息

    通过执行以下命令我可以获得有关方法的信息 Type t typeof someType MemberInfo mInfo t GetMethods 如何获取有关类型内声明的委托的信息 Call Type GetNestedTypes获取嵌套
  • 升级 haskell 堆栈使用的 ghc 版本

    我正在尝试将我为 haskell stack 安装的 ghc 版本从版本 8 0 2 更新到 8 2 1 但似乎我丢失了一些东西 user localhost stack resolver ghc 8 2 1 setup stack wil
  • 计算矩阵中一点与所有其他点之间的距离

    我是Python新手 我需要实现一个聚类算法 为此 我需要计算给定输入数据之间的距离 考虑以下输入数据 1 2 8 7 4 2 9 1 7 0 1 5 6 4 3 我希望在这里实现的是 我想计算 1 2 8 与所有其他点的距离 并找到距离最
  • 线程安全对象 - 静态还是非静态?

    我最近在接受采访 技术人员问我如何使应用程序线程安全 嗯 解释完之后lock 正确的是 他说让物体保持静态并不是一个好主意 private static readonly object syncLock new object 他声称原因是静
  • 使用 HTML5 数据属性的 CSS 值[重复]

    这个问题在这里已经有答案了 width attr data width 我想知道是否有任何方法可以使用 HTML5 设置 css 值data 属性的设置方式与设置 css 的方式相同content 目前它不起作用 HTML div div
  • JAXB 无法生成 XBRL 的 Java 类

    我正在尝试为 XBRL 中定义的类型生成 Java 类 我的构建过程基于 Maven 2 以下是我的试验 我只粘贴build部分 它依赖于一些属性 package是我的目标包的名称 catalog是目录的路径和文件名 因为我没有互联网连接
  • 以编程方式设置网页的默认缩放?

    是否可以在网站上设置默认缩放级别 例如 我可以编码吗my site比如当用户打开它时它会缩放到 125 我的网站主体有这个代码 如何把这个缩放代码放进去 Add zoom 125 到身体风格 body color 536482 backgr
  • 为什么人们将他们的文件命名为index.html?

    我看到很多人在他们的 HTML 文件中使用这个文件名 我想知道为什么 我对 HTML 有点陌生 我还没有学到太多东西 但是当我命名我的 HTML 文件时 我可以随意命名它们 当我搜索 HTML 示例时 我发现它们将其命名为index htm
  • paypal自适应支付IPN中的自定义字段

    我在我的网站中实施了自适应支付 首先 这是标准付款 所以我通过了custom表单中的参数 我在 IPN 中获取它 但我无法找到如何在自适应支付中传递此参数 Thanks Edit 根据 jackvsworld 在 PayPal Adapti
  • 亚音速快死了吗

    我对使用 SubSonic 很感兴趣 我已经下载了它并且到目前为止我很喜欢它 但是看看 github 和 googlegroups 上的活动 它似乎不是很活跃 看起来很像一个即将消亡的项目 tekpub 上没有关于它的视频 而且 Rob 这
  • 在 Kivy 中显示 numpy 数组

    首先 我对 kivy 完全陌生 所以我有点挣扎 我正在尝试在 kivy 窗口中显示 numpy 数组 到目前为止 我发现这应该使用纹理类 http kivy org docs api kivy graphics texture html 由
  • 在 pytorch 上使用 MC Dropout 测量不确定性

    我正在尝试在 Pytorch 上使用 Mc Dropout 实现贝叶斯 CNN 主要思想是 通过在测试时应用 dropout 并运行多次前向传递 您可以从各种不同的模型中获得预测 我发现了 Mc Dropout 的应用 但我真的不明白他们是