[Pytorch系列-53]:循环神经网络 - torch.nn.LSTM()参数详解

2023-10-28

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121644547


目录

第1章 LSTM基本理论

第2章 torch.nn.LSTM类的参数详解

2.1 类的原型

 2.2 类的参数:用于构建LSTM神经网络实例

第3章 前向传播输入详解

3.1 前向传播的格式

3.2 input的格式

3.3 h_0的格式

3.4 c_0的格式

第4章 前向传播输出详解

4.1 输出返回值的格式

4.2 output的格式

4.3 h_n的格式

4.4 c_n的格式



第1章 LSTM基本理论

[人工智能-深度学习-52]:RNN的缺陷与LSTM的解决之道_文火冰糖(王文兵)的博客-CSDN博客作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客本文网址:目录第1章 RNN的缺陷1.1 RNN的前向过程1.2RNN反向求梯度过程1.2 梯度爆炸(每天进一步一点点,N天后,你就会腾飞)1.3 梯度弥散/消失(每天堕落一点点,N天后,你就彻底完蛋)1.4RNN网络梯度消失的原因1.5 解决“梯度消失“的方法主要有:1.5 RNN网络的功能缺陷第2章 LSTM长短期记忆网络2.1 LSTM概述2.3LSTM网...https://blog.csdn.net/HiWangWenBing/article/details/121547541

第2章 torch.nn.LSTM类的参数详解

2.1 类的原型

图1 LSTM内部原理介绍 没啥好说的

 2.2 类的参数:用于构建LSTM神经网络实例

图2 关键参数介绍

input_size: 输入序列的一维向量的长度。

hidden_size: 隐层的输出特征的长度。

num_layers:隐藏层堆叠的高度,用于增加隐层的深度。

bias:是否需要偏置b

batch_first:于确定batch size是否需要放到输入输出数据形状的最前面。

若为真,则输入、输出的tensor的格式为(batch , seq , feature)

若为假,则输入、输出的tensor的格式为(seq,  batch , feature)

为什么需要该参数呢?

在CNN网络和全连接网络,batch通常位于输入数据形状的最前面

而对于具有时间信息的序列化数据,通常需要把seq放在最前面,需要把序列数据串行地输入网络中。

dropout: 默认0 若非0,则为dropout率。
bidirectional:是否为双向LSTM, 默认为否

第3章 前向传播输入详解

3.1 前向传播的格式

lstm (input, (h_0, c_0))

  • input:输入序列样本
  • h_0:先前的短期状态记忆
  • c_0:先前的长期状态记忆

3.2 input的格式

(1)当batch_first=false(默认)

input的形状为(seq_len,  batch,  input_size)

(2)当batch_first=true

input的形状为(batch, seq_len,  input_size)

(3)参数解读

  • batch:batch的长度
  • input_size:输入样本的向量长度。
  • seq_len:输入序列的长度,一次可以串行输入多个输入样本

3.3 h_0的格式

若h_0和c_0不提供,则默认为全0 .

h_0是格式为: (num_layers * num_directions, batch, hidden_size) 的tensor

  • num_directions:由隐层的层数决定
  • batch:batch
  • hidden_size:隐层输出特征的长度

3.4 c_0的格式

若c_0不提供时,则默认为全0 .

c_0是格式为: (num_layers * num_directions, batch, hidden_size) 的tensor

  • num_directions:由隐层的层数决定
  • batch:batch
  • hidden_size:隐层输出特征的长度

第4章 前向传播输出详解

4.1 输出返回值的格式

output, (h_n, c_n) = lstm (input, (h_0, c_0))

4.2 output的格式

(1)当batch_first=false(默认)

input的形状为(seq_len,  batch,  num_layers * num_directions,)

(2)当batch_first=true

input的形状为(batch, seq_len,  num_layers * num_directions,)

(3)参数解读

  • batch:batch的长度
  • num_directions:取决于堆叠的长度
  • seq_len:输入序列的长度,输出的长度与输入序列的长度一致。

4.3 h_n的格式

