【机器学习系列】变分推断第三讲:基于随机梯度上升法SGD的变分推断解法

2023-11-15


作者:CHEONG

公众号:AI机器学习与知识图谱

研究方向:自然语言处理与知识图谱

阅读本文之前,首先注意以下两点:

1. 机器学习系列文章常含有大量公式推导证明,为了更好理解,文章在最开始会给出本文的重要结论,方便最快速度理解本文核心。需要进一步了解推导细节可继续往后看。

2. 文中含有大量公式,若读者需要获取含公式原稿Word文档,可关注公众号【AI机器学习与知识图谱】后回复:变分推断第三讲,可添加微信号【17865190919】进学习交流群,加好友时备注来自CSDN。原创不易,转载请告知并注明出处!

本文将先对变分推断所要解决的问题进行分析,然后给出基于随机梯度上升法的变分推断解法。


一、本文结论

结论1: 变分推断的主要思想:在给定数据集 X X X下,问题是求后验概率 p p p,简单情况下后验概率 p p p可直接通过贝叶斯公式推导求出,但有些情况无法直接求解。因此变分推断想法是先假设另一个简单的概率分布 q q q,如高斯分布,通过优化 p p p q q q之间距离最小化,让概率分布 q q q逼近 p p p,这样就可以用概率分布 q q q近似表示后验概率 p p p

结论2: 基于随机梯度上升法主要思路就是对优化的目标函数 q ∗ = a r g m a x q E L B O q^*=argmax_qELBO q=argmaxqELBO求梯度的过程。最后使用MCMC采样的方式近似求出梯度,并且考虑到求解出梯度近似值的稳定性,使用了重参数化技巧Reparameterization Trick。在梯度求出之后便可使用迭代方式求出参数。


二、问题分析

在上一节详细介绍了变分推断所要解决的问题,下面我们首先重新明确优化的目标函数

在这里插入图片描述

其中:

在这里插入图片描述

为了表示方便,这里假设 q ( z ) q(z) q(z) z z z是关于参数 ϕ \phi ϕ的函数,这样优化函数就变成:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RYM5IxiA-1617961098543)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image024.png)]

在明确了优化函数后,接下来就通过随机梯度上升法求解,因此下面通过公式推导求求梯度。


三、公式推导

下面是 L ( ϕ ) L(\phi) L(ϕ)关于 ϕ \phi ϕ求梯度的过程:

在这里插入图片描述

这里为了方便表示,做以下赋值操作,用 A A A表示公式前半部分,用 B B B表示公式后半部分:

在这里插入图片描述

先看 B B B项,其中 l o g p θ ( x , z ) logp_\theta(x,z) logpθ(x,z) L ( ϕ ) L(\phi) L(ϕ)无关,所以有:

在这里插入图片描述

所以最终化简可得 B B B项为0,所以原始公式就只剩下 A A A项:

在这里插入图片描述

所以可以将上述式子写成 q ϕ q_\phi qϕ期望的形式如下:

在这里插入图片描述

这样我们就将 L ( ϕ ) L(\phi) L(ϕ)关于 ϕ \phi ϕ的梯度求出来了,是一个关于 q ϕ q_\phi qϕ的期望,就可以通过MCMC采样的方式把梯度具体表示出来,知道了梯度便可以利用梯度上升法进行求解了。首先通过MCMC采样法对 z z z进行采样, z l ∼ q ϕ , l = 1 , 2 , . . . , L z^l \sim q_{\phi}, l=1,2,...,L zlqϕ,l=1,2,...,L,得到 L ( ϕ ) L(\phi) L(ϕ)关于 ϕ \phi ϕ的梯度为:

在这里插入图片描述

知道梯度后便可以通过随机梯度上升法求解参数:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pbvehGlq-1617961098615)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image064.png)]

但这里存在一个问题,问题就出在:

在这里插入图片描述

