基于Pytorch实现LSTM(多层LSTM,双向LSTM)进行文本分类

2023-10-27

LSTM原理请看这:点击进入

LSTM:

在这里插入图片描述

nn.LSTM(input_size, 
		hidden_size, 
		num_layers=1, 
		nonlinearity=tanh, 
		bias=True, 
		batch_first=False, 
		dropout=0, 
		bidirectional=False)


input_size:表示输入 xt 的特征维度
hidden_size:表示输出的特征维度
num_layers:表示网络的层数
nonlinearity:表示选用的非线性激活函数,默认是 ‘tanh’
bias:表示是否使用偏置,默认使用
batch_first:表示输入数据的形式,默认是 False,就是这样形式,(seq, batch, feature),也就是将序列长度放在第一位,batch 放在第二位
dropout:表示是否在输出层应用 dropout
bidirectional:表示是否使用双向的 LSTM,默认是 False。

import  torch
from    torch import nn
from    torch.nn import functional as F 
lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=2,bidirectional=False)
# 可理解为一个字串长度为5, batch size为3, 字符维度为10的输入
input_tensor  = torch.randn(5, 3, 10)
# 两层LSTM的初始H参数,维度[layers, batch, hidden_len]
 #在lstm中c和h是不一样的,而RNN中是一样的
h0,c0 = torch.randn(2,3, 20),torch.randn(2,3, 20)
# output_tensor最后一层所有的h输出, hn表示两层最后一个时序的输出, cn表示两层最后一个时刻的状态的输出
output_tensor, (hn,cn) =lstm(input_tensor, (h0,c0))
print(output_tensor.shape, hn.shape,cn.shape)

torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) torch.Size([2, 3, 20])

从上面可以看到输出的h,x,和输入的h,x维度一致。
上面的参数中,num_layers=2相当于有两个rnn cell串联,即一个的输出h作为下一个的输入x。也可单独使用两个nn.LSTMCell实现

而当我们设置成双向LSTM时,即bidirectional=True

lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=2,bidirectional=True) 
h0,c0 = torch.randn(4,3, 20),torch.randn(4,3, 20)

torch.Size([5, 3, 40]) torch.Size([4, 3, 20]) torch.Size([4, 3, 20])

一共5个时刻,可以看到最后一时刻的output维度是[3, 40],因为nn.LSTM模块他在最后会将正向和反向的结果进行拼接concat。而hn中的4是指正反向,还有因为num_layers是两层所以为4。

