变分推断

2023-11-18

一、概述

对于概率模型来说,如果从频率派角度来看就会是一个优化问题,从贝叶斯角度来看就会是一个积分问题

从贝叶斯角度来看,如果已有数据 x x x,对于新的样本 x ^ \hat{x} x^,需要得到:

p ( x ^ ∣ x ) = ∫ θ p ( x ^ , θ ∣ x ) d θ = ∫ θ p ( x ^ ∣ θ , x ) p ( θ ∣ x ) d θ = x ^ 与 x 独 立 ∫ θ p ( x ^ ∣ θ ) p ( θ ∣ x ) d θ = E θ ∣ x [ p ( x ^ ∣ θ ) ] p(\hat{x}|x)=\int _{\theta }p(\hat{x},\theta |x)\mathrm{d}\theta =\int _{\theta }p(\hat{x}|\theta ,x)p(\theta |x)\mathrm{d}\theta \\ \overset{\hat{x}与x独立}{=}\int _{\theta }p(\hat{x}|\theta)p(\theta |x)\mathrm{d}\theta =E_{\theta |x}[p(\hat{x}|\theta )] p(x^x)=θp(x^,θx)dθ=θp(x^θ,x)p(θx)dθ=x^xθp(x^θ)p(θx)dθ=Eθx[p(x^θ)]

如果新样本和数据集独立,那么推断就是概率分布依参数后验分布的期望。推断问题的中⼼是参数后验分布的求解,推断分为:

  1. 精确推断
  2. 近似推断-参数空间无法精确求解
    ①确定性近似-如变分推断
    ②随机近似-如 MCMC,MH,Gibbs

二、公式导出

有以下数据:

x x x:observed variable → X : { x i } i = 1 N \rightarrow X:\left \{x_{i}\right \}_{i=1}^{N} X:{xi}i=1N
z z z:latent variable + parameter → Z : { z i } i = 1 N \rightarrow Z:\left \{z_{i}\right \}_{i=1}^{N} Z:{zi}i=1N
( X , Z ) (X,Z) (X,Z):complete data

我们记 z z z为隐变量和参数的集合。接着我们变换概率 p ( x ) p(x) p(x)的形式然后引入分布 q ( z ) q(z) q(z)

l o g    p ( x ) = l o g    p ( x , z ) − l o g    p ( z ∣ x ) = l o g    p ( x , z ) q ( z ) − l o g    p ( z ∣ x ) q ( z ) log\; p(x)=log\; p(x,z)-log\; p(z|x)=log\; \frac{p(x,z)}{q(z)}-log\; \frac{p(z|x)}{q(z)} logp(x)=logp(x,z)logp(zx)=logq(z)p(x,z)logq(z)p(zx)

式子两边同时对 q ( z ) q(z) q(z)求积分:

左 边 = ∫ z q ( z ) ⋅ l o g    p ( x ∣ θ ) d z = l o g    p ( x ∣ θ ) ∫ z q ( z ) d z = l o g    p ( x ∣ θ ) 右 边 = ∫ z q ( z ) l o g    p ( x , z ∣ θ ) q ( z ) d z ⏟ E L B O ( e v i d e n c e    l o w e r    b o u n d ) − ∫ z q ( z ) l o g    p ( z ∣ x , θ ) q ( z ) d z ⏟ K L ( q ( z ) ∣ ∣ p ( z ∣ x , θ ) ) = L ( q ) ⏟ 变 分 + K L ( q ∣ ∣ p ) ⏟ ≥ 0 左边=\int _{z}q(z)\cdot log\; p(x |\theta )\mathrm{d}z=log\; p(x|\theta )\int _{z}q(z )\mathrm{d}z=log\; p(x|\theta )\\ 右边=\underset{ELBO(evidence\; lower\; bound)}{\underbrace{\int _{z}q(z)log\; \frac{p(x,z|\theta )}{q(z)}\mathrm{d}z}}\underset{KL(q(z)||p(z|x,\theta ))}{\underbrace{-\int _{z}q(z)log\; \frac{p(z|x,\theta )}{q(z)}\mathrm{d}z}}\\ =\underset{变分}{\underbrace{L(q)}} + \underset{\geq 0}{\underbrace{KL(q||p)}} =zq(z)logp(xθ)dz=logp(xθ)zq(z)dz=logp(xθ)=ELBO(evidencelowerbound) zq(z)logq(z)p(x,zθ)dzKL(q(z)p(zx,θ)) zq(z)logq(z)p(zx,θ)dz= L(q)+0 KL(qp)

