我有一个张量logits
与尺寸[batch_size, num_rows, num_coordinates]
(即批次中的每个 logit 都是一个矩阵)。在我的例子中,批量大小为 2,有 4 行和 4 个坐标。
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0],
[12.0, 10.0, 10.0, 20.0],
[13.0, 10.0, 10.0, 20.0]],
[[14.0, 11.0, 21.0, 31.0],
[15.0, 11.0, 11.0, 21.0],
[16.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
我想选择第一批的第一行和第二行以及第二批的第二行和第四行。
indices = tf.constant([[0, 1], [1, 3]])
所以期望的输出是
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0]],
[[15.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
如何使用 TensorFlow 做到这一点?我尝试使用tf.gather(logits, indices)
但它没有返回我所期望的。谢谢!