可以估计不确定性的神经网络:SDE-Net

2023-10-27

作者丨段易通@知乎

来源丨https://zhuanlan.zhihu.com/p/234834189

编辑丨极市平台

随着深度学习技术的不断发展,DNN模型的预测能力变得越来越强,然而在一些情况下这却并不是我们想要的,比如说给模型一个与训练集完全不相关的测试样本,我们希望模型能够承认自己的“无知”,而不是强行给出一个预测结果,这种能力对于自动驾驶或者医疗诊断等重视风险的任务是至关重要的。因此,为了达到这个目的,我们的模型需要具有量化不确定性的能力,对于那些它没有把握的样本,模型应该给出较高的不确定性,这样就能指导我们更好地利用模型的预测结果。

之前我介绍过可以预测概率分布的DeepAR模型,其实这次介绍的SDE-Net与它的目标是一致的,都是令模型在预测的基础上还能够度量预测结果的不确定性,不过SDE-Net的实现这个目标的思路与DeepAR不同,下面就来具体介绍。

段易通:概率自回归预测——DeepAR模型浅析

https://zhuanlan.zhihu.com/p/201030350

不确定性

上文中已经提到,我们的目的是要量化不确定性,那么我们当然要先知道是什么导致了模型的不确定性、并且要了解不确定性产生的来源有哪些,论文中认为模型预测的不确定性来自于两个方面:

  • aleatoric uncertainty:来自于任务本身所固有的自然随机性(比如说label噪声等)

  • epistemic uncertainty:由于缺乏训练数据所导致的,模型对于训练数据分布之外的样本是无知的

对于aleatoric uncertainty,它是由任务本身天然决定的,可以设想一个所有标签都是噪声的训练集,用这样的数据集训练出来的模型,它的预测结果显然是不可信的,即不确定度很大;而对于epistemic uncertainty,它是由于模型的认知不足造成的,在面对训练集分布之外的数据时,模型的预测结果会具有较高的不确定度。

下图是对两种不确定度对模型预测结果影响的示意图,我们这里用到的是概率模型(比如说DeepAR或者后面要说的SDE-Net,输出的是一个随机变量而不是一个定值,因此通过模型得到的其实是一个概率分布)。其中左边的simplex corner代表分类任务(三分类),右边的二维坐标代表回归任务,其中横轴代表predictive mean 、纵轴是predictive variance (不太明白这里 的含义,求解惑~)。

SDE量化不确定性