我们的目的是找到一个使得 q q q p p p更接近,也就是使 K L ( q ∣ ∣ p ) KL(q||p) KL(qp)越小越好,也就是要使 L ( q ) L(q) L(q)越大越好:

q ~ ( z ) = a r g m a x q ( z )    L ( q ) ⇒ q ~ ( z ) ≈ p ( z ∣ x ) \tilde{q}(z)=\underset{q(z)}{argmax}\; L(q)\Rightarrow \tilde{q}(z)\approx p(z|x) q~(z)=q(z)argmaxL(q)q~(z)p(zx)

在变分推断中我们对 q ( z ) q(z) q(z)做以下假设(基于平均场假设的变分推断),也就是说我们把多维变量的不同维度分为M组,组与组之间是相互独立的:

q ( z ) = ∏ i = 1 M q i ( z i ) q(z)=\prod_{i=1}^{M}q_{i}(z_{i}) q(z)=i=1Mqi(zi)

求解时我们固定 q i ( z i ) , i ≠ j q_{i}(z_{i}),i\neq j qi(zi),i=j来求 q j ( z j ) q_{j}(z_{j}) qj(zj),接下来将 L ( q ) L(q) L(q)写作两部分:

L ( q ) = ∫ z q ( z ) l o g    p ( x , z ) d z ⏟ ① − ∫ z q ( z ) l o g    q ( z ) d z ⏟ ② L(q)=\underset{①}{\underbrace{\int _{z}q(z)log\; p(x,z)\mathrm{d}z}}-\underset{②}{\underbrace{\int _{z}q(z)log\; q(z)\mathrm{d}z}} L(q)= zq(z)logp(x,z)dz zq(z)logq(z)dz

对于①:

① = ∫ z ∏ i = 1 M q i ( z i ) l o g    p ( x , z ) d z 1 d z 2 ⋯ d z M = ∫ z j q j ( z j ) ( ∫ z − z j ∏ i ≠ j M q i ( z i ) l o g    p ( x , z ) d z 1 d z 2 ⋯ d z M ( i ≠ j ) ) ⏟ ∫ z − z j l o g    p ( x , z ) ∏ i ≠ j M q i ( z i ) d z i d z j = ∫ z j q j ( z j ) ⋅ E ∏ i ≠ j M q i ( z i ) [ l o g    p ( x , z ) ] ⋅ d z j ①=\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; p(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{j}}q_{j}(z_{j})\underset{\int _{z-z_{j}}log\; p(x,z)\prod_{i\neq j}^{M}q_{i}(z_{i})\mathrm{d}z_{i}}{\underbrace{\left (\int _{z-z_{j}}\prod_{i\neq j}^{M}q_{i}(z_{i})log\; p(x,z)\underset{(i\neq j)}{\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}}\right )}}\mathrm{d}z_{j}\\ =\int _{z_{j}}q_{j}(z_{j})\cdot E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p(x,z)]\cdot \mathrm{d}z_{j} =zi=1Mqi(zi)logp(x,z)dz1dz2dzM=zjqj(zj)zzjlogp(x,z)i=jMqi(zi)dzi zzji=jMqi(zi)logp(x,z)(i=j)dz1dz2dzMdzj=zjqj(zj)Ei=jMqi(zi)[logp(x,z)]dzj

对于②:

