pytorch学习笔记 —— torch.nn.LSTM

2023-10-29

使用 torch.nn.LSTM 可以方便的构建 LSTM,不熟悉 LSTM 的可以先看这两篇文章:

RNN:https://blog.csdn.net/yizhishuixiong/article/details/105588233

LSTM:https://blog.csdn.net/yizhishuixiong/article/details/105572296


下面详细讲述 torch.nn.LSTM 的使用

torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False)

  • input_size:输入数据的大小;
  • hidden_size:隐藏层的大小(节点数量),输出向量的维度等于节点数量;
  • num_layers:recurrent layer 的数量(默认为1);
  • bias:默认为 True;
  • batch_first:输入输出维度的第一维是否为 batch_size。若为True,则 batch_size 在第一维,若为 False(默认),则 batch_size 在第二维;
  • dropout:若非0,则在除了最后一层的各层都使用 dropout 层,默认为0;
  • bidirectional:若为 True,则使用双向 LSTM,默认为 False;

LSTM 的输入:input,(h_0,c_0)

  • input:输入数据,shape 为(句子长度seq_len, 句子数量batch, 每个单词向量的长度input_size);
  • h_0:默认为0,shape 为(num_layers * num_directions单向为1双向为2, batch, 隐藏层节点数hidden_size);
  • c_0:默认为0,shape 为(num_layers * num_directions, batch, hidden_size);

LSTM 的输出:output,(h_n,c_n)

  • output:输出的 shape 为(seq_len, batch, num_directions * hidden_size);
  • h_n:shape 为(num_layers * num_directions, batch, hidden_size);
  • c_n:shape 为(num_layers * num_directions, batch, hidden_size);

代码演示

import torch
import torch.nn as nn

rnn = nn.LSTM(10, 20, 3)   # 一个单词向量长度为10,隐藏层节点数为20,LSTM数量为3
input = torch.randn(8, 3, 10)   # batch_size为3(输入数据有3个句子),每个句子有8个单词,每个单词向量长度为10
h_0, c_0 = torch.randn(3, 3, 20), torch.randn(3, 3, 20)
output, (h_n, c_n) = rnn(input, (h_0, c_0))

print("input.shape:", input.shape)
print("h_0.shape:", h_0.shape)
print("c_0.shape:", c_0.shape)
print("*" * 50)
print("output.shape:", output.shape)
print("h_n.shape:", h_n.shape)
print("c_n.shape:", c_n.shape)

双向:

import torch
import torch.nn as nn

rnn = nn.LSTM(10, 20, 3, bidirectional=True)   # 一个单词向量长度为10,隐藏层节点数为20,LSTM数量为3,双向
input = torch.randn(8, 3, 10)   # batch_size为3(输入数据有3个句子),每个句子有8个单词,每个单词向量长度为10
h_0, c_0 = torch.randn(6, 3, 20), torch.randn(6, 3, 20)
output, (h_n, c_n) = rnn(input, (h_0, c_0))

print("input.shape:", input.shape)
print("h_0.shape:", h_0.shape)
print("c_0.shape:", c_0.shape)
print("*" * 50)
print("output.shape:", output.shape)
print("h_n.shape:", h_n.shape)
print("c_n.shape:", c_n.shape)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch学习笔记 —— torch.nn.LSTM 的相关文章