我们知道,神经网络尤其是ResNet可以看做是由常微分方程(ODE)控制的一个动力学系统(具体可以看ResNet的相关资料,或者我的这篇文章,https://zhuanlan.zhihu.com/p/92254686),相邻层之间的输入输出关系为:

其中 是第t层的隐藏状态,如果我们令 ,上式可以写做: ,如果我们让 ,那么就有:

ResNet其实可以看做是离散化的动力学系统,不过控制方程是一个ODE,所以神经网络得到是只是一个确定性结果。为了让模型可以估计不确定性,我们自然就想到是不是可以改用一个随机微分方程(SDE)来控制dynamic呢?这就是论文的核心思想,其实也很简单,就是在原有ODE的基础上再加上一个随机项,这里采用的是标准布朗运动,那么dynamic的形式就变为:

这里 就是标准布朗运动,可以看出函数 控制的就是dynamic的波动,我们就用它来代表模型对于epistemic uncertainty的估计,下图是 的大小对于dynamic的影响。

模型构造

这样一来,我们就利用SDE来描述了隐层状态的dynamic,并通过随机过程的方差来量化估计epistemic uncertainty。为了使模型具有良好的预测精度和可靠的不确定性估计,论文的SDE-Net模型用了两个单独的神经网络来表示分别dynamic的漂移扩散,如下图所示:

可以看出,对于分布内的测试样本,diffusion net计算出的不确定度很小,因此drift net占主导地位,我们可以获得置信度很高的预测结果;但是对于分布外的样本,计算出的不确定度很大,因此diffusion net占主导地位,得到的结果几乎就是随机分布的结果。

对于SDE-Net的两个神经网络,论文采用了如下的目标函数来进行训练

其中 是任务的损失函数,T是随机过程的末时刻(即网络的输出层), 分别是训练集分布内与分布外(OOD)的数据,OOD数据可以通过给原数据加噪声做变换的方式获得,也可以直接用另一个任务目标不相关的数据集。

可以看出,目标函数一共分为三部分,前两项是关于分布内样本的目标函数,其目的是保证在常规的损失最小化的基础上,还要使得这些样本的不确定度估计较小,后一项是关于OOD数据的,对于这些样本,我们不关注其loss的大小,而是只令模型对于OOD样本的不确定度增加

需要注意的是,这里SDE-Net中每一层的参数都是共享的,而且扩散项的方差仅由起始点x0决定,这样可以使模型训练起来更容易。

训练好模型之后,我们可以通过多次采样的方法,来得到多个输出 ,这种采样计算的思路与传统的集成方法具有相似之处,但是传统方法需要训练多个模型,而SDE-Net只需训练一次即可通过布朗运动的随机性得到多个输出样本,从而大大减小了训练成本。

理论分析

论文还对模型做了一些理论分析,内容不多,就直接放原文了

模型训练

考虑到模型的层数是有限的,因此我们需要将SDE离散化,形式如下:

其中时间区间为 ,模型一共有N层,因此

总的来看,SDE-Net的训练算法如下:

简单概括一下这个算法,首先我们从分布内采样出一批训练数据,然后通过一个降采样层得到输入 ,接着就根据SDE-Net来控制隐层状态的dynamic,并在最后接一个全连接层得到模型的输出 ,这样我们就可以通过计算loss的梯度来更新降采样层、drift net以及全连接层的参数;另外,我们还要从分布外采样出一批数据(OOD数据),然后根据分布内外的数据分别对diffusion net的参数做梯度下降和梯度上升。

实验

论文的实验研究了不确定性估计在model robustness和label efficiency中的作用,实验采用的对比模型有:Threshold、MC-dropout、DeepEnsemble、Prior network(PN)、Bayes by Backpropagation (BBP)、preconditioned Stochastic gradient Langevin dynamics(p-SGLD);其中PN和SDE-Net需要额外的OOD数据,这里通过对原有的数据样本上加上高斯噪声来进行构造、或者直接采用另一个数据集,至于其它的一些具体设定可以看论文的实验和补充材料部分。

1.OOD检测

就像在文章开头提到的,让模型有“自知之明”是非常重要的,因此第一个任务就是评估模型识别OOD样本的能力,实验中使用的metric如下所示,这些metrics都是值越大越好:

  1. True negative rate (TNR) at 95% true positive rate (TPR)

  2. Area under the receiver operating characteristic curve (AUROC)

  3. Area under the precision-recall curve (AUPR)

  4. Detection accuracy

实验结果如下所示:

分类任务

回归任务

从表中可以看出,SDE-Net的性能基本超越了其它所有模型。另外,下图是提高模型层数或者ensemble数量对OOD检测的影响,可以看出SDE-Net不需要像一些其它模型那样必须大量堆叠才能达到最优性能。

2.误分类检测

如果模型预测的不确定性很大,那么就说明模型对预测结果是没有把握的,样本可能被分类错误。因此这个任务的目的是利用预测的不确定性来找出模型分类错误的样本,其结果如下:

虽然P-SGLD的效果也不错,不过它的计算成本很高,因此在实际情况中SDE-Net可能会是一个更好的选择。

3.对抗样本检测

我们知道,在样本中加入一些很小的对抗扰动后,正常的DNNs会变得非常容易出错,因此这个任务的目标就是从样本集中找出对抗样本,这里采用了两种对抗攻击方式Fast Gradient-Sign Method (FGSM)和Projected Gradient Descent (PGD)来产生对抗样本,实验结果如下:

4.主动学习

假设一开始样本集里有标注的样本很少,模型需要自己挑一些信息量大的样本出来让专家进行标注,这就是主动学习的思想。直观上来看,挑选信息量大的样本可以显著减少用于模型训练的数据量,而信息量小的样本会增加训练成本、甚至会导致过拟合。最后一个任务就是关于主动学习的,论文设定acquisition function(不了解的同学可以学习一下相关知识点)的形式为:

该式的意思就是让模型选择那些具有较高的epistemic uncertainty但数据具有较低的low aleatoric noise的样本,结果如下,可以看出SDE-Net选择的样本都是信息量比较大样本,因此RMSE下降的更快。

总结

ResNet可以对应为一个离散ODE,这篇文章受到该思路的启发,构建了一个可以被看做离散SDE的SDE-Net模型,模型由两个神经网络drift net和diffusion net构成,其中drift net与传统模型类似,是为了预测模型的输出结果,而diffusion net则用来估计预测的不确定性,估计出的不确定性可以应用于OOD样本检测、误分类检测、主动学习等多个任务,而可以估计不确定性的SDE-Net也更加适合于一些关注风险的实际应用领域。

个人感觉,论文中通过SDE来评估不确定性的想法很有意思,确实有一定的可取之处;不过模型为了训练diffusion net网络,专门构建了用于梯度上升的OOD数据集,这样的数据集无论怎么构建,都很难代表整个训练集以外的分布,因此不可避免地会引入一些bias,而这就可能会影响模型对于不确定度的估计。

参考文献

[1] SDE-Net: Equipping Deep Neural Networks with Uncertainty Estimates

https://arxiv.org/abs/2008.10546

觉得有用麻烦给个在看啦~  

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

可以估计不确定性的神经网络:SDE-Net 的相关文章

  • Matplotlib 标准化颜色条 (Python)

    我正在尝试使用 matplotlib 当然还有 numpy 绘制轮廓图 它有效 它绘制了它应该绘制的内容 但不幸的是我无法设置颜色条范围 问题是我有很多图 并且需要所有图都具有相同的颜色条 相同的最小值和最大值 相同的颜色 我复制并粘贴了在
  • 如何屏蔽 PyTorch 权重参数中的权重?

    我正在尝试在 PyTorch 中屏蔽 强制为零 特定权重值 我试图掩盖的权重是这样定义的def init class LSTM MASK nn Module def init self options inp dim super LSTM
  • 打印 scrapy 请求的“响应”

    我正在尝试学习 scrapy 在遵循教程的同时 我正在尝试进行细微的调整 我想简单地从请求中获取响应内容 然后我会将响应传递到教程代码中 但我无法发出请求并获取响应内容 建议就好 from scrapy http import Respon
  • 为什么我不能导入 geopandas?

    我唯一的代码行是 import geopandas 它给了我错误 OSError Could not find libspatialindex c library file 以前有人遇到过这个吗 我的脚本运行得很好 直到出现此错误 请注意
  • 如何在 Ubuntu 上安装 Python 模块

    我刚刚用Python写了一个函数 然后 我想将其做成模块并安装在我的 Ubuntu 11 04 上 这就是我所做的 创建 setup py 和 function py 文件 使用 Python2 7 setup py sdist 构建分发文
  • 如何更改充当按钮的范围的文本

    我正在为自定义 Web 应用程序编写自动化测试 我遇到了无法更改跨度文本的问题 我尝试过使用 driver execute script 但没有运气 如果我更好地了解 javascript 这确实会有帮助 据我所知 您无法单击跨度 并且列表
  • 如何用 python 和 sympy 解决多元不等式?

    我对使用 python 和 Sympy 还很陌生 并且遇到了使用 sympy 解决多元不等式的问题 假设我的文件中有很多函数 如下所示 cst sqrt x 2 cst exp sqrt cst x 1 4 log log sqrt cst
  • Dask DataFrame 的逐行处理

    我需要处理一个大文件并更改一些值 我想做这样的事情 for index row in dataFrame iterrows foo doSomeStuffWith row lol doOtherStuffWith row dataFrame
  • 类属性在功能上依赖于其他类属性

    我正在尝试使用静态类属性来定义另一个静态类属性 我认为可以通过以下代码来实现 f lambda s s 1 class A foo foo bar f A foo 然而 这导致NameError name A is not defined
  • 当x轴不连续时如何删除冗余日期时间 pandas DatetimeIndex

    我想绘制一个 pandas 系列 其索引是无数的 DatatimeIndex 我的代码如下 import matplotlib dates as mdates index pd DatetimeIndex 2000 01 01 00 00
  • VSCode pytest 测试发现失败

    Pytest 测试发现失败 用户界面指出 Test discovery error please check the configuration settings for the tests 输出窗口显示 Test Discovery fa
  • 如何在 Windows 上使用 Python 3.6 来安装 Python 2.7

    我想问一下如何使用pip install对于 Python 2 7 当我之前安装并使用 Python 3 6 时 我现在必须使用 Windows 上的 Python 版本 pip install 继续安装 Python 3 6 我需要使用以
  • 反加入熊猫

    我有两个表 我想附加它们 以便仅保留表 A 中的所有数据 并且仅在其键唯一时添加表 B 中的数据 键值在表 A 和 B 中是唯一的 但在某些情况下键将出现在表 A 和 B 中 我认为执行此操作的方法将涉及某种过滤联接 反联接 以获取表 B
  • Python While 循环,and (&) 运算符不起作用

    我正在努力寻找最大公因数 我写了一个糟糕的 运算密集型 算法 它将较低的值减一 使用 检查它是否均匀地划分了分子和分母 如果是 则退出程序 但是 我的 while 循环没有使用 and 运算符 因此一旦分子可整除 它就会停止 即使它不是正确
  • 我可以使用 dask 创建 multivariate_normal 矩阵吗?

    有点相关这个帖子 https stackoverflow com questions 52337612 random multivariate normal on a dask array 我正在尝试复制multivariate norma
  • Python int 太大,无法放入 SQLite

    我收到错误 OverflowError Python int 太大 无法转换为 SQLite INTEGER 来自以下代码块 该文件约25GB 因此必须分部分读取 length 6128765 Works on partitions of
  • WindowsError:[错误 5] 访问被拒绝

    我一直在尝试终止一个进程 但我的所有选项都给出了 Windows 访问被拒绝错误 我通过以下方式打开进程 一个python脚本 test subprocess Popen sys executable testsc py 我想杀死那个进程
  • 如何在单独的文件中使用 FastAPI Depends 作为端点/路由?

    我在单独的文件中定义了一个 Websocket 端点 例如 from starlette endpoints import WebSocketEndpoint from connection service import Connectio
  • 使用 Keras 和 fit_generator 绘制 TensorBoard 分布和直方图

    我正在使用 Keras 使用 fit generator 函数训练 CNN 这似乎是一个已知问题 https github com fchollet keras issues 3358TensorBoard 在此设置中不显示直方图和分布 有
  • 从时间序列生成日期特征

    我有一个数据框 其中包含如下列 Date temp data holiday day 01 01 2000 10000 0 1 02 01 2000 0 1 2 03 01 2000 2000 0 3 30 01 2000 200 0 30

随机推荐

  • 登录文件服务器 换用户,win7切换用户访问共享、切换用户账户访问共享、共享文件夹切换用户的方法...

    现在 很多单位都有文件服务器 通常会对局域网设置共享 并且为了方便访问 通常就会设置 记住访问密码 这样以后访问共享文件时就不需要每次都输入密码了 虽然方便了共享文件访问 但是 当用户想切换访问共享文件的用户时 就比较麻烦 具体如何操作呢
  • 如何防御Java中的SQL注入

    SQL注入是应用程序遭受的最常见的攻击类型之一 鉴于其常见性及潜在的破坏性 需要在了解原理的基础上探讨如何保护应用程序免受其害 什么是SQL注入 SQL注入 也称为SQLi 是指攻击者成功篡改Web应用输入 并在该应用上执行任意SQL查询
  • C、C++、C#、python、java编程—程序结构

    C资料 菜鸟教程 C语言中文网 C community C 资料 菜鸟教程 cplusplus C community C 资料 菜鸟教程 microsoftC 文档 python资料 菜鸟教程 python标准库 Java资料 菜鸟教程
  • SpringMVC 接口版本管理/IP访问控制/ANT打包发布到LINUX

    前言 最近懒了很多也忙了很多 好多东西没办法分享到blog 因为知识点比较杂 没有时间整理 写这篇文章主要原因是 因为遇到了同样的问题 但是网上没有很好的解决方案于是自己解决后 分享给大家 源码在csdn download 文章尾部可以下载
  • 文字识别方法全面整理

    来源 https zhuanlan zhihu com p 65707543 作者 白裳 本文来自知乎专栏 仅供学习参考使用 著作权归作者所有 如有侵权 请私信删除 文字识别也是目前CV的主要研究方向之一 本文主要总结目前文字识别方向相关内
  • ubuntu 18.04 server安装(详细安装教程)

    前期准备 准备一个创建一个空文件夹 目的用于装虚拟机 个人习惯 2 准备好ubuntu 18 04 iso 服务版本镜像文件 接下来开始安装叭 1 打开虚拟机VMware workstations 这里用的是16pro 点击 主页 创建新的
  • 因果推断——图的三种基本结构

    因果推断入门笔记 V Structure Chain链状 Fork叉状 Collider碰撞 1 Chain 链状结构 X gt Y gt Z X和Y相关 Y和Z相关 X和Z相关 但是 如果condition在Y上 则X和Z是统计独立的 这
  • 三相桥式全控整流电路matlab仿真实验,三相全控桥式整流电路仿真实验

    三相全控桥式整流电路仿真实验 7页 本资源提供全文预览 点击全文预览即可全文预览 如果喜欢文档就下载吧 查找使用更方便哦 19 90 积分 实验九 电力电子电路的仿真实验 三相全控桥式整流电路仿真实验 一 实验目的 1 掌握MATLAB仿真
  • 故事分享

    一 Java是兴趣所在 L同学坦言说自己喜欢python这门语言 觉得它很有魅力 他说自己对互联网感兴趣 平时接触很多 自己也有尝试自学 看了很多教学视频和资料 然后他更加确定了自己对python的喜欢 他还给自己设置了一个小目标 独自搭建
  • 域名绑定Github个人博客

    首先自吹一波 我个人的博客网址 我的个人博客 1 个人博客搭建 基础的建站工作以下一套视频足以KO 底部音乐栏可以研究一下帮助文档 帮助文档其实非常的重要 很多问题全都在最新版本的帮助文档里面 之前查了网上很多答案都不太对 最后研究了一下帮
  • 最详细的MySQL安装、卸载

    MySQL是想在最主流的关系型数据库 所以作为一名 伟大 的程序猿 你的电脑上是必须要有的 相比较而言安装MySQL数据库还是很简单的 类似于傻瓜式安装 卸载 相对麻烦一些 需要手动删除一些文件 当然要仔细一些 演示版本MySQL 5 7
  • uniapp组件库总结笔记

    uView ui uView 2 0 全面兼容 nvue 的 uni app 生态框架 uni app UI 框架 优点 整体样式风格不错 缺点 不支持vue3 可以使用社区维护的uview plus uview plus 3 0 全面兼容
  • 蓝桥杯2015年第六届真题-机器人塔

    题目 题目链接 题解 DFS 二进制枚举 经典dfs之一 好像比较经典的那个同型dfs题叫 符号三角形 可以看出上面一行的安排方式均由下面一行的安排方式决定 因此我们只要定好最后一行 那么上面的安排方式均可以由下行推出 且最后一行固定则整个
  • 黑马实战项目瑞吉外卖的总结

    文章目录 一 瑞吉外卖项目总结 1 后端Controller层返回结果统一封装的R对象 2 定义静态资源映射关系 3 配置消息资源转换器 3 1 Reggie项目中遇到的问题 3 2 原理 3 3 解决方案 3 4 示例 4 Mybatis
  • python3的xpath_python3爬虫之xpath

    一 简介 XPath 是一门在 XML 文档中查找信息的语言 XPath 可用来在 XML 文档中对元素和属性进行遍历 XPath 是 W3C XSLT 标准的主要元素 并且 XQuery 和 XPointer 都构建于 XPath 表达之
  • 华为OD机试 - 最长连续子序列(Java)

    题目描述 有N个正整数组成的一个序列 给定整数sum 求长度最长的连续子序列 使他们的和等于sum 返回此子序列的长度 如果没有满足要求的序列 返回 1 输入描述 第一行输入是 N个正整数组成的一个序列 第二行输入是 给定整数sum 输出描
  • SpringCloud Alibaba Seata处理分布式事务

    文章目录 第一章 分布式事务问题 第二章 Seata简介 2 1 Seata是什么 2 2 Seata 整体工作流程 2 3 Seata AT 模式 2 3 1 AT 模式的前提 2 3 2 AT 模式的工作机制 2 4 下载 第三章 Se
  • react 事件绑定this指向

    一 使用class的实例方法 class Hello extends React Component onIncrement gt this setState count this state count 1 二 箭头函数
  • 如何透过上层div点击下层的元素解决方法

    如何透过上层div点击下层的元素解决方法 参考文章 1 如何透过上层div点击下层的元素解决方法 2 https www cnblogs com wei dong p 9928566 html 备忘一下
  • 可以估计不确定性的神经网络:SDE-Net

    作者丨段易通 知乎 来源丨https zhuanlan zhihu com p 234834189 编辑丨极市平台 随着深度学习技术的不断发展 DNN模型的预测能力变得越来越强 然而在一些情况下这却并不是我们想要的 比如说给模型一个与训练集