② = ∫ z q ( z ) l o g    q ( z ) d z = ∫ z ∏ i = 1 M q i ( z i ) ∑ i = 1 M l o g    q i ( z i ) d z = ∫ z ∏ i = 1 M q i ( z i ) [ l o g    q 1 ( z 1 ) + l o g    q 2 ( z 2 ) + ⋯ + l o g    q M ( z M ) ] d z 其 中 ∫ z ∏ i = 1 M q i ( z i ) l o g    q 1 ( z 1 ) d z = ∫ z 1 z 2 ⋯ z M q 1 ( z 1 ) q 2 ( z 2 ) ⋯ q M ( z M ) ⋅ l o g    q 1 ( z 1 ) d z 1 d z 2 ⋯ d z M = ∫ z 1 q 1 ( z 1 ) l o g    q 1 ( z 1 ) d z 1 ⋅ ∫ z 2 q 2 ( z 2 ) d z 2 ⏟ = 1 ⋅ ∫ z 3 q 3 ( z 3 ) d z 3 ⏟ = 1 ⋯ ∫ z M q M ( z M ) d z M ⏟ = 1 = ∫ z 1 q 1 ( z 1 ) l o g    q 1 ( z 1 ) d z 1 也 就 是 说 ∫ z ∏ i = 1 M q i ( z i ) l o g    q k ( z k ) d z = ∫ z k q k ( z k ) l o g    q k ( z k ) d z k 则 ② = ∑ i = 1 M ∫ z i q i ( z i ) l o g    q i ( z i ) d z i = ∫ z j q j ( z j ) l o g    q j ( z j ) d z j + C ②=\int _{z}q(z)log\; q(z)\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})\sum_{i=1}^{M}log\; q_{i}(z_{i})\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})[log\; q_{1}(z_{1})+log\; q_{2}(z_{2})+\cdots +log\; q_{M}(z_{M})]\mathrm{d}z\\ 其中\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{1}(z_{1})\mathrm{d}z\\ =\int _{z_{1}z_{2}\cdots z_{M}}q_{1}(z_{1})q_{2}(z_{2})\cdots q_{M}(z_{M})\cdot log\; q_{1}(z_{1})\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}\cdot \underset{=1}{\underbrace{\int _{z_{2}}q_{2}(z_{2})\mathrm{d}z_{2}}}\cdot \underset{=1}{\underbrace{\int _{z_{3}}q_{3}(z_{3})\mathrm{d}z_{3}}}\cdots \underset{=1}{\underbrace{\int _{z_{M}}q_{M}(z_{M})\mathrm{d}z_{M}}}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}\\ 也就是说\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{k}(z_{k})\mathrm{d}z=\int _{z_{k}}q_{k}(z_{k})log\; q_{k}(z_{k})\mathrm{d}z_{k}\\ 则②=\sum_{i=1}^{M}\int _{z_{i}}q_{i}(z_{i})log\; q_{i}(z_{i})\mathrm{d}z_{i}\\ =\int _{z_{j}}q_{j}(z_{j})log\; q_{j}(z_{j})\mathrm{d}z_{j}+C =zq(z)logq(z)dz=zi=1Mqi(zi)i=1Mlogqi(zi)dz=zi=1Mqi(zi)[logq1(z1)+logq2(z2)++logqM(zM)]dzzi=1Mqi(zi)logq1(z1)dz=z1z2zMq1(z1)q2(z2)qM(zM)logq1(z1)dz1dz2dzM=z1q1(z1)logq1(z1)dz1=1 z2q2(z2)dz2=1 z3q3(z3)dz3=1 zMqM(zM)dzM=z1q1(z1)logq1(z1)dz1zi=1Mqi(zi)logqk(zk)dz=zkqk(zk)logqk(zk)dzk=i=1Mziqi(zi)logqi(zi)dzi=zjqj(zj)logqj(zj)dzj+C

然后我们可以得到 ① − ② ①-②

