Pytorch从0实现Transformer

2023-11-08

摘要

With the continuous development of time series prediction, Transformer-like models have gradually replaced traditional models in the fields of CV and NLP by virtue of their powerful advantages. Among them, the Informer is far superior to the traditional RNN model in long-term prediction, and the Swin Transformer is significantly stronger than the traditional CNN model in image recognition. A deep grasp of Transformer has become an inevitable requirement in the field of artificial intelligence. This article will use the Pytorch framework to implement the position encoding, multi-head attention mechanism, self-mask, causal mask and other functions in Transformer, and build a Transformer network from 0.

随着时序预测的不断发展,Transformer类模型凭借强大的优势,在CV、NLP领域逐渐取代传统模型。其中Informer在长时序预测上远超传统的RNN模型,Swin Transformer在图像识别上明显强于传统的CNN模型。深层次掌握Transformer已经成为从事人工智能领域的必然要求。本文将用Pytorch框架,实现Transformer中的位置编码、多头注意力机制、自掩码、因果掩码等功能,从0搭建一个Transformer网络。


一、构造数据

1.1 句子长度

# 关于word embedding,以序列建模为例
# 输入句子有两个,第一个长度为2,第二个长度为4
src_len = torch.tensor([2, 4]).to(torch.int32)
# 目标句子有两个。第一个长度为4, 第二个长度为3
tgt_len = torch.tensor([4, 3]).to(torch.int32)
print(src_len)
print(tgt_len)

输入句子(src_len)有两个,第一个长度为2,第二个长度为4
目标句子(tgt_len)有两个。第一个长度为4, 第二个长度为3
在这里插入图片描述

1.2 生成句子

用随机数生成句子,用0填充空白位置,保持所有句子长度一致

src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L, )), (0, max(src_len)-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len])
print(src_seq)
print(tgt_seq)

src_seq为输入的两个句子,tgt_seq为输出的两个句子。
为什么句子是数字?在做中英文翻译时,每个中文或英文对应的也是一个数字,只有这样才便于处理。
在这里插入图片描述

1.3 生成字典

在该字典中,总共有8个字(行),每个字对应8维向量(做了简化了的)。注意在实际应用中,应当有几十万个字,每个字可能有512个维度。

# 构造word embedding
src_embedding_table = nn.Embedding(9, model_dim)
tgt_embedding_table = nn.Embedding(9, model_dim)
# 输入单词的字典
print(src_embedding_table)
# 目标单词的字典
print(tgt_embedding_table)

字典中,需要留一个维度给class token,故是9行。
在这里插入图片描述

1.4 得到向量化的句子

通过字典取出1.2中得到的句子

# 得到向量化的句子
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)

在这里插入图片描述

该阶段总程序

import torch
# 句子长度
src_len = torch.tensor([2, 4]).to(torch.int32)
tgt_len = torch.tensor([4, 3]).to(torch.int32)
# 构造句子,用0填充空白处
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(src_len)-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len])
# 构造字典
src_embedding_table = nn.Embedding(9, 8)
tgt_embedding_table = nn.Embedding(9, 8)
# 得到向量化的句子
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)

二、位置编码

位置编码是transformer的一个重点,通过加入transformer位置编码,代替了传统RNN的时序信息,增强了模型的并发度。位置编码的公式如下:(其中pos代表行,i代表列)
2.1

2.1 计算括号内的值

# 得到分子pos的值
pos_mat = torch.arange(4).reshape((-1, 1))
# 得到分母值
i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape((1, -1))/8)
print(pos_mat)
print(i_mat)

在这里插入图片描述

2.2 得到位置编码

# 初始化位置编码矩阵
pe_embedding_table = torch.zeros(4, 8)
# 得到偶数行位置编码
pe_embedding_table[:, 0::2] =torch.sin(pos_mat / i_mat)
# 得到奇数行位置编码
pe_embedding_table[:, 1::2] =torch.cos(pos_mat / i_mat)
pe_embedding = nn.Embedding(4, 8)
# 设置位置编码不可更新参数
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)
print(pe_embedding.weight)

