对于概率模型来说,如果从频率派角度来看就会是一个优化问题,从贝叶斯角度来看就会是一个积分问题。
从贝叶斯角度来看,如果已有数据 x x x,对于新的样本 x ^ \hat{x} x^,需要得到:
p ( x ^ ∣ x ) = ∫ θ p ( x ^ , θ ∣ x ) d θ = ∫ θ p ( x ^ ∣ θ , x ) p ( θ ∣ x ) d θ = x ^ 与 x 独 立 ∫ θ p ( x ^ ∣ θ ) p ( θ ∣ x ) d θ = E θ ∣ x [ p ( x ^ ∣ θ ) ] p(\hat{x}|x)=\int _{\theta }p(\hat{x},\theta |x)\mathrm{d}\theta =\int _{\theta }p(\hat{x}|\theta ,x)p(\theta |x)\mathrm{d}\theta \\ \overset{\hat{x}与x独立}{=}\int _{\theta }p(\hat{x}|\theta)p(\theta |x)\mathrm{d}\theta =E_{\theta |x}[p(\hat{x}|\theta )] p(x^∣x)=∫θp(x^,θ∣x)dθ=∫θp(x^∣θ,x)p(θ∣x)dθ=x^与x独立∫θp(x^∣θ)p(θ∣x)dθ=Eθ∣x[p(x^∣θ)]
如果新样本和数据集独立,那么推断就是概率分布依参数后验分布的期望。推断问题的中⼼是参数后验分布的求解,推断分为:
有以下数据:
x x x:observed variable → X : { x i } i = 1 N \rightarrow X:\left \{x_{i}\right \}_{i=1}^{N} →X:{xi}i=1N
z z z:latent variable + parameter → Z : { z i } i = 1 N \rightarrow Z:\left \{z_{i}\right \}_{i=1}^{N} →Z:{zi}i=1N
( X , Z ) (X,Z) (X,Z):complete data
我们记 z z z为隐变量和参数的集合。接着我们变换概率 p ( x ) p(x) p(x)的形式然后引入分布 q ( z ) q(z) q(z):
l o g p ( x ) = l o g p ( x , z ) − l o g p ( z ∣ x ) = l o g p ( x , z ) q ( z ) − l o g p ( z ∣ x ) q ( z ) log\; p(x)=log\; p(x,z)-log\; p(z|x)=log\; \frac{p(x,z)}{q(z)}-log\; \frac{p(z|x)}{q(z)} logp(x)=logp(x,z)−logp(z∣x)=logq(z)p(x,z)−logq(z)p(z∣x)
式子两边同时对 q ( z ) q(z) q(z)求积分:
左 边 = ∫ z q ( z ) ⋅ l o g p ( x ∣ θ ) d z = l o g p ( x ∣ θ ) ∫ z q ( z ) d z = l o g p ( x ∣ θ ) 右 边 = ∫ z q ( z ) l o g p ( x , z ∣ θ ) q ( z ) d z ⏟ E L B O ( e v i d e n c e l o w e r b o u n d ) − ∫ z q ( z ) l o g p ( z ∣ x , θ ) q ( z ) d z ⏟ K L ( q ( z ) ∣ ∣ p ( z ∣ x , θ ) ) = L ( q ) ⏟ 变 分 + K L ( q ∣ ∣ p ) ⏟ ≥ 0 左边=\int _{z}q(z)\cdot log\; p(x |\theta )\mathrm{d}z=log\; p(x|\theta )\int _{z}q(z )\mathrm{d}z=log\; p(x|\theta )\\ 右边=\underset{ELBO(evidence\; lower\; bound)}{\underbrace{\int _{z}q(z)log\; \frac{p(x,z|\theta )}{q(z)}\mathrm{d}z}}\underset{KL(q(z)||p(z|x,\theta ))}{\underbrace{-\int _{z}q(z)log\; \frac{p(z|x,\theta )}{q(z)}\mathrm{d}z}}\\ =\underset{变分}{\underbrace{L(q)}} + \underset{\geq 0}{\underbrace{KL(q||p)}} 左边=∫zq(z)⋅logp(x∣θ)dz=logp(x∣θ)∫zq(z)dz=logp(x∣θ)右边=ELBO(evidencelowerbound) ∫zq(z)logq(z)p(x,z∣θ)dzKL(q(z)∣∣p(z∣x,θ)) −∫zq(z)logq(z)p(z∣x,θ)dz=变分 L(q)+≥0 KL(q∣∣p)
我们的目的是找到一个使得 q q q与 p p p更接近,也就是使 K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p)越小越好,也就是要使 L ( q ) L(q) L(q)越大越好:
q ~ ( z ) = a r g m a x q ( z ) L ( q ) ⇒ q ~ ( z ) ≈ p ( z ∣ x ) \tilde{q}(z)=\underset{q(z)}{argmax}\; L(q)\Rightarrow \tilde{q}(z)\approx p(z|x) q~(z)=q(z)argmaxL(q)⇒q~(z)≈p(z∣x)
在变分推断中我们对 q ( z ) q(z) q(z)做以下假设(基于平均场假设的变分推断),也就是说我们把多维变量的不同维度分为M组,组与组之间是相互独立的:
q ( z ) = ∏ i = 1 M q i ( z i ) q(z)=\prod_{i=1}^{M}q_{i}(z_{i}) q(z)=i=1∏Mqi(zi)
求解时我们固定 q i ( z i ) , i ≠ j q_{i}(z_{i}),i\neq j qi(zi),i=j来求 q j ( z j ) q_{j}(z_{j}) qj(zj),接下来将 L ( q ) L(q) L(q)写作两部分:
L ( q ) = ∫ z q ( z ) l o g p ( x , z ) d z ⏟ ① − ∫ z q ( z ) l o g q ( z ) d z ⏟ ② L(q)=\underset{①}{\underbrace{\int _{z}q(z)log\; p(x,z)\mathrm{d}z}}-\underset{②}{\underbrace{\int _{z}q(z)log\; q(z)\mathrm{d}z}} L(q)=① ∫zq(z)logp(x,z)dz−② ∫zq(z)logq(z)dz
对于①:
① = ∫ z ∏ i = 1 M q i ( z i ) l o g p ( x , z ) d z 1 d z 2 ⋯ d z M = ∫ z j q j ( z j ) ( ∫ z − z j ∏ i ≠ j M q i ( z i ) l o g p ( x , z ) d z 1 d z 2 ⋯ d z M ( i ≠ j ) ) ⏟ ∫ z − z j l o g p ( x , z ) ∏ i ≠ j M q i ( z i ) d z i d z j = ∫ z j q j ( z j ) ⋅ E ∏ i ≠ j M q i ( z i ) [ l o g p ( x , z ) ] ⋅ d z j ①=\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; p(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{j}}q_{j}(z_{j})\underset{\int _{z-z_{j}}log\; p(x,z)\prod_{i\neq j}^{M}q_{i}(z_{i})\mathrm{d}z_{i}}{\underbrace{\left (\int _{z-z_{j}}\prod_{i\neq j}^{M}q_{i}(z_{i})log\; p(x,z)\underset{(i\neq j)}{\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}}\right )}}\mathrm{d}z_{j}\\ =\int _{z_{j}}q_{j}(z_{j})\cdot E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p(x,z)]\cdot \mathrm{d}z_{j} ①=∫zi=1∏Mqi(zi)logp(x,z)dz1dz2⋯dzM=∫zjqj(zj)∫z−zjlogp(x,z)∏i=jMqi(zi)dzi ⎝⎛∫z−zji=j∏Mqi(zi)logp(x,z)(i=j)dz1dz2⋯dzM⎠⎞dzj=∫zjqj(zj)⋅E∏i=jMqi(zi)[logp(x,z)]⋅dzj
对于②:
② = ∫ z q ( z ) l o g q ( z ) d z = ∫ z ∏ i = 1 M q i ( z i ) ∑ i = 1 M l o g q i ( z i ) d z = ∫ z ∏ i = 1 M q i ( z i ) [ l o g q 1 ( z 1 ) + l o g q 2 ( z 2 ) + ⋯ + l o g q M ( z M ) ] d z 其 中 ∫ z ∏ i = 1 M q i ( z i ) l o g q 1 ( z 1 ) d z = ∫ z 1 z 2 ⋯ z M q 1 ( z 1 ) q 2 ( z 2 ) ⋯ q M ( z M ) ⋅ l o g q 1 ( z 1 ) d z 1 d z 2 ⋯ d z M = ∫ z 1 q 1 ( z 1 ) l o g q 1 ( z 1 ) d z 1 ⋅ ∫ z 2 q 2 ( z 2 ) d z 2 ⏟ = 1 ⋅ ∫ z 3 q 3 ( z 3 ) d z 3 ⏟ = 1 ⋯ ∫ z M q M ( z M ) d z M ⏟ = 1 = ∫ z 1 q 1 ( z 1 ) l o g q 1 ( z 1 ) d z 1 也 就 是 说 ∫ z ∏ i = 1 M q i ( z i ) l o g q k ( z k ) d z = ∫ z k q k ( z k ) l o g q k ( z k ) d z k 则 ② = ∑ i = 1 M ∫ z i q i ( z i ) l o g q i ( z i ) d z i = ∫ z j q j ( z j ) l o g q j ( z j ) d z j + C ②=\int _{z}q(z)log\; q(z)\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})\sum_{i=1}^{M}log\; q_{i}(z_{i})\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})[log\; q_{1}(z_{1})+log\; q_{2}(z_{2})+\cdots +log\; q_{M}(z_{M})]\mathrm{d}z\\ 其中\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{1}(z_{1})\mathrm{d}z\\ =\int _{z_{1}z_{2}\cdots z_{M}}q_{1}(z_{1})q_{2}(z_{2})\cdots q_{M}(z_{M})\cdot log\; q_{1}(z_{1})\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}\cdot \underset{=1}{\underbrace{\int _{z_{2}}q_{2}(z_{2})\mathrm{d}z_{2}}}\cdot \underset{=1}{\underbrace{\int _{z_{3}}q_{3}(z_{3})\mathrm{d}z_{3}}}\cdots \underset{=1}{\underbrace{\int _{z_{M}}q_{M}(z_{M})\mathrm{d}z_{M}}}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}\\ 也就是说\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{k}(z_{k})\mathrm{d}z=\int _{z_{k}}q_{k}(z_{k})log\; q_{k}(z_{k})\mathrm{d}z_{k}\\ 则②=\sum_{i=1}^{M}\int _{z_{i}}q_{i}(z_{i})log\; q_{i}(z_{i})\mathrm{d}z_{i}\\ =\int _{z_{j}}q_{j}(z_{j})log\; q_{j}(z_{j})\mathrm{d}z_{j}+C ②=∫zq(z)logq(z)dz=∫zi=1∏Mqi(zi)i=1∑Mlogqi(zi)dz=∫zi=1∏Mqi(zi)[logq1(z1)+logq2(z2)+⋯+logqM(zM)]dz其中∫zi=1∏Mqi(zi)logq1(z1)dz=∫z1z2⋯zMq1(z1)q2(z2)⋯qM(zM)⋅logq1(z1)dz1dz2⋯dzM=∫z1q1(z1)logq1(z1)dz1⋅=1 ∫z2q2(z2)dz2⋅=1 ∫z3q3(z3)dz3⋯=1 ∫zMqM(zM)dzM=∫z1q1(z1)logq1(z1)dz1也就是说∫zi=1∏Mqi(zi)logqk(zk)dz=∫zkqk(zk)logqk(zk)dzk则②=i=1∑M∫ziqi(zi)logqi(zi)dzi=∫zjqj(zj)logqj(zj)dzj+C
然后我们可以得到 ① − ② ①-② ①−②:
首 先 ① = ∫ z j q j ( z j ) ⋅ E ∏ i ≠ j M q i ( z i ) [ l o g p ( x , z ) ] ⏟ 写 作 l o g p ^ ( x , z j ) ⋅ d z j 然 后 ① − ② = ∫ z j q j ( z j ) ⋅ l o g p ^ ( x , z j ) q j ( z j ) d z j + C ∫ z j q j ( z j ) ⋅ l o g p ^ ( x , z j ) q j ( z j ) d z j = − K L ( q j ( z j ) ∣ ∣ p ^ ( x , z j ) ) ≤ 0 首先①=\int _{z_{j}}q_{j}(z_{j})\cdot\underset{写作log\; \hat{p}(x,z_{j})}{ \underbrace{E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p(x,z)]}}\cdot \mathrm{d}z_{j}\\ 然后①-②=\int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}+C\\ \int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}=-KL(q_{j}(z_{j})||\hat{p}(x,z_{j}))\leq 0 首先①=∫zjqj(zj)⋅写作logp^(x,zj) E∏i=jMqi(zi)[logp(x,z)]⋅dzj然后①−②=∫zjqj(zj)⋅logqj(zj)p^(x,zj)dzj+C∫zjqj(zj)⋅logqj(zj)p^(x,zj)dzj=−KL(qj(zj)∣∣p^(x,zj))≤0
当 q j ( z j ) = p ^ ( x , z j ) q_{j}(z_{j})=\hat{p}(x,z_{j}) qj(zj)=p^(x,zj)才能得到最⼤值。
回想一下广义EM算法中,我们需要固定 θ \theta θ然后求解与 p p p最接近的 q q q,这里就可以使用变分推断的方法,我们有如下式子:
l o g p θ ( x ) = E L B O ⏟ L ( q ) + K L ( q ∣ ∣ p ) ⏟ ≥ 0 ≥ L ( q ) log\; p_{\theta }(x)=\underset{L(q)}{\underbrace{ELBO}}+\underset{\geq 0}{\underbrace{KL(q||p)}}\geq L(q) logpθ(x)=L(q) ELBO+≥0 KL(q∣∣p)≥L(q)
然后求解 q q q:
q ^ = a r g m i n q K L ( q ∣ ∣ p ) = a r g m a x q L ( q ) \hat{q}=\underset{q}{argmin}\; KL(q||p)=\underset{q}{argmax}\; L(q) q^=qargminKL(q∣∣p)=qargmaxL(q)
使用上述平均场变分推断的话,我们就可以得出以下结果(注意这里 z i z_i zi不是代表 z z z的第 i i i个维度):
l o g q j ( z j ) = E ∏ i ≠ j M q i ( z i ) [ l o g p θ ( x , z ) ] = ∫ z 1 ∫ z 2 ⋯ ∫ z j − 1 ∫ z j + 1 ⋯ ∫ z M q 1 q 2 ⋯ q j − 1 q j + 1 ⋯ q M ⋅ l o g p θ ( x , z ) d z 1 d z 2 ⋯ d z j − 1 d z j + 1 ⋯ d z M log\; q_{j}(z_{j})=E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p_{\theta }(x,z)]\\ =\int _{z_{1}}\int _{z_{2}}\cdots \int _{z_{j-1}}\int _{z_{j+1}}\cdots \int _{z_{M}}q_{1}q_{2}\cdots q_{j-1}q_{j+1}\cdots q_{M}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{j-1}\mathrm{d}z_{j+1}\cdots \mathrm{d}z_{M} logqj(zj)=E∏i=jMqi(zi)[logpθ(x,z)]=∫z1∫z2⋯∫zj−1∫zj+1⋯∫zMq1q2⋯qj−1qj+1⋯qM⋅logpθ(x,z)dz1dz2⋯dzj−1dzj+1⋯dzM
一次迭代求解的过程如下:
l o g q ^ 1 ( z 1 ) = ∫ z 2 ⋯ ∫ z M q 2 ⋯ q M ⋅ l o g p θ ( x , z ) d z 2 ⋯ d z M l o g q ^ 2 ( z 2 ) = ∫ z 1 ∫ z 3 ⋯ ∫ z M q ^ 1 q 3 ⋯ q M ⋅ l o g p θ ( x , z ) d z 1 d z 3 ⋯ d z M ⋮ l o g q ^ M ( z M ) = ∫ z 1 ⋯ ∫ z M − 1 q ^ 1 ⋯ q ^ M − 1 ⋅ l o g p θ ( x , z ) d z 1 ⋯ d z M − 1 log\; \hat{q}_{1}(z_{1})=\int _{z_{2}}\cdots \int _{z_{M}}q_{2}\cdots q_{M}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ log\; \hat{q}_{2}(z_{2})=\int _{z_{1}}\int _{z_{3}}\cdots \int _{z_{M}}\hat{q}_{1}q_{3}\cdots q_{M}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{3}\cdots \mathrm{d}z_{M}\\ \vdots \\ log\; \hat{q}_{M}(z_{M})=\int _{z_{1}}\cdots \int _{z_{M-1}}\hat{q}_{1}\cdots \hat{q}_{M-1}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\cdots \mathrm{d}z_{M-1} logq^1(z1)=∫z2⋯∫zMq2⋯qM⋅logpθ(x,z)dz2⋯dzMlogq^2(z2)=∫z1∫z3⋯∫zMq^1q3⋯qM⋅logpθ(x,z)dz1dz3⋯dzM⋮logq^M(zM)=∫z1⋯∫zM−1q^1⋯q^M−1⋅logpθ(x,z)dz1⋯dzM−1
我们看到,对每⼀个 q j ( z j ) q_{j}(z_{j}) qj(zj),都是固定其余的 q i ( z i ) q_{i}(z_{i}) qi(zi),求这个值,于是可以使⽤坐标上升的⽅法进⾏迭代求解,上⾯的推导针对单个样本,但是对数据集也是适⽤的。
基于平均场假设的变分推断存在⼀些问题:
①假设太强,⾮常复杂的情况下,假设不适⽤;
②期望中的积分,可能⽆法计算。
从 Z Z Z到 X X X的过程叫做⽣成过程或译码,从 X X X到 Z Z Z过程叫推断过程或编码过程,基于平均场的变分推断可以导出坐标上升的算法,但是这个假设在⼀些情况下假设太强,同时积分也不⼀定能算。我们知道,优化⽅法除了坐标上升,还有梯度上升的⽅式,我们希望通过梯度上升来得到变分推断的另⼀种算法。
假定 q ( Z ) = q ϕ ( Z ) q(Z)=q_{\phi }(Z) q(Z)=qϕ(Z),是和 ϕ \phi ϕ这个参数相连的概率分布。于是 a r g m a x q ( Z ) L ( q ) = a r g m a x ϕ L ( ϕ ) \underset{q(Z)}{argmax}\; L(q)=\underset{\phi }{argmax}\; L(\phi ) q(Z)argmaxL(q)=ϕargmaxL(ϕ),其中 L ( ϕ ) = E q ϕ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] L(\phi )=E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)] L(ϕ)=Eqϕ[logpθ(x,z)−logqϕ(z)],这里的 x x x表示的是样本。
∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] = ∇ ϕ ∫ q ϕ ( z ) [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z = ∫ ∇ ϕ q ϕ ( z ) ⋅ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z ⏟ ① + ∫ q ϕ ( z ) ∇ ϕ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z ⏟ ② 其 中 ② = ∫ q ϕ ( z ) ∇ ϕ [ l o g p θ ( x , z ) ⏟ 与 ϕ 无 关 − l o g q ϕ ( z ) ] d z = − ∫ q ϕ ( z ) ∇ ϕ l o g q ϕ ( z ) d z = − ∫ q ϕ ( z ) 1 q ϕ ( z ) ∇ ϕ q ϕ ( z ) d z = − ∫ ∇ ϕ q ϕ ( z ) d z = − ∇ ϕ ∫ q ϕ ( z ) d z = − ∇ ϕ 1 = 0 因 此 ∇ ϕ L ( ϕ ) = ① = ∫ ∇ ϕ q ϕ ( z ) ⋅ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z = ∫ q ϕ ( z ) ∇ ϕ l o g q ϕ ( z ) ⋅ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z = E q ϕ [ ( ∇ ϕ l o g q ϕ ( z ) ) ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ) ] \nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\underset{①}{\underbrace{\int \nabla_{\phi }q_{\phi }(z)\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}+\underset{②}{\underbrace{\int q_{\phi }(z)\nabla_{\phi }[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}\\ 其中②=\int q_{\phi }(z)\nabla_{\phi }[\underset{与\phi 无关}{\underbrace{log\; p_{\theta }(x,z)}}-log\; q_{\phi }(z)]\mathrm{d}z\\ =-\int q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)\mathrm{d}z\\ =-\int q_{\phi }(z)\frac{1}{q_{\phi }(z)}\nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\int \nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }\int q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }1\\ =0\\ 因此\nabla_{\phi }L(\phi )=①\\ =\int {\color{Red}{\nabla_{\phi }q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\int {\color{Red}{q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))] ∇ϕL(ϕ)=∇ϕEqϕ[logpθ(x,z)−logqϕ(z)]=∇ϕ∫qϕ(z)[logpθ(x,z)−logqϕ(z)]dz=① ∫∇ϕqϕ(z)⋅[logpθ(x,z)−logqϕ(z)]dz+② ∫qϕ(z)∇ϕ[logpθ(x,z)−logqϕ(z)]dz其中②=∫qϕ(z)∇ϕ[与ϕ无关 logpθ(x,z)−logqϕ(z)]dz=−∫qϕ(z)∇ϕlogqϕ(z)dz=−∫qϕ(z)qϕ(z)1∇ϕqϕ(z)dz=−∫∇ϕqϕ(z)dz=−∇ϕ∫qϕ(z)dz=−∇ϕ1=0因此∇ϕL(ϕ)=①=∫∇ϕqϕ(z)⋅[logpθ(x,z)−logqϕ(z)]dz=∫qϕ(z)∇ϕlogqϕ(z)⋅[logpθ(x,z)−logqϕ(z)]dz=Eqϕ[(∇ϕlogqϕ(z))(logpθ(x,z)−logqϕ(z))]
这个期望可以通过蒙特卡洛采样来近似,从⽽得到梯度,然后利⽤梯度上升的⽅法来得到参数:
z l ∼ q ϕ ( z ) E q ϕ [ ( ∇ ϕ l o g q ϕ ( z ) ) ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ) ] ∼ 1 L ∑ i = 1 L ( ∇ ϕ l o g q ϕ ( z ) ) ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ) z^{l}\sim q_{\phi }(z)\\ E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\sim \frac{1}{L}\sum_{i=1}^{L}(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z)) zl∼qϕ(z)Eqϕ[(∇ϕlogqϕ(z))(logpθ(x,z)−logqϕ(z))]∼L1i=1∑L(∇ϕlogqϕ(z))(logpθ(x,z)−logqϕ(z))
但是由于求和符号中存在⼀个对数项,于是直接采样的⽅差很⼤,需要采样的样本⾮常多。为了解决⽅差太⼤的问题,我们采⽤重参数化技巧(Reparameterization)。
我们取 z = g ϕ ( ε , x ) , ε ∼ p ( ε ) z=g_{\phi }(\varepsilon ,x),\varepsilon \sim p(\varepsilon ) z=gϕ(ε,x),ε∼p(ε),对于 z ∼ q ϕ ( z ∣ x ) z\sim q_{\phi }(z|x) z∼qϕ(z∣x),我们有 ∣ q ϕ ( z ∣ x ) d z ∣ = ∣ p ( ε ) d ε ∣ \left | q_{\phi }(z|x)\mathrm{d}z \right |=\left | p(\varepsilon )\mathrm{d}\varepsilon \right | ∣qϕ(z∣x)dz∣=∣p(ε)dε∣。代入上面的梯度中:
∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] = ∇ ϕ ∫ q ϕ ( z ) [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] d z = ∇ ϕ ∫ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] q ϕ ( z ) d z = ∇ ϕ ∫ [ l o g p θ ( x , z ) − l o g q ϕ ( z ) ] p ( ε ) d ε = ∇ ϕ E p ( ε ) ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ] = E p ( ε ) [ ∇ ϕ ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ) ] = E p ( ε ) [ ∇ z ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ) ∇ ϕ z ] = E p ( ε ) [ ∇ z ( l o g p θ ( x , z ) − l o g q ϕ ( z ) ) ∇ ϕ g ϕ ( ε , x ) ] \nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]q_{\phi }(z)\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]p(\varepsilon )\mathrm{d}\varepsilon \\ =\nabla_{\phi }E_{p(\varepsilon )}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =E_{p(\varepsilon )}[\nabla_{\phi }(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }z]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }g_{\phi }(\varepsilon ,x)] ∇ϕL(ϕ)=∇ϕEqϕ[logpθ(x,z)−logqϕ(z)]=∇ϕ∫qϕ(z)[logpθ(x,z)−logqϕ(z)]dz=∇ϕ∫[logpθ(x,z)−logqϕ(z)]qϕ(z)dz=∇ϕ∫[logpθ(x,z)−logqϕ(z)]p(ε)dε=∇ϕEp(ε)(logpθ(x,z)−logqϕ(z)]=Ep(ε)[∇ϕ(logpθ(x,z)−logqϕ(z))]=Ep(ε)[∇z(logpθ(x,z)−logqϕ(z))∇ϕz]=Ep(ε)[∇z(logpθ(x,z)−logqϕ(z))∇ϕgϕ(ε,x)]
对这个式⼦进⾏蒙特卡洛采样,然后计算期望,得到梯度。
SGVI的迭代过程为:
ϕ t + 1 ← ϕ t + λ t ⋅ ∇ ϕ L ( ϕ ) \phi ^{t+1}\leftarrow \phi ^{t}+\lambda ^{t}\cdot \nabla_{\phi }L(\phi ) ϕt+1←ϕt+λt⋅∇ϕL(ϕ)