Seq2Seq模型学习(pytorch)

2023-11-09

在看pytorch的官方英文例子,做些笔记,如有纰漏请指正,原文:https://pytorch.org/tutorials/beginner/chatbot_tutorial.html

数据准备

  • 首先是单词编码。seq2seq的单词编码的方式可以参见seq2seq翻译教程,这篇文章采用了对单词建立索引,每个单词用一个index表示的方法。
  • 正常情况的句子矩阵是(句子长度,句子数量)的矩阵。为了批量训练,需要对其进行转置,每次读取各个句子一个时间步长的单词,如下图batches
  • 输入的内容包括以下几个部分:
  1. input_variable:转化为tensor的句子编码矩阵,形式为:
    tensor([[  34,   16,  147, 4556,   50],
            [   4, 5553,  379, 2528,    6],
            [  76, 3810,   45,    6,    2],
            [ 331,    4,  380,    2,    0],
            [ 117,    4, 1967,    0,    0],
            [  47,    4,    6,    0,    0],
            [   4,    2,    2,    0,    0],
            [   2,    0,    0,    0,    0]])
  2. lengths,每个句子的长度,可以理解为上面矩阵的每个纵向的没用0填充之前的长度,形式为
    tensor([8, 7, 7, 4, 3])
  3. target_variable:目标结果,即希望得到的输出
    tensor([[  50,   16,   25,   27, 1582],
            [ 331,  997,  197,  153,    4],
            [ 117,    4,  117,  112,    2],
            [  47,   50,   24, 3484,    0],
            [   6,  368,    4,  587,    0],
            [   2,    6,    2,    4,    0],
            [   0,    2,    0,    2,    0]])
  4. mask:即pad_token
    tensor([[1, 1, 1, 1, 1],
            [1, 1, 1, 1, 1],
            [1, 1, 1, 1, 1],
            [1, 1, 1, 1, 0],
            [1, 1, 1, 1, 0],
            [1, 1, 1, 1, 0],
            [0, 1, 0, 1, 0]], dtype=torch.uint8)

Seq2Seq模型原理

Seq2Seq意如其名,是使用固定大小的模型实现一个变长序列到变长序列的模型。输入一个变长的句子,返回一个变长的句子,故可以用来执行翻译任务和起标题任务,人机对话等。是rnn的衍生模型。

具体模型使用两个rnn模型来实现此功能。一个RNN充当编码器,他将输入的变长序列编码成固定长度的上下文向量,理论上,这个上下文向量,也就是RNN最终的隐藏层包含了输入的语义信息,(类似于CNN卷积后的中间层包含了图片的边缘、色彩等信息一样)。第二个RNN充当解码器,他接受输入和上下文向量,并返回对下一个单词的猜测和下一次迭代中使用的隐藏状态。model

编码器

编码器每次迭代接受输入语句中的一个标记(一个单词),并产生一个输出向量和一个隐藏状态向量。这个中间隐藏状态向量接着被送向下一个迭代中,而输出向量被记录。编码器将他在句子中每个点处看到的上下文转化成高维空间中的一组点,以便后面解码器用它来产生有意义的输出。

这里用的编码器的核心是有Cho等人在2014年发表的多层门控循环单元。这里使用GRU的双向变体,以为着这里有两个独立的RNN,一个以正常顺序输入序列,另一个以相反顺序输入序列。每一个迭代中输出将会求和,使用双向GRU将提供之前和之后的上下文的优势。

双向RNN原理如图:rnn_bidir

  • embedding层将把单词编码到任意大小的特征空间,对于我们的模型,该层将映射每个单词到大小为hidden_size的特征空间,当训练时,意义相同的单词在特征空间里的值也应当相近。
  • 填充序列时用了nn.utils.rnn.pack_padded_sequence,解包时需要用nn.utils.rnn.pad_packed_sequence

 模型流程:

  1. 将单词索引转化为embeddings
  2. 为RNN模块打包填充序列
  3. 正向通过GRU
  4. 解填充
  5. 对双向GRU输出进行求和
  6. 返回输出和最终的隐藏层状态

