encoder decoder模型_Transformer 模型的 PyTorch 实现

2023-11-20

Google 2017年的论文 Attention is all you need 阐释了什么叫做大道至简!该论文提出了Transformer模型,完全基于Attention mechanism,抛弃了传统的RNNCNN

我们根据论文的结构图,一步一步使用 PyTorch 实现这个Transformer模型。

Transformer架构

首先看一下transformer的结构图:

1d1c745fc11a64706afeb95c2b412101.png

解释一下这个结构图。首先,Transformer模型也是使用经典的encoer-decoder架构,由encoder和decoder两部分组成。

上图的左半边用Nx框出来的,就是我们的encoder的一层。encoder一共有6层这样的结构。

上图的右半边用Nx框出来的,就是我们的decoder的一层。decoder一共有6层这样的结构。

输入序列经过word embeddingpositional encoding相加后,输入到encoder。

输出序列经过word embeddingpositional encoding相加后,输入到decoder。

最后,decoder输出的结果,经过一个线性层,然后计算softmax。

word embeddingpositional encoding我后面会解释。我们首先详细地分析一下encoder和decoder的每一层是怎么样的。

Encoder

encoder由6层相同的层组成,每一层分别由两部分组成:

  • 第一部分是一个multi-head self-attention mechanism

  • 第二部分是一个position-wise feed-forward network,是一个全连接层

两个部分,都有一个 残差连接(residual connection),然后接着一个Layer Normalization

如果你是一个新手,你可能会问:

  • multi-head self-attention 是什么呢?

  • 参差结构是什么呢?

  • Layer Normalization又是什么?

这些问题我们在后面会一一解答。

Decoder

和encoder类似,decoder由6个相同的层组成,每一个层包括以下3个部分:

  • 第一个部分是multi-head self-attention mechanism

  • 第二部分是multi-head context-attention mechanism

  • 第三部分是一个position-wise feed-forward network

还是和encoder类似,上面三个部分的每一个部分,都有一个残差连接,后接一个Layer Normalization

但是,decoder出现了一个新的东西multi-head context-attention mechanism。这个东西其实也不复杂,理解了multi-head self-attention你就可以理解multi-head context-attention。这个我们后面会讲解。

Attention机制

在讲清楚各种attention之前,我们得先把attention机制说清楚。

通俗来说,attention是指,对于某个时刻的输出y,它在输入x上各个部分的注意力。这个注意力实际上可以理解为权重

attention机制也可以分成很多种。Attention? Attention! 一问有一张比较全面的表格:

2d42d176e97c54dc99bf6749b120f422.png

Figure 2. a summary table of several popular attention mechanisms.

上面第一种additive attention你可能听过。以前我们的seq2seq模型里面,使用attention机制,这种**加性注意力(additive attention)**用的很多。Google的项目 tensorflow/nmt 里面使用的attention就是这种。

为什么这种attention叫做additive attention呢?很简单,对于输入序列隐状态d33f851d-c02d-eb11-8da9-e4434bdf6706.svg和输出序列的隐状态d53f851d-c02d-eb11-8da9-e4434bdf6706.svg,它的处理方式很简单,直接合并,变成d73f851d-c02d-eb11-8da9-e4434bdf6706.svg

但是我们的transformer模型使用的不是这种attention机制,使用的是另一种,叫做乘性注意力(multiplicative attention)

那么这种乘性注意力机制是怎么样的呢?从上表中的公式也可以看出来:两个隐状态进行点积

Self-attention是什么?

到这里就可以解释什么是self-attention了。

上面我们说attention机制的时候,都会说到两个隐状态,分别是d33f851d-c02d-eb11-8da9-e4434bdf6706.svgd53f851d-c02d-eb11-8da9-e4434bdf6706.svg,前者是输入序列第i个位置产生的隐状态,后者是输出序列在第t个位置产生的隐状态。

所谓self-attention实际上就是,输出序列就是输入序列!因此,计算自己的attention得分,就叫做self-attention

Context-attention是什么?

知道了self-attention,那你肯定猜到了context-attention是什么了:它是encoder和decoder之间的attention!所以,你也可以称之为encoder-decoder attention!

context-attention一词并不是本人原创,有些文章或者代码会这样描述,我觉得挺形象的,所以在此沿用这个称呼。其他文章可能会有其他名称,但是不要紧,我们抓住了重点即可,那就是两个不同序列之间的attention,与self-attention相区别。

不管是self-attention还是context-attention,它们计算attention分数的时候,可以选择很多方式,比如上面表中提到的:

  • additive attention

  • local-base

  • general

  • dot-product

  • scaled dot-product

那么我们的Transformer模型,采用的是哪种呢?答案是:scaled dot-product attention