首 先 ① = ∫ z j q j ( z j ) ⋅ E ∏ i ≠ j M q i ( z i ) [ l o g    p ( x , z ) ] ⏟ 写 作 l o g    p ^ ( x , z j ) ⋅ d z j 然 后 ① − ② = ∫ z j q j ( z j ) ⋅ l o g p ^ ( x , z j ) q j ( z j ) d z j + C ∫ z j q j ( z j ) ⋅ l o g p ^ ( x , z j ) q j ( z j ) d z j = − K L ( q j ( z j ) ∣ ∣ p ^ ( x , z j ) ) ≤ 0 首先①=\int _{z_{j}}q_{j}(z_{j})\cdot\underset{写作log\; \hat{p}(x,z_{j})}{ \underbrace{E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p(x,z)]}}\cdot \mathrm{d}z_{j}\\ 然后①-②=\int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}+C\\ \int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}=-KL(q_{j}(z_{j})||\hat{p}(x,z_{j}))\leq 0 =zjqj(zj)logp^(x,zj) Ei=jMqi(zi)[logp(x,z)]dzj=zjqj(zj)logqj(zj)p^(x,zj)dzj+Czjqj(zj)logqj(zj)p^(x,zj)dzj=KL(qj(zj)p^(x,zj))0

q j ( z j ) = p ^ ( x , z j ) q_{j}(z_{j})=\hat{p}(x,z_{j}) qj(zj)=p^(x,zj)才能得到最⼤值。

三、回顾EM算法

回想一下广义EM算法中,我们需要固定 θ \theta θ然后求解与 p p p最接近的 q q q,这里就可以使用变分推断的方法,我们有如下式子:

l o g    p θ ( x ) = E L B O ⏟ L ( q ) + K L ( q ∣ ∣ p ) ⏟ ≥ 0 ≥ L ( q ) log\; p_{\theta }(x)=\underset{L(q)}{\underbrace{ELBO}}+\underset{\geq 0}{\underbrace{KL(q||p)}}\geq L(q) logpθ(x)=L(q) ELBO+0 KL(qp)L(q)

然后求解 q q q

q ^ = a r g m i n q    K L ( q ∣ ∣ p ) = a r g m a x q    L ( q ) \hat{q}=\underset{q}{argmin}\; KL(q||p)=\underset{q}{argmax}\; L(q) q^=qargminKL(qp)=qargmaxL(q)

使用上述平均场变分推断的话,我们就可以得出以下结果(注意这里 z i z_i zi不是代表 z z z的第 i i i个维度):

l o g    q j ( z j ) = E ∏ i ≠ j M q i ( z i ) [ l o g    p θ ( x , z ) ] = ∫ z 1 ∫ z 2 ⋯ ∫ z j − 1 ∫ z j + 1 ⋯ ∫ z M q 1 q 2 ⋯ q j − 1 q j + 1 ⋯ q M ⋅ l o g    p θ ( x , z ) d z 1 d z 2 ⋯ d z j − 1 d z j + 1 ⋯ d z M log\; q_{j}(z_{j})=E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p_{\theta }(x,z)]\\ =\int _{z_{1}}\int _{z_{2}}\cdots \int _{z_{j-1}}\int _{z_{j+1}}\cdots \int _{z_{M}}q_{1}q_{2}\cdots q_{j-1}q_{j+1}\cdots q_{M}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{j-1}\mathrm{d}z_{j+1}\cdots \mathrm{d}z_{M} logqj(zj)=Ei=jMqi(zi)[logpθ(x,z)]=z1z2zj1zj+1zMq1q2qj1qj+1qMlogpθ(x,z)dz1dz2dzj1dzj+1dzM

一次迭代求解的过程如下:

