源码剖析transformer、self-attention

2023-11-09

原文链接

首先给大家引入一个github博客,这份代码是我在看了4份transformer的源码后选出来的,这位作者的写法非常易懂,代码质量比较高。GitHub - Separius/BERT-keras: Keras implementation of BERT with pre-trained weights

这篇文章主要跟大家分享四个点:多头机制(multi-head)、LN和GELU、位置编码。

在这再给大家安利几篇博客,便于大家更具体的理解自注意力的内在原理。

【NLP】Transformer模型原理详解 - 知乎

Attention机制详解(二)——Self-Attention与Transformer - 知乎(精华)

自然语言处理中的自注意力机制(Self-attention Mechanism) - robert_ai - 博客园

transformer是self-attention的落地或者说扩展,多头机制把自注意力机制发挥得淋漓尽致。transformer最亮眼的地方就是完全抛弃了常规的链式RNN结构(包括LSTM等其他变体),即:并行计算能力特别弱的计算方法。它应该会是早期NLP训练技术跟当期技术的一个里程碑,毕竟人家BERT是吧,刷新了不造几个记录,虽然XLNET又刷新了BERT的记录,但是这也正证实了这种设计理念的优秀!优秀啊。。。[斜眼笑]。。。

言归正传!

一、自注意力机制(self-attention)和多头机制(mutil-head)

常规的语言生成模型长这样

下一个字的生成,依靠且只依靠上一个字的输出状态和当前输入的输入状态,也就是说,预测值在某一程度上说只跟上一个字关系大一些,而自注意力模型,差不多长下面这个样子。

这个图的意思是,每一个字的生成,会跟所有的字(encode的时候)都有关系,这就是所谓的“注意力”机制。整个文字的生成过程中,每一个字都可能会跟所有的字做加权,为什么是“可能”呢,因为有mask嘛,随机给屏蔽掉一些词,屏蔽掉的那就没办法顾及了。这样的好处有两个:

一是能“照顾”所有的词,也就是我们理解的“语境”,

比如,句子1:“优秀!这就很优秀了!我做梦都没想到他这么175的个子能在中场投篮投进了!”;

和句子2:“优秀!这就很优秀了!他这185的个子在篮下那么好的位置还是没进球!”。

同样位置的一个词“优秀!”,一模一样的字,它的意思的完全相反的(中华文化博大精深),自注意力机制就需要在即使后面说了一堆废话的基础上,还是能学出这个词是褒义还是贬义。换句话说,它在判断优秀是褒义还是贬义时,甚至需要看到最后几个关键语气的词,才能做出判断,而这个功能正是RNN系列模型做不到的!数学意义上可以叫做“贡献度”。

二是,这样所以词就能并行计算了,至少这一步是可以并行计算了!

OK,自注意力就是大致这样个流程,多头又是什么鬼!很简单,经过嵌入层后,每个词有多个维度(代码嵌入为768列),把这些维度均分成n_head(12)份,每一份都去做这么一件事,就是多头机制。简而言之,就是自注意力的模式,复制了几次,这个“几次”就是“多头”,12次就是12头。。。只不过,不是做复制,二是做拆分,均分成12次来进行注意力的计算。

原理懂了哈,咱们看下人家是怎么实现的。

(1)funcs的multihead_attention函数

_q, _k, _v = x[:, :, :n_state], x[:, :, n_state:2 * n_state], x[:, :, -n_state:]

# [B, n_head, max_len, 768 // n_head]
q = split_heads(_q, n_head)  # B, H, L, C//H
# [B, n_head, 768 // n_head, max_len]
k = split_heads(_k, n_head, k=True)  # B, H, C//H, L
# [B, n_head, max_len, 768 // n_head]
v = split_heads(_v, n_head)

x是embedding后的输入,经过3*768个1x1卷积,变成[B, max_len, 3*768]的特征矩阵,q、k和v各占1/3,即:[B, max_len, 768]

(2)funcs的split_heads函数

