计算测试集每个类别的熵以测量 pytorch 上的不确定性

2024-01-11

我正在尝试使用 MC Dropout 方法和此链接中提出的解决方案来计算图像分类任务的数据集的每一类的熵,以测量 pytorch 上的不确定性
在 pytorch 上使用 MC Dropout 测量不确定性 https://stackoverflow.com/questions/63285197/measuring-uncertainty-using-mc-dropout-on-pytorch

首先,我计算了不同前向传递中每批每个类的平均值 (class_mean_batch),然后计算了所有测试加载程序 (classes_mean),然后进行了一些转换以获取 (total_mean) 以使用它来计算熵,如下面的代码所示

def mcdropout_test(batch_size,n_classes,model,T):

    #set non-dropout layers to eval mode
    model.eval()

    #set dropout layers to train mode
    enable_dropout(model)
    
    softmax = nn.Softmax(dim=1)
    classes_mean = []
       
    for images,labels in testloader:
        images = images.to(device)
        labels = labels.to(device)
        classes_mean_batch = []
            
        with torch.no_grad():
          output_list = []
          
          #getting outputs for T forward passes
          for i in range(T):
            output = model(images)
            output = softmax(output)
            output_list.append(torch.unsqueeze(output, 0))
            
        
        concat_output = torch.cat(output_list,0)
        
        # getting mean of each class per batch across multiple MCD forward passes
        for i in range (n_classes):
          mean = torch.mean(concat_output[:, : , i])
          classes_mean_batch.append(mean)
        
        # getting mean of each class for the testloader
        classes_mean.append(torch.stack(classes_mean_batch))
        

    total_mean = []
    concat_classes_mean = torch.stack(classes_mean)

    for i in range (n_classes):
      concat_classes = concat_classes_mean[: , i]
      total_mean.append(concat_classes)


    total_mean = torch.stack(total_mean)
    total_mean = np.asarray(total_mean.cpu())
 
    epsilon = sys.float_info.min
    # Calculating entropy across multiple MCD forward passes 
    entropy = (- np.sum(total_mean*np.log(total_mean + epsilon), axis=-1)).tolist()
    for i in range(n_classes):
      print(f'The uncertainty of class {i+1} is {entropy[i]:.4f}')
    
    

任何人都可以纠正或确认我用来计算每个类的熵的实现。


None

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

