初步使用LSTM网络

2023-10-27

声明:文中所有使用图片均来自网络,侵删。

什么是LSTM

LSTM(Long Short-Term Memory网络),是一种特殊的RNN类型,可以有效解决RNN神经网络存在的长期依赖问题。通过模仿人脑可以进行遗忘的功能,在每一个LSTM模块中加入了遗忘门,对信息进行处理,具体如下:
在这里插入图片描述

遗忘门

通过全部的LSTM网络图片可以观察到,每一个单元的输入不仅仅只是X_t,还包括前一个单元的输出状态H_t-1,下面我们对单个的LSTM单元进行分析
在这里插入图片描述
输入的数据X_t和上一时刻状态H_t-1一起传入当前时刻的LSTM单元,通过第一个SIGMOD函数,即所谓的遗忘门,因为SIGMOD函数输出值如下图所示在0-1之间,所以相当于给信息赋权,决定哪些信息会被遗忘。
在这里插入图片描述

输入门

在这里插入图片描述
这一道门决定哪些信息会被真正地输入到记忆中(就好像人脑一样,记不下所有的输入信息)。其中,由于SIGMOD函数输出为0-1的特性,它可以决定输入多少比列的信息(加权)
tanh函数决定着输入什么信息(对信息进行加工处理)
在这里插入图片描述

输出门

在这里插入图片描述
输出当前状态和隐藏状态

如何使用(基于pytorch)

上述理论看似需要很多数学运算,但实际上我们使用Python编写LSTM时算法时,不需要自己编写这些,只需要调用库文件里的封装好的API就行。这里我们使用pytorch库进行编写

我们先进行简单的运用,生成一组数据表示要处理的文本,对它运用LSTM

import torch

# 1 设置参数
batch_size = 10     # 设置每一组取词数量
seq_len = 20        # 设置每一次取多少组
embedding_dim = 30  # 将文件用多少维的数据表示
hidden_size = 22    # 每一层隐藏层有多少LSTM单元
num_layer = 2       # 有多少层隐藏层
voc_doc = 200       # 要训练的文本中有多少不一样的词

# 2 导入文本对象
text = torch.randint(low=0, high=100, size=(seq_len, batch_size))

# 3 实例化API
'''
 torch.nn.Embedding(voc_doc, embedding_dim)
  一个将文本转化成数字数据的API
'''
embedding = torch.nn.Embedding(voc_doc, embedding_dim)
embed = embedding(text)
lstm = torch.nn.LSTM(embedding_dim, hidden_size, num_layer)

# 4 开整
out, (h_n, c_n) = lstm(embed)

注意:

  1. 在第12行代码中,如果batch_size和seq_len位置调换,在实例化LSTM时必须添加一个参数batch_first=True,如下
text = torch.randint(low=0, high=100, size=(batch_size, seq_len))
lstm = torch.nn.LSTM(embedding_dim, hidden_size, num_layer, batch_first=True)

具体原因请查阅torch.nn.LSTM函数文档

  1. 实例化LSTM对象之后,不仅需要传入数据,还需要前一次的h_0(前一次的隐藏状态)和c_0(前一次memory),即:lstm(input,(h_0,c_0)),如果不进行设置,则LSTM的默认输出为output, (h_n, c_n)

可以在上述代码中添加以下代码查看输出结果

print(out)
print("*"*100)
print(h_n)
print("*"*100)
print(c_n)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

