Transformer论文及源码笔记——Attention Is All You Need

2023-11-05

Transformer论文及源码笔记——Attention Is All You Need

综述

论文题目:《Attention Is All You Need》

时间会议:Advances in Neural Information Processing Systems, 2017 (NIPS, 2017)

论文链接:http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf

介绍

  transformer结构的优点:

  • 长程依赖性处理能力强:自注意力机制可以实现对整张图片进行全局信息的建模
  • 并行化能力强:可以并行计算输入序列中的所有位置;

  网络结构图如下所示:

在这里插入图片描述

其中,多头注意力网络结构为:

在这里插入图片描述

注意力公式:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , … , h e a d h ) W O h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V\\ MultiHead(Q,K,V)=Concat(head_1,\dots,head_h)W^O\\ head_i=Attention(QW_i^Q,KW_i^K,VW_i^V) Attention(Q,K,V)=softmax(dk QKT)VMultiHead(Q,K,V)=Concat(head1,,headh)WOheadi=Attention(QWiQ,KWiK,VWiV)
其中:

  • Q代表query,后续会去和每一个k进行匹配;

  • K代表key,后续会被每个q匹配;

  • V代表输入数据,又称为编码数据Embedding;

  • 后续Q和K匹配的过程可以理解成计算两者的相关性,相关性越大对应v的权重也就越大,主要就是给V生成一组权重;

多头注意力K、Q、V解释:

  目前有多组键值匹配对k、v,每个k对应一个v,计算q所对应的值。思路:计算q与每个k的相似度,得到v的权重,之后对v做加权求和,得到q对应的数值。因此在解码过程中,第二个多头注意力的输入中,k、v传入编码特征(是已知的特征匹配对),q传入解码特征(可迭代传入),求解码对应的特征(根据编码特征之间的相似度求解码的注意力加权特征)。

  注:kqv的关系用一句话来说就是根据kv的键值匹配关系,预测q对应的数值,根据kq的相似度对v做加权求和

视频参考:https://www.bilibili.com/video/BV1dt4y1J7ov/?spm_id_from=333.788.recommend_more_video.2&vd_source=b1b1710a3f74753e8bfc47c5c2e4d49e

多头注意力计算过程:

  先让kqv做线性映射,之后沿特征向量的方向拆分成不同的“头”,之后利用拆分的向量做运算→q和k做矩阵乘法,得到注意力权重→注意力权重除以缩放因子 d k \sqrt{d_k} dk d k d_k dk表示每个头的维度,再做Softmax运算→经过一次Dropout运算(可选)→所得的权重与v做矩阵乘法→合并所有“头”,最后经过一次线性映射;

代码实现

编码模块

class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu"):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)

        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.activation = _get_activation_fn(activation)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward(self,
                     src,
                     pos: Optional[Tensor] = None):
        # 特征先与位置编码相加
        v = q = k = self.with_pos_embed(src, pos)
        src2 = self.self_attn(q, k, v)[0]
        src = src + src2
        src = self.norm1(src)
        src2 = self.linear2(self.activation(self.linear1(src)))
        src = src + src2
        src = self.norm2(src)
        return src

解码模块

class TransformerDecoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu"):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.activation = _get_activation_fn(activation)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward(self, encoding, decoding, pos: Optional[Tensor] = None):
        q = k = v = self.with_pos_embed(decoding, pos)
        # 解码特征先做一次自注意力运算
        decoding2 = self.self_attn(q, k, v)[0]
        decoding = decoding + decoding2
        decoding = self.norm1(decoding)

        decoding2 = self.multihead_attn(query=decoding, key=encoding, value=encoding)[0]
        decoding = decoding + decoding2
        decoding = self.norm2(decoding)
        decoding2 = self.linear2(self.activation(self.linear1(decoding)))
        decoding = decoding + decoding2
        decoding = self.norm3(decoding)
        return decoding

注:以上仅是笔者个人见解,若有问题,欢迎指正。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Transformer论文及源码笔记——Attention Is All You Need 的相关文章

  • JAVA学习笔记(二)浮点数与精确计算

    浮点数分为float 单精度浮点数 和double 双精度浮点数 float取值范围是4字节32位 精度是7到8位 double取值范围是8字节64位 精度是16到17位 数据转化时会有精度的损失 所以通过BigDecimal类将浮点数转化
  • hive 写入mysql 覆盖_一文搞定hive之insert into 和 insert overwrite与数据分区

    版权声明 本文为博主原创文章 未经博主允许不得转载 数据分区 数据库分区的主要目的是为了在特定的SQL操作中减少数据读写的总量以缩减响应时间 主要包括两种分区形式 水平分区与垂直分区 水平分区是对表进行行分区 而垂直分区是对列进行分区 一般
  • Android的ListView控件的常用适配器

    ListView的常用适配器 一 ArrayAdapter适配器 1 创建ListView 2 创建用于加载数据的布局 3 java的逻辑代码 使用GridView 以多列的方式排列 处理GridView的逻辑代码 二 SimpleAdap
  • pfx证书转pem、crt、key

    今天测试端的服务器突然不能下载苹果APP了 经查看 发现原来是测试环境的https证书过期了 需要更换证书 于是赶紧从阿里云更新我们的最新证书 我们程序部署在tomcat上 于是下载tomcat版本 下载完成后如下 我们的程序部署在天翼云上

