《Attention Is All You Need》算法详解

2023-05-16

该篇文章右谷歌大脑团队在17年提出,目的是解决对于NLP中使用RNN不能并行计算(详情参考《【译】理解LSTM(通俗易懂版)》),从而导致算法效率低的问题。该篇文章中的模型就是近几年大家到处可以听到的Transformer模型。

一、算法介绍前的说明

由于该文章提出是解决NLP(Nature Language Processing)中的任务,例如文章实验是在翻译任务上做的。为了CV同学更好的理解,先简单介绍一下NLP任务的一个工作流程,来理解模型的输入和输出是什么。

1.1 CV模型的输入和输出

首先拿CV中的分类任务来说,训练前我们会有以下几个常见步骤:

  1. 获取图片
  2. 定义待分类的类别,用数字标签或者one-hot向量标签表示
  3. 对图片进行类别的标注
  4. 图片预处理(翻转、裁剪、缩放等)
  5. 将预处理后的图片输入到模型中

所以对于分类任务来说,模型的输入为预处理过的图片,输出为图片的类别(一般为预测的向量,然后求argmax获得类别)。

1.2 NLP模型的输入

在介绍NLP任务预处理流程前,先解释两个词,一个是tokenize,一个是embedding。

tokenize是把文本切分成一个字符串序列,可以暂且简单的理解为对输入的文本进行分词操作。对英文来说分词操作输出一个一个的单词,对中文来说分词操作输出一个一个的字。(实际的分词操作多有种方式,会复杂一点,这里说的只是一种分词方式,姑且这么定,方便下面的理解。)

embedding是可以简单理解为通过某种方式将词向量化,即输入一个词输出该词对应的一个向量。(embedding可以采用训练好的模型如GLOVE等进行处理,也可以直接利用深度学习模型直接学习一个embedding层,Transformer模型的embedding方式是第二种,即自己去学习的一个embedding层。)

在NLP中,拿翻译任务(英文翻译为中文)来说,训练模型前存在下面步骤:

  1. 获取英文中文对应的句子
  2. 定义英文词表(常用的英文单词作为一个类别)和中文词表(一个字为一个类别)
  3. 对中英文进行分词
  4. 将分好的词根据步骤2定义好的词表获得句子中每个词的one-hot向量
  5. 对每个词进行embedding(输入one-hot输出与该词对应的embedding向量)
  6. embedding向量输入到模型中去

所以对于翻译任务来说,翻译模型的输入为句子每个词的one-hot向量或者embedding后的向量(取决于embedding是否是翻译模型自己学习的,如果是则输入one-hot就可以了,如果不是那么输入就是通过别的模型获得的embedding向量)组成的序列,输出为当前预测词的类别(一般为词表大小维度的向量)

二、Transformer的结构

知道了Transformer模型的输入和输出后,下面来介绍一下Transformer模型的结构。

先来看看Transformer的整体结构,如下图所示:

1

可以看出它是一个典型的seq2seq结构(encoder-decoder结构),Encoder里面有N个重复的block结构,Decoder里面也有N个重复的block结构。

2

2.1 Embedding

可以注意到这里的embedding操作是与翻译模型一起学习的。所以Transformer模型的输入为对句子分词后,每个词的one-hot向量组成的一个向量序列,输出为预测的每个词的预测向量。

2.2 Positional Encoding

为了更好的利用序列的位置信息,在对embedding后的向量加上位置相关的编码。文章采用的是人工预设的方式计算出来的编码。计算方式如下

P E ( p o s , 2 i ) = s i n ( p o s / 1000 0 2 i / d m o d e l ) PE_{(pos, 2i)}=sin(pos/10000^{2i/d_{model}}) PE(pos,2i)=sin(pos/100002i/dmodel)

P E ( p o s , 2 i + 1 ) = c o s ( p o s / 1000 0 2 i / d m o d e l ) PE_{(pos, 2i+1)}=cos(pos/10000^{2i/d_{model}}) PE(pos,2i+1)=cos(pos/100002i/dmodel)