Scaled dot-product attention是什么?

论文Attention is all you need里面对于attention机制的描述是这样的:

An attention function can be described as a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility of the query with the corresponding key.

这句话描述得很清楚了。翻译过来就是:通过确定Q和K之间的相似程度来选择V!

用公式来描述更加清晰:

b59477418143557fbf8c4ecd5295a44b.png

上面公式中的dc3f851d-c02d-eb11-8da9-e4434bdf6706.svg表示的是K的维度,在论文里面,默认是64

那么为什么需要加上这个缩放因子呢?论文里给出了解释:对于dc3f851d-c02d-eb11-8da9-e4434bdf6706.svg很大的时候,点积得到的结果维度很大,使得结果处于softmax函数梯度很小的区域。

我们知道,梯度很小的情况,这对反向传播不利。为了克服这个负面影响,除以一个缩放因子,可以一定程度上减缓这种情况。

为什么是e43f851d-c02d-eb11-8da9-e4434bdf6706.svg呢?论文没有进一步说明。个人觉得你可以使用其他缩放因子,看看模型效果有没有提升。

论文也提供了一张很清晰的结构图,供大家参考:

b6a8d06c617095684d49da3bdb111521.png

Figure 3. Scaled dot-product attention architecture.

首先说明一下我们的K、Q、V是什么:

  • 在encoder的self-attention中,Q、K、V都来自同一个地方(相等),他们是上一层encoder的输出。对于第一层encoder,它们就是word embedding和positional encoding相加得到的输入。

  • 在decoder的self-attention中,Q、K、V都来自于同一个地方(相等),它们是上一层decoder的输出。对于第一层decoder,它们就是word embedding和positional encoding相加得到的输入。但是对于decoder,我们不希望它能获得下一个time step(即将来的信息),因此我们需要进行sequence masking

  • 在encoder-decoder attention中,Q来自于decoder的上一层的输出,K和V来自于encoder的输出,K和V是一样的。

  • Q、K、V三者的维度一样,即 ea3f851d-c02d-eb11-8da9-e4434bdf6706.svg

上面scaled dot-product attention和decoder的self-attention都出现了masking这样一个东西。那么这个mask到底是什么呢?这两处的mask操作是一样的吗?这个问题在后面会有详细解释。

Scaled dot-product attention的实现

咱们先把scaled dot-product attention实现了吧。代码如下:

import torchimport torch.nn as nnclass ScaledDotProductAttention(nn.Module):    """Scaled dot-product attention mechanism."""    def __init__(self, attention_dropout=0.0):        super(ScaledDotProductAttention, self).__init__()        self.dropout = nn.Dropout(attention_dropout)        self.softmax = nn.Softmax(dim=2)    def forward(self, q, k, v, scale=None, attn_mask=None):        """前向传播.        Args:          q: Queries张量,形状为[B, L_q, D_q]          k: Keys张量,形状为[B, L_k, D_k]          v: Values张量,形状为[B, L_v, D_v],一般来说就是k          scale: 缩放因子,一个浮点标量          attn_mask: Masking张量,形状为[B, L_q, L_k]        Returns:          上下文张量和attetention张量        """        attention = torch.bmm(q, k.transpose(1, 2))        if scale:          attention = attention * scale        if attn_mask:          # 给需要mask的地方设置一个负无穷          attention = attention.masked_fill_(attn_mask, -np.inf)    # 计算softmax        attention = self.softmax(attention)    # 添加dropout        attention = self.dropout(attention)    # 和V做点积        context = torch.bmm(attention, v)        return context, attention

Multi-head attention又是什么呢?

理解了Scaled dot-product attention,Multi-head attention也很简单了。论文提到,他们发现将Q、K、V通过一个线性映射之后,分成 ec3f851d-c02d-eb11-8da9-e4434bdf6706.svg 份,对每一份进行scaled dot-product attention效果更好。然后,把各个部分的结果合并起来,再次经过线性映射,得到最终的输出。这就是所谓的multi-head attention。上面的超参数 ec3f851d-c02d-eb11-8da9-e4434bdf6706.svg 就是heads数量。论文默认是8

下面是multi-head attention的结构图:

c8f996521c85bd30a54c86869d257945.png

Figure 4: Multi-head attention architecture.

值得注意的是,上面所说的分成 ec3f851d-c02d-eb11-8da9-e4434bdf6706.svg是在 f53f851d-c02d-eb11-8da9-e4434bdf6706.svg 维度上面进行切分的。因此,进入到scaled dot-product attention的 dc3f851d-c02d-eb11-8da9-e4434bdf6706.svg 实际上等于未进入之前的 fe3f851d-c02d-eb11-8da9-e4434bdf6706.svg

