我找不到令人满意的解决方案来找到 a 的补集多维索引张量并最终实现了我自己的。它可以在cuda上运行并享受快速的并行计算。
def complement_idx(idx, dim):
"""
Compute the complement: set(range(dim)) - set(idx).
idx is a multi-dimensional tensor, find the complement for its trailing dimension,
all other dimension is considered batched.
Args:
idx: input index, shape: [N, *, K]
dim: the max index for complement
"""
a = torch.arange(dim, device=idx.device)
ndim = idx.ndim
dims = idx.shape
n_idx = dims[-1]
dims = dims[:-1] + (-1, )
for i in range(1, ndim):
a = a.unsqueeze(0)
a = a.expand(*dims)
masked = torch.scatter(a, -1, idx, 0)
compl, _ = torch.sort(masked, dim=-1, descending=False)
compl = compl.permute(-1, *tuple(range(ndim - 1)))
compl = compl[n_idx:].permute(*(tuple(range(1, ndim)) + (0,)))
return compl
Example:
>>> import torch
>>> a = torch.rand(3, 4, 5)
>>> a
tensor([[[0.7849, 0.7404, 0.4112, 0.9873, 0.2937],
[0.2113, 0.9923, 0.6895, 0.1360, 0.2952],
[0.9644, 0.9577, 0.2021, 0.6050, 0.7143],
[0.0239, 0.7297, 0.3731, 0.8403, 0.5984]],
[[0.9089, 0.0945, 0.9573, 0.9475, 0.6485],
[0.7132, 0.4858, 0.0155, 0.3899, 0.8407],
[0.2327, 0.8023, 0.6278, 0.0653, 0.2215],
[0.9597, 0.5524, 0.2327, 0.1864, 0.1028]],
[[0.2334, 0.9821, 0.4420, 0.1389, 0.2663],
[0.6905, 0.2956, 0.8669, 0.6926, 0.9757],
[0.8897, 0.4707, 0.5909, 0.6522, 0.9137],
[0.6240, 0.1081, 0.6404, 0.1050, 0.6413]]])
>>> b, c = torch.topk(a, 2, dim=-1)
>>> b
tensor([[[0.9873, 0.7849],
[0.9923, 0.6895],
[0.9644, 0.9577],
[0.8403, 0.7297]],
[[0.9573, 0.9475],
[0.8407, 0.7132],
[0.8023, 0.6278],
[0.9597, 0.5524]],
[[0.9821, 0.4420],
[0.9757, 0.8669],
[0.9137, 0.8897],
[0.6413, 0.6404]]])
>>> c
tensor([[[3, 0],
[1, 2],
[0, 1],
[3, 1]],
[[2, 3],
[4, 0],
[1, 2],
[0, 1]],
[[1, 2],
[4, 2],
[4, 0],
[4, 2]]])
>>> compl = complement_idx(c, 5)
>>> compl
tensor([[[1, 2, 4],
[0, 3, 4],
[2, 3, 4],
[0, 2, 4]],
[[0, 1, 4],
[1, 2, 3],
[0, 3, 4],
[2, 3, 4]],
[[0, 3, 4],
[0, 1, 3],
[1, 2, 3],
[0, 1, 3]]])
>>> al = torch.cat([c, compl], dim=-1)
>>> al
tensor([[[3, 0, 1, 2, 4],
[1, 2, 0, 3, 4],
[0, 1, 2, 3, 4],
[3, 1, 0, 2, 4]],
[[2, 3, 0, 1, 4],
[4, 0, 1, 2, 3],
[1, 2, 0, 3, 4],
[0, 1, 2, 3, 4]],
[[1, 2, 0, 3, 4],
[4, 2, 0, 1, 3],
[4, 0, 1, 2, 3],
[4, 2, 0, 1, 3]]])
>>> al, _ = al.sort(dim=-1)
>>> al
tensor([[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]]])