随机推荐

  • lr中需要注意的点--安装后打不开ie需要设置的地方

    前提 Loadrunner11仅支持ie9向下版本 若安装了ie10则需要在查看一安装的更新中卸载 1 设置ie浏览器为默认浏览器 2 计算机 gt 属性 gt 高级系统管理 gt 性能 gt 设置 过程中会提示重启 3 tools gt
  • spring 和springboot 整合rabbitmq

    文章目录 spring springboot整合 rabbitmq 4 1 spring 整合rabbitmq 4 2 springboot 整合rabbitmq spring springboot整合 rabbitmq rabbitmq
  • Java文档注释

    Java文档注释 Doc umentation Comments 注意不要将注解 Annotation 与注释 Comments 混淆 Java的有三种注释 1 单行注释 注释内容 2 多行注释 注释内容 3 文档注释 注释内容 Java文
  • vue z-index层级显示问题

    一个单页面 顶部有fixed的nav 当向上滑动页面时 发现nav里有的组件被下放的组件遮盖 第一时间明白这时需要修改层级设置 将下方的组件z index设为 1 nav的组件z index调高 发现还是有各种遮盖的问题 然后花了点时间找资
  • linux查看某个应用占用多少线程

    以tomcat为例 获取tomcat进程pid ps ef grep tomcat 10090 统计该tomcat进程内的线程个数 ps Lf 10090 wc l 数量就是该tomcat启动了多少线程
  • java异常(机制和捕捉(常见异常类))详解 +练习题

    Java 中的异常处理机制 1 什么是异常 异常 程序在运行过程中产生的不正常情况 程序在运行的时候 发生了一些不被预期的事件 从而没有按照我们编写的代码执行 这就是异常 异常是Java中的错误 但是并不是所有的错误都是异常 比如说 你在定
  • zookeeper的动态扩容

    附属意义的扩容 扩容的新增节点为观察者observer 1 观察者概念 a 在zookeeper引入此新的zookeeper节点类型为observer 是为了帮助处理投票成本随着追随者增加而增加的问题并且进一步完善了zookeeper的可扩
  • 研一寒假C++复习笔记--运算符重载实例

    目录 1 运算符重载 2 加号运算符重载 3 左移运算符重载 lt lt 4 递增运算符重载 5 赋值运算符重载 6 关系运算符重载 7 函数调用运算符重载 1 运算符重载 对已有运算符重新进行定义 赋予其另一种功能 以适应不同的数据类型
  • 数组扁平化flat方法的多种实现

    let arr 1 2 3 4 5 6 7 8 9 10 11 12 1 flat console log arr flat Infinity 2 toString console log arr toString split map it
  • puppet配置

    作为重量级批量自动化运维利器 puppet可以方便大批量停止或启动服务 比如我们经常需在一下停止几十台 mysql服务器 使用puppet配置分分钟搞定 而不需要一台台去手动停止 非常方便 确认服务器端和客户端正常工作 开始编写module
  • LaTeX出现图片错误代码:Paragraph ended before \Gin@iii was complete.

    问题 LaTeX出现图片错误代码 Paragraph ended before Gin iii was complete 答案 将导言区的 usepackage graphics 替换为 usepackage graphicx
  • node.js中res.writeHead的用法总结

    向请求的客户端发送响应头 该函数在一个请求内最多只能调用一次 如果不调用 则会自动生成一个响应头 因为实际开发中 我们需要返回对应的中文以及对应的的文本格式 所以我们需要设置对应的响应头 响应头决定了对应的返回数据的格式以及编码格式 使用方
  • 机器学习实战第十章 k均值聚类

    k均值聚类 文章目录 k均值聚类 什么是k均值聚类 具体实现 二分k均值聚类 实验 小结 什么是k均值聚类 试想一下 如果给一张图如下 要求对这张图中的点分类 你会怎么进行呢 我们当然可以认为所有的点都只有一个种类 毕竟他们本身只有坐标不同
  • 2023华为OD机试真题【数组合并】

    题目内容 现在有多组整数数组 需要将他们合并成一个新的数组 合并规则 从每个数组里按顺序取出固定长度的内容合并到新的数组中 取完的内容会删除掉 如果该行不足固定长度或者已经为空 则直接取出剩余部分的内容放到新的数组中 继续下一行 如样例1
  • 数据挖掘中常用的数据清洗方法

    在数据挖掘过程中 数据清洗主要根据探索性分析后得到的一些结论入手 然后主要对四类异常数据进行处理 分别是缺失值 missing value 异常值 离群点 去重处理 Duplicate Data 以及噪音数据的处理 1 探索性分析 探索性分
  • windows 设置exe文件开机自启动

    设置本地exe服务文件开机自启动 编辑up bat 内容如下 注意 binPath 后面必须有一个空格 echo off sc create Test binPath C Users test exe start auto start C
  • SUSAN边缘检测

    核同值区 USAN 相对于模板的核 模板中总有一定的区域与它有相同的灰度 这部分区域称为USAN区域 当核像素处在图像中的灰度一致区域 USAN的面积最大 当核处在直边缘处面积约为最大值的一半 当核处在角点处时则为最大值的1 4 因此 使用
  • 洛谷 P1009 [NOIP1998 普及组] 阶乘之和

    题目链接 https www luogu com cn problem P1009 思路 计算阶乘相当于大整数 1 1 1 依次乘以 1 n 1 sim n
  • unity日记4(鼠标键盘交互、实例)

    目录 鼠标事件 鼠标点击 抬起 长按事件 键盘事件 键盘点击 抬起 长按事件 键盘键位替换 实例 鼠标 音乐播放 暂停 实例 调用其他对象的组件 双方法 实例 调整其他对象的公有参数 鼠标事件 鼠标点击 抬起 长按事件 左键0 右键1 中键
  • pytorch学习笔记 —— torch.nn.LSTM

    使用 torch nn LSTM 可以方便的构建 LSTM 不熟悉 LSTM 的可以先看这两篇文章 RNN https blog csdn net yizhishuixiong article details 105588233 LSTM