Multi-head attention允许模型加入不同位置的表示子空间的信息。

Multi-head attention的公式如下:

fc31d7ba122142831bad7bb0977ec8cd.png

Multi-head attention的实现

相信大家已经理清楚了multi-head attention,那么我们来实现它吧。代码如下:

import torchimport torch.nn as nnclass MultiHeadAttention(nn.Module):    def __init__(self, model_dim=512, num_heads=8, dropout=0.0):        super(MultiHeadAttention, self).__init__()        self.dim_per_head = model_dim // num_heads        self.num_heads = num_heads        self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)        self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)        self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)        self.dot_product_attention = ScaledDotProductAttention(dropout)        self.linear_final = nn.Linear(model_dim, model_dim)        self.dropout = nn.Dropout(dropout)    # multi-head attention之后需要做layer norm        self.layer_norm = nn.LayerNorm(model_dim)    def forward(self, key, value, query, attn_mask=None):    # 残差连接        residual = query        dim_per_head = self.dim_per_head        num_heads = self.num_heads        batch_size = key.size(0)        # linear projection        key = self.linear_k(key)        value = self.linear_v(value)        query = self.linear_q(query)        # split by heads        key = key.view(batch_size * num_heads, -1, dim_per_head)        value = value.view(batch_size * num_heads, -1, dim_per_head)        query = query.view(batch_size * num_heads, -1, dim_per_head)        if attn_mask:            attn_mask = attn_mask.repeat(num_heads, 1, 1)        # scaled dot product attention        scale = (key.size(-1) // num_heads) ** -0.5        context, attention = self.dot_product_attention(          query, key, value, scale, attn_mask)        # concat heads        context = context.view(batch_size, -1, dim_per_head * num_heads)        # final linear projection        output = self.linear_final(context)        # dropout        output = self.dropout(output)        # add residual and norm layer        output = self.layer_norm(residual + output)        return output, attention

上面的代码终于出现了Residual connection和Layer normalization。我们现在来解释它们。

Residual connection是什么?

残差连接其实很简单!给你看一张示意图你就明白了:

f01d3e2fc02740c1aa9704e1b7d46531.png

Figure 5. Residual connection.

假设网络中某个层对输入x作用后的输出是0640851d-c02d-eb11-8da9-e4434bdf6706.svg,那么增加residual connection之后,就变成了:

0d3bb9a3b91851e451677ca2d87b26c1.png

这个+x操作就是一个shortcut

那么残差结构有什么好处呢?显而易见:因为增加了一项0c40851d-c02d-eb11-8da9-e4434bdf6706.svg,那么该层网络对x求偏导的时候,多了一个常数项0e40851d-c02d-eb11-8da9-e4434bdf6706.svg!所以在反向传播过程中,梯度连乘,也不会造成梯度消失

所以,代码实现residual connection很非常简单:

def residual(sublayer_fn,x):  return sublayer_fn(x)+x

文章开始的transformer架构图中的Add & Norm中的Add也就是指的这个shortcut。

至此,residual connection的问题理清楚了。更多关于残差网络的介绍可以看文末的参考文献。

Layer normalization是什么?

GRADIENTS, BATCH NORMALIZATION AND LAYER NORMALIZATION一文对normalization有很好的解释:

Normalization有很多种,但是它们都有一个共同的目的,那就是把输入转化成均值为0方差为1的数据。我们在把数据送入激活函数之前进行normalization(归一化),因为我们不希望输入数据落在激活函数的饱和区。

说到normalization,那就肯定得提到Batch Normalization。BN在CNN等地方用得很多。

BN的主要思想就是:在每一层的每一批数据上进行归一化。

我们可能会对输入数据进行归一化,但是经过该网络层的作用后,我们的的数据已经不再是归一化的了。随着这种情况的发展,数据的偏差越来越大,我的反向传播需要考虑到这些大的偏差,这就迫使我们只能使用较小的学习率来防止梯度消失或者梯度爆炸。

BN的具体做法就是对每一小批数据,在批这个方向上做归一化。如下图所示:

5924a00ac68c4f37e28d18ecf35383ef.png

Figure 6. Batch normalization example.(From theneuralperspective.com)

可以看到,右半边求均值是沿着数据批量N的方向进行的!

Batch normalization的计算公式如下:

764017ff469af33e33741473d64842b5.png

具体的实现可以查看上图的链接文章。

说完Batch normalization,就该说说咱们今天的主角Layer normalization

那么什么是Layer normalization呢?:它也是归一化数据的一种方式,不过LN是在每一个样本上计算均值和方差,而不是BN那种在批方向计算均值和方差

