Transformer中的position encoding(位置编码一)

2023-10-30

本文主要讲解Transformer 中的 position encoding,在当今CV的目标检测最前沿,都离不开position encoding,在DETR,VIT,MAE框架中应用广泛,下面谈谈我的理解。

一般position encoding 分为 正余弦编码和可学习编码。

正余弦编码

 以下为DETR中的position encoding过程,本文将以简单的数据帮助大家理解。以下过程是按照DETR走的,为了更好理解,对数据进行简化,针对不同的图像,产生不同的数据大小。

1.创建mask 

假设mask为4×4大小,输入图像大小为3×3。

下图为mask生成的4*4维度的矩阵,根据对应与输入图像大小3*3生成以下的mask编码tensor,下右图为反mask编码tensor,这一步就得到了图像的大小及对应与mask下的位置。

 

2.生成Y_embed和X_embed的tensor

y_embed = not_mask.cumsum(1, dtype=torch.float32)#在行方向累加#(b , h , w)
x_embed = not_mask.cumsum(2, dtype=torch.float32)#在列方向累加#(b , h , w)

    DETR中运用两行编码实现Y_embed和X_embed,生成大小为(bitch_size , h , w)的tensor。

    根据在1中我们产生的反mask编码,生成的Y_embed和X_embed如下。

     Y_embed对为mask编码True的进行行方向累加1,X_embed对为mask编码True的进行列方向累加1。下图所示:

3. 运用10维(自己可以延申为1024维)position进行编码

num_pos_feats = 10
temperature = 10000
dim_t = torch.arange(num_pos_feats, dtype=torch.float32,device=a.device)#生成10维数
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) #i=dim_t // 2#对10维数进行计算

 第三行代码生成了10个tensor数据,第四行代码相当于dim_t=10000^{2*(dimt1//2)/10},对10个生成的tensor进行计算得到位置编码公式中的分母10000^{2i/d},结果如下。

 4.生成pos_x以及pos_y

pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)#不降维
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)#不降维

 

 

第四步以后的直观效果如上图所示,可以对照第二步的X_embed和Y_embed,会发现pos_x,y的tensor分母和X,Y_embed对应 ,很好理解,其中i对应的是10维position的不同维度的数,d代表的是position编码维度。

5.组合Pos_x和Pos_y

 因为上述位置编码的生成是行列方向分开的,这一步需要进行组合。

pos = torch.cat((pos_y, pos_x), dim=2)

  

 组合以后直观图的样子如上,这时会发现16个位置的分母已经根据pos的不同,达到了位置编码的不同,因为本文采用的是10维的position,分子i的范围为0-10,每个位置就形成了1*20的tensor数据。

 上述两个位置的编码就可以理解为1*20的tensor数据,因为比较长,分开写了,不是4*5的,而是1*20的tensor数据,通过上图可以很直观的理解position encoding。

程序结果如下,类似于此。下面将自己改写的简单的position encoding 程序段放在下面,大家可以复制自己跑一下,看看输出结果,加强理解。

import torch
import numpy as np
import math

#正余弦位置编码
num_pos_feats = 10
temperature = 10000
normalize = False
scale = 2 * math.pi

a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
a = torch.tensor(a)
mask = [[False,False,False,True],[False,False,False,True],[False,False,False,True],[True,True,True,True]]
mask = torch.tensor(mask)
print(mask)
assert mask is not None
not_mask = ~mask
print(not_mask)
y_embed = not_mask.cumsum(0, dtype=torch.float32)
x_embed = not_mask.cumsum(1, dtype=torch.float32)
print(y_embed)
print(x_embed)

if normalize:
    eps = 1e-6
    # b = a[i:j:s]表示:i,j与上面的一样,但s表示步进,缺省为1.
    # 所以a[i:j:1]相当于a[i:j]
    # 当s<0时,i缺省时,默认为-1. j缺省时,默认为-len(a)-1
    # 所以a[::-1]相当于 a[-1:-len(a)-1:-1],也就是从最后一个元素到第一个元素复制一遍,即倒序。
    # 对于X[:,:,m:n]是取三维矩阵中第m维到第n-1维的所有数据
    # 归一化
    y_embed = y_embed / (y_embed[-1:, :] + eps) * scale  # y_embed[:, -1:, :]代表取三维数据中的最后一行数据
    x_embed = x_embed / (x_embed[:, -1:] + eps) * scale  # x_embed[:, :, -1:]代表取三维数据中的最后一列数据
    print(y_embed)
    print(x_embed)