初步使用LSTM网络 的相关文章

  • 将字符串转换为浮点数列表(在 python 中)

    出于数据存储的目的 我尝试从 txt 文件恢复浮点列表 从字符串中可以看出 a 1 3 2 3 4 5 我想恢复 a 1 3 2 3 4 5 我期待一个简单的解决方案 例如list a 但我找不到类似的东西 Use the AST模块 Ex
  • 如何打印脚本的每一行,因为它仅针对正在运行的顶级脚本运行?

    python 跟踪模块将允许您运行一个脚本 打印每一行代码 因为它在脚本和所有导入的模块中运行 如下所示 python m trace trace myscript py 有没有办法做同样的事情 但是only打印顶级调用 即仅打印以下行my
  • 如何使用 tkinter 使用网格功能显示不同的图像?

    我想使用显示文件夹中的图像grid 但是当我尝试使用以下代码时 我得到了迭代单个图像的输出 My code def messageWindow win Toplevel path C Users HP Desktop dataset for
  • 导入 SciPy 或 scikit-image,“from scipy.linalg import _fblas:导入错误:DLL 失败”

    我正在导入 from scipy import misc io 但我收到这些错误 Traceback most recent call last File C work asaaki code generateProposals py li
  • 如何搜索一列并用找到的内容填充另一列?

    我有一个带有虚构人物数据的大熊猫数据框 下面是一个小例子 每个人都由一个数字定义 import pandas as pd import numpy as np df pd DataFrame Number 5569 3385 9832 64
  • 清理 MongoDB 的输入

    我正在为 MongoDB 数据库程序编写 REST 接口 并尝试实现搜索功能 我想公开整个 MongoDB 接口 我确实有两个问题 但它们是相关的 所以我将它们放在一篇文章中 使用 Python json 模块解码不受信任的 JSON 是否
  • 按升序对数字字符串列表进行排序

    我创建了一个SQLite https en wikipedia org wiki SQLite数据库有一个存储温度值的表 第一次将温度值按升序写入数据库 然后 我将数据库中的温度值读入列表中 然后将该列表添加到组合框中以选择温度 效果很好
  • 熊猫系列到二维数组

    所以 我使用了来自的答案将二维数组放入 Pandas 系列中 https stackoverflow com questions 38840319 put a 2d array into a pandas series将 2D numpy
  • Python:处理图像并保存到文件流

    我需要使用 python 处理图像 应用过滤器和其他转换 然后使用 HTTP 将其提供给用户 现在 我正在使用 BaseHTTPServer 和 PIL 问题是 PIL 无法直接写入文件流 因此我必须写入临时文件 然后读取该文件 以便将其发
  • 私有属性,但却是一个神秘的领域

    我想将属性设为私有 但带有 pydantic 字段 from pydantic import BaseModel Field PrivateAttr validator class A BaseModel a str I want a py
  • python os.fork 使用相同的 python 解释器吗?

    据我所知 Python 中的线程使用相同的 Python 解释器实例 我的问题是与创建的流程相同os fork 或者每个进程创建的os fork有自己的翻译吗 每当你 fork 时 整个 Python 进程都会在内存中复制 包括Python
  • 在 Python 中引发异常的正确方法是什么? [复制]

    这个问题在这里已经有答案了 这是简单的代码 import sys class EmptyArgs StandardError pass if name main The first way to raise an exception if
  • 如何在 tkinter 后台运行函数[重复]

    这个问题在这里已经有答案了 我是 GUI 编程新手 我想用 tkinter 编写一个 Python 程序 我想要它做的就是在后台运行一个可以通过 GUI 影响的简单函数 该函数从 0 计数到无穷大 直到按下按钮为止 至少这是我想要它做的 但
  • 如何在使用 Flask for Python 3 的同时使用 Bootstrap 4?

    我检查过 发现默认安装时 Flask Bootstrap 原生使用 Bootstrap 3 3 7 但实际上我想通过使用 Flask Bootstrap 包在我的项目中使用 Bootstrap 4 任何有关如何更新它或类似内容的帮助将不胜感
  • 在 Keras 中使用有状态 LSTM 训练多变量多级数回归问题

    我有时间序列P过程 每个过程的长度各不相同 但都有 5 个变量 维度 我试图预测测试过程的估计寿命 我正在用有状态的方法来解决这个问题LSTM在喀拉斯 但我不确定我的训练过程是否正确 我将每个序列分成长度的批次30 所以每个序列都是这样的形
  • model.predict() 返回类而不是概率

    Hello 我是第一次使用 Keras 我训练并保存了一个模型 作为 json 文件及其权重 该模型旨在将图像分为 3 个类别 我的编译方法 model compile loss categorical crossentropy optim
  • 阻止 BeautifulSoup 将我的 XML 标签转换为小写

    我正在使用 BeautifulStoneSoup 来解析 XML 文档并更改一些属性 我注意到它会自动将所有 XML 标签转换为小写 例如我的源文件有
  • Docker Python 脚本找不到文件

    我已经成功构建了一个 Docker 容器 并将应用程序的文件复制到 Dockerfile 中的容器中 但是 我正在尝试执行引用输入文件 在 Docker 构建期间复制到容器中 的 Python 脚本 我似乎无法弄清楚为什么我的脚本告诉我它无
  • Python 中的可逆 STFT 和 ISTFT

    有没有通用的形式短时傅立叶变换 https en wikipedia org wiki Short time Fourier transform与内置于 SciPy 或 NumPy 或其他什么中的相应逆变换 这是pyplotspecgram
  • Python 子进程:无法转义引号

    我知道以前曾问过类似的问题 但它们似乎都是通过重新设计参数的传递方式 即使用列表等 来解决的 但是 我这里有一个问题 因为我没有这个选项 有一个特定的命令行程序 我使用的是 Bash shell 我必须向其传递带引号的字符串 它不能不被引用