下面是LN的示意图:

0874ccc299555821afb153441b90ac7a.png

Figure 7. Layer normalization example.

和上面的BN示意图一比较就可以看出二者的区别啦!

下面看一下LN的公式,也BN十分相似:

2d2b706da66f8d1e91f33850ff2f2088.png

Layer normalization的实现

上述两个参数2140851d-c02d-eb11-8da9-e4434bdf6706.svg2440851d-c02d-eb11-8da9-e4434bdf6706.svg都是可学习参数。下面我们自己来实现Layer normalization(PyTorch已经实现啦!)。代码如下:

import torchimport torch.nn as nnclass LayerNorm(nn.Module):    """实现LayerNorm。其实PyTorch已经实现啦,见nn.LayerNorm。"""    def __init__(self, features, epsilon=1e-6):        """Init.        Args:            features: 就是模型的维度。论文默认512            epsilon: 一个很小的数,防止数值计算的除0错误        """        super(LayerNorm, self).__init__()        # alpha        self.gamma = nn.Parameter(torch.ones(features))        # beta        self.beta = nn.Parameter(torch.zeros(features))        self.epsilon = epsilon    def forward(self, x):        """前向传播.        Args:            x: 输入序列张量,形状为[B, L, D]        """        # 根据公式进行归一化        # 在X的最后一个维度求均值,最后一个维度就是模型的维度        mean = x.mean(-1, keepdim=True)        # 在X的最后一个维度求方差,最后一个维度就是模型的维度        std = x.std(-1, keepdim=True)        return self.gamma * (x - mean) / (std + self.epsilon) + self.beta

顺便提一句,Layer normalization多用于RNN这种结构。

Mask是什么?

现在终于轮到讲解mask了!mask顾名思义就是掩码,在我们这里的意思大概就是对某些值进行掩盖,使其不产生效果

需要说明的是,我们的Transformer模型里面涉及两种mask。分别是padding masksequence mask。其中后者我们已经在decoder的self-attention里面见过啦!

其中,padding mask在所有的scaled dot-product attention里面都需要用到,而sequence mask只有在decoder的self-attention里面用到。

所以,我们之前ScaledDotProductAttentionforward方法里面的参数attn_mask在不同的地方会有不同的含义。这一点我们会在后面说明。

Padding mask

什么是padding mask呢?回想一下,我们的每个批次输入序列长度是不一样的!也就是说,我们要对输入序列进行对齐!具体来说,就是给在较短的序列后面填充0。因为这些填充的位置,其实是没什么意义的,所以我们的attention机制不应该把注意力放在这些位置上,所以我们需要进行一些处理。

具体的做法是,把这些位置的值加上一个非常大的负数(可以是负无穷),这样的话,经过softmax,这些位置的概率就会接近0

而我们的padding mask实际上是一个张量,每个值都是一个Boolen,值为False的地方就是我们要进行处理的地方。

下面是实现:

def padding_mask(seq_k, seq_q):  # seq_k和seq_q的形状都是[B,L]    len_q = seq_q.size(1)    # `PAD` is 0    pad_mask = seq_k.eq(0)    pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1)  # shape [B, L_q, L_k]    return pad_mask

Sequence mask

文章前面也提到,sequence mask是为了使得decoder不能看见未来的信息。也就是对于一个序列,在time_step为t的时刻,我们的解码输出应该只能依赖于t时刻之前的输出,而不能依赖t之后的输出。因此我们需要想一个办法,把t之后的信息给隐藏起来。

那么具体怎么做呢?也很简单:产生一个上三角矩阵,上三角的值全为1,下三角的值权威0,对角线也是0。把这个矩阵作用在每一个序列上,就可以达到我们的目的啦。

具体的代码实现如下:

def sequence_mask(seq):    batch_size, seq_len = seq.size()    mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),                    diagonal=1)    mask = mask.unsqueeze(0).expand(batch_size, -1, -1)  # [B, L, L]    return mask

哈佛大学的文章The Annotated Transformer有一张效果图:

0ef80511fcda4031ec447df363280ad5.png

Figure 8. Sequence mask.

值得注意的是,本来mask只需要二维的矩阵即可,但是考虑到我们的输入序列都是批量的,所以我们要把原本二维的矩阵扩张成3维的张量。上面的代码可以看出,我们已经进行了处理。

回到本小结开始的问题,attn_mask参数有几种情况?分别是什么意思?

  • 对于decoder的self-attention,里面使用到的scaled dot-product attention,同时需要padding masksequence mask作为attn_mask,具体实现就是两个mask相加作为attn_mask。

  • 其他情况,attn_mask一律等于padding mask

