梯度消失和梯度爆炸的定义
- 梯度消失:又叫梯度弥散。是指模型梯度在反向传播时,梯度值接近零,导致模型权重不能正常更新,使模型无法正常收敛的现象;
- 梯度爆炸:是指模型梯度在反向传播使,梯度值无限扩大,导致模型权重趋于无穷,使模型无法正常收敛的现象。常常伴随着loss为nan的现象。
数学公式解释
梯度消失和梯度爆炸都可以用
y
=
(
x
)
n
y={\left( {\rm{x}} \right)^{\rm{n}}}
y=(x)n 来解释,其中n表示模型层数,当n很大且x大于1时,y会趋于无穷大;而当x小于1时,随着n增大,y会趋于无穷小。从此处我们也可以看出,无论是梯度爆炸还是梯度消失,都是发生在远离输出的底层网络。
那现在的问题是,x在梯度反向传播时表示什么?什么情况会导致x>0; 或者x < 0 呢
链路法则下的x
链路法则下的求导可以分为两步,第一步是损失函数对logit的导数;第二步是当前层输出(logit是最后一层的输出)对前一层输出的导数;
损失函数对logit的导数
我们以交叉熵损失函数为例:设分类模型共
m
m
m个类别,其中计算梯度的样本标签为
k
k
k,损失函数为
L
o
s
s
Loss
Loss,最后一层模型输出logit为
x
x
x, 则损失函数可以表示为:
y
k
=
e
x
k
∑
i
=
1
m
e
x
i
{y_k} = \frac{{{e^{{x_k}}}}}{{\sum\limits_{i = 1}^m {{e^{{x_i}}}} }}
yk=i=1∑mexiexk
L
o
s
s
=
−
log
(
y
k
)
Loss = - \log ({y_k})
Loss=−log(yk)
对其求导可得:
∂
L
o
s
s
∂
x
=
[
y
1
,
y
2
,
.
.
.
y
k
−
1
,
.
.
.
y
m
]
\frac{{\partial Loss}}{{\partial x}} = [{y_1},{y_2},...{y_{k}-1},...{y_m}]
∂x∂Loss=[y1,y2,...yk−1,...ym]
当前层输出和前一层输出的导数
我们假设,每一层由一个激活函数
f
f
f和一个全连接层
W
x
Wx
Wx构成, 则
x
n
=
f
(
W
x
n
−
1
)
{{\text{x}}_n} = f(W{x_{n - 1}})
xn=f(Wxn−1), 对其求导:
∂
x
n
∂
x
n
−
1
=
f
−
1
∗
W
\frac{{\partial {x_n}}}{{\partial {x_{n - 1}}}} = {f^{ - 1}}*W
∂xn−1∂xn=f−1∗W,则通过递归法则可知,
∂
x
n
∂
x
n
−
k
=
(
f
−
1
)
k
∗
W
k
\frac{{\partial {x_n}}}{{\partial {x_{n - k}}}} = ({f^{ - 1}})^k*W^k
∂xn−k∂xn=(f−1)k∗Wk, 可知,此处的
∂
x
n
∂
x
n
−
1
\frac{{\partial {x_n}}}{{\partial {x_{n - 1}}}}
∂xn−1∂xn就近似等于前面提到的链路法则下的
x
x
x,当激活函数的导数
f
−
1
f^{ - 1}
f−1小于1时,
∂
x
n
∂
x
n
−
k
\frac{{\partial {x_n}}}{{\partial {x_{n - k}}}}
∂xn−k∂xn倾向于趋近零(对应梯度消失),而当
W
W
W大于1时,
∂
x
n
∂
x
n
−
k
\frac{{\partial {x_n}}}{{\partial {x_{n - k}}}}
∂xn−k∂xn倾向于趋近无穷大(对应梯度爆炸)
如何解决梯度消失和梯度爆炸
解决梯度消失
- 修改激活函数,使
f
−
1
f^{-1}
f−1不小于1,比如,采用relu替换sigmoid激活函数
- 残差连接,将
x
n
=
f
(
W
x
n
−
1
)
{{\text{x}}_n} = f(W{x_{n - 1}})
xn=f(Wxn−1)修改为
x
n
=
f
(
W
x
n
−
1
+
x
n
−
1
)
{{\text{x}}_n} = f(W{x_{n - 1}}+x_{n - 1})
xn=f(Wxn−1+xn−1)
- 层连接输入前进行标准化
解决梯度爆炸
- 梯度裁剪。使每一层的梯度都小于1
- 权重正则化或者权重衰减。权重每一步迭代更新时,都乘以一个小于1的因子
- 层连接输入前进行标准化
- 预训练+模型微调(微调时,学习率一般设置较低)