随机推荐

  • 利用IDEA 进行debug发现错误的一次经历

    利用IDEA debug发现错误的一次经历 今天在做实训项目的时候遇到了一个问题 就是在进行添加学生信息的时 发现总是提示手机号码格式不正确 可是明明是以正确的格式输入的却总是提示格式错误 于是在这里打上断点并且按右上角的虫子按钮 同样输入
  • vue项目怎么安装依赖

    安装node js 从node js官网下载并安装node 安装过程很简单 一路 下一步 就可以了 傻瓜式安装 安装完成之后 打开命令行工具 输入 node v 如下图 如果出现相应的版本号 则说明安装成功 npm包管理器 是集成在node
  • 不想学习的时候如何逼迫自己去学习?(长文预警)

    尼采曾用酒神和日神来比喻人类艺术活动的两种方式 一种是日神的 走向世界 追求成功 类的理性 一种是酒神的 走向内心 寻求超越 类的情感 而从学习上来看 由于中国特殊的教育环境 几乎不可能有后者的闲情逸致 家长们送孩子们上学 除了超一线城市确
  • Hive概论、架构和基本操作

    Hive是一个构建在Hadoop上的数据仓库框架 最初 Hive是由Facebook开发 后台移交由Apache软件基金会开发 并做为一个Apache开源项目 Hive是基于Hadoop的一个数据仓库工具 可以将结构化的数据文件映射为一张数
  • vue简单实现查询排序功能

  • Jenkins与SonarQube配置

    Jenkins与SonarQube Jenkins 配置 SonarQube 在 SonarQube 中生成 Server authentication token 登录 SonarQube 后 在 My Account gt Securi
  • 2020年蓝桥杯B组个人题解(热的,不知道对错)

    文章目录 A B C D E F G H I J 总结 结果 现在是蓝桥杯刚结束 趁着有记忆 写下这篇博客 不知道对错 如果我错了 请指出 A 因为是到0就结束了 那么每次看看 600是否结束 如果没有结束就 300 然后时间 2 60 最
  • 抗渗等级p6是什么意思_混凝土抗渗等级w4是什么意思?

    混凝土抗渗等级可按28d龄期的标准试件测定 混凝土抗渗等级分为 W2 W4 W6 W8 W10 W12六级 根据建筑物开始承受水压力的时间 也可利用60d或90d龄期的试件测定抗渗等级 抗渗等级是以28d龄期的标准试件 按标准试验方法进行试
  • jeeplus-js-获取table中复选框选中的列

    function getSelectedIds var str var ids contentTable tbody tr td input i checks checkbox each function if true this is c
  • 解决error C2065:"..."未声明的标识符,C2065:语法错误: 标识符“...”

    网狐项目工程中有时候会出现 C2065错误 一般情况下有可能是 项目工程配置出错 只需要选择 Visual Studio 2013 v120 就可以了
  • 算法笔记(5)-K最近邻算法及python代码实现

    K最近邻算法既可以用于分类又可以用于回归 K最近邻 k Nearest Neighbor KNN 算法分类的基本原理 如果一个样本在特征空间中的k个最相似 即特征空间中最邻近 的样本中的大多数属于某一个类别 则该样本也属于这个类别 K最近邻
  • 分析解决【No module named ‘triton‘】的问题

    文章目录 一 现象 二 分析 三 安装 3 1 项目虚拟环境 3 2 环境版本问题 三 与主题无关 一 现象 在Windows11下训练Stable Diffusion的LoRA模型的时候 总是重复提示 A matching Triton
  • nginx 配置文件关键字

    L1 location root html index index html index htm L1可以匹配到请求127 0 0 1 127 0 0 1 root html 是一个相对路径 表示以ng安装路径下html目录查找 index
  • x99芯片组服务器版叫什么,Intel X99主板、Z97主板以及H97主板的区别是什么?

    Intel X99主板 Z97主板以及H97主板的区别是什么 虽然让他出来丢人现眼 但是这个帖子却是冒了一定风险码出来的 题目就已经让人败坏了兴致 拿出来只会让专业人士嗤之以鼻 作为目前最新的Intel主板芯片组 9系列主板并没有受到很多朋
  • 印刷纸张尺寸,纸张种类规格

    印刷纸张尺寸 纸张种类规格 2007 07 25 15 17 正度16开 185x260 大度16开 210x285 开 又是什么单位 全开的纸能开出来多少张 就是多少开 例如 16K的就是全开开出16张 对开就是开出两张 一开是多大 全开
  • Cadence Orcad原理图导出pdf文件

    1 安装虚拟打印机 通过打印实现pdf文件输出 略 2 设置输出格式 1 点击file gt print gt print setup 2 设置尺寸与输出方向 3 点击确定 gt OK 开始转化并设置保存路径与文件名 3 打印出现页码错乱问
  • jquery获取隐藏元素的宽度高度

    info show 50 function var w info outerWidth console log w 注意 show的第一个参数不能为0否则在刷新页面或页面默认载入并显示该隐藏元素的时候 w仍然为0 虽然通过单击事件可以获取到
  • LeetCode343-整数拆分

    昨天晚上写作业时 腾讯的一条笔试通知邮件 着实让我有点吃惊 我三月份就投了鹅厂 身边的朋友早就面试了 自己的简历杳无音讯 本早就放弃了 却没想还能最后有次机会 害 要好好珍惜了 题目描述 给定一个正整数 n 将其拆分为至少两个正整数的和 并
  • CVPR2023 多目标跟踪(MOT)汇总

    一 OVTrack Open Vocabulary Multiple Object Tracking 作者 Siyuan Li Tobias Fischer Lei Ke Henghui Ding Martin Danelljan Fish
  • 初步使用LSTM网络

    声明 文中所有使用图片均来自网络 侵删 什么是LSTM LSTM Long Short Term Memory网络 是一种特殊的RNN类型 可以有效解决RNN神经网络存在的长期依赖问题 通过模仿人脑可以进行遗忘的功能 在每一个LSTM模块中