至此,mask相关的问题解决了。

Positional encoding是什么?

好了,终于要解释位置编码了,那就是文字开始的结构图提到的Positional encoding

就目前而言,我们的Transformer架构似乎少了点什么东西。没错,就是它对序列的顺序没有约束!我们知道序列的顺序是一个很重要的信息,如果缺失了这个信息,可能我们的结果就是:所有词语都对了,但是无法组成有意义的语句!

为了解决这个问题。论文提出了Positional encoding。这是啥?一句话概括就是:对序列中的词语出现的位置进行编码!如果对位置进行编码,那么我们的模型就可以捕捉顺序信息!

那么具体怎么做呢?论文的实现很有意思,使用正余弦函数。公式如下:

be0d82b5c784b3b17ea42ca005fb8c95.png

其中,pos是指词语在序列中的位置。可以看出,在偶数位置,使用正弦编码,在奇数位置,使用余弦编码

上面公式中的2d40851d-c02d-eb11-8da9-e4434bdf6706.svg是模型的维度,论文默认是512

这个编码公式的意思就是:给定词语的位置2f40851d-c02d-eb11-8da9-e4434bdf6706.svg,我们可以把它编码成2d40851d-c02d-eb11-8da9-e4434bdf6706.svg维的向量!也就是说,位置编码的每一个维度对应正弦曲线,波长构成了从3540851d-c02d-eb11-8da9-e4434bdf6706.svg3740851d-c02d-eb11-8da9-e4434bdf6706.svg的等比序列。

上面的位置编码是绝对位置编码。但是词语的相对位置也非常重要。这就是论文为什么要使用三角函数的原因!

正弦函数能够表达相对位置信息。,主要数学依据是以下两个公式:

9885fac246b4894c2749b70bdc026c38.png

上面的公式说明,对于词汇之间的位置偏移k3d40851d-c02d-eb11-8da9-e4434bdf6706.svg可以表示成4040851d-c02d-eb11-8da9-e4434bdf6706.svg4440851d-c02d-eb11-8da9-e4434bdf6706.svg的组合形式,这就是表达相对位置的能力!

以上就是4740851d-c02d-eb11-8da9-e4434bdf6706.svgE的所有秘密。说完了positional encoding,那么我们还有一个与之处于同一地位的word embedding

Word embedding大家都很熟悉了,它是对序列中的词汇的编码,把每一个词汇编码成2d40851d-c02d-eb11-8da9-e4434bdf6706.svg维的向量!看到没有,Postional encoding是对词汇的位置编码,word embedding是对词汇本身编码

所以,我更喜欢positional encoding的另外一个名字Positional embedding

Positional encoding的实现

PE的实现也不难,按照论文的公式即可。代码如下:

import torchimport torch.nn as nnclass PositionalEncoding(nn.Module):        def __init__(self, d_model, max_seq_len):        """初始化。                Args:            d_model: 一个标量。模型的维度,论文默认是512            max_seq_len: 一个标量。文本序列的最大长度        """        super(PositionalEncoding, self).__init__()                # 根据论文给的公式,构造出PE矩阵        position_encoding = np.array([          [pos / np.pow(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)]          for pos in range(max_seq_len)])        # 偶数列使用sin,奇数列使用cos        position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2])        position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2])        # 在PE矩阵的第一行,加上一行全是0的向量,代表这`PAD`的positional encoding        # 在word embedding中也经常会加上`UNK`,代表位置单词的word embedding,两者十分类似        # 那么为什么需要这个额外的PAD的编码呢?很简单,因为文本序列的长度不一,我们需要对齐,        # 短的序列我们使用0在结尾补全,我们也需要这些补全位置的编码,也就是`PAD`对应的位置编码        pad_row = torch.zeros([1, d_model])        position_encoding = torch.cat((pad_row, position_encoding))                # 嵌入操作,+1是因为增加了`PAD`这个补全位置的编码,        # Word embedding中如果词典增加`UNK`,我们也需要+1。看吧,两者十分相似        self.position_encoding = nn.Embedding(max_seq_len + 1, d_model)        self.position_encoding.weight = nn.Parameter(position_encoding,                                                     requires_grad=False)    def forward(self, input_len):        """神经网络的前向传播。        Args:          input_len: 一个张量,形状为[BATCH_SIZE, 1]。每一个张量的值代表这一批文本序列中对应的长度。        Returns:          返回这一批序列的位置编码,进行了对齐。        """                # 找出这一批序列的最大长度        max_len = torch.max(input_len)        tensor = torch.cuda.LongTensor if input_len.is_cuda else torch.LongTensor        # 对每一个序列的位置进行对齐,在原序列位置的后面补上0        # 这里range从1开始也是因为要避开PAD(0)的位置        input_pos = tensor(          [list(range(1, len + 1)) + [0] * (max_len - len) for len in input_len])        return self.position_encoding(input_pos)

