pytorch 中 多头注意力机制 MultiHeadAttention的代码实现及应用

2023-11-18

本文将对 Scaled Dot-Product Attention,Multi-head attentionSelf-attentionTransformer等概念做一个简要介绍和区分。最后对通用的 Multi-head attention 进行代码实现和应用

一、概念:

1. Scaled Dot-Product Attention

在实际应用中,经常会用到 Attention 机制,其中最常用的是Scaled Dot-Product Attention,它是通过计算query和key之间的点积 来作为 之间的相似度。

  • Scaled 指的是 Q和K计算得到的相似度 再经过了一定的量化,具体就是 除以 根号下K_dim;
  • Dot-Product 指的是 Q和K之间 通过计算点积作为相似度;
  • Mask 可选择性 目的是将 padding的部分 填充负无穷,这样算softmax的时候这里就attention为0,从而避免padding带来的影响.

2. Multi-head attention

是在 Scaled Dot-Product Attention 的基础上,分成多个头,也就是有多个Q、K、V并行进行计算attention,可能侧重与不同的方面的相似度和权重。

3. Self-attention

自注意力机制 是在Scaled Dot-Product Attention 以及Multi-head attention的基础上的一种应用场景,就是指 QKV的来源是相同的自己和自己计算attention,类似于经过一个线性层等,输入输出等长。

如果QKV的来源是不同的,不能叫做 self-attention,只能是attention。比如GST中的KV是随机初始化的多个token,而Q是reference encoder得到的梅尔谱的一帧。同理,Q也可以是随机初始化的一个,而KV是来自于输入,这样就可以将某一变长长度为N的输入计算attention得到一个长度为1的向量。

4. Transformer

Transformer 是指 在Scaled Dot-Product Attention 以及Multi-head attention以及Self-attention的基础上的一种通用的模型框架,它包括Positional Encoding,Encoder,Decoder等等。Transformer不等于Self-attention。

二、代码实现

 平时经常会用到Attention操作,接下来对Multi-head Attention 进行代码整理和实现,方便以后可以直接调用接口,其中单头注意力机制作为其中的一种特殊情况。

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    '''
    input:
        query --- [N, T_q, query_dim] 
        key --- [N, T_k, key_dim]
        mask --- [N, T_k]
    output:
        out --- [N, T_q, num_units]
        scores -- [h, N, T_q, T_k]
    '''

    def __init__(self, query_dim, key_dim, num_units, num_heads):

        super().__init__()
        self.num_units = num_units
        self.num_heads = num_heads
        self.key_dim = key_dim

        self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
        self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
        self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)

    def forward(self, query, key, mask=None):
        querys = self.W_query(query)  # [N, T_q, num_units]
        keys = self.W_key(key)  # [N, T_k, num_units]
        values = self.W_value(key)

        split_size = self.num_units // self.num_heads
        querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0)  # [h, N, T_q, num_units/h]
        keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0)  # [h, N, T_k, num_units/h]
        values = torch.stack(torch.split(values, split_size, dim=2), dim=0)  # [h, N, T_k, num_units/h]

        ## score = softmax(QK^T / (d_k ** 0.5))
        scores = torch.matmul(querys, keys.transpose(2, 3))  # [h, N, T_q, T_k]
        scores = scores / (self.key_dim ** 0.5)

        ## mask
        if mask is not None:
            ## mask:  [N, T_k] --> [h, N, T_q, T_k]
            mask = mask.unsqueeze(1).unsqueeze(0).repeat(self.num_heads,1,querys.shape[2],1)
            scores = scores.masked_fill(mask, -np.inf)
        scores = F.softmax(scores, dim=3)

        ## out = score * V
        out = torch.matmul(scores, values)  # [h, N, T_q, num_units/h]
        out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0)  # [N, T_q, num_units]

        return out,scores

 三、实际应用:

1. 接口调用:

## 类实例化
attention = MultiHeadAttention(3,4,5,1)

## 输入
qurry = torch.randn(8, 2, 3)
key = torch.randn(8, 6 ,4)
mask = torch.tensor([[False, False, False, False, True, True],
                     [False, False, False, True, True, True],
                     [False, False, False, False, True, True],
                     [False, False, False, True, True, True],
                     [False, False, False, False, True, True],
                     [False, False, False, True, True, True],
                     [False, False, False, False, True, True],
                     [False, False, False, True, True, True],])

## 输出
out, scores = attention(qurry, key, mask)
print('out:', out.shape)         ## torch.Size([8, 2, 5])
print('scores:', scores.shape)   ## torch.Size([1, 8, 2, 6])

2. mask的作用:

mask之前的 scores:

mask之后的 scores:

softmax之后的scores:

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch 中 多头注意力机制 MultiHeadAttention的代码实现及应用 的相关文章