# [-1, max_len, 768]
x_shape = shape_list(x)
# 768
m = x_shape[-1]
# [-1, max_len, n_head,  768 // n_head]
new_x_shape = x_shape[:-1] + [n, m // n]
# [B, max_len, n_head,  768 // n_head]

new_x = K.reshape(x, new_x_shape)
# return [B, n_head, max_len, 768 // n_head] False
# return [B, n_head, 768 // n_head, max_len] True
return K.permute_dimensions(new_x, [0, 2, 3, 1] if k else [0, 2, 1, 3])

这里对q和v拆分成[B, 12,  max_len, 768 // 12] = [B, 12, max_len, 64]的长度,k拆成[B, 12, 64, max_len]

也就是说,每个词拆成了12等分,每一等分特征由之前的768变成了64。整个分成了B个batch,12的小batch,每个小batch的句子是max_len个长度的词组成,每个词有64的特征。

(3)这里咱们细讲一下这个permute,permute是自注意力计算逻辑最核心最抽象的地方之一。先上一张图

permute跟numpy的transpose是异曲同工的,都可以理解为转置。只是我们在对张量做转置的时候,用常规的二维矩阵的思维去理解比较难。但是!解释还是得用二维矩阵的方法去解释!

假设我们的q是一个2X6的矩阵,每一个字被嵌入成6个特征。k跟q一毛一样,但是我把k做一下转置,变成6X2的矩阵。这样的形式,我们用矩阵相乘的方法,把2x6的矩阵跟6x2的矩阵相乘,是不是变成了2x2,这个2x2是不是就可以理解为2个字对自己的排列组合。换句话说,所有的字与字之间的加权值得出来了。如图:优跟优的加权值是50(1x1+2x2+3x3+4x4+5x5),优跟秀的加权值是70(1x2+2x3+3x4+4x5+5x6),秀跟优的加权值是70(1x2+2x3+3x4+4x5+5x6),秀跟秀的加权值是90(2x2+3x3+4x4+5x5+6x6)。

(4)funcs的scaled_dot_product_attention_tf函数

先看这个函数:w = K.batch_dot(q, k)

这一步就是上一步所说的矩阵乘法。算法中,对每一组嵌入数据

[B, max_len, 768*3]通过均分,拆成3份[B, max_len, 768],分别作为q、k和v的前身;通过reshape,都变成[B, max_len, 12, 64];在通过permute变成q=[B, 12, max_len, 64]、k=[B, 12, 64, max_len]和v=[B, 12, max_len, 64]。

其中,q和k进行矩阵乘法,把max_len个字彼此求加权值(q=[B, 12, max_len, 64] * k=[B, 12, max_len, 64]),变成w=[B, 12, max_len, max_len]。这里的max_len放在“优秀”一组词里面就是2,即:当前句子的长度。

接着,继续对w和v做矩阵乘法,再求一次加权值(至于这一步有什么实际的原理,我不太确定),这一次矩阵乘法,在数学上把维度还原到[B, 12, max_len, 64],以方便后期继续还原成输入的shape。我揣测,这一步w*v的意义跟全连接层的意义是接近的。只不过在设计的出发点和可解释性上,要稍强于全连接层的意义。

所以,我用我自己组织的方式,给大家解释一下整个自注意力的过程

a、首先是split_heads,就是切分出qkv

b、矩阵乘法求权值,即:scaled_dot_product_attention_tf

c、merge_heads

经过上述两步,O(output)=[B, 12, max_len, 64],在这个函数里面,还原成[B, max_len, 64]

到此,多头+自注意力机制暂告一段落!

接下来你想在这一步重复多少次就重复多少次,因为输入输出都是一个shape。

二、LN(Layer Normalization)

四图秒懂BN、LN和IN_罗小丰同学的博客-CSDN博客_ln和bn

往这看!!!

三、GELU

一个公式说明一切(这个是近似函数,不是本征函数)

0.5 * x * (1 + K.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * K.pow(x, 3))))

然后看看图像

当我看到这个图像的时候,我第一反应想起了Swish激活函数

他俩真的是异曲同工,几乎长得都一毛一样。从论文来看,GELU函数的收敛性会比RELU和ELU都有略好。

四、位置编码

位置信息对于理解一句话来说,也是很重要的。比如:

a、难受!滑板坏了,水还洒身上了!

b、滑板坏了,难受!水还洒身上了!

对于句子b,其实表达出来的意思,难受的重心是滑板,水是附加的负面影响。对于a,可能整体对心情造成的负面影响是差不多的。这就是词在不同位置可能对语境带了影响的一种情况。

常用的位置编码一般无外乎两种:一种是词嵌入,相当于先加一步全连接层,并且该层参数可学;另一种是自己设计位置编码方法。比如我印象中bert是用正余弦函数做编码的,以后看到再跟大家分享;或者,做一个递进的简单累加也不是不行哇,哈哈。

这里的Transformer阶段的位置编码只是使用了简单的词嵌入的方式,你也可以理解其实就是全连接层的一种应用方式。

结语:transformer以及后来的bert模型最核心的地方就是自注意力机制,大家能把自注意力的实现原理看懂,其核心思想也就一通百通了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

源码剖析transformer、self-attention 的相关文章

