PyTorch学习系统之 scatter() 函数详解 one hot 编码

2023-05-16

torch.Tensor.scatter_

scatter()scatter_() 的作用是一样的,只不过 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会

torch.Tensor.scatter_()torch.gather()函数的方向反向操作。两个函数可以看成一对兄弟函数。gather用来解码one hot,scatter_用来编码one hot。

PyTorch 中,一般函数加下划线代表直接在原来的 Tensor 上修改

scatter_(dimindexsrc) → Tensor

参数:

  • dim:沿着哪个维度进行索引
  • index:用来 scatter 的元素索引
  • src:用来 scatter 的源元素,可以是一个标量或一个张量

这个 scatter  可以理解成放置元素或者修改元素

简单说就是通过一个张量 src  来修改另一个张量,哪个元素需要修改、用 src 中的哪个元素来修改由 dim 和 index 决定

官方文档给出了 3维张量 的具体操作说明,如下所示

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
x = torch.rand(2, 5)

#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
#        [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])

torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

#tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
#        [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
#        [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

具体地说,我们的 index 是 torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]),一个二维张量,下面用图简单说明

我们是 2维 张量,一开始进行 self[index[0][0]][0]self[index[0][0]][0],其中 index[0][0]index[0][0] 的值是0,所以执行 self[0][0]=x[0][0]=0.1940self[0][0]=x[0][0]=0.1940 

self[index[i][j]][j]=src[i][j]

 

再比如self[index[1][0]][0]self[index[1][0]][0],其中 index[1][0]index[1][0] 的值是2,所以执行 self[2][0]=x[1][0]=0.2078self[2][0]=x[1][0]=0.2078 

计算过程:index[0,0]=0→self[0,0]→src[0,0] =0.1940

index[0,1]=1→self[1,1]→src[0,1] =0.3340

index[0,2]=2→self[2,2]→src[0,2] =0.8184

 

example:

torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7)

#tensor([[7., 7., 7., 7., 7.],
#        [0., 7., 0., 7., 0.],
#        [7., 0., 7., 0., 7.]]

计算过程:index[0,0]=0→self[0,0]→src[0,0] =7

index[0,1]=1→self[1,1]→src[0,1] =7

index[0,2]=2→self[2,2]→src[0,2] =7

scatter() 一般可以用来对标签进行 one-hot 编码,这就是一个典型的用标量来修改张量的一个例子

用于产生one hot编码的向量

example:

class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size, 1).random_() % class_num
#tensor([[6],
#        [0],
#        [3],
#        [2]])
torch.zeros(batch_size, class_num).scatter_(1, label, 1)
#tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
#        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
#        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
#        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])
indices = torch.tensor(list(range(5))).view(5,1)  
indices
result = torch.zeros(5, 5)
result.scatter_(1, indices, 1)        

 

当没有src值时,则所有用于填充的值均为value值。

需要注意的时候,这个时候index.shape[dim]必须与result.shape[dim]相等,否则会报错。

result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0], 
                        [2, 0, 3, 1, 2],
                        [2, 1, 3, 1, 4]])
result.scatter_(1, indices, value=1)        

输出为

tensor([[1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1.]])

参考资料

https://pytorch.org/docs/stable/tensors.html?highlight=scatter_#torch.Tensor.scatter_

https://www.cnblogs.com/dogecheng/p/11938009.html

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch学习系统之 scatter() 函数详解 one hot 编码 的相关文章

