PyTorch 索引:选择索引的补充

2023-12-01

假设我有一个张量和索引:

x = torch.tensor([1,2,3,4,5])
idx = torch.tensor([0,2,4])

如果我想选择所有元素not在索引中,我可以手动定义布尔掩码像这样:

mask = torch.ones_like(x)
mask[idx] = 0

x[mask]

有没有更优雅的方法来做到这一点?

即我可以直接传递索引而不是创建掩码的语法,例如就像是:

x[~idx]

我找不到令人满意的解决方案来找到 a 的补集多维索引张量并最终实现了我自己的。它可以在cuda上运行并享受快速的并行计算。

def complement_idx(idx, dim):
    """
    Compute the complement: set(range(dim)) - set(idx).
    idx is a multi-dimensional tensor, find the complement for its trailing dimension,
    all other dimension is considered batched.
    Args:
        idx: input index, shape: [N, *, K]
        dim: the max index for complement
    """
    a = torch.arange(dim, device=idx.device)
    ndim = idx.ndim
    dims = idx.shape
    n_idx = dims[-1]
    dims = dims[:-1] + (-1, )
    for i in range(1, ndim):
        a = a.unsqueeze(0)
    a = a.expand(*dims)
    masked = torch.scatter(a, -1, idx, 0)
    compl, _ = torch.sort(masked, dim=-1, descending=False)
    compl = compl.permute(-1, *tuple(range(ndim - 1)))
    compl = compl[n_idx:].permute(*(tuple(range(1, ndim)) + (0,)))
    return compl

Example:

>>> import torch
>>> a = torch.rand(3, 4, 5)
>>> a
tensor([[[0.7849, 0.7404, 0.4112, 0.9873, 0.2937],
         [0.2113, 0.9923, 0.6895, 0.1360, 0.2952],
         [0.9644, 0.9577, 0.2021, 0.6050, 0.7143],
         [0.0239, 0.7297, 0.3731, 0.8403, 0.5984]],

        [[0.9089, 0.0945, 0.9573, 0.9475, 0.6485],
         [0.7132, 0.4858, 0.0155, 0.3899, 0.8407],
         [0.2327, 0.8023, 0.6278, 0.0653, 0.2215],
         [0.9597, 0.5524, 0.2327, 0.1864, 0.1028]],

        [[0.2334, 0.9821, 0.4420, 0.1389, 0.2663],
         [0.6905, 0.2956, 0.8669, 0.6926, 0.9757],
         [0.8897, 0.4707, 0.5909, 0.6522, 0.9137],
         [0.6240, 0.1081, 0.6404, 0.1050, 0.6413]]])
>>> b, c = torch.topk(a, 2, dim=-1)
>>> b
tensor([[[0.9873, 0.7849],
         [0.9923, 0.6895],
         [0.9644, 0.9577],
         [0.8403, 0.7297]],

        [[0.9573, 0.9475],
         [0.8407, 0.7132],
         [0.8023, 0.6278],
         [0.9597, 0.5524]],

        [[0.9821, 0.4420],
         [0.9757, 0.8669],
         [0.9137, 0.8897],
         [0.6413, 0.6404]]])
>>> c
tensor([[[3, 0],
         [1, 2],
         [0, 1],
         [3, 1]],

        [[2, 3],
         [4, 0],
         [1, 2],
         [0, 1]],

        [[1, 2],
         [4, 2],
         [4, 0],
         [4, 2]]])
>>> compl = complement_idx(c, 5)
>>> compl
tensor([[[1, 2, 4],
         [0, 3, 4],
         [2, 3, 4],
         [0, 2, 4]],

        [[0, 1, 4],
         [1, 2, 3],
         [0, 3, 4],
         [2, 3, 4]],

        [[0, 3, 4],
         [0, 1, 3],
         [1, 2, 3],
         [0, 1, 3]]])
>>> al = torch.cat([c, compl], dim=-1)
>>> al
tensor([[[3, 0, 1, 2, 4],
         [1, 2, 0, 3, 4],
         [0, 1, 2, 3, 4],
         [3, 1, 0, 2, 4]],

        [[2, 3, 0, 1, 4],
         [4, 0, 1, 2, 3],
         [1, 2, 0, 3, 4],
         [0, 1, 2, 3, 4]],

        [[1, 2, 0, 3, 4],
         [4, 2, 0, 1, 3],
         [4, 0, 1, 2, 3],
         [4, 2, 0, 1, 3]]])
>>> al, _ = al.sort(dim=-1)
>>> al
tensor([[[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]]])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch 索引:选择索引的补充 的相关文章