Word embedding的实现

Word embedding应该是老生常谈了,它实际上就是一个二维浮点矩阵,里面的权重是可训练参数,我们只需要把这个矩阵构建出来就完成了word embedding的工作。

所以,具体的实现很简单:

import torch.nn as nnembedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)# 获得输入的词嵌入编码seq_embedding = seq_embedding(inputs)*np.sqrt(d_model)

上面vocab_size就是词典的大小,embedding_size就是词嵌入的维度大小,论文里面就是等于4b40851d-c02d-eb11-8da9-e4434bdf6706.svg。所以word embedding矩阵就是一个vocab_size*embedding_size的二维张量。

如果你想获取更详细的关于word embedding的信息,可以看我的另外一个文章word2vec的笔记和实现。

Position-wise Feed-Forward network是什么?

这就是一个全连接网络,包含两个线性变换和一个非线性函数(实际上就是ReLU)。公式如下:

db133b323b7b0b8b4e2eb4801e7a91aa.png

这个线性变换在不同的位置都表现地一样,并且在不同的层之间使用不同的参数。

论文提到,这个公式还可以用两个核大小为1的一维卷积来解释,卷积的输入输出都是4b40851d-c02d-eb11-8da9-e4434bdf6706.svg,中间层的维度是5640851d-c02d-eb11-8da9-e4434bdf6706.svg

实现如下:

import torchimport torch.nn as nnclass PositionalWiseFeedForward(nn.Module):    def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):        super(PositionalWiseFeedForward, self).__init__()        self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)        self.w2 = nn.Conv1d(model_dim, ffn_dim, 1)        self.dropout = nn.Dropout(dropout)        self.layer_norm = nn.LayerNorm(model_dim)    def forward(self, x):        output = x.transpose(1, 2)        output = self.w2(F.relu(self.w1(output)))        output = self.dropout(output.transpose(1, 2))        # add residual and norm layer        output = self.layer_norm(x + output)        return output

Transformer的实现

至此,所有的细节都已经解释完了。现在来完成我们Transformer模型的代码。

首先,我们需要实现6层的encoder和decoder。

encoder代码实现如下:

import torchimport torch.nn as nnclass EncoderLayer(nn.Module):  """Encoder的一层。"""    def __init__(self, model_dim=512, num_heads=8, ffn_dim=2018, dropout=0.0):        super(EncoderLayer, self).__init__()        self.attention = MultiHeadAttention(model_dim, num_heads, dropout)        self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)    def forward(self, inputs, attn_mask=None):        # self attention        context, attention = self.attention(inputs, inputs, inputs, padding_mask)        # feed forward network        output = self.feed_forward(context)        return output, attentionclass Encoder(nn.Module):  """多层EncoderLayer组成Encoder。"""    def __init__(self,               vocab_size,               max_seq_len,               num_layers=6,               model_dim=512,               num_heads=8,               ffn_dim=2048,               dropout=0.0):        super(Encoder, self).__init__()        self.encoder_layers = nn.ModuleList(          [EncoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in           range(num_layers)])        self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)        self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)    def forward(self, inputs, inputs_len):        output = self.seq_embedding(inputs)        output += self.pos_embedding(inputs_len)        self_attention_mask = padding_mask(inputs, inputs)        attentions = []        for encoder in self.encoder_layers:            output, attention = encoder(output, self_attention_mask)            attentions.append(attention)        return output, attentions

通过文章前面的分析,代码不需要更多解释了。同样的,我们的decoder代码如下:

import torchimport torch.nn as nnclass DecoderLayer(nn.Module):    def __init__(self, model_dim, num_heads=8, ffn_dim=2048, dropout=0.0):        super(DecoderLayer, self).__init__()        self.attention = MultiHeadAttention(model_dim, num_heads, dropout)        self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)    def forward(self,              dec_inputs,              enc_outputs,              self_attn_mask=None,              context_attn_mask=None):        # self attention, all inputs are decoder inputs        dec_output, self_attention = self.attention(          dec_inputs, dec_inputs, dec_inputs, self_attn_mask)        # context attention        # query is decoder's outputs, key and value are encoder's inputs        dec_output, context_attention = self.attention(          enc_outputs, enc_outputs, dec_output, context_attn_mask)        # decoder's output, or context        dec_output = self.feed_forward(dec_output)        return dec_output, self_attention, context_attentionclass Decoder(nn.Module):    def __init__(self,               vocab_size,               max_seq_len,               num_layers=6,               model_dim=512,               num_heads=8,               ffn_dim=2048,               dropout=0.0):        super(Decoder, self).__init__()        self.num_layers = num_layers        self.decoder_layers = nn.ModuleList(          [DecoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in           range(num_layers)])        self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)        self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)    def forward(self, inputs, inputs_len, enc_output, context_attn_mask=None):        output = self.seq_embedding(inputs)        output += self.pos_embedding(inputs_len)        self_attention_padding_mask = padding_mask(inputs, inputs)        seq_mask = sequence_mask(inputs)        self_attn_mask = torch.gt((self_attention_padding_mask + seq_mask), 0)        self_attentions = []        context_attentions = []        for decoder in self.decoder_layers:            output, self_attn, context_attn = decoder(            output, enc_output, self_attn_mask, context_attn_mask)            self_attentions.append(self_attn)            context_attentions.append(context_attn)        return output, self_attentions, context_attentions