输入

  1. input_seq:一批输入的句子,矩阵大小为(最大句子长度,句子数量)
  2. input_lengths:每个句子对应的长度矩阵
  3. hidden:隐藏层的状态,大小为(n层,维数,数量,hidden_size)

输出

  1. GRU最后一个隐藏层的输入特征(双向输出之和),大小为(max_length,batch_size,hidden_size)
  2. GRU的隐藏层更新状态,大小为(n_layers x num_directions,batch_size,hidden_size)

 

解码器

解码器RNN以token-by-token的形式生成回应。他使用编码器生成的上下文向量,和内部的隐藏层状态来生成序列中的下一个单词。他持续生成单词知道他输出了 EOS_token,即一个句子的末尾。一个常见的Seq2Seq问题是如我我们仅仅依赖于上下文向量来编码整个输入序列的含义,很可能会丢失信息,特别是在处理长输入序列的时候,这将极大限制解码器的能力。

为了克服这种现象,Bahdanau等人提出了注意机制来允许解码器来关注输入句子的一部分,而不是使用整个固定的上下文。

在更高的级别里,计算注意力是靠着解码器当前隐藏层的状态和编码器的输入。输出注意力的权重与输入句子有着相同形式的矩阵大小,以便允许将其和编码器的输入相乘,给出一个加权和表示要注意的编码器输出部分,下图很好的描述了这一点:

attn2

之后有一种改进的方法称为global attention,关键的变化是通过global attention,我们可以关注编码器的全部隐藏层状态,而不是之前提出的当前迭代中的局部隐藏层状态。另一个变化是计算注意力权重只使用当前迭代的隐藏层状态,而上一种还需要知道上一个迭代中解码器的状态,而且给出了各种方法来计算attention energies,被称为score functions

scores

ht代表当前目标解码器状态,hs代表所有解码器的状态

global attention机制可以通过下图总结,在pytorch中注意层实现为一个名为Attn的独立nn.Module,输出的shape为(batch_size,1,max_length).

global_attn

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Seq2Seq模型学习(pytorch) 的相关文章