q ϕ q_\phi qϕ很小时,如在0-1之间时,log函数的结果就会有很大的波动,会导致求出来的梯度值有很大的波动,这样MCMC采样时只有让 L L L取非常大时才能避免这种波动带来的高方差High Variance的问题,所以在实际使用时存在工程上的问题。解决的方案就是使用重参数化技巧来避免。


四、重参数化技巧

Reparameterization Trick,假设:

在这里插入图片描述

其中

在这里插入图片描述

则:

在这里插入图片描述

在使用重参数化技巧之后,我们再来求目标函数的梯度值:

在这里插入图片描述

这里将 q ϕ q_\phi qϕ可以利用重参数化技巧可以等价替换成 p ( ε ) p(\varepsilon) p(ε)

在这里插入图片描述

这里就是关于 p ( ε ) p(\varepsilon) p(ε)的期望了,所以对 ϕ \phi ϕ求梯度时就不会那么复杂

在这里插入图片描述

这里我们再使用MCMC采样法对 ε \varepsilon ε进行采样, ε l ∼ p ( ε ) , l = 1 , 2 , . . . , L \varepsilon^l \sim p(\varepsilon), l=1,2,...,L εlp(ε),l=1,2,...,L,最终可以得出目标函数的梯度值为:

在这里插入图片描述

得知梯度值之后,便可以使用随机梯度上升法对参数进行迭代求解:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-aMNGRNor-1617961098694)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image064.png)]

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【机器学习系列】变分推断第三讲:基于随机梯度上升法SGD的变分推断解法 的相关文章

  • 【Idea】创建包自动分层

    Idea 创建包自动分层 创建Maven 项目时 新建包使得Tomcat查找访问路径时更准确 但是有时包会不分层 如图1 然后我们使用图3的方法取消勾选 使得新建包时自动分层 如图2