h_0是格式为: (num_layers * num_directions, batch, hidden_size) 的tensor

  • num_directions:由隐层的层数决定
  • batch:batch
  • hidden_size:隐层输出特征的长度

4.4 c_n的格式

c_0是格式为: (num_layers * num_directions, batch, hidden_size) 的tensor

  • num_directions:由隐层的层数决定
  • batch:batch
  • hidden_size:隐层输出特征的长度

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121644547

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

[Pytorch系列-53]:循环神经网络 - torch.nn.LSTM()参数详解 的相关文章

  • 时间序列数据和 LSTM 中分类的实体嵌入

    我正在尝试解决时间序列问题 简而言之 对于每个客户和材料 SKU代码 我过去都下了不同的订单 我需要建立一个模型来预测每个客户和材料下一次订单之前的天数 我想做的是在 Keras 中构建一个 LSTM 模型 其中对于每个客户和材料 我有 5
  • Keras LSTM 中的维度不匹配

    我想创建一个可以添加两个字节的基本 RNN 以下是输入和输出 需要进行简单的加法 X 0 0 0 1 1 1 0 1 1 0 1 0 1 1 1 0 那是 X1 00101111 and X2 01110010 Y 1 0 1 0 0 0
  • 无法将 NumPy 数组转换为张量(不支持的对象类型字典)

    我的方法我认为问题是 history model fit generator train generator epochs epochs steps per epoch train steps verbose 1 callbacks che
  • 如何在 Keras 中使用 model.reset_states() ?

    我有顺序数据 并且声明了一个 LSTM 模型来预测y with x在喀拉斯 所以如果我打电话model predict x1 and model predict x2 调用是否正确model reset states两者之间predict
  • 如何将 Shap 与 LSTM 神经网络结合使用?

    我正在与 keras 合作生成 LSTM 神经网络模型 我想使用 shap 包查找模型每个特征的 Shapley 值 当然 问题在于模型的 LSTM 层需要三维输入 样本 时间步长 特征 但 shap 包需要二维输入 无论如何 这个问题有解
  • 如何实现每个时间步都带有向量输入的LSTM网络?

    我正在尝试在 Tensorflow 中创建一个生成 LSTM 网络 我有这样的输入向量 0 0 1 0 1 0 0 0 1 0 1 0 0 0 0 1 0 1 该矩阵中的每个向量都是一个时间步 或者换句话说 每个向量应该是 LSTM 的一个
  • 带有嵌入层的 Keras LSTM 自动编码器

    我正在尝试在 Keras 中构建一个文本 LSTM 自动编码器 我想使用嵌入层 但我不确定如何实现 代码如下所示 inputs Input shape timesteps input dim embedding layer Embeddin
  • 在 Keras 中,当我创建具有 N 个“单元”的有状态“LSTM”层时,我到底要配置什么?

    正常的第一个参数Dense层也是units 是该层中神经元 节点的数量 然而 标准 LSTM 单元如下所示 这是 的修改版本 了解 LSTM 网络 http colah github io posts 2015 08 Understandi
  • 多特征因果 CNN - Keras 实现

    我目前正在使用基本的 LSTM 进行回归预测 并且我想实现一个因果 CNN 因为它的计算效率应该更高 我正在努力弄清楚如何重塑当前的数据以适应因果 CNN 单元并表示相同的数据 时间步关系以及扩张率应设置为多少 我当前的数据是这样的 num
  • Keras LSTM:检查模型输入维度时出错

    我是 keras 的新用户 正在尝试实现 LSTM 模型 为了测试 我声明了如下所示的模型 但由于输入维度的差异而失败 虽然我在这个网站上发现了类似的问题 但我自己无法发现我的错误 ValueError Error when checkin
  • 了解 Tensorflow LSTM 模型输入?

    我在理解 TensorFlow 中的 LSTM 模型时遇到一些困难 我用tflearn http tflearn org 作为包装器 因为它自动完成所有初始化和其他更高级别的工作 为了简单起见 我们考虑这个示例程序 https github
  • 如何使用有状态 LSTM 和 batch_size > 1 布置训练数据

    背景 我想在 Keras 中对 有状态 LSTM 进行小批量训练 我的输入训练数据位于一个大矩阵 X 中 其维度为 m x n 其中 m number of subsequences n number of time steps per s
  • .fit() 层的 shuffle = 'batch' 参数如何在后台工作?

    当我使用以下方法训练模型时 fit 层的参数 shuffle 预设为 True 假设我的数据集有 100 个样本 批量大小为 10 当我设置shuffle True然后 keras 首先随机选择样本 现在 100 个样本具有不同的顺序 根据
  • 如何为 keras lstm 输入重塑数据?

    我是 Keras 新手 我发现很难理解 LSTM 层输入数据的形状 Keras 文档表示输入数据应该是形状为 nb samples timesteps input dim 的 3D 张量 我有808信号 每个信号有22个通道和2000个数据
  • Keras:嵌入 LSTM

    在 LSTM 的 keras 示例中 用于对 IMDB 序列数据进行建模 https github com fchollet keras blob master examples imdb lstm py https github com
  • 张量流:简单 LSTM 网络的共享变量错误

    我正在尝试构建一个最简单的 LSTM 网络 只是想让它预测序列中的下一个值np input data import tensorflow as tf from tensorflow python ops import rnn cell im
  • 在 Tensorflow 2.0 中的简单 LSTM 层之上添加 Attention

    我有一个由一个 LSTM 和两个 Dense 层组成的简单网络 如下所示 model tf keras Sequential model add layers LSTM 20 input shape train X shape 1 trai
  • LSTM 批次与时间步

    我按照 TensorFlow RNN 教程创建了 LSTM 模型 然而 在这个过程中 我对 批次 和 时间步长 之间的差异 如果有的话 感到困惑 并且我希望得到帮助来澄清这个问题 教程代码 见下文 本质上是根据指定数量的步骤创建 批次 wi
  • Caffe 的 LSTM 模块

    有谁知道 Caffe 是否有一个不错的 LSTM 模块 我从 russel91 的 github 帐户中找到了一个 但显然包含示例和解释的网页消失了 以前是http apollo deepmatter io http apollo deep
  • 验证 Transformer 中多头注意力的实现

    我已经实施了MultiAttention head in Transformers 周围有太多的实现 所以很混乱 有人可以验证我的实施是否正确 DotProductAttention 引用自 https www tensorflow org

