这里只给出变分推断的数学推导(变分颇为高深,这里只是简单介绍一下基本概念,想了解更多详见:https://blog.csdn.net/weixin_40255337/article/details/83088786):
变分推断的目的是构造
q
(
w
∣
θ
)
q(w| \theta)
q(w∣θ) ,通过优化得到最优的 θ*,从而使得 q(w| θ) 逼近未知的后验分布 P(w |X)。
由贝叶斯公式可知:
P
(
X
)
=
P
(
X
,
w
)
P
(
w
∣
X
)
P(X) = \frac{P(X,w)}{P(w | X)}
P(X)=P(w∣X)P(X,w)
等式两边取对数:
log
P
(
X
)
=
log
P
(
X
,
w
)
−
log
P
(
w
∣
X
)
\log P(X) = \log P(X, w) - \log P(w |X)
logP(X)=logP(X,w)−logP(w∣X)
等式右侧
+
log
q
(
w
∣
θ
)
+\log q(w| θ)
+logq(w∣θ) 再
−
log
q
(
w
∣
θ
)
- \log q(w | θ)
−logq(w∣θ):
log
P
(
X
)
=
log
P
(
X
,
w
)
q
(
w
∣
θ
)
−
log
P
(
w
∣
X
)
q
(
w
∣
θ
)
\log P(X) = \log \frac{P(X, w) }{q(w| θ)} -\log \frac{P(w | X)}{q(w | θ)}
logP(X)=logq(w∣θ)P(X,w)−logq(w∣θ)P(w∣X)
等式两侧对
w
w
w(服从分布
q
(
w
∣
θ
)
q(w| \theta)
q(w∣θ))取期望,由于等式左侧与
q
(
w
∣
θ
)
q(w| \theta)
q(w∣θ) 无关,因此有:
log
P
(
X
)
=
E
[
log
P
(
X
∣
w
)
+
log
P
(
w
)
−
log
q
(
w
∣
θ
)
]
+
E
[
log
q
(
w
∣
θ
)
P
(
w
∣
X
)
]
\log P(X) = \mathbb{E} \big[ \log P(X | w) + \log P(w) - \log q(w| \theta) \big] + \mathbb{E} \Bigg[ \log \frac{q(w| θ)}{P(w| X)} \Bigg]
logP(X)=E[logP(X∣w)+logP(w)−logq(w∣θ)]+E[logP(w∣X)q(w∣θ)]
等式左侧为定值,右侧第一项为定义为ELBO,第二项为
K
L
{
q
(
w
∣
θ
)
∣
∣
P
(
w
∣
X
)
}
KL\{q(w| θ) || P(w| X)\}
KL{q(w∣θ)∣∣P(w∣X)},即:
E
L
B
O
+
K
L
{
q
(
w
∣
θ
)
∣
∣
P
(
w
∣
X
)
}
=
C
o
n
s
t
a
n
t
ELBO+ KL\{q(w| θ) || P(w| X)\} = Constant
ELBO+KL{q(w∣θ)∣∣P(w∣X)}=Constant
因此
arg
min
θ
K
L
{
q
(
w
∣
θ
)
∣
∣
P
(
w
∣
X
)
}
=
arg
max
E
L
B
O
\arg \min_\theta KL\{q(w| θ) || P(w| X)\} = \arg\max ELBO
argminθKL{q(w∣θ)∣∣P(w∣X)}=argmaxELBO。
论文原文中的ELBO还有另一种形式:
E
B
L
O
=
E
[
log
P
(
X
∣
w
)
+
log
P
(
w
)
−
log
q
(
w
∣
θ
)
]
=
E
[
log
P
(
X
∣
w
)
]
−
E
[
log
q
(
w
∣
θ
)
P
(
w
)
]
=
∫
q
(
w
∣
θ
)
log
P
(
X
∣
w
)
d
w
−
K
L
{
q
(
w
∣
θ
)
∣
∣
P
(
w
)
}
\begin{aligned} EBLO &= \mathbb{E} \big[ \log P(X | w) + \log P(w) - \log q(w| \theta) \big] \\ &= \mathbb{E} \big[ \log P(X | w) \big] - \mathbb{E} \Bigg[ \log \frac{q(w| \theta)}{P(w)} \Bigg] \\ &= \int q(w | \theta) \log P(X | w) dw - KL\{ q(w| \theta) || P(w) \} \end{aligned}
EBLO=E[logP(X∣w)+logP(w)−logq(w∣θ)]=E[logP(X∣w)]−E[logP(w)q(w∣θ)]=∫q(w∣θ)logP(X∣w)dw−KL{q(w∣θ)∣∣P(w)}
而对于一个深度学习问题,给定数据集D,神经网络的参数为
w
w
w,输出为
P
(
D
∣
w
)
P(D | w)
P(D∣w)。