随机推荐

  • 华为机试--简单题(一)

    HJ14 字符串排序 知识点 字符串 排序 描述 给定 n 个字符串 请对 n 个字符串按照字典序排列 数据范围 1 n 1000 字符串长度满足1 len 100 输入描述 输入第一行为一个正整数n 1 n 1000 下面n行为n个字符串
  • JAVA基于Slack实现异常日志报警

    一 功能介绍 在我们日常开发中 如果系统在线上环境上 发生异常 开发人员不能及时知晓来修复 可能会造成重大的损失 因此后端服务中加入异常报警的功能是十分必要的 而开发一个功能全面的异常报警服务 可能会花费较长的周期 今天给大家带来一种基于S
  • STM32F407基于RT-Thread连接ESP8266WiFi模块

    1 连接规则 STM32F4连接ESP8266无线通信 串口通信 首先 本次用到两个串口 我使用的是普中STM32F407 第一个串口为USART1 PA2 PA3 串口一 就是数据线连接单片机和电脑时用的口 串口三USART3 PB10T
  • 当面试官问你离职原因的时候怎么回答比较好?

    所有的前提都是建立在有一定的物质基础 当你的一日三餐都成了问题 都需要家庭支持的时候我希望你可以找一份工作 靠自己的本事养活自己从来不丢人 我觉得死要面子活受罪才是真的让你看不起 所有的建议都是建立在我们是普通打工人的前提 大佬是不需要建议
  • C++:二维数组--输出斐波那契数列的前20项

    大家都知道 在数学世界中有很多神奇的数列 斐波那契数列正是众多有规律的数列中的一种 该数列是意大利数学家列昂纳多 斐波那契发现的 他的基本规律是从第三项开始 每一项都等于前两项之和 第一项和第二项都是1 斐波那契数列如下图所示 1 1 2
  • http概述

    目录 概述 Web客户端和服务器 资源 http如何通信 媒体类型 URI 事务 方法 状态码 报文 连接 版本历程 Web的结构组件 代理 缓存 网关 隧道 Agent代理 爬虫 概述 HTTP是现代全球因特网中使用的公共语言 web浏览
  • 11个强大的Visual Studio调试小技巧

    伯乐在线注 我们在 程序员的那些事 微博上推荐了英文原文 感谢 halftone 被禁用了 的热心翻译 简介 调试是软件开发周期中很重要的一部分 它具有挑战性 同时也很让人疑惑和烦恼 总的来说 对于稍大一点的程序 调试是不可避免的 最近几年
  • 人工智能技术在软件开发中的应用

    人工智能技术的不断发展和成熟 使得它在软件开发中的应用越来越广泛 人工智能技术的应用可以帮助软件开发人员提高效率 降低成本 增强软件的功能性和可靠性 在本文中 我们将探讨人工智能技术在软件开发中的应用 并且提供一些实际案例 以帮助读者更好地
  • PHP 两个页面跳转,session会失效?

    两个页面都包含以下信息 可是 在A php中设置 SESSION go go 在B php中读出来的 SESSION
  • Pycharm远程连接服务器(实践笔记)

    Pycharm远程连接服务器 实践笔记 1 远程连接服务器 2 配置服务器上的环境 记录一下过程 防止自己隔一段时间又忘了 只有pycharm专业版才能远程连接 搞错了步骤1和2的顺序 然后代码一直不能实现同步 一下午配置了n次都不成功 不
  • java计算算术表达式

    直接上代码 String str 1 0 3 2 1 2 ScriptEngineManager manager new ScriptEngineManager ScriptEngine engine manager getEngineBy
  • Android 将布局文件放在服务器上,动态改变布局。

    转自 https blog csdn net chan1116 article details 44200405 目前在做项目时候有这样的需求 布局文件的控件类型大致相同 例如某布局文件由GridView ScrollView TextVi
  • 网银木马TrickBot的分析调试笔记

    Trickbot描述 Trickbot是2016年出现的一种网银木马 它以大银行的客户为目标 窃取他们的信息 自出现以来 新的变体不断出现 每次都有新的技巧和模块更新 Trickbot是一种模块化恶意软件 包括针对其恶意活动的不同模块 主要
  • Elasticsearch使用教程

    下载ES elasticsearch的下载地址 https www elastic co cn downloads elasticsearch ik分词器的下载地址 https github com medcl elasticsearch
  • csharp:百度翻译

    参考 http api fanyi baidu com api trans product index http developer baidu com wiki index php title E5 B8 AE E5 8A A9 E6 9
  • 如何在 Hive 中使用最近的值填补到缺失的日期中

    我花了几天的时间试图弄清楚如何在 Hive 中使用最近的值填补到缺失的日期中 但没有成功 原始表目前看起来像下表 account name available balance Date of balance Peter 50000 2021
  • NVIDIA GTC主题演讲内容学习<4>

    AI的进步为自动化 以前无法想象的任务开辟了新的机会 用子计算机行业的说法 边缘就是计算机接触世界的地方 如今 大量边缘应用可以在云中处理 例如 人们使用收集连接到云服务 对于许多边缘应用 由于响应时间 数据安全性或可靠性原因 或不间断高速
  • UE4 UE4 C++ Gameplay Abilities的GameplayCue

    UE4 UE4 C Gameplay Abilities的GameplayCue GAS参考文档 用GameplayCue 做一个玩家加血 buff效果 初始化 加血 加buff buff消失 加血的播放一个粒子特效 这个是用Gamepla
  • arm32上uImage镜像的生成过程

    arm32上uImage镜像的生成过程 arch arm boot Image cmd cmd arch arm boot Image arm himix200 linux objcopy O binary R comment S vmli
  • 【机器学习系列】变分推断第三讲:基于随机梯度上升法SGD的变分推断解法

    作者 CHEONG 公众号 AI机器学习与知识图谱 研究方向 自然语言处理与知识图谱 阅读本文之前 首先注意以下两点 1 机器学习系列文章常含有大量公式推导证明 为了更好理解 文章在最开始会给出本文的重要结论 方便最快速度理解本文核心 需要