计算测试集每个类别的熵以测量 pytorch 上的不确定性 的相关文章

  • 如何平衡 GAN 中生成器和判别器的性能?

    这是我第一次使用 GAN 我面临着判别器多次优于生成器的问题 我正在尝试重现PA模型来自本文 http openaccess thecvf com content ICCV 2017 papers Sajjadi EnhanceNet Si
  • 为什么 RNN 需要两个偏置向量?

    In Pytorch RNN 实现 http pytorch org docs master nn html highlight rnn torch nn RNN 有两个偏差 b ih and b hh 为什么是这样 它与使用一种偏差有什么
  • torch.mm、torch.matmul 和 torch.mul 有什么区别?

    阅读完 pytorch 文档后 我仍然需要帮助来理解之间的区别torch mm torch matmul and torch mul 由于我不完全理解它们 所以我无法简明地解释这一点 B torch tensor 1 1207 0 3137
  • Scikit-learn Ridge 分类器:提取类概率

    我目前正在使用 sklearn 的 Ridge 分类器 并且希望将此分类器与 sklearn 和其他库中的分类器集成 为了做到这一点 理想的做法是提取给定输入属于类列表中每个类的概率 目前 我正在使用 model decision func
  • LSTM 错误:AttributeError:“tuple”对象没有属性“dim”

    我有以下代码 import torch import torch nn as nn model nn Sequential nn LSTM 300 300 nn Linear 300 100 nn ReLU nn Linear 300 7
  • PyTorch 中复数矩阵的行列式

    有没有办法在 PyTorch 中计算复矩阵的行列式 torch det未针对 ComplexFloat 实现 不幸的是 目前尚未实施 一种方法是实现您自己的版本或简单地使用np linalg det 这是一个简短的函数 它计算我使用 LU
  • pytorch 中的 keras.layers.Masking 相当于什么?

    我有时间序列序列 我需要通过将零填充到矩阵中并在 keras 中使用 keras layers Masking 来将序列的长度固定为一个数字 我可以忽略这些填充的零以进行进一步的计算 我想知道它怎么可能在 Pytorch 中完成 要么我需要
  • 仅正样本和未标记数据集的二元半监督分类

    我的数据由评论组成 保存在文件中 其中很少被标记为正面 我想使用半监督和PU http www cs uic edu liub publications ICDM 03 pdf分类将这些评论分为正面和负面类别 我想知道 python sci
  • SPMD 与 Parfor

    我对 matlab 中的并行计算很陌生 我有一个创建分类器 SVM 的函数 我想用几个数据集来测试它 我有一个 2 核工作站 所以我想并行运行测试 有人可以向我解释一下以下之间的区别 dataset array dataset1 datas
  • 如何计算 CNN 第一个线性层的维度

    目前 我正在使用 CNN 其中附加了一个完全连接的层 并且我正在使用尺寸为 32x32 的 3 通道图像 我想知道是否有一个一致的公式可以用来计算第一个线性层的输入尺寸和最后一个卷积 最大池层的输入 我希望能够计算第一个线性层的尺寸 仅给出
  • Pytorch Tensor 如何获取元素索引? [复制]

    这个问题在这里已经有答案了 我有 2 个名为x and list它们的定义如下 x torch tensor 3 list torch tensor 1 2 3 4 5 现在我想获取元素的索引x from list 预期输出是一个整数 2
  • 如何解释R中SVM的预测结果?

    我是 R 新手 我正在使用e1071R 中的 SVM 分类包 我使用了以下代码 data lt loadNumerical model lt svm data ncol data data ncol data gamma 10 print
  • 在 Pytorch 中估计高斯模型的混合

    我实际上想估计一个以高斯混合作为基本分布的归一化流 所以我有点被火炬困住了 但是 您可以通过估计 torch 中高斯模型的混合来在代码中重现我的错误 我的代码如下 import numpy as np import matplotlib p
  • 如何使用 pytorch 同时迭代两个数据加载器?

    我正在尝试实现一个接收两张图像的暹罗网络 我加载这些图像并创建两个单独的数据加载器 在我的循环中 我想同时遍历两个数据加载器 以便我可以在两个图像上训练网络 for i data in enumerate zip dataloaders1
  • Fine-Tuning DistilBertForSequenceClassification:不是学习,为什么loss没有变化?权重没有更新?

    我对 PyTorch 和 Huggingface transformers 比较陌生 并对此尝试了 DistillBertForSequenceClassificationKaggle 数据集 https www kaggle com c
  • TensorFlow 相当于 PyTorch 的 Transforms.Normalize()

    我正在尝试推断最初在 PyTorch 中构建的 TFLite 模型 我一直在遵循PyTorch 实现 https github com leoxiaobin deep high resolution net pytorch blob 1ee
  • ValueError:使用火炬张量时需要解压的值太多

    对于神经网络项目 我使用 Pytorch 并使用 EMNIST 数据集 已经给出的代码加载到数据集中 train dataset dsets MNIST root data train True transform transforms T
  • Weka J48 分类器:无法处理数字类?

    我现在尝试使用 Weka 在我的训练数据上构建 J48 C4 5 分类器模型 首先我这样做 这似乎很顺利 java Xmx10G cp weka weka jar weka core converters TextDirectoryLoad
  • 在requirements.txt中包含.whl安装

    如何将其包含在requirements txt 文件中 对于Linux pip install http download pytorch org whl cu75 torch 0 1 12 post2 cp27 none linux x8
  • Weka - 探索者和实验者结果之间的差异

    我只是想知道为什么正确分类的百分比与 Weka 的探索者和实验者方面不同 我已检查以确保使用 10 交叉折叠验证以及所有其他参数 有人有主意吗 Thanks 当我在 Weka 邮件列表上给马克 霍尔 Mark Hall 发送电子邮件时 我已