最后,我们把encoder和decoder组成Transformer模型!

代码如下:

import torchimport torch.nn as nnclass Transformer(nn.Module):    def __init__(self,               src_vocab_size,               src_max_len,               tgt_vocab_size,               tgt_max_len,               num_layers=6,               model_dim=512,               num_heads=8,               ffn_dim=2048,               dropout=0.2):        super(Transformer, self).__init__()        self.encoder = Encoder(src_vocab_size, src_max_len, num_layers, model_dim,                               num_heads, ffn_dim, dropout)        self.decoder = Decoder(tgt_vocab_size, tgt_max_len, num_layers, model_dim,                               num_heads, ffn_dim, dropout)        self.linear = nn.Linear(model_dim, tgt_vocab_size, bias=False)        self.softmax = nn.Softmax(dim=2)    def forward(self, src_seq, src_len, tgt_seq, tgt_len):        context_attn_mask = padding_mask(tgt_seq, src_seq)        output, enc_self_attn = self.encoder(src_seq, src_len)        output, dec_self_attn, ctx_attn = self.decoder(          tgt_seq, tgt_len, output, context_attn_mask)        output = self.linear(output)        output = self.softmax(output)        return output, enc_self_attn, dec_self_attn, ctx_attn

至此,Transformer模型已经实现了!

参考文章

1.为什么ResNet和DenseNet可以这么深?一文详解残差块为何有助于解决梯度弥散问题
2.GRADIENTS, BATCH NORMALIZATION AND LAYER NORMALIZATION
3.The Annotated Transformer
4.Building the Mighty Transformer for Sequence Tagging in PyTorch : Part I
5.Building the Mighty Transformer for Sequence Tagging in PyTorch : Part II
6.Attention?Attention!

参考代码

1.jadore801120/attention-is-all-you-need-pytorch2.JayParks/transformer


作者:luozhouyang
链接: https://juejin.im/post/5b9f1af0e51d450e425eb32d
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

48d9d930cb508c9f7079c6c88d3aafbf.png

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

encoder decoder模型_Transformer 模型的 PyTorch 实现 的相关文章