在这里插入图片描述

三、多头注意力

3.1 self mask

有些位置是空白用0填充的,训练时不希望被这些位置所影响,那么就需要用到self mask。self mask的原理是令这些位置的值为无穷小,经过softmax后,这些值会变为0,不会再影响结果。

3.1.1 得到有效位置矩阵

# 得到有效位置矩阵
vaild_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len) - L)), 0)for L in src_len]), 2)
valid_encoder_pos_matrix = torch.bmm(vaild_encoder_pos, vaild_encoder_pos.transpose(1, 2))
print(valid_encoder_pos_matrix)

在这里插入图片描述
3.1.2 得到无效位置矩阵

invalid_encoder_pos_matrix = 1-valid_encoder_pos_matrix
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)
print(mask_encoder_self_attention)

True代表需要对该位置mask
在这里插入图片描述
3.1.3 得到mask矩阵
用极小数填充需要被mask的位置

# 初始化mask矩阵
score = torch.randn(2, max(src_len), max(src_len))
# 用极小数填充
mask_score = score.masked_fill(mask_encoder_self_attention, -1e9)
print(mask_score)

在这里插入图片描述
算其softmat

mask_score_softmax = F.softmax(mask_score)
print(mask_score_softmax)

可以看到,已经达到预期效果
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch从0实现Transformer 的相关文章

