它妙就妙在它为每个输入
x
x
x, 生成了一个潜在概率分布
p
(
z
∣
x
)
p(z|x)
p(z∣x),然后再从分布中进行随机采样,从而得到了连续完整的潜在空间,解决了AE中无法用于生成的问题。
VAE除了能让我们能够自己产生随机的潜在变量,这种约束也能提高网络的产生图片的能力。
但是,VAE的一个劣势就是没有使用对抗网络,所以VAE会更趋向于产生模糊的图片。
4.1 变分推断
变分自编码器(VAE)的想法和名字的由来便是变分推断了,那么什么是变分推断呢?
变分推断是MCMC搞不定场景的一种替代算法,它考虑一个贝叶斯推断问题,给定观测变量
x
∈
R
k
x \in \mathbb{R}^k
x∈Rk 和潜变量
z
∈
R
d
z \in \mathbb{R}^d
z∈Rd,其联合概率分布为
p
(
z
,
x
)
=
p
(
z
)
p
(
x
∣
z
)
p(z, x) = p(z)p(x|z)
p(z,x)=p(z)p(x∣z) , 目标是计算后验分布
p
(
z
∣
x
)
p(z|x)
p(z∣x)。然后我们可以假设一个变分分布
q
(
z
)
q(z)
q(z) 来自分布族
Q
Q
Q,通过最小化KL散度来近似后验分布
p
(
z
∣
x
)
p(z|x)
p(z∣x) :
q
∗
=
arg min
q
(
z
)
∈
Q
K
L
(
q
(
z
)
∣
∣
p
(
z
∣
x
)
)
\begin{aligned}q^* = \argmin_{q(z) \in Q} KL(q(z)||p(z|x))\end{aligned}
q∗=q(z)∈QargminKL(q(z)∣∣p(z∣x))
这么一来,就成功的将一个贝叶斯推断问题转化为了一个优化问题~
4.2 变分推导过程
有了变分推断的认知,我们再回过头去看一下VAE模型的整体框架,VAE就是将AE的编码和解码过程转化为了一个贝叶斯概率模型:我们的训练数据即为观测变量
x
x
x, 假设它由不能直接观测到的潜变量
z
z
z 生成, 于是,生成观测变量过程便是似然分布:
p
(
x
∣
z
)
p(x|z)
p(x∣z) ,也就是解码器,因而编码器自然就是后验分布:
p
(
z
∣
x
)
p(z|x)
p(z∣x) .
根据贝叶斯公式,建立先验、后验和似然的关系:
p
(
z
∣
x
)
=
p
(
x
∣
z
)
p
(
z
)
p
(
x
)
=
∫
z
p
(
x
∣
z
)
p
(
z
)
p
(
x
)
d
z
p(z|x) = \frac{p(x|z)p(z)}{p(x)} = \int_z \frac{p(x|z)p(z)}{p(x)}dz
p(z∣x)=p(x)p(x∣z)p(z)=∫zp(x)p(x∣z)p(z)dz
接下来,基于上面变分推断的思想,我们假设变分分布
q
x
(
z
)
q_x(z)
qx(z) , 通过最小化KL散度来近似后验分布
p
(
z
∣
x
)
p(z|x)
p(z∣x) ,于是,最佳的
q
x
∗
q_x^*
qx∗便是:
q
x
∗
=
a
r
g
m
i
n
(
K
L
(
q
x
(
z
)
∣
∣
p
(
z
∣
x
)
)
=
a
r
g
m
i
n
(
E
q
x
(
z
)
[
l
o
g
q
x
(
z
)
−
l
o
g
p
(
x
∣
z
)
−
l
o
g
p
(
z
)
]
+
l
o
g
p
(
x
)
)
\begin{aligned} q_x^* &= argmin (KL(q_x(z)||p(z|x)) \\ &= argmin (E_{q_x(z)}[log~q_x(z)- log~p(x|z) -log~p(z)]+log~p(x)) \\ \end{aligned}
qx∗=argmin(KL(qx(z)∣∣p(z∣x))=argmin(Eqx(z)[logqx(z)−logp(x∣z)−logp(z)]+logp(x))
因为训练数据
x
x
x 是确定的,因此
l
o
g
p
(
x
)
log~p(x)
logp(x) 是一个常数,于是上面的优化问题等价于:
q
x
∗
=
a
r
g
m
i
n
(
E
q
x
(
z
)
[
l
o
g
q
x
(
z
)
−
l
o
g
p
(
x
∣
z
)
−
l
o
g
p
(
z
)
]
=
a
r
g
m
i
n
(
E
q
x
(
z
)
[
−
l
o
g
p
(
x
∣
z
)
+
(
l
o
g
p
(
z
)
−
l
o
g
q
x
(
z
)
)
]
)
=
a
r
g
m
i
n
(
E
q
x
(
z
)
[
−
l
o
g
p
(
x
∣
z
)
+
K
L
(
q
x
(
z
)
∣
∣
p
(
z
)
)
]
)
\begin{aligned} q_x^* &= argmin (E_{q_x(z)}[log~q_x(z)- log~p(x|z) -log~p(z)] \\ &= argmin( E_{q_x(z)}[-log~p(x|z) + (log~p(z) -log~q_x(z))]) \\ &= argmin (E_{q_x(z)}[-log~p(x|z) + KL(q_x(z)||p(z))]) \\ \end{aligned}
qx∗=argmin(Eqx(z)[logqx(z)−logp(x∣z)−logp(z)]=argmin(Eqx(z)[−logp(x∣z)+(logp(z)−logqx(z))])=argmin(Eqx(z)[−logp(x∣z)+KL(qx(z)∣∣p(z))])
此时,观察一下优化方程的形式…已经是我们前面所说的VAE的损失函数了~~
显然,跟我们希望解码准确的目标是一致的。要解码的准,则
p
(
x
∣
z
)
p(x|z)
p(x∣z) 应该尽可能的小,编码特征
z
z
z 的分布
q
x
(
z
)
q_x(z)
qx(z) 同
p
(
z
)
p(z)
p(z) 尽可能的接近,此时恰好
−
l
o
g
p
(
x
∣
z
)
-log~p(x|z)
−logp(x∣z) 和
K
L
(
q
x
(
z
)
∣
∣
p
(
z
)
)
KL(q_x(z)||p(z))
KL(qx(z)∣∣p(z)) 都尽可能的小,与损失的优化的目标也一致。
4.3 如何计算极值
正如前面所提到的AE潜变量的局限性,我们希望VAE的潜变量分布
p
(
z
)
p(z)
p(z) 应该能满足海量的输入数据
x
x
x 并且相互独立,基于中心极限定理,以及为了方便采样,我们有理由直接假设
p
(
z
)
p(z)
p(z) 是一个标准的高斯分布
N
(
0
,
1
)
\mathcal{N}(0,1)
N(0,1) .
4.4 编码部分
我们先来看一下编码部分,我们希望拟合一个分布
q
x
(
z
)
=
N
(
μ
,
σ
)
q_x(z)=\mathcal{N}(\mu,\sigma)
qx(z)=N(μ,σ) 尽可能接近
p
(
z
)
=
N
(
0
,
1
)
p(z) =\mathcal{N}(0,1)
p(z)=N(0,1), 关键就在于基于输入
x
x
x 计算
μ
\mu
μ 和
σ
\sigma
σ, 直接算有点困难,于是就使用两个神经网络
f
(
x
)
f(x)
f(x) 和
g
(
x
)
g(x)
g(x) 来无脑拟合
μ
\mu
μ 和
σ
\sigma
σ。
值得一提的是,很多地方实际使用的
f
(
x
)
f(x)
f(x)、
g
(
x
)
g(x)
g(x) 两部分神经网络并不是独立的,而是有一部分交集,即他们都先通过一个
h
(
x
)
h(x)
h(x) 映射到一个中间层
h
h
h, 然后分别对
h
h
h 计算
f
(
h
)
f(h)
f(h) 和
g
(
h
)
g(h)
g(h). 这样错的好处的话一方面是可以减少参数数量,另外这样算应该会导致拟合的效果差一些,算是防止过拟合吧。
4.5 解码部分
解码,即从潜变量
z
z
z 生成数据
x
x
x 的过程,在于最大化似然
p
(
x
∣
z
)
p(x|z)
p(x∣z) ,那这应该是个什么分布呢?通常我们假设它是一个伯努利分布或是高斯分布。
知道了分布类型,那计算
−
l
o
g
p
(
x
∣
z
)
-log~p(x|z)
−logp(x∣z) 最小值其实只要把分布公式带进去算就可以了.
4.5.1 高斯分布
a
r
g
m
i
n
(
−
log
q
(
x
∣
z
)
)
=
a
r
g
m
i
n
1
2
∥
x
−
μ
~
(
z
)
σ
~
(
z
)
∥
2
+
c
2
log
2
π
+
1
2
=
a
r
g
m
i
n
1
2
∥
x
−
μ
~
(
z
)
σ
~
(
z
)
∥
2
\begin{aligned} arg min( -\log~q(x|z)) &=argmin \frac{1}{2}\left\Vert\frac{x-\tilde{\mu}(z)}{\tilde{\sigma}(z)}\right\Vert^2 + \frac{c}{2}\log~2\pi + \frac{1}{2} \\ &= argmin \frac{1}{2}\left\Vert\frac{x-\tilde{\mu}(z)}{\tilde{\sigma}(z)}\right\Vert^2 \end{aligned}
argmin(−logq(x∣z))=argmin21∥∥σ~(z)x−μ~(z)∥∥2+2clog2π+21=argmin21∥∥σ~(z)x−μ~(z)∥∥2
和预期一样,演变为了均方误差。
4.5.2 伯努利分布
假设伯努利的二元分布是
P
P
P 和
1
−
P
1-P
1−P (注意这里是输出没一维组成的向量)
a
r
g
m
i
n
(
−
log
q
(
x
∣
z
)
)
=
a
r
g
m
i
n
(
−
x
log
P
−
(
1
−
x
)
log
(
1
−
P
)
)
argmin ( -\log q(x|z)) = argmin( - x \log P - (1-x) \log (1 -P))
argmin(−logq(x∣z))=argmin(−xlogP−(1−x)log(1−P))
正好就是交叉熵的损失。然后,将编码和解码部分组合到一起,就形成了完整的VAE网络。
4.6 Reparameterization trick
训练的时候似乎出了点问题。从编码得到的分布
N
(
μ
,
σ
)
\mathcal{N}(\mu,\sigma)
N(μ,σ) 随机采样
z
z
z 的这个过程没法求导,没法进行误差反向传播。
Reparameterization trick :解决VAE 的 Encoder 产生的分布函数无梯度可优化问题
z
=
μ
+
ε
⋅
σ
=
f
(
x
)
+
ε
⋅
g
(
x
)
,
ε
∼
N
(
0
,
1
)
z = μ+ε·σ=f(x) +ε·g(x) , ε\sim N(0,1)
z=μ+ε⋅σ=f(x)+ε⋅g(x),ε∼N(0,1)