随机推荐

  • FCN的代码解读

    目录 模型初始化 VGG初始化 FCN初始化 图片的预处理 图片处理 图片编码 计算相关参数 模型训练 一个小问题 完整代码 参考 最近浅研究了一下关于图像领域的图像分割的相关知识 发现水还是挺深的 因为FCN差不多也是领域的开山鼻祖 所以
  • Android无线网络调试手机

    adb tcpip 5555 adb下载地址 http download clockworkmod com test UniversalAdbDriverSetup msi 3 在设备中下载超级终端 是andriod软件 设置端口 su s
  • JVM笔记-黑马-1

    文章目录 视频资源地址 笔记资源地址 我的笔记 1 什么是JVM 2 学习jvm的作用 3 常见的jvm 4 学习路线 5 内存结构 程序计数器 作用 6 内存结构 程序计数器 特点 7 内存结构 虚拟机栈 8 内存结构 虚拟机栈的演示 9
  • C++文件操作和文件流

    C 文件操作和文件流 1文件的概念 2 文件流的分类 2 打开文件 2 1 通过类对象调用 open 函数打开一个文件 2 2 通过类对象构造函数打开文件 3 关闭文件 4 读写文件 4 1 文本文件的读写 4 2 二进制文件的读写 1文件
  • ESP8266之AT指令

    一 8266作为client 1 AT 功能 测试8266能否工作 2 AT CWMODE 3 功能 设置工作模式 1 station模式 2 ap模式 3 ap station复位保存当前值 3 AT RST 功能 复位 4 AT CWL
  • Android利用AIDL实现apk之间跨进程通信

    AIDL 最广泛与最简单的应用是与四大组件之一 Serivce 的配合使用了 我们都知道 启动一个 Serivce 有两种方式 1 通过 startService 的方式 2 通过 bindService 的方式 通过 binService
  • 图像处理之目标检测入门总结

    点击上方 小白学视觉 选择加 星标 或 置顶 重磅干货 第一时间送达 本文转自 机器学习算法那些事 本文首先介绍目标检测的任务 然后介绍主流的目标检测算法或框架 重点为Faster R CNN SSD YOLO三个检测框架 本文内容主要整理
  • linux安装自动化部署工具jenkins

    创建工程目录 mkdir home software jenkins 创建工作空间 mkdir home workspaces jenkins 进入工程目录 cd home software jenkins 下载Jenkins rpm安装包
  • 伪代码及其实例讲解

    伪代码 Pseudocode 是一种算法描述语言 使用伪代码的目的是为了使被描述的算法可以容易地以任何一种编程语言 Pascal C Java etc 实现 因此 伪代码必须结构清晰 代码简单 可读性好 并且类似自然语言 介于自然语言与编程
  • 基于svg.js实现对图形的拖拽、选择和编辑操作

    本文主要记录如何使用 svg js 实现对图形的拖拽 选择 图像渲染及各类形状的绘制操作 1 关于SVG SVG 是可缩放的矢量图形 使用XML格式定义图像 可以生成对应的DOM节点 便于对单个图形进行交互操作 比CANVAS更加灵活一点
  • 分享一下Python数据分析常用的8款工具

    Python是数据处理常用工具 可以处理数量级从几K至几T不等的数据 具有较高的开发效率和可维护性 还具有较强的通用性和跨平台性 Python可用于数据分析 但其单纯依赖Python本身自带的库进行数据分析还是具有一定的局限性的 需要安装第
  • 移动端开发技术小结(前端)

    移动端开发技术小结 前端 移动端处理webkit内核即可 浏览器的私有前缀只需要考虑添加 webkit 布局视口 视觉视口 理想视口 将布局宽度改为视觉视口 图片 2倍图 3倍图 背景缩放 background size 背景图片宽度 背景
  • windows 系统安装sonarqube

    SonarQube是一种自动代码审查工具 用于检测代码中的错误 漏洞和代码异味 它可以与您现有的工作流程集成 以便在项目分支和拉取请求之间进行连续的代码检查 官方网站 https www sonarqube org 1 使用前提条件 运行S
  • FTP-读取指定目录下的文件,上传到FTP服务器,一键复制黏贴,就是这么丝滑~

    背景 需要定时将服务器下的日志文件上传到指定FTP服务器的目录下 并通知第三方平台文件已上传 FTP服务器模拟工具 application yml配置 spring logfilepath home jboss server default
  • Cesium Terrain Builder (CTB) 简单使用_地形切片

    Cesium Terrain Builder CTB 简单使用 地形切片 目录 Cesium Terrain Builder CTB 简单使用 地形切片 官网地址 win r cmd 打开命令提示符工具运行 Create a GDAL Vi
  • windows计算机无法打开,为什么我的电脑无法运行Win11?原因可能是这个

    原标题 为什么我的电脑无法运行Win11 原因可能是这个 为什么我的电脑无法运行Win11 原因可能是这个 微软已经在日前正式发布Windows 11操作系统 虽然新系统的更新升级与发布并非同步进行 甚至现在连官方都未公开预览版 但由于此前
  • eccms静态页面出现出现基础链接已关闭,无法链接到远程服务器错误的解决办法

    出现 基础链接已关闭 无法链接到远程服务器 错误 一 系统组件错误 如果属于系统Socket组件错误 可以重启socket组件 netsh winsock reset 进行解决 二 实际发生的原因 由于实际情况所需 禁止服务器访问外网 解决
  • 【C++笔记】《C++编程思想-卷一》笔记

    C 编程思想 笔记 Volume 1 第一章 对象导言 OOP ObjectOriented Programming 面对对象编程 UML Unified Model Language 统一建模语言 堆 stack 和栈 heap 预备知识
  • dgl 操作

    dgl图的基本操作 dgl简单使用 udf函数怎么写 通过edges 打得到两端的列表 将所有nodes换成edges 截错图了 通过nodes则不是很好使 返回很多的tensor原因是这个函数运行好几遍
  • Seq2Seq模型学习(pytorch)

    在看pytorch的官方英文例子 做些笔记 如有纰漏请指正 原文 https pytorch org tutorials beginner chatbot tutorial html 数据准备 首先是单词编码 seq2seq的单词编码的方式