上式中,pos表示当前词在句子中的位置,例如输入的序列长L=5,那么pos取值分别为0-4,i表示维度的位置,偶数位置用 P E ( p o s , 2 i ) PE(pos, 2i) PE(pos,2i)公式计算, 奇数位置用 P E ( p o s , 2 i + 1 ) PE(pos, 2i+1) PE(pos,2i+1)公式计算。

文章也采用了加入模型训练来自动学习位置编码的方式,发现效果与人工预设方式差不多。

2.3 Encoder结构

Encoder包含了N个重复的block结构,文章N=6。下面来拆解一个每个块的具体结构。

6
2.3.1 Multi-Head Attention(encoder)

为了便于理解,介绍Multi-Head Attention结构前,先介绍一下基础的Scaled Dot-Product Attention结构,该结构是Transformer的核心结构。

Scaled Dot-Product Attention结构如下图所示

3

Scaled Dot-Product Attention模块用公式表示如下

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

上式中,可以假设Q\K的维度皆为 ( L , d k ) {(L, d_k)} (L,dk),V的维度为 ( L , d v ) (L, d_v) (L,dv),L为输入的句子长度, d k , d v d_k,d_v dkdv为特征维度。

s o f t m a x ( Q K T ) softmax(QK^T) softmax(QKT)得到的维度为 ( L , L ) (L, L) (L,L),该张量可以理解为计算Q与K中向量两两间的相似度或者说是模型应该着重关注(attention)的地方。这里还除了 d k \sqrt{d_k} dk ,文章解释是防止维度 d k d_k dk太大得到的值就会太大,导致后续的导数会太小。(这里为什么一定要除 d k \sqrt{d_k} dk 而不是 d k {d_k} dk或者其它数值,文章没有给出解释。)

经过 s o f t m a x ( Q K T d k ) softmax(\frac{QK^T}{\sqrt{d_k}}) softmax(dk QKT)获得attention权重后,与V相乘,既可以得到attention后的张量信息。最终的 A t t e n t i o n ( Q , K , V ) Attention(Q, K, V) Attention(Q,K,V)输出维度为 ( L , d v ) (L, d_v) (L,dv)

这里还可以看到在Scaled Dot-Product Attention模块中还存在一个可选的Mask模块(Mask(opt.)),后续会介绍它的作用。

文章认为采用多头(Multi-Head)机制有利于模型的性能提高,所以文章引入了Multi-Head Attention结构。

Multi-Head Attention结构如下图所示

4

Multi-Head Attention结构用公式表示如下

M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , . . . , h e a d n ) W O w h e r e h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) MultiHead(Q, K, V) = Concat(head_1, ..., head_n)W^O\\ where head_i = Attention(QW^Q_i, KW^K_i, VW^V_i) MultiHead(Q,K,V)=Concat(head1,...,headn)WOwhereheadi=Attention(QWiQ,KWiK,VWiV)

上述参数矩阵为 W i Q ∈ R d m o d e l × d k W^Q_i\in R^{d_{model} \times d_k} WiQRdmodel×dk, W i K ∈ R d m o d e l × d k W^K_i\in R^{d_{model} \times d_k} WiKRdmodel×dk, W i V ∈ R d m o d e l × d v W^V_i\in R^{d_{model} \times d_v} WiVRdmodel×dv, W i O ∈ R h d v × d m o d e l W^O_i\in R^{hd_v \times d_{model}} WiORhdv×dmodel d m o d e l d_{model} dmodel为multi-head attention模块输入与输出张量的通道维度,h为head个数。文中h=8, d k = d v = d m o d e l / h = 64 d_k=d_v=d_{model}/h=64 dk=dv=dmodel/h=64 d m o d e l = 512 d_{model}=512 dmodel=512

关于multi-head机制为什么可以提高模型性能