随机推荐

  • Java之String类

    作者简介 zoro 1 目前大二 正在学习Java 数据结构等 作者主页 zoro 1的主页 欢迎大家点赞 收藏 加关注哦 Java之String类 String的构造 String底层 String之间的比较 比较内容 比较地址 字符串查
  • [用python辅助学生中考与高考-3]:中考科技特长生知多少

    目录 前言 1 什么是科技特长生 2 科技特长生政策正在覆盖全国 3 科技特长生招生流程 1 制定章程 2 考生报名 考生自主报名 3 专业加试 考试 4 预录取 5 上报名单 6 名单公示 7 参加考试 8 正式录取 4 如何布局中考科技
  • 毕业设计-基于 MATLAB 的图像边缘检测算法的研究和实现

    目录 前言 课题背景和意义 实现技术思路 一 MATLAB概述 二 图像边缘检测 实现效果图样例 最后 前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实习为毕业后面临的就业升学做准备 一边要为毕业设计耗费大量精力 近几年各个学校要求
  • 2021-11-24 qt 串口

    这货跟vb 比 就系更方便了 打完收工
  • 正则表达式 匹配以特定字符串开头 到任意第一个字符中间的空格

    lt p style text indent 2em S 正则表达式 匹配以特定字符串开头 到任意第一个字符中间的空格 lt p p style text indent 2em u4e00 u9fa5 正则表达式 匹配以特定字符串开头 到任
  • 2021年的保研之旅总结

    保研之旅 个人情况介绍 1 学校 末流211 2 专业 信息管理与信息系统 信管算管理学位 保研的时候有的时候会被认为是跨保的 3 绩点 1 36 4 比赛获奖 没有什么拿得出手的获奖 只有一些小奖 全国大学生物联网设计竞赛全国一等奖 美国
  • robotframework-ride安装注意点

    欢迎关注 无量测试之道 公众号 回复 领取资源 Python编程学习资源干货 Python Appium框架APP的UI自动化 Python Selenium框架Web的UI自动化 Python Unittest框架API自动化 资源和代码
  • Server2008R2:由于没有远程桌面授权服务器可以提供许可证,远程会话被中断.的解决方法,求大神们指导

    出现 由于没有远程桌面授权服务器可以提供许可证 远程会话被中断 问题是因为微软默认的远程登录只提供120天的使用期限 解决该问题的具体步骤如下 1 打开运行 在运行中输入注册表命令 regedit 然后回车通过命令打开注册表对话框 2 在注
  • 获取windows中活跃的Com口

    获取windows中活跃的Com口 记录于2021年11月9日 今天对我来说是个很特殊的一天 母胎SOLO二十一周年 无奈 Orz 闲暇之余写下此文章 记录一下我的日常 文章目录 获取windows中活跃的Com口 前言 一 如何寻找活跃C
  • “另一个程序正在使用此文件,进程无法访问”的解决方法

    另一个程序正在使用此文件 进程无法访问 的解决方法 参考文章 1 另一个程序正在使用此文件 进程无法访问 的解决方法 2 https www cnblogs com shiningrise archive 2012 12 02 279812
  • apache 2.4 配置php,Apache2.4 PHP 配置

    Apache2 4服务器 http www apachehaus com cgi bin download plx APACHE24VC14 64位 http www apachehaus com cgi bin download plx
  • Vue创建Demo项目

    Vue创建Demo项目 Vue 发音为 vju 类似 view 是一款用于构建用户界面的 JavaScript 框架 它基于标准 HTML CSS 和 JavaScript 构建 并提供了一套声明式的 组件化的编程模型 帮助你高效地开发用户
  • 魔兽怀旧服联盟服务器不稳定,魔兽世界怀旧服上次被联盟攻击至少三个月前,“单边服”何去何从...

    游戏中我们是朋友 聊天侃地 在这里我们可以无拘无束的发言 不会有任何人阻挠 还有大家最喜欢吐槽的小编 请把口水收集好 随时准备和小编一起吐槽 魔兽世界怀旧服上次被联盟攻击至少三个月前 单边服 何去何从 今天公会一个人表示他被联盟杀了 大家都
  • React性能优化的手段有哪些

    1 使用纯组件 2 使用 React memo 进行组件记忆 React memo 是一个高阶组件 对 于相同的输入 不重复执行 3 如果是类组件 使用 shouldComponentUpdate 这是在重新渲染组件之前触发的其中一个生命周
  • 21. 成语接龙

    小张非常喜欢与朋友们玩成语接龙的游戏 但是作为 文化沙漠 的小张 成语的储备量有些不足 现在他的大脑中存储了m个成语 成语中的四个汉字都用一个1000000以内的正整数来表示 现在小张的同学为了考验他给出了他一个成语做开头和一个成语做结尾
  • HDFS的基础练习--新建目录

    实验 1 在HDFS的 上创建10目录 data01 data10 在浏览器上查看 2 在HDFS data03下递归创建 data05 data06 data07 递归创建 使用命令 hdfs fs mkdir p xx1 xx2 xx3
  • IDEA菜单栏不见了怎么办

    开始时候我的IDEA主菜单不见了 解决方法 打开Idea 按两次shift 并在弹出框内的搜索框里输入 view 然后往下拉 找图里的这个View 点击它 会弹出新的框 然后就 这样主菜单栏就出来了
  • SpringBoot 基础

    1 认识Spring Boot Spring 不同于一般框架 它是一个聚合的框架 通过Spring 框架可以使Java 更为便捷和系统化 Java web 中最为使用的框架为 Spring Framework Spring boot 是 S
  • python中使用apscheduler二步简单完成定时任务设置,用于自动化任务的创建,无人值守后台任务创建

    一 apscheduler的安装 首先需要安装pip 打开CMD输入pip install apscheduler 安装apscheduler模块 安装过程如下图 二 导入apscheduler包 设置参数与需要执行的脚本 coding u
  • Pytorch从0实现Transformer

    文章目录 摘要 一 构造数据 1 1 句子长度 1 2 生成句子 1 3 生成字典 1 4 得到向量化的句子 该阶段总程序 二 位置编码 2 1 计算括号内的值 2 2 得到位置编码 三 多头注意力 3 1 self mask 摘要 Wit