在 TensorFlow 图中使用 if 条件

2023-12-07

在张量流 CIFAR-10 中tutorial in cifar10_inputs.py第 174 行据说您应该随机化 random_contrast 和 random_brightness 操作的顺序,以获得更好的数据增强。

为此,我首先想到的是从 0 到 1 之间的均匀分布中抽取一个随机变量:p_order。并做:

if p_order>0.5:
  distorted_image=tf.image.random_contrast(image)
  distorted_image=tf.image.random_brightness(distorted_image)
else:
  distorted_image=tf.image.random_brightness(image)
  distorted_image=tf.image.random_contrast(distorted_image)

然而,有两种可能的选项来获取 p_order:

1)使用numpy,这让我不满意,因为我想要纯TF,而TF不鼓励用户混合numpy和tensorflow

2) 使用 TF,但是 p_order 只能在 tf.Session() 中计算 我真的不知道我是否应该这样做:

with tf.Session() as sess2:
  p_order_tensor=tf.random_uniform([1,],0.,1.)
  p_order=float(p_order_tensor.eval())

所有这些操作都在函数体内,并从具有不同会话/图形的另一个脚本运行。或者我可以将其他脚本中的图表作为参数传递给该函数,但我很困惑。 即使像这样的张量流函数或推理似乎以全局方式定义图形,而没有明确地将其作为输出返回,这一事实对我来说有点难以理解。


您可以使用tf.cond(pred, fn1, fn2, name=None) (see doc)。 该函数允许您使用布尔值pred在 TensorFlow 图中(无需调用self.eval() or sess.run(), hence 不需要会话).

以下是如何使用它的示例:

def fn1():
    distorted_image=tf.image.random_contrast(image)
    distorted_image=tf.image.random_brightness(distorted_image)
    return distorted_image
def fn2():
    distorted_image=tf.image.random_brightness(image)
    distorted_image=tf.image.random_contrast(distorted_image)
    return distorted_image

# Uniform variable in [0,1)
p_order = tf.random_uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
pred = tf.less(p_order, 0.5)

distorted_image = tf.cond(pred, fn1, fn2)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 TensorFlow 图中使用 if 条件 的相关文章

随机推荐