如何从 TensorFlow 中的 3-D 张量中选择行?

2024-01-11

我有一个张量logits与尺寸[batch_size, num_rows, num_coordinates](即批次中的每个 logit 都是一个矩阵)。在我的例子中,批量大小为 2,有 4 行和 4 个坐标。

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0],
                      [12.0, 10.0, 10.0, 20.0],
                      [13.0, 10.0, 10.0, 20.0]],
                     [[14.0, 11.0, 21.0, 31.0],
                      [15.0, 11.0, 11.0, 21.0],
                      [16.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])

我想选择第一批的第一行和第二行以及第二批的第二行和第四行。

indices = tf.constant([[0, 1], [1, 3]])

所以期望的输出是

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0]],
                     [[15.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])

如何使用 TensorFlow 做到这一点?我尝试使用tf.gather(logits, indices)但它没有返回我所期望的。谢谢!


这在 TensorFlow 中是可能的,但稍微不方便,因为tf.gather() https://www.tensorflow.org/versions/r0.7/api_docs/python/tf/gather目前仅适用于一维索引,并且仅从张量的第 0 维选择切片。但是,仍然可以通过转换参数以便将它们传递给tf.gather():

logits = ... # [2 x 4 x 4] tensor
indices = tf.constant([[0, 1], [1, 3]])

# Use tf.shape() to make this work with dynamic shapes.
batch_size = tf.shape(logits)[0]
rows_per_batch = tf.shape(logits)[1]
indices_per_batch = tf.shape(indices)[1]

# Offset to add to each row in indices. We use `tf.expand_dims()` to make 
# this broadcast appropriately.
offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1)

# Convert indices and logits into appropriate form for `tf.gather()`. 
flattened_indices = tf.reshape(indices + offset, [-1])
flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]]))

selected_rows = tf.gather(flattened_logits, flattened_indices)

result = tf.reshape(selected_rows,
                    tf.concat(0, [tf.pack([batch_size, indices_per_batch]),
                                  tf.shape(logits)[2:]]))

请注意,由于这使用tf.reshape() https://www.tensorflow.org/versions/r0.7/api_docs/python/tf/reshape并不是tf.transpose() https://www.tensorflow.org/versions/r0.7/api_docs/python/tf/transpose,它不需要修改中的(可能很大的)数据logits张量,所以它应该相当有效。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何从 TensorFlow 中的 3-D 张量中选择行? 的相关文章

随机推荐