随机推荐

  • IP地址与mac地址是什么?dhcp与arp又是什么?

    计算机网络中 数据的通信就类比写信 两个人写信 需要知道家庭住址以及收件人信息 1 IP地址与mac地址 IP地址就是家庭住址 mac地址就是收件人 例如 一个网卡 硬件设备 就是一个公寓 一个网卡有一个mac地址 出厂时已写入 全球唯一地
  • react+umi3配置代理问题

    前端配置项目代理 一般是为了解决浏览器跨域策略 在umi中有非常方便的方式可以供我们快速配置代理 在我的项目中配置代理也遇到了一个坑点 特此记录一下 环境 react 17 0 0 umi 3 5 0 开始 umi项目中 在项目根目录下创建
  • sonarqube汉化

    参考文档SonarQube基础 中文设定设定方法 知行合一 止于至善 CSDN博客 sonarqube设置中文 用方法一解决
  • 递归实现逆序输出整数

    在这里插入代码片 本题目要求读入1个正整数n 然后编写递归函数reverse int n 实现将该正整数逆序输出 输入格式 输入在一行中给出1个正整数n 输出格式 对每一组输入 在一行中输出n的逆序数 输入样例 12345 样例 54321
  • Java中的的类和对象

    类的概念 类是对生活中具有相同属性和行为的事物的抽象 它是一个大概的范围 类包含属性和行为 属性和行为在程序中也叫做成员变量和成员方法 对象的概念 是能够看得到的具备行为和属性的真实存在的实体 类和对象的关系 类是对象的抽象的范围表达 对象
  • HTML5中把一首古诗变大缩小和变颜色并用数据储存起来

    效果图如下 代码如下
  • 可中断睡眠 sleep

    可中断睡眠 可中断睡眠的执行情况 进程开始时处于可中断睡眠状态 那么如果进程接收到信号后 进程将被唤醒而不在阻塞 当执行完信号处理函数后 就不再睡眠了 直接向下执行代码 sleep 函数 unsigned int sleep unsigne
  • bugku各种绕过

    题目要求uname passwd 但是他们的SHA1值要相同 且id值为margin 利用PHP的sha1漏洞 当参数为数组时返回false 判断成立
  • 嵌入式入门基础知识有哪些?

    嵌入式系统是指在特定应用领域内为满足特定要求而设计的计算机系统 通常被嵌入到设备中 具有实时性 可靠性 低功耗等特点 嵌入式系统应用广泛 例如 智能家居 智能手表 汽车控制系统 医疗设备等 在本篇博客中 我们将讨论嵌入式入门基础知识 包括嵌
  • 狂神说Mybatis笔记(全网最全)

    Mybatis 环境说明 jdk 8 MySQL 5 7 19 maven 3 6 0 IDEA 学习前需要掌握 JDBC MySQL Java 基础 Maven Junit 1 Mybatis简介 1 1 什么是MyBatis MyBat
  • 认识区块链,认知区块链— —数据上链

    上周末参加一次长沙本地胡子互联网俱乐部举办的区块链分享会 颇受启发 同时感谢俱乐部提供的这个交流平台 祝好 好吧 还是先把前些天对区块链的一点理解简单整理下 再回顾下上周末的参会纪要比较好 下篇给大家分享出来 个人区块链思考第一篇 认识区块
  • Yolov3计算准确率、误报率、漏检率等

    思想很简单 将标注的yolo数据转下格式 转为 类别 xmin ymin xmax ymax 转换valid后的信息 两个信息进行对比 完事 具体的 在终端执行 darknet detector valid cfg voc data cfg
  • 【SSM框架】之Spring

    SSM框架笔记 自用 Spring Spring Framework系统架构 Spring程序开发步骤 核心概念 IoC Inversion of Control 控制反转 使用对象时 由主动new产生对象转换为由外部提供对象 此过程中对象
  • 计算机毕业设计看这篇就够了(二)毕设流程

    本篇将为大家介绍计算机专业毕业设计流程 提前了解毕设流程可以让同学们从宏观角度去看毕设要做些什么样的事情 大概知道每个阶段要去做哪些工作 为后续毕设任务的真正开展打下心理预期 也不至于一脸懵 计算机毕设分为以下主流程 选题 确定导师 完成前
  • 【Proteus仿真】【STM32单片机】基于stm32的智能书桌设计

    文章目录 一 功能简介 二 软件设计 三 实验现象 联系作者 一 功能简介 系统运行后 默认为手动模式 当检测有人 可通过K2键开关灯 如果姿势不对 警示灯亮 否则灭 可通过K3和K4键调节桌子高度 按下K1键切换为自动模式 此时有人 且光
  • Sentinel原理与Demo

    Sentinel 是什么 随着微服务的流行 服务和服务之间的稳定性变得越来越重要 Sentinel 以流量为切入点 从流量控制 熔断降级 系统负载保护等多个维度保护服务的稳定性 Sentinel 具有以下特征 丰富的应用场景 Sentine
  • 【FreeRTOS(三)】任务状态

    文章目录 任务状态 任务挂起 vTaskSuspend 取消任务挂起 vTaskResume 挂起任务调度器 vTaskSuspendAll 取消挂起任务调度器 xTaskResumeAll 代码示例 任务挂起 取消任务挂起 代码示例 挂起
  • Docker help帮助文档

    1 查看 docker help 帮助 docker help 2 用法 docker 选项 命令 3 选项 客户端配置文件的配置字符串位置 默认为 root docker D 启用调试模式 H 要连接的主机列表守护进程套接字 l 设置日志
  • Centos7.3安装和配置Mysql5.7

    第一步 获取mysql YUM源 进入mysql官网获取RPM包下载地址 https dev mysql com downloads repo yum 点击 下载 右击 复制链接地址 https dev mysql com get mysq
  • 源码剖析transformer、self-attention

    原文链接 首先给大家引入一个github博客 这份代码是我在看了4份transformer的源码后选出来的 这位作者的写法非常易懂 代码质量比较高 GitHub Separius BERT keras Keras implementatio