对“tf.cond”的行为感到困惑

2024-04-16

我的图中需要一个条件控制流。如果pred is True,图表应该调用一个更新变量然后返回它的操作,否则它会返回不变的变量。一个简化版本是:

pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
  with tf.control_dependencies([assign_x_2]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

然而我发现两者pred=True and pred=False导致相同的结果y=[2],这意味着当以下情况时也会调用分配操作update_x_2未被选择tf.cond。这怎么解释呢?以及如何解决这个问题呢?


TL;DR:如果你想tf.cond() https://www.tensorflow.org/versions/r0.8/api_docs/python/control_flow_ops.html#cond要在其中一个分支中执行副作用(如赋值),您必须创建执行副作用的操作inside您传递给的函数tf.cond().

的行为tf.cond()有点不直观。由于 TensorFlow 图中的执行在图中向前流动,因此您在either分支必须在评估条件之前执行。这意味着 true 和 false 分支都接收对tf.assign()欧普,等等y总是被设置为2,即使 pred 是False.

解决方案是创建tf.assign()op 位于定义 true 分支的函数内。例如,您可以按如下方式构建代码:

pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
  with tf.control_dependencies([tf.assign(x, [2])]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval(feed_dict={pred: False}))  # ==> [1]
  print(y.eval(feed_dict={pred: True}))   # ==> [2]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

对“tf.cond”的行为感到困惑 的相关文章

随机推荐