随机推荐

  • adb命令

    adb screenshot2 aapt monitor uiautomatorviewer等命令都在Android SDK的tools platform tools build tools下面 如果希望直接运行命令 不写绝对路径 请把相应
  • TypeScript中的模块与命名空间

    一 模块 在TypeScript中 模块是一种组织和封装代码的方式 模块使得代码可以按照特定的规则划分为不同的文件 并且可以在这些文件之间进行导入和导出 从而实现代码的重用和组织 1 默认导入导出 默认模块导出是一种特殊的导出语法 在一个模
  • 读取串口 :javax.comm 2.0 windows下Eclipse的配置

    javax comm 2 0 windows下Eclipse的配置 要在Windows下 对计算机的串口或并口等进行编程 可以选择使用Java Communication API javax comm 包 现在最新的版本是3 0版本 但是3
  • B站评论采集

    B站评论采集 打开目标网址 哔哩哔哩 干杯 bilibili 找到爬取得剧的评论 打开浏览器抓包工具进行抓包分析 这里爬取鬼灭之刃第一季的评论数据 分析网页 打开评论页面 可以看到分为短评 128702 和长评 639 条 常规操作直接F1
  • fusion360界面字体模糊处理方法

    fusion360界面字体模糊处理方法 1 右键点击桌面fusion360图标 2 选择兼容性 3 选择更改高DPI设置 4 设置为如下界面 5 重新启动程序 over
  • C++指向类成员的指针

    指向类成员 以前C定义指针 int a int p a void func void pf func 而在这里本质也是相同 去掉类名 就是上面的形式 定义如下 成员类型 类名 指针名 类名 成员名 函数返回类型 类名 函数指针名 参数列表
  • 数学建模笔记(六):常微分方程及其应用

    文章目录 一 常微分方程概述 1 什么是常微分方程 2 以微分方程解决实际问题的一般思维 3 微分方程求解 4 微分方程适用问题 5 建立微分方程模型的方法 二 物体的冷却过程 1 问题背景 2 问题分析 3 模型建立与求解 三 水桶的放水
  • MySQL数据库解读之-内置数据库:mysql

    数据字典表 不可见 不能用 SELECT 读取 不会出现在 SHOW TABLES 的输出中 不会列在 information schema TABLES 表中 从概念上讲 information schema 提供了一个视图 MySQL
  • centos7.5安装zabbix5.0(亲测有效)

    配置环境 操作系统 centos7 5 必须要是Centos7以上的系统 zabbix版本 5 0 Zabbix 特性 1 数据收集 2 灵活的阀值定义 3 高级告警配置 4 实时绘图 5 扩展的图形化显示 6 历史数据存储 7 配置简单
  • 数据结构笔记(C语言版)

    一 绪论 程序 数据结构 算法 1 基本的数据结构 线性结构 线性表 栈和队列 串 数组和广义表 非线性结构 树 图 用计算机解题一个问题的步骤 具体问题抽象为数学模型 设计算法 编程 调试 运行 数据结构是一门研究非数值计算的程序设计中计
  • [Hive SQL] 实现分组排序、分组topN

    举个场景例子 我们要计算app内在每小时区间内访问量前2的服务 根据访问日志处理完后的数据集如下所示 visit hour service name visit cnt 2021062401 A 421 2021062401 B 710 2
  • python实现简易万年历_Python编程——万年历

    2017年五月份日历 万年历这个题目几乎是不论学哪种编程语言必要尝试的一个小知识 综合了循环 逻辑关系判断等各编程语言的基础知识 今天我们一起用Python实现简单的万年历功能 查看某年各个月份日历和查看确定月份日历 网上大概浏览了一部分代
  • C++从0到1(5):循环结构

    目录 1 while循环 2 do while循环 3 for循环 4 嵌套循环 1 while循环 作用 满足循环条件 执行循环语句 语法 while 循环条件 循环语句 循环条件为真 就执行循环语句 include
  • Maven一定要会的这几个知识!

    一 Maven概念 Maven是一个项目管理和整合工具 Maven为开发者提供了一套完整的构建生命周期框架 开发团队几乎不用花多少时间就能够自动完成工程的基础构建配置 因为Maven使用了一个标准的目录结构和一个默认的构建生命周期 若有多个
  • 初识Nacos

    目录 1 Nacos介绍 1 1四大功能 1 2微服务中配置文件的问题 1 3配置中心解决了什么 1 4业界常见的配置中心 1 5解决不同环境相同配置的问题 1 6不同微服务之间相同配置的共享 2 Nacos Config 动态刷新原理 2
  • 使用ImageMagick批量转换图片格式

    需求 需要将1000张 DCM后缀结尾的图片文件转换为常见的jpg格式 解决 windows下载安装 http www imagemagick org script download php 将ImageMagick安装完成后 确保在命令行
  • QT下assimp库的模型加载

    Assimp库概述 一个非常流行的模型导入库是Assimp 它是Open Asset Import Library 开放的资产导入库 的缩写 Assimp能够导入很多种不同的模型文件格式 并也能够导出部分的格式 它会将所有的模型数据加载至A
  • 云服务器:Linux宝塔面板如何部署node服务

    前情 有自己的服务器 已经安装了宝塔面板 也安装了node js 在本地编写了一个node程序 如何要挂载到阿里云服务器中运行 解决 将本地node文件上传至服务器中www目录下 node modules可以不用上传 运行node程序 np
  • 网络基础之OSI七层模型与TCP/IP五层模型

    OSI七层模型及功能概述 一 OSI七层模型 二 七层模型的功能概述 1物理层 2数据链路层 3网络层 4传输层 5会话层 6表示层 7应用层 三 TCP IP五层模型的组成 四 五层模型中的协议族组成 五 数据封装与解封装过程 六 设备与
  • Transformer论文及源码笔记——Attention Is All You Need

    Transformer论文及源码笔记 Attention Is All You Need 综述 介绍 代码实现 编码模块 解码模块 综述 论文题目 Attention Is All You Need 时间会议 Advances in Neu