随机推荐

  • ASP.NET MVC 中部分视图的正确位置是什么?

    有人会确认 ASP NET MVC 中部分视图的最佳位置吗 我的想法是 如果这是一个将在许多地方使用的全球视图 那么就可以共享 如果它是视图的一部分 并被包装到部分视图中以使代码阅读更容易 那么它应该进入 Views Controller
  • 理解从先序遍历构造树的伪代码

    我需要做一些类似于这个问题中描述的任务 根据给定的前序遍历构造树 https stackoverflow com questions 4908545 construct tree with pre order traversal given
  • 如何使用 WebGL 和 GLSL 在 J/s 文件中运行 Shadertoy 中的着色器?

    我是着色器编程新手 我想使用 WebGL 和 GLSL 创建一个着色器 为了了解它的实际工作原理 我想测试 Shadertoy 的着色器 但是如何从 Shadertoy 获取代码并实际在 J S 文件中运行它呢 您是否只需将 Shadert
  • 以编程方式从“p”和“q”生成“d”(RSA)

    我有两个号码 p and q 我知道我能得到phi p 1 q 1 然后ed 1 mod phi 但我不确定我明白这意味着什么 我写了一些Python p NUM q NUM e NUM phi p 1 q 1 d 1 phi float
  • 回显所有 json_encoded 行

    我正在尝试循环访问数据库并输出与连接表匹配的所有行 我有以下两个表 任务项目存储与项目相关的所有数据 加入任务项存储玩家 ID 和玩家拥有的物品之间的关联 JS 传入查询表所需的所有信息 getJSON phpscripts php pla
  • 尝试使用 Protocol Buffers - Google 的数据交换格式时,goog 未定义错误

    我正在尝试使用 Protocol Buffers Google 的数据交换格式https github com google protobuf tree master js https github com google protobuf
  • plpgsql For循环中的Select语句创建多个CSV文件

    我想重复以下查询 8760 次 将一年中每个小时的 2 替换为 1 到 8760 我们的想法是每小时创建一个单独的 CSV 文件以进行进一步处理 COPY SELECT FROM public completedsolarirad2012
  • ZF2 toRoute 与 https

    我们正在使用 Zend Framework 2 并使用toRoute在我们的控制器中重定向到不同的位置 例如 this gt redirect gt toRoute home 无论如何 是否可以使用此方法或替代方法将其重定向到 https
  • 如何嵌入文件以供以后解析执行使用

    我本质上是想浏览一个 html 文件的文件夹 我想将它们嵌入到二进制文件中 并能够根据请求解析它们以用于模板执行目的 如果我措辞不当 请原谅 任何想法 提示 技巧或更好的方法来实现这一点都非常感谢 Template Files type T
  • Base64 java 中的文件编码失败

    我有这个类来编码和解码文件 当我使用 txt 文件运行该类时 结果成功 但是 当我使用 jpg 或 doc 运行代码时 我无法打开该文件 或者它不等于原始文件 我不知道为什么会发生这种情况 我修改了这个类http myjeeva com c
  • 在 Node 中通过“_id”搜索 MongoDB 条目的正确方法

    我在用着MongoDb 作为 的一部分MongoJS in Node 这是 MongoJS 的文档 https github com gett mongojs 我正在尝试根据条目在 Node 内进行调用 id场地 使用香草时MongoDB从
  • 如何改变gvim中的左边距

    我在 XP 上有 gvim 7 3 我的问题是 当我编辑文件并关闭行号时 文本距离左窗口边距太近 我不想添加前导空白 我想增加边距 当我有行号时 我不喜欢 左窗口边框和行号之间有足够的空间 行号和文本之间有足够的空间 但是当行号关闭时就没有
  • 如何获取隐藏数据库的数据库模式?

    我的客户是一家牙科诊所 购买了一款诊所管理软件 该软件安装在他们的本地服务器上 包括患者数据库 时间表和各种医疗记录 现在他们希望我为他们编写一些他们的软件包中未提供的实用程序 为此我需要能够查询该数据库 我尝试致电软件制造商的技术支持 帕
  • Azure AD - 仅应用程序令牌中缺少角色声明

    当我尝试从 Nodejs 后端服务器获取仅应用程序令牌时 如下所述here https learn microsoft com en us graph auth v2 service 4 get an access token 有时role
  • 如何在 Vim 中创建文件夹(优先使用 NERDTree)?

    我知道如何创建重命名 删除和移动文件NERDTree 只需按m then either a d or m 但我不知道如何创建文件夹 有谁知道如何做到这一点NERDTree 或者只是以 vim 的原生 方式 You use m a并放置一个尾
  • ##+#. 是什么意思?是什么意思?

    谷歌几乎是不可能的 因此我的理解仅限于阅读 slime 源代码的上下文线索 也许它是 common lisp 中对象系统的一部分 类似 自己 的东西 片段 cond swank backend sbcl with new stepper p
  • 基于列子集修剪 NA - 更优雅的解决方案?

    stackoverflow 社区的新年难题 通过阅读过去的帖子和答案很有帮助 这是我的第一个问题 我找到了解决方法 但我想知道是否可以建议其他方法 解决方案 我正在尝试从大型文件中删除尾随的 NAdata frame 但这些 NA 只出现在
  • jQuery UI DatePicker - 禁用除每月最后一天之外的所有日期

    我正在尝试使用 jquery UI 日期选择器来显示仅可选择该月最后一天的日历 我已成功使用 beforeShowDay 事件禁用一周中的几天 但不确定如何使用它来禁用除该月最后一天之外的所有内容 beforeShowDay 会为日历上显示
  • Android - 仅垂直布局

    如何确保我的应用程序仅适用于垂直布局 我努力了android screenOrientation portrait 但这似乎并不能解决问题 您需要添加到所有活动中 而不仅仅是一项活动 我认为您了解设置是每个应用程序范围内的 但事实并非如此
  • 计算测试集每个类别的熵以测量 pytorch 上的不确定性

    我正在尝试使用 MC Dropout 方法和此链接中提出的解决方案来计算图像分类任务的数据集的每一类的熵 以测量 pytorch 上的不确定性 在 pytorch 上使用 MC Dropout 测量不确定性 https stackoverf