文章末尾给出了多头中其中两个头的attention可视化结果,如下所示

5

图中,线条越粗表示attention的权重越大,可以看出,两个头关注的地方不一样,绿色图说明该头更关注全局信息,红色图说明该头更关注局部信息。

2.3.2 Add&Norm结构

从结构图不难看出网络加入了residual结构,所以add很好理解,就是输入张量与输出张量相加的操作。

Norm操作与CV常用的BN不太一样,这里采用NLP领域较常用的LN(Layer Norm)。(关于BN、LN、IN、GN的计算方式可以参考《GN-Group Normalization》)

还要多说一下的是,文章中共Add&Norm结构是先相加再进行Norm操作。

2.3.3 Feed Forward结构

该结构很简单,由两个全连接(或者kernel size为1的卷积)和一个ReLU激活单元组成。

Feed Forward结构用公式表示如下

F F N ( x ) = m a x ( 0 , x W 1 + b 1 ) W 2 + b 2 FFN(x)=max(0, xW_1 + b_1)W_2 + b_2 FFN(x)=max(0,xW1+b1)W2+b2

2.4 Decoder结构

Decoder同样也包含了N个重复的block结构,文章N=6。下面来拆解一个每个块的具体结构。

6
2.4.1 Masked Multi-Head Attention

从名字可以看出它比2.3.1部分介绍的Multi-Head Attention结构多一个masked,其实它的基本结构如下图所示

3

可以看出这就是Scaled Dot-Product Attention,只是这里mask是启用的状态。

这里先从维度角度考虑mask是怎么工作的,然后再解释为什么要加这个mask操作。

mask工作方式

为了方便解释,先不考虑多batch和多head情况。

可以假设Q\K的维度皆为 ( L , d k ) {(L, d_k)} (L,dk),V的维度为 ( L , d v ) (L, d_v) (L,dv)

那么在进行mask操作前,经过MatMul和Scale后得到的张量维度为 Q K T d k ∈ R ( L , L ) \frac{QK^T}{\sqrt{d_k}}\in R^{(L, L)} dk QKTR(L,L)

现在有一个提前计算好的mask为 M ∈ R ( L , L ) M\in R^{(L, L)} MR(L,L),M是一个上三角为-inf,下三角为0的方阵。如下图所示(图中假设L=5)。

8

s o f t m a x ( Q K T d k + M ) softmax(\frac{QK^T}{\sqrt{d_k}}+M) softmax(dk QKT+M)的结果如下图所示(图中假设L=5)

注意:下图中的非0区域的值不一定是一样的,这里为了方便显示画成了一样的颜色

9

现在Scaled Dot-Product Attention的公式如下所示

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k + M ) V Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}}+M)V Attention(Q,K,V)=softmax(dk QKT+M)V

可以看出经过M后,softmax在-inf处输出结果为0,其它地方为非0,所以softmax的输出为 s o f t m a x ( Q K T d k + M ) ∈ R ( L , L ) softmax(\frac{QK^T}{\sqrt{d_k}}+M)\in R^{(L, L)} softmax(dk QKT+M)R(L,L),该结果为上三角为0的方阵。与 V ∈ R ( L , d v ) V\in R^{(L, d_v)} VR(L,dv)进行相乘得到结果为 A t t e n t i o n ( Q , K , V ) ∈ R ( L , d v ) Attention(Q, K, V) \in R^{(L, d_v)} Attention(Q,K,V)R(L,dv)

从上述运算可以看出mask的目的是为了让V与attention权重计算attention操作时只考虑当前元素以前的所有元素,而忽略之后元素的影响。即V的维度为 ( L , d v ) (L, d_v) (L,dv),那么第i个元素只考虑0-i元素来得出attention的结果。

mask操作的作用

在解释mask作用之前,我们先解释一个概念叫teacher forcing

teacher forcing这个操作方式经常在训练序列任务时被用到,它的含义是在训练一个序列预测模型时,模型的输入是ground truth。