随机推荐

  • MATLAB入门学习系列之基本绘图函数

    目录 创建绘图 在一幅图形中绘制多个数据集 指定线型和颜色 绘制线条和标记 将绘图添加到现有图形中 图窗窗口 在一幅图窗中显示多个绘图 控制轴 保存图窗 保存工作区数据 创建绘图 plot 函数具有不同的形式 xff0c 具体取决于输入参数
  • MATLAB入门学习系列之图像

    显示图像 图像数据 您可以将二维数值数组显示为图像 在图像中 xff0c 数组元素决定了图像的亮度或颜色 例如 xff0c 加载一个图像数组及其颜色图 xff1a gt gt load durer gt gt whos Name Size
  • 知识跟踪的深度知识跟踪和动态学生分类 Deep Knowledge Tracing and Dynamic Student Classification for Knowledge Tracing

    Deep Knowledge Tracing and Dynamic Student Classification for Knowledge Tracing xff08 译 xff09 知识跟踪的深度知识跟踪和动态学生分类 摘要 在智能辅
  • 知识追踪常见建模方法之IRT项目反应理论

    目录 A 项目反应理论 xff08 IRT item response theory xff09 概述 历史发展 特点 模型 A 项目反应理论 xff08 IRT item response theory xff09 概述 IRT理论即项目
  • MATLAB图像处理基本操作(1)

    matlib软件巨大 xff0c 没有安装 找了几个在线网址 http www compileonline com execute matlab online php https octave online net 从文件读取图像 a xf
  • Python学习系列之类的定义、构造函数 def __init__

    python def init self name等多参数 def init self 常见的两种类的定义方式如下 第一种 class Student def init self 两者之间的区别 self name 61 None self
  • ChatGPT,爆了!

    这段时间真是太刺激了 xff0c AI领域几乎每天都会爆出一个超震撼的产品 xff0c 有一种科幻马上要成现实的感觉 不知道大家朋友圈是什么样 xff0c 在整个创业的圈子里面 xff0c 几乎全是 AI 相关 就连 N 多年 xff0c
  • 分类回归模型评估常见方法及ROC AUC

    目录 模型评估常见方法 ROC和AUC定义 sklearn计算ROC具体实现 计算ROC需要知道的关键概念 1 分析数据 2 针对score xff0c 将数据排序 3 将截断点依次取为score值 3 1 截断点为0 1 sklearn
  • Coursera 吴恩达《Machine Learning》课堂笔记 + 作业

    记录一下最近学习的资源 xff0c 方便寻找 xff1a Github 上已经有人把作业整理成为 Python 的形式了 有 py 和 ipynb 两种格式 https github com nsoojin coursera ml py h
  • tensflow学习小知识tf.train.exponential_decay

    tf train exponential decay是tensflow1 X版本的2 版本使用以下语句 tf compat v1 train exponential decay 将指数衰减应用于学习率 tf compat v1 train
  • PyTorch学习系列之PyTorch:nn和PyTorch:optim优化

    PyTorch xff1a nn 在构建神经网络时 xff0c 我们经常考虑将计算分为几层 xff0c 其中一些层具有可学习的参数 xff0c 这些参数将在学习过程中进行优化 在TensorFlow xff0c 像包 Keras xff0c
  • tf.gather()用法详解

    tf gather params indices validate indices 61 None axis 61 None batch dims 61 0 name 61 None 请注意 xff0c 在CPU上 xff0c 如果找到超出
  • 代码学习之Python冒号详解

    最近看代码发现对冒号用法理解不够透彻 xff0c 记录学习一下 xff1a 1 冒号的用法 1 1 一个冒号 a i j 这里的i指起始位置 xff0c 默认为0 xff1b j是终止位置 xff0c 默认为len a xff0c 在取出数
  • Jupyter Notebook导入和删除虚拟环境 超详细

    记录一下Jupyter Notebook导入和删除虚拟环境的步骤 xff0c 网上博客参差不齐 xff0c 每次找好几个才看到简明容易理解的 方法一步骤 为不同的环境配置kernel 有时候使用conda命令创建了新的python环境 xf
  • tf.expand_dims用法详解

    看官方讲解一些博客感觉一直不是很懂 xff0c 下面是我的个人理解结合官方文档 xff0c 有问题欢迎指出 tf expand dims tf expand dims input axis 61 None name 61 None dim
  • argparse 命令行选项、参数和子命令解析器

    最近看到很多论文代码都是用解析器写的 argparse 命令行选项 参数和子命令解析器 argparse 模块可以让人轻松编写用户友好的命令行接口 程序定义它需要的参数 xff0c 然后 argparse 将弄清如何从 sys argv 解
  • torch.unsqueeze和 torch.squeeze() 详解

    1 torch unsqueeze 详解 torch unsqueeze input dim out 61 None 作用 xff1a 扩展维度 返回一个新的张量 xff0c 对输入的既定位置插入维度 1 注意 xff1a 返回张量与输入张
  • Android中获取唯一的id

    文章目录 Android唯一设备ID现状IMEIMAC地址唯一Id实现方案那些硬件适合硬件标识工具类 Android唯一设备ID现状 设备ID xff0c 简单来说就是一串符号 xff08 或者数字 xff09 xff0c 映射现实中硬件设
  • debian虚拟机下如何安装增强功能

    1 安装gcc和kernel headers gcc有可能默认安装的有 xff08 如果没有还需要安装gcc xff09 xff0c 但是还需要安装build essential sudo apt get install build ess
  • PyTorch学习系统之 scatter() 函数详解 one hot 编码

    torch Tensor scatter scatter 和 scatter 的作用是一样的 xff0c 只不过 scatter 不会直接修改原来的 Tensor xff0c 而 scatter 会 torch Tensor scatter