pytorch seq2seq+attention机器翻译注

2023-11-02

准备深入学习一下神经网络的搭建方法的时候,选了机器翻译来试试,正好查了很多资料,发现pytorch里有例子。就结合自己的理解和探究记录一下。原文实现代码:https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
其他博主的中文翻译及解释:
https://blog.csdn.net/u014514939/article/details/89410425?utm_medium=distribute.pc_relevant.none-task-blog-title-2&spm=1001.2101.3001.4242
(此处只做大致流程和细节分析,完整代码上述原文就可获取)
机器翻译流程:

喂进神经网络前的数据准备与处理

1.准备语料,也就是句子对,我找到的是英语-法语的语料,下载地址:https://download.pytorch.org/tutorial/data.zip
在这里插入图片描述
2.明确输入进神经网络的是怎样的数据,也就是怎样将数据处理成神经网络能理解的形式。我们这里不讨论one-hot、embedding等概念,只需要明白,将一个句子转化成数字来表示就可以,比如“i love you”,可能对应的向量形式就是[3,63,8],"i hate you"可能是[3,99,8],也就是句子中的每一个词都是由一个能表示他的数字来代指的(这个数字在此时并没有什么高深的意思,仅仅是个数字而已,就跟取名字一样)
在这里插入图片描述
为句子加上首尾标识SOS、EOS后转化为向量:
(此处是我自己实现的代码,格式与原文有点不一样,但是表述的是这个意思)
在这里插入图片描述

Seq2Seq训练细节

每一轮训练时数据的使用

将句子对转化为向量后,就可以将其喂进神经网络中,这个机器翻译模型在开始训练后trainIters函数里。在训练时,每次喂进一对句子向量
在这里插入图片描述
需要注意的是,原文实现的代码里,将此处喂进去的向量全部转化为了列向量喂进去,如果自己实现代码不注意会报错。
在这里插入图片描述

而在encoder_decoder模型训练时每次喂进去的是一个句子中的一个词。
在这里插入图片描述

神经网络中train的细节

在这里插入图片描述

encoder的输入input_tensor[ei], encoder_hidden,在原文中input_tensor[ei] 是一个tensor([1])的单一元素,encoder_hidden在初始化时大小为tensor[1,1,256]
在这里插入图片描述
而在输入后,input_tensor经过embedding层输出形状是tensor([1,256]),之后经过.view(1,1,-1),形状变为tensor([1,1,256]),encoder的输出也是torch.Size([1, 1, 256])
在这里插入图片描述
在每一轮的encoder训练完后得到的输出encoder_output,encoder_hidden,由于是单个单词训练的结果,其实从内容上看是一样的,因此,encoder_outputs[ei]相当于保留了每一步每一个单词的state。

在这里插入图片描述
decoder中由于名义上不知道句子的第一个开头词是什么,因此使用通用标识SOS来作为句子的第一个输入。
在这里插入图片描述
对于decoder的三个输入decoder_input, decoder_hidden, encoder_outputs:
decoder_input即是目标句子的此轮输入进去的单词向量,大小为tensor.size([1])
decoder_hidden是直接继承自encoder_hidden,此时的encoder_hidden是原句子最后一个单词的输出,大致来说,这样一个输出可以看做是包含了一整个句子的信息,大小是tensor.size([1,1,256])
encoder_outputs[ei]则是保存了一个句子中所有词的当时输出的状态。encoder_outputs的大小是torch.Size([10, 256])

decoder内部
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        print(342, torch.cat((embedded[0], hidden[0]), 1).size(), attn_weights)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))  #必须有三个维度
        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = F.log_softmax(self.out(output[0]), dim=1)
        print(output.size(), hidden.size(), attn_weights.size())
        return output, hidden, attn_weights

这里首先要注意的是decoder的输入和输出output_size都是法语词表的大小,而具体到最后的三个输出output,hidden,attn_weight各自的大小为torch.Size([1, 79]) torch.Size([1, 1, 256]) torch.Size([1, 10])。至于output的大小原因,因为对于decoder来说,最终是为了确定在法语词表中哪个词的可能性最大,因此它最后一层实际上可以看做一个分类类别为词表大小的分类模型。

attention的问题