l o g    q ^ 1 ( z 1 ) = ∫ z 2 ⋯ ∫ z M q 2 ⋯ q M ⋅ l o g    p θ ( x , z ) d z 2 ⋯ d z M l o g    q ^ 2 ( z 2 ) = ∫ z 1 ∫ z 3 ⋯ ∫ z M q ^ 1 q 3 ⋯ q M ⋅ l o g    p θ ( x , z ) d z 1 d z 3 ⋯ d z M ⋮ l o g    q ^ M ( z M ) = ∫ z 1 ⋯ ∫ z M − 1 q ^ 1 ⋯ q ^ M − 1 ⋅ l o g    p θ ( x , z ) d z 1 ⋯ d z M − 1 log\; \hat{q}_{1}(z_{1})=\int _{z_{2}}\cdots \int _{z_{M}}q_{2}\cdots q_{M}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ log\; \hat{q}_{2}(z_{2})=\int _{z_{1}}\int _{z_{3}}\cdots \int _{z_{M}}\hat{q}_{1}q_{3}\cdots q_{M}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{3}\cdots \mathrm{d}z_{M}\\ \vdots \\ log\; \hat{q}_{M}(z_{M})=\int _{z_{1}}\cdots \int _{z_{M-1}}\hat{q}_{1}\cdots \hat{q}_{M-1}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\cdots \mathrm{d}z_{M-1} logq^1(z1)=z2zMq2qMlogpθ(x,z)dz2dzMlogq^2(z2)=z1z3zMq^1q3qMlogpθ(x,z)dz1dz3dzMlogq^M(zM)=z1zM1q^1q^M1logpθ(x,z)dz1dzM1

我们看到,对每⼀个 q j ( z j ) q_{j}(z_{j}) qj(zj),都是固定其余的 q i ( z i ) q_{i}(z_{i}) qi(zi),求这个值,于是可以使⽤坐标上升的⽅法进⾏迭代求解,上⾯的推导针对单个样本,但是对数据集也是适⽤的。

基于平均场假设的变分推断存在⼀些问题:
①假设太强,⾮常复杂的情况下,假设不适⽤;
②期望中的积分,可能⽆法计算。

四、随机梯度变分推断(SGVI)

  1. 直接求导数的方法

Z Z Z X X X的过程叫做⽣成过程或译码,从 X X X Z Z Z过程叫推断过程或编码过程,基于平均场的变分推断可以导出坐标上升的算法,但是这个假设在⼀些情况下假设太强,同时积分也不⼀定能算。我们知道,优化⽅法除了坐标上升,还有梯度上升的⽅式,我们希望通过梯度上升来得到变分推断的另⼀种算法。

假定 q ( Z ) = q ϕ ( Z ) q(Z)=q_{\phi }(Z) q(Z)=qϕ(Z),是和 ϕ \phi ϕ这个参数相连的概率分布。于是 a r g m a x q ( Z )    L ( q ) = a r g m a x ϕ    L ( ϕ ) \underset{q(Z)}{argmax}\; L(q)=\underset{\phi }{argmax}\; L(\phi ) q(Z)argmaxL(q)=ϕargmaxL(ϕ),其中 L ( ϕ ) = E q ϕ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] L(\phi )=E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)] L(ϕ)=Eqϕ[logpθ(x,z)logqϕ(z)],这里的 x x x表示的是样本。

∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] = ∇ ϕ ∫ q ϕ ( z ) [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] d z = ∫ ∇ ϕ q ϕ ( z ) ⋅ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] d z ⏟ ① + ∫ q ϕ ( z ) ∇ ϕ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] d z ⏟ ② 其 中 ② = ∫ q ϕ ( z ) ∇ ϕ [ l o g    p θ ( x , z ) ⏟ 与 ϕ 无 关 − l o g    q ϕ ( z ) ] d z = − ∫ q ϕ ( z ) ∇ ϕ l o g    q ϕ ( z ) d z = − ∫ q ϕ ( z ) 1 q ϕ ( z ) ∇ ϕ q ϕ ( z ) d z = − ∫ ∇ ϕ q ϕ ( z ) d z = − ∇ ϕ ∫ q ϕ ( z ) d z = − ∇ ϕ 1 = 0 因 此 ∇ ϕ L ( ϕ ) = ① = ∫ ∇ ϕ q ϕ ( z ) ⋅ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] d z = ∫ q ϕ ( z ) ∇ ϕ l o g    q ϕ ( z ) ⋅ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] d z = E q ϕ [ ( ∇ ϕ l o g    q ϕ ( z ) ) ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ) ] \nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\underset{①}{\underbrace{\int \nabla_{\phi }q_{\phi }(z)\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}+\underset{②}{\underbrace{\int q_{\phi }(z)\nabla_{\phi }[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}\\ 其中②=\int q_{\phi }(z)\nabla_{\phi }[\underset{与\phi 无关}{\underbrace{log\; p_{\theta }(x,z)}}-log\; q_{\phi }(z)]\mathrm{d}z\\ =-\int q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)\mathrm{d}z\\ =-\int q_{\phi }(z)\frac{1}{q_{\phi }(z)}\nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\int \nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }\int q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }1\\ =0\\ 因此\nabla_{\phi }L(\phi )=①\\ =\int {\color{Red}{\nabla_{\phi }q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\int {\color{Red}{q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))] ϕL(ϕ)=ϕEqϕ[logpθ(x,z)logqϕ(z)]=ϕqϕ(z)[logpθ(x,z)logqϕ(z)]dz= ϕqϕ(z)[logpθ(x,z)logqϕ(z)]dz+ qϕ(z)ϕ[logpθ(x,z)logqϕ(z)]dz=qϕ(z)ϕ[ϕ logpθ(x,z)logqϕ(z)]dz=qϕ(z)ϕlogqϕ(z)dz=qϕ(z)qϕ(z)1ϕqϕ(z)dz=ϕqϕ(z)dz=ϕqϕ(z)dz=ϕ1=0ϕL(ϕ)==ϕqϕ(z)[logpθ(x,z)logqϕ(z)]dz=qϕ(z)ϕlogqϕ(z)[logpθ(x,z)logqϕ(z)]dz=Eqϕ[(ϕlogqϕ(z))(logpθ(x,z)logqϕ(z))]

这个期望可以通过蒙特卡洛采样来近似,从⽽得到梯度,然后利⽤梯度上升的⽅法来得到参数:

z l ∼ q ϕ ( z ) E q ϕ [ ( ∇ ϕ l o g    q ϕ ( z ) ) ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ) ] ∼ 1 L ∑ i = 1 L ( ∇ ϕ l o g    q ϕ ( z ) ) ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ) z^{l}\sim q_{\phi }(z)\\ E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\sim \frac{1}{L}\sum_{i=1}^{L}(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z)) zlqϕ(z)Eqϕ[(ϕlogqϕ(z))(logpθ(x,z)logqϕ(z))]L1i=1L(ϕlogqϕ(z))(logpθ(x,z)logqϕ(z))

但是由于求和符号中存在⼀个对数项,于是直接采样的⽅差很⼤,需要采样的样本⾮常多。为了解决⽅差太⼤的问题,我们采⽤重参数化技巧(Reparameterization)。

  1. 重参数化技巧

我们取 z = g ϕ ( ε , x ) , ε ∼ p ( ε ) z=g_{\phi }(\varepsilon ,x),\varepsilon \sim p(\varepsilon ) z=gϕ(ε,x),εp(ε),对于 z ∼ q ϕ ( z ∣ x ) z\sim q_{\phi }(z|x) zqϕ(zx),我们有 ∣ q ϕ ( z ∣ x ) d z ∣ = ∣ p ( ε ) d ε ∣ \left | q_{\phi }(z|x)\mathrm{d}z \right |=\left | p(\varepsilon )\mathrm{d}\varepsilon \right | qϕ(zx)dz=p(ε)dε。代入上面的梯度中:

∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] = ∇ ϕ ∫ q ϕ ( z ) [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] d z = ∇ ϕ ∫ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] q ϕ ( z ) d z = ∇ ϕ ∫ [ l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] p ( ε ) d ε = ∇ ϕ E p ( ε ) ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ] = E p ( ε ) [ ∇ ϕ ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ) ] = E p ( ε ) [ ∇ z ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ) ∇ ϕ z ] = E p ( ε ) [ ∇ z ( l o g    p θ ( x , z ) − l o g    q ϕ ( z ) ) ∇ ϕ g ϕ ( ε , x ) ] \nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]q_{\phi }(z)\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]p(\varepsilon )\mathrm{d}\varepsilon \\ =\nabla_{\phi }E_{p(\varepsilon )}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =E_{p(\varepsilon )}[\nabla_{\phi }(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }z]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }g_{\phi }(\varepsilon ,x)] ϕL(ϕ)=ϕEqϕ[logpθ(x,z)logqϕ(z)]=ϕqϕ(z)[logpθ(x,z)logqϕ(z)]dz=ϕ[logpθ(x,z)logqϕ(z)]qϕ(z)dz=ϕ[logpθ(x,z)logqϕ(z)]p(ε)dε=ϕEp(ε)(logpθ(x,z)logqϕ(z)]=Ep(ε)[ϕ(logpθ(x,z)logqϕ(z))]=Ep(ε)[z(logpθ(x,z)logqϕ(z))ϕz]=Ep(ε)[z(logpθ(x,z)logqϕ(z))ϕgϕ(ε,x)]

对这个式⼦进⾏蒙特卡洛采样,然后计算期望,得到梯度。

SGVI的迭代过程为:

ϕ t + 1 ← ϕ t + λ t ⋅ ∇ ϕ L ( ϕ ) \phi ^{t+1}\leftarrow \phi ^{t}+\lambda ^{t}\cdot \nabla_{\phi }L(\phi ) ϕt+1ϕt+λtϕL(ϕ)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

变分推断 的相关文章

