张量流在梯度计算过程中如何处理不可微节点?

2024-04-14

我理解自动微分的概念,但找不到任何解释张量流如何计算不可微函数的误差梯度,例如tf.where在我的损失函数中或tf.cond在我的图表中。它工作得很好,但我想了解张量流如何通过这些节点反向传播误差,因为没有公式可以计算它们的梯度。


如果是tf.where,你有一个具有三个输入的函数,条件C, 真实值T和 false 的值F,和一个输出Out。梯度接收一个值并且必须返回三个值。目前,没有为该条件计算梯度(这几乎没有意义),因此您只需要计算梯度T and F。假设输入和输出是向量,想象一下C[0] is True. Then Out[0]来自T[0],并且它的梯度应该传播回来。另一方面,F[0]会被丢弃,所以它的梯度应该为零。如果Out[1] were False,那么梯度为F[1]应该传播但不适合T[1]。所以,简而言之,对于T你应该传播给定的梯度,其中C is True并使其为零False,则相反F。如果你看梯度的实现tf.where (Select手术) https://github.com/tensorflow/tensorflow/blob/v1.12.0/tensorflow/python/ops/math_grad.py#L1112-L1118,它正是这样做的:

@ops.RegisterGradient("Select")
def _SelectGrad(op, grad):
  c = op.inputs[0]
  x = op.inputs[1]
  zeros = array_ops.zeros_like(x)
  return (None, array_ops.where(c, grad, zeros), array_ops.where(
      c, zeros, grad))

请注意,输入值本身并不用于计算,这将通过生成这些输入的操作的梯度来完成。为了tf.cond, 代码有点复杂 https://github.com/tensorflow/tensorflow/blob/v1.12.0/tensorflow/python/ops/control_flow_grad.py#L95-L138,因为相同的操作(Merge)在不同的上下文中使用,并且tf.cond还使用Switch里面的操作。然而,想法是一样的。本质上,Switch操作用于每个输入,因此被激活的输入(如果条件是第一个输入)True否则第二个)获得接收到的梯度,另一个输入获得“关闭”梯度(例如None),并且不会进一步传播回来。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

张量流在梯度计算过程中如何处理不可微节点? 的相关文章

随机推荐