dim_t1 = torch.arange(num_pos_feats, dtype=torch.float32,device=a.device)
print(dim_t1)
dim_t = temperature ** (2 * (dim_t1 // 2) / num_pos_feats) #i=dim_t1 // 2
print(dim_t)
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
print(pos_x)
print(pos_y)
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)#不降维
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)#不降维
print(pos_x)
print(pos_y)
pos = torch.cat((pos_y, pos_x), dim=2)
print(pos)

 以上是我的理解,欢迎大家批评指正,互相交流!Transformer中的position encoding(位置编码二)_zuoyou-HPU的博客-CSDN博客本文依旧采用4*4大小的词嵌入模型,和模仿3*3大小的特征图进行解读——可学习编码1.根据自己模型中的定义的最大特征图大小进而定义词嵌入模型大小。假设模型中的特征图大小不超过4*4,那么我定义的词嵌入模型大小就为4*4,同正余弦编码一样,采用10维数据进行编码。生成行方向的词嵌入模型(4 ,10),及生成列方向的词嵌入模型(4 , 10),进而生成4*10的随机权重值并均匀分布在0-1之间。row_embed = nn.Embedding(4, 10)#生成行方向词嵌入模型col_embe.https://blog.csdn.net/weixin_42715977/article/details/122139883?spm=1001.2014.3001.5501

Swin Transformer 中的 shift window attention_zuoyou-HPU的博客-CSDN博客https://blog.csdn.net/weixin_42715977/article/details/124151870?spm=1001.2014.3001.5502

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Transformer中的position encoding(位置编码一) 的相关文章