随机推荐

  • 十八.欧几里得算法

    欧几里得算法 unsigned int Gcd unsigned int M unsigned int N unsigned int Rem while N gt 0 Rem M N M N N Rem return M 此算法用来计算最大
  • Vue实现动画的几种方式

    vue内置组件transition 元素出现和消失都呈现动画
  • For循环结构的使用

    一 四个要素 初始化条件 循环条件 gt 是boolean类型 循环体 迭代条件 二 for循环的结构 for 执行过程 1 2 3 4 2 3 4 2 遍历100以内的偶数 输出所有偶数和 int sum 0 记录所有偶数的和 int c
  • FTP服务器版本信息可被获取(CVE-1999-0614)(建议修改源代码或者配置文件改变缺省banner信息。)

    漏洞扫描报告 1 测试查看 默认端口21 telnet localhost 21 下图所示即为漏洞信息描述的 可获取版本号 2 修改 vsftpd conf 配置文件 etc vsftpd vsftpd conf 找到 ftpd banne
  • linux 怎样停定时任务,linux停用cron定时执行任务的方法

    linux下用cron定时执行任务的方法 名称 crontab 使用权限 所有使用者 使用方式 crontab file u user 用指定的文件替代目前的crontab crontab u user 用标准输入替代目前的crontab
  • C++泛型编程

    C 泛型编程 1 泛型编程 1 1 模板 1 2 函数模板 1 2 1 语法 1 2 2 使用函数模板方式 1 2 3 普通函数和函数模板的区别 1 2 4 普通函数与函数模板的调用规则 1 2 5 模板的局限性 1 3 类模板 1 3 1
  • findBug 错误修改指南

    FindBugs错误修改指南 1 EC UNRELATED TYPES Bug Call to equals comparing different types Pattern id EC UNRELATED TYPES type EC c
  • spark_hadoop集群搭建自动化脚本

    bin bash 脚本使用说明 1 使用脚本前需要弄好服务器的基础环境 2 在hadoop的每个节点需要手动创建如下目录 data hdfs tmp 3 修改下面的配置参数 4 脚本执行完备后需要收到格式化namenode
  • APP移动端自动化测试(八)总览

    官网地址 https github com appium appium blob master docs en writing running appium server args md 项目中你是怎么结合自动化的 apk for andr
  • JavaScript 获取时间日期方法

    Date对象包含日期和时间的相关信息 Date对象没有任何属性 它只具有很多用于设置和获取日期时间的方法 方法 说明 getDate 返回Date对象中月份的天数 gateDay 返回Date对象中的星期几 getHours 返回Date对
  • 机器学习算法案例:泰坦尼克号乘客生存预测

    学习目标 通过案例进一步掌握决策树算法api的具体使用 1 案例背景 泰坦尼克号沉没是历史上最臭名昭着的沉船之一 1912年4月15日 在她的处女航中 泰坦尼克号在与冰山相撞后沉没 在2224名乘客和机组人员中造成1502人死亡 这场耸人听
  • ubuntu编译安装mmcv 1.6.2和mmsegmentation 0.28.0

    环境 ubuntu16 04 cuda10 1 python 3 8 pytorch 1 6 0 cuda10 1 对应的torch版本 lt 1 8 但是1 8和1 7都试了 mmcv没有编译成功 只有1 6成功了 1 编译MMCV 1
  • 行为型模式-策略模式

    package per mjn pattern strategy 抽象策略类 public interface Strategy void show package per mjn pattern strategy 具体策略类 用来封装算法
  • warning: function declared implicitly错误原因

    http blog sina com cn s blog 629f56a70100irbn html line 10 warning function declared implicitly 这是由于没有声明函数原型造成的 在a c中 vo
  • Vue插件

    目录 vue项目目录结构 es6导入导出语法 Vue项目开发规范单页面组件写法 vue项目集成axios vue项目前后端打通 前后端交互之登录功能 props配置项 父组件通过自定义属性与子组件通信 混入 插件 scoped样式 loca
  • c语言socket如何传输图片,socket文件传输功能的实现

    这节我们来完成 socket 文件传输程序 这是一个非常实用的例子 要实现的功能为 client 从 server 下载一个文件并保存到本地 编写这个程序需要注意两个问题 1 文件大小不确定 有可能比缓冲区大很多 调用一次 write se
  • 如何理解面向过程和面向对象?

    一句话理解面向对象 有人说 如果上帝是程序员 他怎么创造世界上的所有动物 理解这个问题就理解了面向对像 面向过程和面向对象区别 面向过程的思路 什么事都自己做 分析解决问题所需的步骤 用函数把这些步骤依次实现 面向对象的思路 什么事都指挥对
  • 可连接点对象及示例(二)

    转载请标明是引用于 http blog csdn net chenyujing1234 例子代码 包括客户端与服务端 http www rayfile com zh cn files de82908f 7309 11e1 9db1 0015
  • 域名怎么解析到服务器上

    今天无事说一说如何把自己的域名解析绑定到自己的服务器上 让访客们可以通过你的域名来访问你的网站 域名解析定义 域名解析是把域名指向网站空间IP 让人们通过注册的域名可以方便地访问到网站的一种服务 IP地址是网络上标识站点的数字地址 为了方便
  • pytorch 中 多头注意力机制 MultiHeadAttention的代码实现及应用

    本文将对 Scaled Dot Product Attention Multi head attention Self attention Transformer等概念做一个简要介绍和区分 最后对通用的 Multi head attenti