举例来说,对于"I Love China -> 我爱中国"这个翻译任务来说,测试阶段,Encoder会将输入英文编译为feature,Decoder解码时首先会收到一个BOS(Begin Of Sentence)标识,模型输出"我",然后将"我"作为decoder的输入,输出"爱",重复这个步骤直到输出EOS(End Of Sentence)标志。

但是为了能快速的训练一个效果好的网络,在训练时,不管decoder输出是什么,它的输入都是ground truth。例如,网络在收到BOS后,输出的是"你",那么下一步的网络输入依然还是使用gt中的"我"。这种训练方式称为teacher forcing。如下图所示

12

我们看下面两张图,第一张是没有mask操作时的示例图,第二张是有mask操作时的示例图。可以看到,按照teacher forcing的训练方式来训练Transformer,如果没有mask操作,模型在预测"我"这个词时,就会利用到"我爱中国"所有文字的信息,这不合理。所以需要加入mask,使得网络只能利用部分已知的信息来模拟推断阶段的流程。

13 11
2.4.2 Multi-Head Attention(decoder)

decoder中的Multi-Head Attention内部结构与encoder是一模一样的,只是输入中的Q为2.4.1部分提到的Masked Multi-Head Attention的输出,输入中的K与V则都是encoder模块的输出。

下面用一张图来展示encoder和decoder之间的信息传递关系

15

decoder中Add&Norm和Feed Forward结构都与encoder一模一样了。

2.5 其它说明

1. 从图中看出encoder和decoder中每个block的输入都是一个张量,但是输入给attention确实Q\K\V三个张量?

对于block来说,Q=K=V=输入张量

2. 推断阶段,解码可以并行吗?

不可以,上面说的并行是采用了teacher forcing+mask的操作,在训练时可以并行计算。但是推断时的解码过程同RNN,都是通过auto-regression方式获得结果的。(当然也有non auto-regression方面的研究,就是一次估计出最终结果)

参考:

  1. https://arxiv.org/abs/1706.03762
  2. https://www.youtube.com/watch?v=ugWDIIOHtPA&t=1697s
  3. http://nlp.seas.harvard.edu/2018/04/03/attention.html
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

《Attention Is All You Need》算法详解 的相关文章