随机推荐

  • 网页 序列号 逆向linux,逆向序列号生成算法(一)

    对逆向工程一直很感兴趣 工作之余自己也研究一下 好久没有练手了 OllyDBG的使用都感觉生疏了 晚上抽空先去补了补OllyDBG的使用方法 然后看到一个叫做CycleCrackMe 的序列号保护练手程序 如图1 刚好是OllyDBG入门文
  • GGally与pairs相关关系图_史上最全(二)

    作者 李誉辉 四川大学在读研究生 接上一篇 GGally与pairs相关关系图 史上最全 一 2 4 wrap 封装 其它需要指定到geom xxx 中的参数 可以通过wrap 传递给lower upper 或diag 语法 1wrap f
  • 【Pytorch with fastai】第 2 章:从模型到生产

    大家好 我是Sonhhxg 柒 希望你看完之后 能对你有所帮助 不足请指正 共同学习交流 个人主页 Sonhhxg 柒的博客 CSDN博客 欢迎各位 点赞 收藏 留言 系列专栏 机器学习 ML 自然语言处理 NLP 深度学习 DL fore
  • MySQL常见的数据类型

    MySQL的常见数据类型 数据类型是什么 数据类型是列 存储过程的参数 表达式和局部变量的数据特征 它决定了数据的存储格式 代表了不同的信息类型 有一些数据是要存储为数字的 数字当中有些是要存储为整数 小数 日期型等 MySQL常见的数据类
  • 【第十四届蓝桥杯三月真题刷题训练——第 24 天 (3.27)& 旋转 & 附近最小 & 扫地机器人 & 窗口】

    第一题 旋转 import java util Scanner public class Main static int N 300 static int a new int N N static int b new int N N pub
  • 数模培训第二周——图论模型

    图论中最短路算法与程序实现 图论中的最短路问题 包括无向图和有向图 是一个基本且常见的问题 主要的算法有Dijkstra算法和Floyd算法 Floyd算法 简介 Floyd Warshall算法 英语 Floyd Warshall alg
  • 机器学习入门-数值特征-时间特征处理

    我们可以将一连串的时间特征进行拆分 比如 2015 03 08 10 30 00 360000 00 00 我们可以将其转换为日期类型 然后从里面提取年 月 日等时间信息 对于一些hour month等信息 我们也可以使用pd cut将ho
  • 单个IMU实现精确的轨迹重构

    惯性传感器 IMU 被广泛用于导航 运动状态研究 人体运动和步态分析等领域 然而 由于IMU的固有误差和测量误差 尤其是漂移误差 很少有人尝试基于IMU实现精确的轨迹重建 尤其是使用单个IMU 尽管如此 与视觉 红外线和超声波定位技术相比
  • C#——反射和特性

    元数据 程序是用来处理数据的 文本和特性都是数据 而我们程序本身这些也是数据 有关程序及其类型的数据被称为元数据 他们保存在程序的程序集中 反射 程序在运行时 可以查看其它程序集或其本身的元数据 一个运行的程序查看本身的元数据或者其他程序集
  • Java中基本类型自动转换与强制转换

    类型转换 Java 语言是一种强类型的语言 强类型的语言有以下几个要求 变量或常量必须有类型 要求声明变量或常量时必须声明类型 而且只能在声明以后才能使用 赋值时类型必须一致 值的类型必须和变量或常量的类型完全一致 运算时类型必须一致 参与
  • Python学习(3):批量修改文件名(以excel文件为例)

    coding utf 8 import os dir input 请输入文件路径 for root dirs files in os walk dir for i in range len files filename files i ne
  • python django 学习第3天 文件长传

    在根目录下新建media目录 在settings py 加入代码 为上传文件操作做准备 MEDIA ROOT os path join BASE DIR media MEDIA URL media 做一个新闻调查页面 在views 中加入
  • bash 括号(小括号,双小括号,中括号,双中括号,大括号)

    小括号 和大括号 主要包括一下几种 var cmd 和 exp var string var string var string var string var pattern var pattern var pattern var patt
  • 计算机网络运输层运输层协议概述

    运输层协议概述 进程之间的通信 下图说明运输层的作用 可以看出网络层为主机之间提供逻辑通信 而运输层为应用进程之间提供端到端的逻辑通信 根据应用程序的不同需求 运输层有两种不同的运输协议 1 面向连接的TCP 2 无连接的UDP 运输层的两
  • Vue-cli3更改项目logo图标

    1 图标切成对应大小 2 图标名称后缀与vue原有图标logo名称 后缀一致 favicon ico 并替换 3 vue项目根目录下 新建 vue config js 添加下列代码 module exports pwa iconPaths
  • 网络爬虫 - 1 网络爬虫基本概念和相关工具

    网络爬虫基本概念和相关工具 1 基本概念 1 什么是网络爬虫 web crawler 以前经常称之为网络蜘蛛 spider 是按照一定的规则自动浏览万维网并获取信息的机器人程序 或脚本 曾经被广泛的应用于互联网搜索引擎 使用过互联网和浏览器
  • Linux环境下的VScode使用教程

    前言 1 对于学习本文需要先有自行安装好VMware 对VMware有简单的了解 2 对于绝大多数使用Linux的人而言 经常在Windows环境下使用source insight进行编译程序 然后利用FileZilla将Windows的文
  • Vue出现弹出层时,禁止底部页面跟随滑动

    背景 最近在写一个vue项目 当出现弹出层时 发现底部页面跟随滚动 但是产品不想要这种效果 于是找各种资料 发现很多说法 但是试了试 发现有的根本就不行 比如说有人提出用vue中提供的 touchmove prevent方法来解决 但是我试
  • 算法设计与分析——算法设计工具Standard Template Library即STL(C++模板库)概述

    算法设计工具 STL 前言 STL是一个功能强大的基于模板的容器库 通过直接使用这些现成的标准化组件可以大大提高算法设计的效率和可靠性 一 STL构成 Container 容器 Algorithm 算法 Iterator 迭代器 二 什么是
  • encoder decoder模型_Transformer 模型的 PyTorch 实现

    Google 2017年的论文 Attention is all you need 阐释了什么叫做大道至简 该论文提出了Transformer模型 完全基于Attention mechanism 抛弃了传统的RNN和CNN 我们根据论文的结