随机推荐

  • Python3 goto 语句的使用

    熟悉 C 语言的小伙伴一定对 goto 语句不陌生 它可以在代码之间随意的跳来跳去 但是好多老鸟都告诫大家 不要使用 goto 因为 goto 会使你的代码逻辑变的极其混乱 但是有时候我们不得不用它 因为它太高效了 比如进入循环内部深层一个
  • 学生管理系统C语言

    include
  • Win10 隐藏远程桌面,连接栏

    https www cnblogs com tuhong p 13307579 html 快捷键 Ctrl Alt Home
  • Django 省、市、区 三级联动 及数据库的地址添加 !!!

    应用场景 淘宝 京东 等需要地址的地方 Models py模型 from django db import models Create your models here class Area models Model name models
  • Redis可视化工具RedisInsight

    今天是老苏居家隔离的第 39 天 周五抗原 周六 周日 周一每天两次抗原 上午一次 下午一次 没完没了的捅鼻子 感觉都要捅出鼻炎了 虽然小区早就是防范区了 但是一直处于提级管理中 还是不能出小区 也看不到任何松动的迹象 最近几天都在传 一人
  • R reason ‘拒绝访问‘的解决方案

    Win11系统 安装rms的时候报错 Error in loadNamespace j lt i 1L c lib loc libPaths versionCheck vI j namespace Matrix 1 5 4 1 is alr
  • 使用Thinkphp5.0 中 {include file="index/left" /} 引入模板 影响样式

    在使用Thinkphp 5 0框架开发后台的时候 需要在模板中引用公共头部 我使用 include file index header 的方式引用了公共头部 引用之后发现头部和页面顶端之间出现了间距 未引用之前 头部和页面顶端是没有间距的
  • Azure文件同步服务的创建和配置

    将Azure FileShare share1同步到Server Endpoint 在这没法添加 只能管理服务 选择 Create a resource 查找 azure file sync 注意 选择的Location 一定要与File服
  • PLY文件格式及其MATLAB读写操作

    PLY是一种电脑档案格式 全名为多边形档案 Polygon File Format 或 斯坦福三角形档案 Stanford Triangle Format 史丹佛大学的 The Digital Michelangelo Project计划采
  • Arduino实验三:伺服马达

    目录 前言 1 伺服马达 1 1 相关参数 1 2实物图 1 3连接线路图 1 4程序代码 1 5运行结果 前言 伺服马达和直流马达的区别 伺服马达有3条接入线 在输入信号的控制下 能够转动特定角度 其中三条线中 红色线接正极 棕色线接地
  • App常见内存泄漏以及解决方法

    如果是想认真学习的话 请先不要看以下内容 此记录只是为加深记忆 可能会有错误的地方 以免有误导 学习转载链接 https blog csdn net u014674558 article details 62234008 App常见内存泄漏
  • python怎么实现类似#define宏定义

    我怎么了 怎么突然问出这个问题 一时还认真的点进了论坛 面壁思过一下 python是解释性语言 不需要编译 define是预编译阶段起作用的 python没得必要 在c语言中 define在调试或者多平台兼容的时候很有用 特别是 defin
  • [内核文档系列] NMI 看门狗

    内核文档系列 NMI 看门狗 秦白衣 qinchenggang sict ac cn X86和X86 64体系结构均支持NMI看门狗 你的系统是不是会经常被锁住 Lock up 直至解锁 系统不再响应键盘 你希望帮助我们解决类似的问题吗 如
  • AttributeError: 'numpy.dtype' object has no attribute 'base_dtype'

    AttributeError numpy dtype object has no attribute base dtype 这个错误其实是有说keras版本有点高的问题 我调低了 Keras 2 0 2 具体他有没有影响 我没有验证 但是下
  • 机器学习K-均值——nonzero(clusterAssment[冒号,0].A==cent的一步步操作演示,看完你就明白了

    先准备测试数据 如下 上面都是准备数据 下面才是一步步的告诉你怎么生成我们要的数据 原文链接 https blog csdn net xinjieyuan article details 81477120
  • Educoder C&C++线性表实训

    目录 第1关 顺序构建线性表 第2关 逆序构建线性表 第3关 排序构建线性表 第4关 查找元素 第5关 删除指定位置的结点 第6关 删除包含特定数据的结点 第7关 线性表长度 第8关 线性表应用一 栈 第9关 线性表应用二 队列 第10关
  • Min Difference

    C Min Difference Problem Statement You are given two sequences A A1 A2 AN consisting of NN positive integers and B B1 BM
  • JenKins工作流程

    程序员提交代码到Git SVN仓库 触发钩子程序向 JenKins 进行通知 Jenkins 调用Git SVN插件获取源码 调用Maven打包为war包 调用Deploy to web container插件部署到Tomcat服务器
  • 无法登陆宝塔面板?宝塔界面为什么无法访问?宝塔面板登陆不上?宝塔面板打不开解决办法

    很多小伙伴很久没有登陆宝塔面板 再次打开宝塔面板就出现了上面这种情况 下面张大哥介绍几个排查方法 帮助大家协助解决一下此类问题 第一 请检查你是否在安全组开放8888端口 一般安装环境时默认为8888端口 如果更改为其他自定义端口的话 需要
  • [Pytorch系列-53]:循环神经网络 - torch.nn.LSTM()参数详解

    作者主页 文火冰糖的硅基工坊 文火冰糖 王文兵 的博客 文火冰糖的硅基工坊 CSDN博客 本文网址 https blog csdn net HiWangWenBing article details 121644547 目录 第1章 LST