随机推荐

  • github开源项目Bringing-Old-Photos-Back-to-Life实战

    Github项目 xff1a https github com microsoft Bringing Old Photos Back to Life 1 环境 win11 python 3 8 8 torch 1 9 1 torchvisi
  • sapjco3.jar在maven项目中的打包问题

    sapjco3 jar的打包问题 前几天做一个跟SAP系统有关的小功能时 xff0c 项目用到了sapjco3 jar的jar包 xff0c 项目打包部署后报错 com hand sapjco 3 system basedir src ma
  • 记录:c#中使用Selenium之一 使用chrome驱动手机模式浏览

    1 最近研究c 使用各种模拟浏览器的工具 xff0c 下面是Selenium的使用记录 1 xff09 首先使用Nuget搜索以下依赖库 2 xff09 我使用的是手机模式的浏览方式 下面为手机浏览模式设置的代码 xff0c 以移动端百度搜
  • 解决VS平台迁移时报错error MSB8020:The build tools for v141

    将在VS2017上编译的程序放到VS2013中 xff0c 报错 xff1a error MSB8020 The build tools for v141 Platform Toolset 61 39 v141 39 cannot be f
  • C++ 中的char型变量

    最简单的字符数据类型是 char 数据类型 该类型的变量只能容纳一个字符 xff0c 而且在大多数系统上 xff0c 只使用一个字节的内存 以下示例即声明了一个名为 letter 的 char 变量 请注意 xff0c 这里的字符常数就是赋
  • linux下DISPLAY和xhost + 作用

    在Linux Unix类操作系统上 DISPLAY用来设置将图形显示到何处 直接登陆图形界面或者登陆命令行界面后使用startx启动图形 DISPLAY环境变量将自动设置为 0 0 此时可以打开终端 输出图形程序的名称 比如xclock 来
  • 配置 maven 编译的 JDK 版本

    两种方式 xff1a 一 可以修改 MAVEN 的 setting xml 文件 xff0c 统一修改 lt profiles gt lt profile gt lt id gt jdk 1 6 lt id gt lt activation
  • 利用redis的setIfAbsent()方法实现分布式锁

    再集群环境中 xff0c 存在定时任务多次执行 xff0c 浪费资源 xff0c 那么如何避免这种情况呢 xff0c 下面就说明一下如何利用一个注解解决问题 xff0c 利用切面配合redis可以简单实现分布式锁 xff0c 解决定时任务重
  • Virtualbox主机和虚拟机之间文件夹共享及双向拷贝(win7——centos7)

    一 双向拷贝 xff1a 然后 xff0c 还需要通过virtualbox上安装一个增强的工具 此时 xff0c 会在centos上安装一些工具 xff1a 鼠标自动在宿主机 虚拟机之间移出 同时 xff0c 在centos上会出现一个安装
  • Record something about DL

    这篇文章算是DL实践杂谈吧 xff0c 主要是想把自己模型调优和复现算法遇到的一些坑总结一下 xff08 里面的一行字可能是我当时花费了一周甚至更长时间得到的总结 xff09 xff0c 希望能对读者有所帮助 一 熟悉数据 模型是数据的浓缩
  • Image captioning任务常用的评价指标计算

    BLEU ACL 2002Meteor AMTA 2004ROUGE L ACL 2004CIDEr CVPR 2015SPICE ECCV 2016
  • Image captioning评价方法之BLEU (bilingual evaluation understudy)

    文章地址 xff1a BLEU a Method for Automatic Evaluation of Machine Translation 代码地址 非官方 xff1a https github com tylin coco capt
  • Image captioning评价方法之Meteor

    项目地址 xff1a http www cs cmu edu alavie METEOR 代码地址 xff08 非官方实现 xff0c 实现的是项目地址中的1 5版本 xff09 xff1a https github com tylin c
  • Image captioning评价方法之ROUGE-L

    文章地址 xff1a ROUGE A Package for Automatic Evaluation of Summaries 代码地址 非官方 xff1a https github com tylin coco caption 文章由U
  • Image captioning评价方法之CIDEr

    文章地址 xff1a CIDEr Consensus based Image Description Evaluation 代码地址 xff08 非官方 xff0c 且代码实现的是CIDEr D xff09 xff1a https gith
  • Image captioning评价方法之SPICE

    项目地址 xff1a https panderson me spice 上述的项目地址包含了论文地址和代码地址 该方法是由The Australian National University和Macquarie University联合发表
  • R3DS Wrap基本使用方法

    中文的R3DS Wrap软件的教程较少 xff0c 最近刚好实操了一遍 xff0c 特此记录下来 为了描述方便 xff0c 下面将R3DS Wrap简称Wrap 软件官网 xff1a https www russian3dscanner c
  • docker使用入门简介

    一 什么是docker xff1f https www docker com resources what container 使用docker时有两个重要概念 xff0c 一个是镜像 xff08 images xff09 xff0c 一个
  • SpringBoot整合Quartz 实现分布式定时任务调度

    一 Quartz 集群架构 Quartz 是 Java 领域最著名的开源任务调度工具 在上篇文章中 xff0c 我们详细的介绍了 Quartz 的单体应用实践 xff0c 如果只在单体环境中应用 xff0c Quartz 未必是最好的选择
  • 《Attention Is All You Need》算法详解

    该篇文章右谷歌大脑团队在17年提出 xff0c 目的是解决对于NLP中使用RNN不能并行计算 xff08 详情参考 译 理解LSTM xff08 通俗易懂版 xff09 xff09 xff0c 从而导致算法效率低的问题 该篇文章中的模型就是