这可以通过使用来实现tf.while_loop()
和标准tuples https://docs.python.org/3/tutorial/datastructures.html#tuples-and-sequences按照第二个例子文档 https://www.tensorflow.org/api_docs/python/tf/while_loop.
def rosenbrock(data_tensor):
columns = tf.unstack(data_tensor)
# Track both the loop index and summation in a tuple in the form (index, summation)
index_summation = (tf.constant(1), tf.constant(0.0))
# The loop condition, note the loop condition is 'i < n-1'
def condition(index, summation):
return tf.less(index, tf.subtract(tf.shape(columns)[0], 1))
# The loop body, this will return a result tuple in the same form (index, summation)
def body(index, summation):
x_i = tf.gather(columns, index)
x_ip1 = tf.gather(columns, tf.add(index, 1))
first_term = tf.square(tf.subtract(x_ip1, tf.square(x_i)))
second_term = tf.square(tf.subtract(x_i, 1.0))
summand = tf.add(tf.multiply(100.0, first_term), second_term)
return tf.add(index, 1), tf.add(summation, summand)
# We do not care about the index value here, return only the summation
return tf.while_loop(condition, body, index_summation)[1]
值得注意的是,索引增量应该发生在类似于标准 while 循环的循环体中。在给出的解决方案中,它是由返回的元组中的第一项body()
功能。
此外,循环条件函数必须为求和分配一个参数,尽管在此特定示例中未使用它。