随机推荐

  • git报错:warning: unable to access

    git操作的时候出现该错误 warning unable to access Users a10 12 config git ignore Permission denied warning unable to access Users a
  • 一个女孩的就业之路(同济大学BBS上两年不沉的帖子)

    文章很长 有机会见到这篇文章的童鞋 希望能耐心看完 其他不多说 我是2005年毕业的 偶尔来这里看看 不常灌水 今天来随意写下一些 如果对各位有任何的帮助 是我衷心所愿 1 考研与就业 2004年的暑假 我和大多数人一样 艰难的抉择 究竟是
  • NacosValue 注解

    NacosValue 定义在 nacos api 工程中 com alibaba nacos api config annotation NacosValue 注解解析在 nacos spring project 工程中 com aliba
  • 阻塞队列java实现

    阻塞队列 目前队列存在的问题 1 很多场景要求分离生产者和消费者两个角色 它们得由不同的线程来担当 而之前的实现根本没有考虑线程安全问题 2 队列为空 那么在之前的实现里会返回null 如果硬拿到一个元素 只能不断循环尝试 3 队列为满 那
  • PHP魔术方(2)

    PHP魔术方 2 文章目录 PHP魔术方 2 1 toString 和 invoke tostring 和 invoke 两者的触发形式接近 2 call 用来检测所调用的成员方法是否存在 3 callStatic 4 get 5 set
  • 在Linux系统上用C++将主机名称转换为IPv4、IPv6地址

    在Linux系统上用C 将主机名称转换为IPv4 IPv6地址 功能 指定一个std string类型的主机名称 函数解析主机名称为IP地址 含IPv4和IPv6 解析结果以std vector
  • vue div高度自适应

    1 在 js文件中编写自定义指令 export default install Vue 在组件标签上绑定 v resizable 指令 并使用对象的形式通过绑定值传递宽度和高度以及最大 最小高度的值 在 bind 函数中 获取传递的值 并根
  • 走进区块链企业 I 用实践赋能实体产业,坚持提供价值服务的旺链科技

    作为华东师范大学MBA高材生 他在高科技制造 金融行业有着超过16年的业务咨询管理和技术架构经验 他是中国云体系产业创新联盟理事会常务理事 边缘计算产业联盟专家委员 也是原 Accenture资深总监 集成技术专家 而在如今话题正盛的 区块
  • linux创建,恢复和删除screen

    学习记录 侵删 目录 1 创建 2 恢复 3 删除 使用服务器训练模型时 如果服务器断开 之前的训练结果显示的终端就不好找到了 貌似可以通过线程去恢复 没试过 可以使用screen 训练前先打开一个screen 如果服务器断开 重连后可以恢
  • 最新免费版 Office 全家桶Copilot,Gamma+MindShow 两大ChatGPT AI创意工具GPT-4神器助力高效智能制作 PPT,一键生成,与AI智能对话修改PPT(免安装)

    目录 前言 ChatGPT MindShow 1 使用ChatGPT工具生成PPT内容 2 使用MindShow工具一键智能制作PPT MindShow简介 使用网页版制作 pdf转ppt GAMMA AI神器 GAMMA app介绍 注册
  • MySQL基础篇:sql_mode配置

    文章目录 零 简介 一 sql mode常用来解决的几类问题 二 sql mode包含的模式 三 sql mode各个选项作用示例 3 1 sql mode为空 对于不符合定义的值 会截断到符合定义类型 3 2 sql mode为ANSI
  • 编程语言用 Java 开发一个打飞机小游戏(附完整源码)

    编程语言用 Java 开发一个打飞机小游戏 附完整源码 上图 写在前面 技术源于分享 所以今天抽空把自己之前用java做过的小游戏整理贴出来给大家参考学习 java确实不适合写桌面应用 这里只是通过这个游戏让大家理解oop面向对象编程的过程
  • 【React】路由(详解)

    目录 单页应用程序 SPA 路由 前端路由 后端路由 路由的基本使用 使用步骤 常用组件说明 BrowserRouter和HashRouter的区别 路由的执行过程 默认路由 精确匹配 Switch的使用 重定向路由 嵌套路由 向路由组件传
  • 计算机网络体系结构 - 运输层

    一 运输层协议概述 运输层为应用进程之间提供端到端的逻辑通信 二 运输层的端口 端口 port 也称为协议端口号 protocol port number 对上层的应用进程进行标识 端口用一个16位端口号进行标志 端口号只具有本地意义 端口
  • 剑指offer-输出字符串所有种类的排列组合

    常规题 先校验长度 不符合则直接输出 符合则判断是否为最后一个字符 是则直接new对象输出 不是则交换begin和i位置的数字 再用递归输出 public class Test28 先校验 public static void permut
  • 笔试

    文章目录 前言 40 复位电路设计 1 recovery time和removal time 2 同步复位和异步复位 3 异步复位同步释放 本文参考 往期精彩 前言 嗨 今天来学习复位电路设计相关问题 微信关注 FPGA学习者 获取更多精彩
  • cec2017(python):红狐优化算法(Red fox optimization,RFO)求解cec2017

    一 红狐优化算法 红狐优化算法 Red fox optimization RFO 由Dawid Po ap和 Marcin Wo niak于2021年提出 该算法模拟了红狐的狩猎行为 具有收敛速度快 寻优精度高等优势 参考文献 Poap D
  • easyexcel读取excel将数据存到mysql【一个简单的例子】

    读取excel 1 xml里面增加maven
  • 使用Java程序向手机发送短信

    JAVA发送手机短信 有几种方法 1 使用webservice接口发送手机短信 这个可以使用sina提供的webservice进行发送 需要进行注册 2 使用短信mao的方式进行短信的发送 这种方式应该是比较的常用 前提是需要购买硬件设备
  • 变分推断

    一 概述 对于概率模型来说 如果从频率派角度来看就会是一个优化问题 从贝叶斯角度来看就会是一个积分问题 从贝叶斯角度来看 如果已有数据 x x x 对于新的样本 x hat x