一般而言attention的思想用比较简单的话来说就是两个向量之间的相似度,他的方法有很多简单的乘法,点积,还有各种公式,但是我理解的便是两个向量之间越相似,方向越接近,他们点乘后就越大
在这里插入图片描述
如图两个红色向量因为方向近似,相乘后会是一个正的较大的值,而绿色的和红色的方向相悖(不相似),他们乘出来会使一个负值。
而关于原文attention实现的问题,我在看的时候就很困惑觉得无法理解,查阅大量资料也没有结果,最后在某个角落找到了这种实现attention的方法其实错的,因此如果无法理解建议去看论文或者别的例子的attention实现。
查错指路:https://zhuanlan.zhihu.com/p/68637282

在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch seq2seq+attention机器翻译注 的相关文章

  • JSP数据交互(一)---内置对象》request(乱码解决)理解原理解决乱码问题

    Jsp内置对象之out JSp内置对象是Web容器创建的一组对象 没有进行声明创建但却可以使用out对象 不经常使用的内置对象 pageContext 内置对象的集大成者 config 指定Jsp页面初始配置的 Servlet page 当
  • C语言-----标识符、关键字、常量、变量

    这篇文章主要对C语言的标识符 关键字 常量 变量的一些细致知识点进行详细的讲解 比如 1 标识符的命名规范 也就是常量 变量 函数名的命名规则进行规范的讲解 2 C语言的关键字列表 3 常量的定义及其分类 4 对全局变量和局部变量的细节知识
  • 带你掌握Vue3新宠——快速Diff算法

    前言 我们都知道Vue 2中用的diff算法是双端Diff 而Vue 3的其中一个特性就是把底层的diff算法改成了快速Diff 与字面意思一样 快速diff是目前已知的最快的diff算法 本文将带大家解剖一下快速diff的原理 预处理 在
  • 常用服务器命令

    ssh 用户名 服务器地址 密码 nvidia smi查看当前显卡状态11 top 用户使用进程 Q 推出top CUDA VISIBLE DEVICESE 2 3 nohup python u test py gt test log 2
  • notepad++ json 排版插件 NPPJSONViewer dll

    将下载的32位或者64位插件dll文件拷贝到notepad 安装目录Notepad plugins 下面 例如我的目录 C Program Files Notepad plugins 拷贝后 或者在此路径下新建一个文件夹 重启notepad
  • cart回归树备忘录

    Cart回归树相关 决策树回顾 Cart树 分类树 回归树 最小二乘回归树 决策树回顾 1 决策树进化 ID3 C4 5 Cart 提问 异同点 2 决策树节点分裂评估准则 分类 信息增益 信息增益比率 gini系数 回归 MSE 提问 优
  • Python3,掌握这4个自动化脚本,让工作效率提升200%。

    4个自动化脚本 1 引言 2 自动发送多封邮件 2 1 模块介绍 2 2 代码实战 3 自动桌面提示 3 1 模块介绍 3 2 代码实战 4 自动生成素描草图 4 1 模块介绍 4 2 代码实战 5 自动化阅读网页新闻 5 1 模块介绍 5
  • kafka在Linux上的安装 运行,Linux下Kafka单机安装配置

    说明 操作系统 CentOS 6 x 64位 Kafka版本 kafka 2 11 0 8 2 1 实现目的 单机安装配置kafka 具体操作 一 关闭SELINUX 开启防火墙9092端口 1 关闭SELINUX vi etc selin
  • 8086汇编语言:标志寄存器的各个标志位的详细介绍

    一 基本介绍 CPU的内部的寄存器中 有一类特殊的寄存器 对于不同的处理机 其个数和结构都可能不同 它具有以下三种作用 这种特殊的寄存器在8086CPU中 被称为标志寄存器flag 8086CPU的标志寄存器有16位 其中存储的信息通常又被
  • 电路中VCC VDD VSS VEE GND的含义

    在电路中 芯片引脚经常会出现VCC VDD VSS VEE和GND这些标示 其中VCC一般表示通用芯片的电源引脚 比如一些模拟运放的正电源引脚 74系列数字芯片的电源引脚 VCC一般接相应的正电源电压 VDD一般表示数字芯片的电源引脚 如果
  • 手写Promise

    Promise是JS进行异步操作的重要API 也是开发基本上绕不开的技术 所以很有必要对其进行深入的了解 本文我们就 一步步手动实现Promise的相关功能 Promise属性和构造函数 原生功能 Promise对象的属性 验证原生Prom
  • python 工具变量回归_工具变量多重高维固定效应ivreghdfe

    Stable versionTo install reghdfe open Stata and run ssc install reghdfeNote that reghdfe requires at least Stata 11 2 an
  • win10启动项_win10系统开机启动项的设置教程

    小编给大家详解win10系统开机启动项的设置教程 使用win10系统过程中 有时会遇到启动项过多影响开机速度的问题 为此事困扰的用户 可参照以下的方法进行开机启动项的设置 win10系统的开机启动项如果过多的话 就会影响电脑的开机速度 其实
  • 计算机专业毕业设计一

    概述 从一个医学生转行成为一名程序员 对于我来说 是一个超前的跨越 好奇的朋友会问了 医学这么吃香的行业 怎么转行做码农呢 这个道理很简单 就是想象和显示差距太大了 距离梦想的专业差了点距离 请允许我去小黑屋哭上半个小时 想当年 我意气风发
  • JSP与Servlet之间的值传递种种

    这几天搞那个网上书店的过程中对JSP河Servlet有有了很多的认识 恩 下面是我遇到的问题解决了以后总结了一下 希望对大家有用吧 JSP与 servlet之间的传值有两种情况 JSP gt servlet servlet gt JSP 通
  • HTTPSConnectionPool(host='xxxxx', port=443): Max retries exceeded with url:xxxxxxxx (Caused by Ne...

    requests exceptions ConnectionError HTTPSConnectionPool host baike baidu com port 443 Max retries exceeded with url http
  • 提升mysql服务器性能(分库、分片与监控)

    节点一的建立 节点2 3 也要建立
  • MySQL学习5:事务、存储引擎

    事务 简介 事务是一组数据库操作的执行单元 它要么完全执行 要么完全不执行 事务是确保数据库中的数据一致性和完整性的重要机制之一 事务具有以下四个特性 称为ACID特性 原子性 Atomicity 事务作为一个整体被执行 要么全部操作成功
  • 一个开发的记单词小程序

    这里写目录标题 效果演示 功能1测试 功能简介 代码实现 效果演示 输入1 敲下Enter 回车键 后 进入第一个功能英译汉 给出Hello 用户输入中文意思 你好 敲下回车确定 进入下一个单词 功能1测试 功能简介 1 分别编辑中文和英文
  • Windows下的darknet安装

    1 下载darknet源码后 解压到文件夹 下载链接 https link zhihu com target https 3A github com AlexeyAB darknet 解压后的文件夹里面的内容是 2 打开build文件夹下的

随机推荐

  • 解决VMware“此主机支持Intel VT-x,但Intel VT-x处于禁用状态“

    1 问题 在启动VMware安装好的虚拟机时出现下图中的错误 2 问题原因 该主机处理器虚拟化技术处于禁用状态 可以在BIOS设置修改 3 问题处理 修改BIOS 本机型号为联想 开机点击F2进入BIOS 接着进入到bios的界面 选择标题
  • 被动与主动信息收集

    文章目录 信息收集 被动信息收集介绍 收集手段 收集内容 信息用途 信息收集 域名解析过程以手段 域名解析过程 信息收集 DNS DNS 信息收集 NSLOOKUP DNS 信息收集 DIG 查询网站的域名注册信息和备案信息 信息收集 被动
  • mysql 免安装版本下载地址

    5 7 32位 https dev mysql com get Downloads MySQL 5 7 mysql 5 7 19 win32 zip 5 7 64位 https dev mysql com get Downloads MyS
  • c源代码检查工具splint使用问题及方案

    splint使用时 可以使用splint help查看需要帮助的项目 然后针对需要了解的项目可以splint help 项目 查看具体的帮助 在splint使用过程中 老是出现Parse Error 下面是问题可能解决的方案 1 splin
  • cn_windows_7_ultimate_with_sp1_x64_dvd_u_677408.iso镜像下载

    链接 https pan baidu com s 1RvniUrq JpKQInKFs9bdvAhttps pan baidu com s 1RvniUrq JpKQInKFs9bdvA 提取码 zt88
  • Robot Framework 关于上传文件的问题的简单解决

    使用关键字选择文件 使用方式就是 解释一下 这里的xpath的 是输入标签的xpath的 而大多数的网络上传文件都会有这个输入标签 下边看几个简单的例子 本地上传按钮点开之后会弹出窗口选择文件 我们只需要获取这个本地上传的
  • VUE开发一个组件——Vue PC城市选择控件

    前言 前面用vue开发了三四个组件了 都是H5的 现在来看看PC是如何玩转组件的 其实和H5相同 样式不同而已 相关推荐 VUE开发一个组件 日历选择控件 VUE开发一个组件 移动端弹出层 IOS版 VUE开发一个组件 Vue tree树形
  • 建立和使用Python自定义模块

    文章目录 一 现状以及问题 二 Python模块 2 1 包的结构 2 2 包的位置 2 2 1 site packages目录 2 2 2 dist packages目录 2 3 自定义包 2 3 1 结构和位置 2 3 2 引用自定义包
  • 网关架构演进

    1 前言 天翼账号是中国电信打造的互联网账号体系产品 利用中国电信管道优势为企业提供用户身份认证能力 其中网关系统是天翼账号对外能力开放体系的重要组成 业务侧它以集中入口 集中计费 集中鉴权管控为目标 技术侧它支持隔离性 可配置 易开发 动
  • 什么是IO Pad?

    1 什么是IO pad IO pad是一个芯片管脚处理模块 即可以将芯片管脚的信号经过处理送给芯片内部 又可以将芯片内部输出的信号经过处理送到芯片管脚 输入信号处理包含时钟信号 复位信号等 输出信号包含观察时钟 中断等 IO pad模块可以
  • 防止浏览器嗅探音视频--blob对象在audio和video标签中的使用

    现在的浏览器很聪明 会对页面中的mp3 MP4等内容进行嗅探下载 但是对于部分付费或敏感内容 我们并不想版权资源被嗅探 这就需要使用html5 提供的 blob 对象对文件内容进行保护 blob格式的资源是无法被嗅探的 具体可以参考一下 b
  • 使用R语言计算DataFrame数据中指定范围多个数据列的两两相关系数

    使用R语言计算DataFrame数据中指定范围多个数据列的两两相关系数 在数据分析和统计建模中 了解数据列之间的相关性是非常重要的 R语言提供了许多函数来计算数据集中数据列之间的相关系数 本文将介绍如何使用R语言中的cor函数来计算Data
  • Hadoop的shuffle原理和过程图解

    wordcount为例详细阐述shuffle的实现过程 1 对HDFS输入的文件进行切割为KV形式 2 在mapper方法中执行 分割单词为KV形式 3 shuffle在Map端的三个操作 partition 多节点的相同K合并 sort
  • OpenCv中计算图像像素最大值、最小值、均值和方差

    1 寻找图像像素的最大值最小值 寻找图像最大值最小值的函数 minMaxLoc 函数 minMaxLoc 函数原型 void cv minMaxLoc InputArray src double minVal double maxVal 0
  • 浮点数——科学计数法、浮点数表示、加减运算和浮点数的使用

    目录 1 2浮点数 1 2 1 科学计数法 1 2 2 浮点数表示 1 符号位 2价码位 1 2 3 加减运算 1 2 4 浮点数的使用 1 2浮点数 浮点数是采用科学计数法来表示的 由符号位 有效数字 指数三部分组成 使用浮点数存储和计算
  • 性能测试------LoadRunner

    1 常见的性能问题 1 内存泄漏 软件运行的时候没有回收内存 导致内存越来越慢 2 CPU使用率达到了100 3 线程死锁 阻塞 造成系统运行越来越慢 4 查询的速度越来越慢 5 受外部系统的影响越来越大 2 为什么要进行性能测试 1 获取
  • three.js源码翻译及案例(五)-GLTFLoader.js

    写在前面 Three中的加载脚本很多 但是核心思想是差不多的 就是文件用文件解析器加载 图片用图片解析器加载 然后json转换为对象 但是由于gltf格式可以自己编辑所以有的源码参考意义不大 glb及拓展材质都没用上就还没有翻译 以后可能会
  • xcode APP 打包以及提交apple审核详细流程(新版本更新提交审核)

    最近项目到了最后的阶段 测试完一切ok后 准备打包以及提交 不料看到网上众教程 好多都是老版本的 现在IDE实现方式改了 那些方法好多都找不到 绕了一大圈 才搞明白流程 现在记录下来 以便朋友们查阅 开发环境 xcode4 4 1 ipho
  • 李胜溢9.5 : 最新黄金原油行情走势分析及布局操作指南。

    趋势价值交易 是所有投资者走向盈利的必经之路 没有捷径 也不要心存侥幸 任何一个投资者从初入市场 到走向盈利 都需要经历亏损再到保本再到盈利的过程 市场绝对不是投机者长久的天堂 一次的投机成功不代表可以善始善终 只有稳定不断的持续盈利才能成
  • pytorch seq2seq+attention机器翻译注

    准备深入学习一下神经网络的搭建方法的时候 选了机器翻译来试试 正好查了很多资料 发现pytorch里有例子 就结合自己的理解和探究记录一下 原文实现代码 https pytorch org tutorials intermediate se