随机推荐

  • 查找 numpy 数组中最接近某个值的所有索引

    在 numpy 数组中 需要最接近给定常量的所有值的索引 背景是数字信号处理 该数组保存滤波器的幅度函数 np abs np fft rfft h 和某些频率 索引 被搜索 其中幅度为例如0 5 或在其他情况下为 0 大多数情况下 所讨论的
  • onActivityResult 执行两次

    From 主页活动我想得到一个结果创建配置文件活动 这是我开始活动的方法 Intent createProfile new Intent this CreatePreacherActivity class startActivityForR
  • 剥离 ASCII 模板意味着什么?

    我正在做练习考试题 问题执行该程序时 用户输入两个数字 xGuess 的值是多少 以便我们可以去掉 ASCII 模板 解释 ORIG x3000 TRAP x23 LD R2 ASCII ADD R1 R2 R0 TRAP x23 ADD
  • vbscript中下标超出范围错误

    有人可以看看下面的脚本并告诉我为什么它在 vbscript 中抛出此错误下标超出范围错误 在文本文件中 它有两个条目正确写入文件 但随后在退出循环时抛出错误 因此它从不调用其他函数 我认为它尝试运行 3 次 但文本文件中只有 2 个条目 T
  • Datagridview,仅显示唯一值是重复的单元格值 C# 2005

    我在显示值时遇到一些问题 但每次它重复 datagridview 中的值时 我都使用 Microsoft Visual C 2005 和框架 2 0 当我编程时 我发现在循环内我需要检查重复的值并对它们进行计数 如果出现新值则显示该值并发送
  • 配置管理器类。锁定时无法编辑 ConfigurationSection 属性

    这是代码 These is works Console WriteLine Properties Settings Default name Configuration configFile ConfigurationManager Ope
  • 从字符数组中删除所有空元素

    我有一个字符数组 一次最多可以容纳 50000 个字符 该数组的内容是通过套接字连接来的 但是 不能保证该字符缓冲区不会有任何空元素 然后我需要将此字符数组转换为字符串 例如 new String buffer 我的问题是 每当我从套接字收
  • 如何按字段之一对结构链接列表进行排序?

    哇现在我知道我不知道 哈哈 我的结构是这样的 struct Medico int Id Doctor int Estado char Nombre 60 focus on this part of the structure this is
  • LINQ 表达式树最多可以做什么?

    LINQ 表达式树的最大功能是多少 它可以定义一个类吗 一个具有所有声明名称 修饰符 参数类型和返回类型的方法怎么样 程序必须始终定义树本身吗 是否可以从给定的 C 文件生成树 在C 3中 表达式树可以表示表达式 由此得名 而且它们还被进一
  • 在 3.0 之前的设备上参考 android.R.id.home

    有没有一种简单的方法可以在运行低于 3 0 的设备上获取主页按钮的引用 我可以在 3 0 及更高版本上执行 findViewById android R id home 操作 但在旧设备上无法正常工作 我不需要监听点击 我只需要它的位置来定
  • 使用mongoose节点插入大数据

    我正在尝试使用 mongoose 将大型数据集插入到 mongodb 中 但在此之前我需要确保我的 for 循环正常工作 basic schema settings var mongoose require mongoose var Sch
  • 如何在 C# MVC 中使用 Azure AD 实现注销后

    我正在开发一个使用 Azure AD 进行用户身份验证的 Web 应用程序 但是 在用户成功注销后 我很难将用户重定向到主页 我试过遵循这个文档 但这不是我正在寻找的解决方案 In SignOut 在 HomeController cs 中
  • 字符串生成器-删除重复值

    我编写了这段小代码 从 lstmodel2 组件中获取值 StringBuilder sb new StringBuilder for int i 0 i lt lstmodel2 getSize i String exsplt lstmo
  • 如何访问列表框中的复选框?

    我有一个列表框 并且设置了项目模板 如下所示 XAML
  • OpenSSL 代码可以在 XP 上运行,但在 Vista 及更高版本中永远挂起

    这段代码开始一个最小的 SSL 服务器 WSAStartup MakeWord 1 1 WData SSL library init SSL load error strings ctx SSL CTX new SSLv23 server
  • MySQL加载数据:准备好的语句协议尚不支持该命令

    我正在尝试编写一个 MySQL 脚本来将数据导入到我的 Linux 服务器的表中 这是名为update sql SET query CONCAT LOAD DATA LOCAL INFILE spaceName INTO TABLE tmp
  • 从 SD 卡删除音频

    我尝试删除 SD 卡中的音频文件 但没有成功 public boolean delete String path return new File path delete 在浏览帖子时我遇到了Storage Access Framework但
  • 嵌套 ng-repeat AngularJS

    我有一个像这样的值 我的问题是我不知道该虚拟数组对象嵌套了多少 所以如何使用 ng repeat 打印所有虚拟数组对象 demo id 1 dummy id 1 dummy id 1 dummy id 1
  • 为什么这不是 C++ 中的内存泄漏?

    几个月前我问过this我问为什么会出现内存泄漏的问题 显然 我忘记了虚拟析构函数 现在我很难理解为什么这不是内存泄漏 include
  • PyTorch 索引:选择索引的补充

    假设我有一个张量和索引 x torch tensor 1 2 3 4 5 idx torch tensor 0 2 4 如果我想选择所有元素not在索引中 我可以手动定义布尔掩码像这样 mask torch ones like x mask