随机推荐

  • C51实现流水灯

    文章目录 一 实验要求 二 实验代码和原理图 1 代码 2 原理图 总结 一 实验要求 1 先八盏灯从左至右依次点亮 同一时刻仅有一盏灯处于被点亮状态 每盏灯亮0 5s 然后八盏灯从右至左依次点亮 同一时刻仅有一盏灯处于被点亮状态 每盏灯亮
  • hdu 1074 Doing Homework

    Problem acm hdu edu cn showproblem php pid 1074 题意 n 份作业 分别给出名字 完成所需时间 cost 最迟上交时间 deadline 作业每迟交一天扣一分 问最少的扣分数 Analysis
  • 关于.sln和.suo文件

    sln 和 suo都是是解决方案文件 sln Visual Studio Solution 它通过为环境提供对项目 项目项和解决方案项在磁盘上位置的引用 可将它们组织到解决方案中 包含了较为通用的信息 包括解决方案所包含项目的列表 解决方案
  • TCL foreach的用法

    总结放于前 foreach var list body是foreach的的常见用法 foreach为关键字 var为形参 list为数据容器 数组等 body为函数块 程序每次在程序执行时从list中取到值并赋给形参var 函数块利用var
  • sql外连接内连接

    内连接 两表的交集 符合要求的数据列出来 外连接 左外连接就是查询 join左边表中的所有数据 并且把join右边表中对应的数据查询出来 主表的数据去跟从表一一比较 有就全部列出来 没有就也要列出一条 主表数据全要 他的从表数据变成Null
  • springboot集成Redis

    springboot集成Redis 1 windows平台安装Redis 2 引入依赖 3 修改配置文件 4 启动类添加注解 5 指定缓存哪个方法 6 配置Redis的超时时间 小BUG 测试 对于项目中一些访问量较大的接口 配置上Redi
  • python连接mysql数据库报错pymysql.err.OperationalError

    一 报错信息 pymysql err OperationalError 1045 Access denied for user root localhost using password YES Traceback most recent
  • docker基础:docker stats监控容器资源消耗

    docker stats docker stats 命令用来显示容器使用的系统资源 默认情况下 stats 命令会每隔 1 秒钟刷新一次输出的内容直到你按下 ctrl c 下面是输出的主要内容 CONTAINER 以短格式显示容器的 ID
  • 基于反事实因果推断的度小满额度模型

    本文约4400字 建议阅读9分钟 本文从三个角度与你分享基于反事实因果推断的度小满额度模型 1 因果推断的研究范式 1 相关性与因果性 2 三大基本假设 2 因果推断的框架演进 1 从随机数据到观测数据 2 反事实表示学习 3 反事实额度模
  • 四川百幕晟科技有限公司:抖音没有视频怎么开店铺?

    抖音是中国最受欢迎的短视频平台之一 吸引了数亿用户 很多电商卖家希望利用抖音平台开展业务 但他们可能没有视频资源 幸运的是 抖音还提供了非视频商店功能 允许卖家开设自己的商店并在抖音上推广产品 本文将详细介绍在抖音上开店的步骤 并探讨如何在
  • 破解windows明文密码

    之前看了法国人写的一个软件 mimikatz 可以直接获取windows下的明文密码 简直是丧心病狂 作者已经开源 大家可以去谷歌一下 用SVN下载了源码 是vs2010的工程 然后按照下面命令开始看代码 privilege debug i
  • SQLAlchemy映射已有数据表

    方法一 手动创建数据表模型类进行映射 映射的表必须要有主键 配置数据库连接参数 class Config SQLALCHEMY DATABASE URI mysql pymysql root 123456 localhost 3306 te
  • mysql5.7 免安装版的配置过程

    1 去官网下载mysql 5 7 2 解压压缩包 首先给压缩包重命名一下 修改为你自己想要的 将解压目录下默认文件 my default ini 拷贝一份 改名 my ini 3 修改一下my ini 文件里的内容 client port
  • 基于卷积神经网络结合注意力机制长短记忆网络CNN-LSTM-Attention实现风电功率多输入单输出回归预测附matlab代码

    作者简介 热爱科研的Matlab仿真开发者 修心和技术同步精进 matlab项目合作可私信 个人主页 Matlab科研工作室 个人信条 格物致知 更多Matlab仿真内容点击 智能优化算法 神经网络预测 雷达通信 无线传感器 电力系统 信号
  • Kafka安装及测试

    系统环境 Linux Ubuntu 16 04 jdk 7u75 linux x64 相关知识 Kafka是由LinkedIn开发的一个分布式的消息系统 使用Scala编写 它因可以水平扩展和高吞吐率而被广泛使用 目前越来越多的开源分布式处
  • WPF编程,通过Path类型制作沿路径运动的动画另一种方法。

    上一篇文章给了一个这方面的例子 那个文章里是通过后台按钮事件进行动画的开始 停止 继续等 这里给出的是通过前台XAML来实现 1 前台 定义路径 定义运动的主体 这里是一圆
  • IEEE 754 round-to-nearest-even

    IEEE 754 二进制的向偶舍入 舍入的值保证最靠近原浮点数值 如果舍入为中间值 即舍还是入距离相等 那么按其最末尾一位是奇数 则入 如果为偶数 则舍 下面例子说明 xxx yyyyy10000 x为实数任意值 y为任意值 最末尾y为需要
  • 用C++实现简单的小游戏

    采用面向对象的编程思想 在头文件中引入acllic图形库 实现c 控制图片以及生成可视化窗口 所需工具 acllib图形库下载地址 acl图形库下载地址 win32位项目的创建 通过visual studio创建win32项目 三张图片 t
  • python 数据分析--数据处理工具Pandas(2)

    数据处理模块 Pandas 4 Pandas处理字符串和日期数据 5 Pandas 数据清洗 5 1 重复观测处理 5 2 缺失值处理 5 2 1 删除法 5 2 2 替换法 5 3 异常值处理 6 获取数据子集 7 透视表 合并与连接 分
  • Transformer中的position encoding(位置编码一)

    本文主要讲解Transformer 中的 position encoding 在当今CV的目标检测最前沿 都离不开position encoding 在DETR VIT MAE框架中应用广泛 下面谈谈我的理解 一般position enco