output_tenso只输出最后一层!!!的所有时刻的状态输出(且正向和反向拼接好了。而hn和cn包含所有层,所有方向的最后时刻的输出。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于Pytorch实现LSTM(多层LSTM,双向LSTM)进行文本分类 的相关文章

  • 将静态数据(不随时间变化)添加到 LSTM 中的序列数据

    我正在尝试建立一个如下图所示的模型 请看下图 我想在 LSTM 层中传递序列数据 在另一个前馈神经网络层中传递静态数据 血型 性别 后来我想将它们合并 然而 我对这里的维度感到困惑 如果我的理解是正确的 如图所示 5维序列数据如何与4维静态
  • 为什么我的 keras LSTM 模型陷入无限循环?

    我正在尝试构建一个小型 LSTM 它可以通过在现有 Python 代码上进行训练来学习编写代码 即使是垃圾代码 我已将数百个文件中的数千行代码连接到一个文件中 每个文件以
  • 如何在 python-gensim 中使用潜在狄利克雷分配(LDA)来抽象二元组主题而不是一元组?

    LDA 原始输出 一元语法 主题1 水肺 水 蒸汽 潜水 主题2 二氧化物 植物 绿色 碳 所需输出 二元组主题 主题1 水肺潜水 水蒸气 主题2 绿色植物 二氧化碳 任何想法 鉴于我有一个名为docs 包含文档中的单词列表 我可以使用 n
  • 如何使用 python 中的 spacy 库将句子转换为问题 [请参阅下面的我的代码进行更正]

    我需要使用 python 中的 spacy 将任何句子转换为问题 我下面的代码太长了 我需要做更多的工作才能将任何句子完成为问题格式 现在在这段代码中我根据以下条件制定条件是形式 需要形式 有形式 做形式通过检查过去时和现在时 输入 尼娜拉
  • 尝试理解 Pytorch 的 LSTM 实现

    我有一个包含 1000 个示例的数据集 其中每个示例都有5特征 a b c d e 我想喂7LSTM 的示例 以便它预测第 8 天的特征 a 阅读 nn LSTM 的 Pytorchs 文档 我得出以下结论 input size 5 hid
  • word2vec中单词的向量代表什么?

    word2vec https code google com p word2vec 是 Google 的开源工具 它为每个单词提供一个浮点值向量 它们到底代表什么 还有一篇论文关于段落向量 http cs stanford edu quoc
  • 在Python或Sklearn中用整数值对具有字符串值的列变量进行编码

    如何用整数值对数据表中字符串类型的列值进行编码 例如 我有两个特征变量 颜色 可能的字符串值 R G 和 B 和技能 可能的字符串值 C Java SQL 和 Python 给定数据表有两列 Color gt R G B B G R B G
  • SGDClassifier 每次为文本分类提供不同的准确度

    我使用 SVM 分类器将文本分类为好文本和乱码 我正在使用 python 的 scikit learn 并按如下方式执行 Created on May 5 2017 import re import random import numpy
  • IOB 准确度和精密度之间的差异

    我正在使用命名实体识别和分块器对 NLTK 进行一些工作 我使用重新训练了分类器nltk chunk named entity py为此 我采取了以下措施 ChunkParse score IOB Accuracy 96 5 Precisi
  • 如何有效计算文档流中文档之间的相似度

    我收集文本文档 在 Node js 中 其中一个文档i表示为单词列表 考虑到新文档以文档流的形式出现 计算这些文档之间相似性的有效方法是什么 我目前对每个文档中单词的归一化频率使用余弦相似度 我不使用 TF IDF 词频 逆文档频率 因为我
  • scikit加权f1分数计算及使用

    我有一个关于weightedsklearn metrics f1 score 中的平均值 sklearn metrics f1 score y true y pred labels None pos label 1 average weig
  • Keras:嵌入 LSTM

    在 LSTM 的 keras 示例中 用于对 IMDB 序列数据进行建模 https github com fchollet keras blob master examples imdb lstm py https github com
  • 使用我自己的训练示例训练 spaCy 现有的 POS 标记器

    我正在尝试在我自己的词典上训练现有的词性标注器 而不是从头开始 我不想创建一个 空模型 在spaCy的文档中 它说 加载您想要统计的模型 下一步是 使用add label方法将标签映射添加到标记器 但是 当我尝试加载英文小模型并添加标签图时
  • 生成易于记忆的随机标识符

    与所有开发人员一样 我们在日常工作中不断处理某种标识符 大多数时候 它与错误或支持票有关 我们的软件在检测到错误后 会创建一个包 该包的名称由时间戳和版本号格式化 这是创建合理唯一标识符以避免混淆包的一种廉价方法 例子 错误报告 20101
  • 语音识别中如何处理同音词?

    对于那些不熟悉什么是同音字 https en wikipedia org wiki Homophone是的 我提供以下示例 我们的 是 嗨和高 到 太 二 在使用时语音API https developer apple com docume
  • 用于估计(一元)困惑度的 NLTK 包

    我正在尝试计算我所拥有的数据的困惑度 我正在使用的代码是 import sys sys path append usr local anaconda lib python2 7 site packages nltk from nltk co
  • 将复数名词转换为单数名词

    如何使用 R 将复数名词转换为单数名词 我使用 tagPOS 函数来标记每个文本 然后提取所有标记为 NNS 的复数名词 但是如果我想将这些复数名词转换为单数该怎么办 library openNLP library tm acq o lt
  • 保存具有自定义前向功能的 Bert 模型并将其置于 Huggingface 上

    我创建了自己的 BertClassifier 模型 从预训练开始 然后添加由不同层组成的我自己的分类头 微调后 我想使用 model save pretrained 保存模型 但是当我打印它并从预训练上传时 我看不到我的分类器头 代码如下
  • 如何在R中使用OpenNLP获取POS标签?

    这是 R 代码 library NLP library openNLP tagPOS lt function x s lt as String x word token annotator lt Maxent Word Token Anno
  • 将 Dropout 与 Keras 和 LSTM/GRU 单元结合使用

    在 Keras 中 您可以像这样指定 dropout 层 model add Dropout 0 5 但对于 GRU 单元 您可以将 dropout 指定为构造函数中的参数 model add GRU units 512 return se

随机推荐

  • 【C++】_4.内存分布

    目录 1 C C 内存分布 2 C语言的动态内存管理方式 3 C 内存管理方式 3 1 new delete 操作内置类型 3 2 new delete 操作自定义类型 4 operator new 与operator delete函数 5
  • RuntimeError: cuda runtime error (30)解决

    程序出错如上 而且总是伴随着黑屏 一开始以为是cuda跑出问题 而且该问题必须重启才能解决 但是一直很好奇我的电脑Ubuntu18 04设置了黑白屏从不 还是出现该错误 最后为了复现该错误就强制锁屏 果然错误复现 找到原因之后就可以比较好解
  • vue前端使用Docker部署

    在上一篇文章中 我们介绍了如果在CentOS上安装docker环境 本文则是介绍docker的具体项目实践 主要介绍如果通过docker容器来部署vue前端项目 本文需要基于vue项目已经开发完成 并且docker环境已经准备好 思路是Do
  • SQL基础(1)

    1 Where条件语句 使用Were语句指定搜索条件过滤返回的数据 用于提取满足指定条件 语法 select b Sid b Sname a score from sc a join Student b on a Sid b Sid whe
  • ChatGPT和智能化能源:如何应用于能源领域的智能化生产和能源管理?

    Chatgpt Chat Gpt 小智Ai Chat小智 Gpt小智 ChatGPT小智Ai GPT小智 GPT小智Ai Chat小智Ai 丨 随着社会的发展和工业化的进程 能源需求不断增加 如何实现能源的高效 低碳 安全 可持续发展成为了
  • python数据分析绘图

    ROC AUC曲线 分类模型 混淆矩阵 混淆矩阵中所包含的信息 True negative TN 称为真阴率 表明实际是负样本预测成负样本的样本数 预测是负样本 预测对了 False positive FP 称为假阳率 表明实际是负样本预测
  • Qt:Drag-Drop操作在QGraphicsView及Model/View框架下的实现

    最近使用到Qt的Drag Drop功能 结合自己的例子写下来给大家分享一下 实现从QTreeView拖动Node到QGraphicsView上 以及QGraphicsView上item之间的拖动 先来说Model View中的实现 1 Mo
  • WPF的TChart控件使用---添加直线---标题勾选---提示

    TChart1 Aspect View3D false 控件3D效果 Steema TeeChart WPF Styles Line line1 new Steema TeeChart WPF Styles Line 直线 line1 Ti
  • 微信公众号 Jssdk调用错误码:63002, 获取access_token错误代码 errcode 40164的解决方法,如何解决,微信公众号的坑。

    今晚在开发公众号 需要调用到 Jssdk 结果配置好了 一运行就提示 Errmsg config fail Error 系统错误 错误码 63002 invalid signature 20200108 00 04 41 我的心突然就好慌
  • HTTPS理论基础

    目录 HTTPS原理 密码学基础 HTTPS通信过程 数字证书 本文链接 https blog csdn net iispring article details 51615631 HTTPS原理 我们知道 HTTP请求都是明文传输的 所谓
  • POSTGIS教程

    一 什么是PostgreSQL和PostGIS 1 1 什么是PostgreSQL 说起数据库 大家耳熟能详的商业数据库产品当推Oracle 微软的SqlServer和IBM的 DB2等 而开源数据库中则有两大产品MySQL和Postgre
  • 数据结构PTA 案例6-1.4 地下迷宫探索

    案例6 1 4 地下迷宫探索 题目 解法 题目 假设有一个地下通道迷宫 它的通道都是直的 而通道所有交叉点 包括通道的端点 上都有一盏灯和一个开关 请问你如何从某个起点开始在迷宫中点亮所有的灯并回到起点 输入格式 输入第一行给出三个正整数
  • Linux内核中断系统结构——软中断

    在 Linux异常 中断 处理体系结构 这篇文章 我们详细描写了内核如何进行中断 异常 向量表的初始化 如何初始化硬件中断 IRQ 的操作 在这篇文章中 我们将重心放在软件中断上 也就是 CPU 本身的中断 这篇文章包括五个内容 软中断 t
  • 当电桥为恒流源时惠斯通电桥电压的计算方法

    http wenku baidu com link url S55C CbY IQBl7oqgICODIz765KasqscVU2ACb6xV1OJB1zhLWwvryumayUWtB7V0b3 uHiclyhZtHHMfejUVFuYfd
  • 大数据毕设 - 校园卡数据分析与可视化(python 大数据)

    文章目录 0 前言 1 课题介绍 2 数据预处理 2 1 数据清洗 2 2 数据规约 3 模型建立和分析 3 1 不同专业 性别的学生与消费能力的关系 3 2 消费时间的特征分析 4 Web系统效果展示 5 最后 0 前言 Hi 大家好 这
  • python和opencv利用摄像头进行视频捕获

    python容易上手 利用opencv进行视频录制及后期的人脸识别 都是比较简单易上手的方案 工具 python3 10 opencv4 54 平台 win10 vscode 摄像头捕获程序 import cv2 as cv cap cv
  • Arduino从零开始(2)——控制舵机与步进电机

    0 前言 本文主要介绍通过Arduino控制舵机 步进电机以及循环的使用 目录 0 前言 1 介绍 2 Arduino控制舵机 2 1方法一 2 2方法二 3 Arduino控制步进电机 1 介绍 对于Arduino控制舵机的方法是通过其输
  • 做方差分析需要正态性检验吗_方差分析(SPSS版)

    方差分析 SPSS版 原创 Gently spss学习乐园 2019 10 15 文章同步于 微信公众号 SPSS学习乐园 方差分析 SPSS版 方差分析的基本思想 R A Fisher提出的统计理论基础 将总变异分解为由研究因素所产生的变
  • 计算机系统结构:流水线技术总结

    文章目录 什么是流水线 流水线的分类 流水线的性能指标 流水线设计中的若干问题 非线性流水线的调度 单功能非线性流水线的最优调度 多功能非线性流水线的调度 一条经典的5段流水线 相关与流水线冲突 结构冲突 因硬件资源满足不了指令重叠执行的要
  • 基于Pytorch实现LSTM(多层LSTM,双向LSTM)进行文本分类

    LSTM原理请看这 点击进入 LSTM nn LSTM input size hidden size num layers 1 nonlinearity tanh bias True batch first False dropout 0
Powered by Hwhale