使用Pytorch的LSTM文本分类

2023-05-16

Photo by Christopher Gower on Unsplash
Christopher Gower在 Unsplash上的 照片

介绍 (Intro)

Welcome to this tutorial! This tutorial will teach you how to build a bidirectional LSTM for text classification in just a few minutes. If you haven’t already checked out my previous article on BERT Text Classification, this tutorial contains similar code with that one but contains some modifications to support LSTM. This article also gives explanations on how I preprocessed the dataset used in both articles, which is the REAL and FAKE News Dataset from Kaggle.

欢迎使用本教程! 本教程将教您如何在短短几分钟内构建用于文本分类的双向LSTM 。 如果您还没有签出我以前关于BERT文本分类的文章,那么本教程将包含与该文章相似的代码,但会进行一些修改以支持LSTM。 本文还提供了有关如何预处理这两篇文章中使用的数据集的说明,这是来自Kaggle 的REAL和FAKE News数据集

First of all, what is an LSTM and why do we use it? LSTM stands for Long Short-Term Memory Network, which belongs to a larger category of neural networks called Recurrent Neural Network (RNN). Its main advantage over the vanilla RNN is that it is better capable of handling long term dependencies through its sophisticated architecture that includes three different gates: input gate, output gate, and the forget gate. The three gates operate together to decide what information to remember and what to forget in the LSTM cell over an arbitrary time.

首先,什么是LSTM?为什么要使用它? LSTM代表长期短期记忆网络 ,它属于较大的神经网络类别,称为递归神经网络(RNN) 。 与香草RNN相比,它的主要优点是它具有复杂的体系结构,能够更好地处理长期依赖性,该体系结构包括三个不同的门:输入门,输出门和遗忘门。 这三个门共同操作,以决定在任意时间内在LSTM单元中要记住哪些信息和要忘记哪些信息。

LSTM Cell
LSTM电池

Now, we have a bit more understanding of LSTM, let’s focus on how to implement it for text classification. The tutorial is divided into the following steps:

现在,我们对LSTM有了更多的了解,让我们集中于如何为文本分类实现它。 本教程分为以下步骤:

  1. Preprocess Dataset

    预处理数据集
  2. Importing Libraries

    导入库
  3. Load Dataset

    加载数据集
  4. Build Model

    建立模型
  5. Training

    训练
  6. Evaluation

    评价

Before we dive right into the tutorial, here is where you can access the code in this article:

在我们直接学习本教程之前,您可以在这里访问本文中的代码:

  • Preprocessing of Fake News Dataset

    假新闻数据集的预处理

  • LSTM Text Classification Google Colab

    LSTM文本分类Google Colab

步骤1:预处理数据集 (Step 1: Preprocess Dataset)

The raw dataset looks like the following:

原始数据集如下所示:

Dataset Overview
数据集概述

The dataset contains an arbitrary index, title, text, and the corresponding label.

数据集包含任意索引,标题,文本和相应的标签。

For preprocessing, we import Pandas and Sklearn and define some variables for path, training validation and test ratio, as well as the trim_string function which will be used to cut each sentence to the first first_n_words words. Trimming the samples in a dataset is not necessary but it enables faster training for heavier models and is normally enough to predict the outcome.

对于预处理,我们导入Pandas和Sklearn并定义一些变量,用于路径,训练验证和测试比率,以及trim_string函数,该函数将每个句子剪切为第一个first_n_words单词。 修剪数据集中的样本不是必需的,但是它可以为较重的模型提供更快的训练,并且通常足以预测结果。

Next, we convert REAL to 0 and FAKE to 1, concatenate title and text to form a new column titletext (we use both the title and text to decide the outcome), drop rows with empty text, trim each sample to the first_n_words , and split the dataset according to train_test_ratio and train_valid_ratio. We save the resulting dataframes into .csv files, getting train.csv, valid.csv, and test.csv.

接下来,我们将REAL转换为0,将FAKE转换为1,将标题文本连接起来以形成新的列标题 文本 (我们使用标题和文本来确定结果),删除带有空文本的行,将每个样本修剪为first_n_words ,然后根据train_test_ratiotrain_valid_ratio分割数据集。 我们将结果数据帧保存到.csv文件中,获得train.csvvalid.csvtest.csv

步骤2:导入库 (Step 2: Importing Libraries)

We import Pytorch for model construction, torchText for loading data, matplotlib for plotting, and sklearn for evaluation.

我们导入Pytorch用于模型构建,torchText用于加载数据,matplotlib用于绘图,而sklearn用于评估。

步骤3:载入资料集 (Step 3: Load Dataset)

First, we use torchText to create a label field for the label in our dataset and a text field for the title, text, and titletext. We then build a TabularDataset by pointing it to the path containing the train.csv, valid.csv, and test.csv dataset files. We create the train, valid, and test iterators that load the data, and finally, build the vocabulary using the train iterator (counting only the tokens with a minimum frequency of 3).

首先,我们使用torchText为数据集中的标签创建一个标签字段,并为titletexttitletext创建一个文本字段。 然后,我们通过将TabularDataset指向包含train.csvvalid.csvtest.csv数据集文件的路径来构建它。 我们创建用于加载数据的训练迭代器,有效迭代器和测试迭代器,最后,使用训练迭代器构建词汇表(仅计算最小频率为3的令牌)。

步骤4:建立模型 (Step 4: Build Model)

We construct the LSTM class that inherits from the nn.Module. Inside the LSTM, we construct an Embedding layer, followed by a bi-LSTM layer, and ending with a fully connected linear layer. In the forward function, we pass the text IDs through the embedding layer to get the embeddings, pass it through the LSTM accommodating variable-length sequences, learn from both directions, pass it through the fully connected linear layer, and finally sigmoid to get the probability of the sequences belonging to FAKE (being 1).

我们构造了从nn.Module继承的LSTM类。 在LSTM内部,我们构造了一个Embedding层,然后是bi-LSTM层,最后是一个完全连接的线性层。 在Forward函数中,我们将文本ID穿过嵌入层以获取嵌入,将其穿过LSTM容纳可变长度序列,从两个方向进行学习,将其穿过完全连接的线性层,最后再通过Sigmoid来获得属于FAKE的序列的概率(为1)。

步骤5:训练 (Step 5: Training)

Before training, we build save and load functions for checkpoints and metrics. For checkpoints, the model parameters and optimizer are saved; for metrics, the train loss, valid loss, and global steps are saved so diagrams can be easily reconstructed later.

在训练之前,我们为检查点和指标构建保存和加载功能。 对于检查点,将保存模型参数和优化器; 对于度量,可以保存火车损耗,有效损耗和全局步长,以便以后可以轻松地重建图表。

We train the LSTM with 10 epochs and save the checkpoint and metrics whenever a hyperparameter setting achieves the best (lowest) validation loss. Here is the output during training:

我们用10个时期训练LSTM,并在超参数设置达到最佳(最低)验证损失时保存检查点和度量。 这是训练期间的输出:

The whole training process was fast on Google Colab. It took less than two minutes to train!

在Google Colab上,整个培训过程非常快捷。 培训不到两分钟!

Once we finished training, we can load the metrics previously saved and output a diagram showing the training loss and validation loss throughout time.

完成训练后,我们可以加载先前保存的指标,并输出一个图表,显示整个时间的训练损失和验证损失。

步骤6:评估 (Step 6: Evaluation)

Finally for evaluation, we pick the best model previously saved and evaluate it against our test dataset. We use a default threshold of 0.5 to decide when to classify a sample as FAKE. If the model output is greater than 0.5, we classify that news as FAKE; otherwise, REAL. We output the classification report indicating the precision, recall, and F1-score for each class, as well as the overall accuracy. We also output the confusion matrix.

最后,为了进行评估,我们选择了先前保存的最佳模型,并根据测试数据集对其进行了评估。 我们使用默认阈值0.5来决定何时将样本分类为FAKE。 如果模型输出大于0.5,我们将该新闻分类为FAKE;否则,将其分类为FAKE。 否则为REAL。 我们输出分类报告,指示每个类别的精度,召回率和F1得分以及整体准确性。 我们还输出混淆矩阵。

We can see that with a one-layer bi-LSTM, we can achieve an accuracy of 77.53% on the fake news detection task.

我们可以看到,使用双层Bi-LSTM,我们可以在假新闻检测任务上达到77.53%的准确性。

结论 (Conclusion)

This tutorial gives a step-by-step explanation of implementing your own LSTM model for text classification using Pytorch. We find out that bi-LSTM achieves an acceptable accuracy for fake news detection but still has room to improve. If you want a more competitive performance, check out my previous article on BERT Text Classification!

本教程分步说明了如何使用Pytorch为文本分类实现您自己的LSTM模型。 我们发现bi-LSTM在伪造新闻检测方面达到了可接受的准确性,但仍有改进的空间。 如果您想获得更具竞争力的性能,请查看我以前关于BERT文本分类的文章!

If you want to learn more about modern NLP and deep learning, make sure to follow me for updates on upcoming articles :)

如果您想了解有关现代NLP和深度学习的更多信息,请确保关注我以获取即将发表的文章的更新:)

翻译自: https://towardsdatascience.com/lstm-text-classification-using-pytorch-2c6c657f8fc0

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用Pytorch的LSTM文本分类 的相关文章

  • 将 CNN 的输出传递给 BILSTM

    我正在开发一个项目 其中我必须将 CNN 的输出传递给双向 LSTM 我创建了如下模型 但它抛出 不兼容 错误 请让我知道哪里出了问题以及如何解决这个问题 model Sequential model add Conv2D filters
  • Tensorflow:如何使用dynamic_rnn从LSTMCell获取中间细胞状态(c)?

    默认情况下 函数dynamic rnn仅输出隐藏状态 称为m 对于每个时间点可以通过如下方式获得 cell tf contrib rnn LSTMCell 100 rnn outputs tf nn dynamic rnn cell inp
  • 我可以使用逻辑索引或索引列表对张量进行切片吗?

    我正在尝试使用列上的逻辑索引对 PyTorch 张量进行切片 我想要与索引向量中的 1 值相对应的列 切片和逻辑索引都是可能的 但是它们可以一起吗 如果是这样 怎么办 我的尝试不断抛出无用的错误 类型错误 使用 ByteTensor 类型的
  • 在 Tensorflow 中检索 LSTM 序列的最后一个值

    我有不同长度的序列 想在 Tensorflow 中使用 LSTM 进行分类 对于分类 我只需要每个序列最后一个时间步长的 LSTM 输出 max length 10 n dims 2 layer units 5 input tf place
  • 如何获取基于Keras的LSTM模型中每个epoch的一层权重矩阵?

    我有一个基于 Keras 的简单 LSTM 模型 X train X test Y train Y test train test split input labels test size 0 2 random state i 10 X t
  • 在pytorch张量中过滤数据

    我有一个张量X like 0 1 0 5 1 0 0 1 2 0 我想实现一个名为的函数filter positive 它可以将正数据过滤成新的张量并返回原始张量的索引 例如 new tensor index filter positive
  • torch.mm、torch.matmul 和 torch.mul 有什么区别?

    阅读完 pytorch 文档后 我仍然需要帮助来理解之间的区别torch mm torch matmul and torch mul 由于我不完全理解它们 所以我无法简明地解释这一点 B torch tensor 1 1207 0 3137
  • 下载变压器模型以供离线使用

    我有一个训练有素的 Transformer NER 模型 我想在未连接到互联网的机器上使用它 加载此类模型时 当前会将缓存文件下载到 cache 文件夹 要离线加载并运行模型 需要将 cache 文件夹中的文件复制到离线机器上 然而 这些文
  • PyTorch 中复数矩阵的行列式

    有没有办法在 PyTorch 中计算复矩阵的行列式 torch det未针对 ComplexFloat 实现 不幸的是 目前尚未实施 一种方法是实现您自己的版本或简单地使用np linalg det 这是一个简短的函数 它计算我使用 LU
  • 使 CUDA 内存不足

    我正在尝试训练网络 但我明白了 我将批量大小设置为 300 并收到此错误 但即使我将其减少到 100 我仍然收到此错误 更令人沮丧的是 在 1200 个图像上运行 10 epoch 大约需要 40 分钟 有什么建议吗 错了 我怎样才能加快这
  • pytorch 中的 autograd 可以处理同一模块中层的重复使用吗?

    我有一层layer in an nn Module并在一次中使用两次或多次forward步 这个的输出layer稍后输入到相同的layer pytorch可以吗autograd正确计算该层权重的梯度 def forward x x self
  • BatchNorm 动量约定 PyTorch

    Is the 批归一化动量约定 http pytorch org docs master modules torch nn modules batchnorm html 默认 0 1 与其他库一样正确 例如Tensorflow默认情况下似乎
  • 将 Keras (Tensorflow) 卷积神经网络转换为 PyTorch 卷积网络?

    Keras 和 PyTorch 使用不同的参数进行填充 Keras 需要输入字符串 而 PyTorch 使用数字 有什么区别 如何将一个转换为另一个 哪些代码在任一框架中获得相同的结果 PyTorch 还采用参数 in channels o
  • Pytorch Tensor 如何获取元素索引? [复制]

    这个问题在这里已经有答案了 我有 2 个名为x and list它们的定义如下 x torch tensor 3 list torch tensor 1 2 3 4 5 现在我想获取元素的索引x from list 预期输出是一个整数 2
  • Pytorch“展开”等价于 Tensorflow [重复]

    这个问题在这里已经有答案了 假设我有大小为 50 50 的灰度图像 在本例中批量大小为 2 并且我使用 Pytorch Unfold 函数 如下所示 import numpy as np from torch import nn from
  • Tensorflow 的 LSTM 输入

    I m trying to create an LSTM network in Tensorflow and I m lost in terminology basics I have n time series examples so X
  • LSTM 批次与时间步

    我按照 TensorFlow RNN 教程创建了 LSTM 模型 然而 在这个过程中 我对 批次 和 时间步长 之间的差异 如果有的话 感到困惑 并且我希望得到帮助来澄清这个问题 教程代码 见下文 本质上是根据指定数量的步骤创建 批次 wi
  • 在Pytorch中计算欧几里得范数..理解和实现上的麻烦

    我见过另一个 StackOverflow 线程讨论计算欧几里德范数的各种实现 但我很难理解特定实现的原因 如何工作 该代码可以在 MMD 指标的实现中找到 https github com josipd torch two sample b
  • Fine-Tuning DistilBertForSequenceClassification:不是学习,为什么loss没有变化?权重没有更新?

    我对 PyTorch 和 Huggingface transformers 比较陌生 并对此尝试了 DistillBertForSequenceClassificationKaggle 数据集 https www kaggle com c
  • Caffe 的 LSTM 模块

    有谁知道 Caffe 是否有一个不错的 LSTM 模块 我从 russel91 的 github 帐户中找到了一个 但显然包含示例和解释的网页消失了 以前是http apollo deepmatter io http apollo deep

随机推荐

  • 各种类型的Writable

    各种类型的Writable xff08 Text ByteWritable NullWritable ObjectWritable GenericWritable ArrayWritable MapWritable SortedMapWri
  • C++ strtok的用法

    size 61 large align 61 center strtok的用法 align size 函数原型 xff1a char strtok char s char delim 函数功能 xff1a 把字符串s按照字符串delim进行
  • 读《遇见未知的自己》笔记

    为什么我不快乐 xff1f 为什么我不能拥有自己想要的生活 xff1f 此刻屏幕前的你 是否想过 xff0c 自己为什么会出现这种情况呢 xff1f 张德芬在 遇见未知的自己 一书给出了解释 xff1a 我们人类所有受苦的根源就是来自不清楚
  • PX4飞控问题汇总

    接触PX4飞控代码一年多了 xff0c 代码都是模块化 开发起来比APM的方便 xff0c 使用过程中也出现过各种怪异问题 xff0c 用的硬件是V5 nano 和V5 43 xff0c 测试的代码版本是1 9和1 10 今天总结一下遇到过
  • Sumo 搭建交叉路口交通流仿真平台

    Sumo安装 注意事项 xff1a 需要工具的使用需要环境变量的设置 需要包含文件Sumo安装路径下的bin和tools Sumo配置文件 Sumo中项目的配置文件的组成如下所示 节点文件 图 1 节点及边的拓扑图 Node的属性主要有id
  • OpenWRT 各种烧录方式及量产(三)

    界面烧录 不更新uboot 电脑连接WIFI xff08 或者通过网线连接电脑与路由器 xff09 通过浏览器访问路由器管理界面 xff0c 进行升级 注意不要断电 xff01 xff01 xff01 xff08 断电只能通过tftp方式恢
  • 华为手机root

    首先手机已解锁 xff42 xff4c 此方法针对 华为手机 可使用 xff0c 其他手机没有测试 xff0c 但应该也可以 官方的twrp没有对mate xff19 进行配适 xff0c 可以使用奇兔 twrp 提取码 ax6d 如果你没
  • 阿里云ubuntu 16.04 Server配置方案 2 远程控制桌面

    通过远程控制 xff0c 更好的管理服务器 1 XRDP远程控制 为了更好的远程管理 xff0c linux一般情况都用VNC进行远程连接 xff0c 如 TightVNC X11VNC ReadVNC等 Xrdp 是开放原始码的远端桌面通
  • 自顶向下(top down)简介

    无论是在实际生活中还是在学术问题上 xff0c 复杂的问题比比皆是 xff0c 当我们对此类问题毫无头绪的时候 xff0c 自顶向下 xff08 top down xff09 为我们提供了一种可靠的解决方法 自顶向下法将复杂的大问题分解为相
  • SecureCRT图形界面(通过设置调用Xmanager - Passive程序)

    首先 xff0c 在服务器进行设置 如果服务器是图形化界面启动的 xff0c xhost 43 命令可以不用执行 root 64 test xhost 43 xhost unable to open display 34 34 设置disp
  • 一种GPS辅助的多方位相机的VIO——Slam论文阅读

    34 A GPS aided Omnidirectional Visual Inertial State Estimator in Ubiquitous Environments 34 论文阅读 这里写目录标题 34 A GPS aided
  • docker & LXC

    目录 一 LXC1 了解Docker的前生LXC2 LXC与docker的关系3 与传统虚拟化对比4 LXC部署4 1 安装LXC软件包和依赖包4 2 启动服务4 3 创建虚拟机 5 LXC常用命令 二 doker1 什么是docker2
  • curl命令总结

    curl no cache d Users Administrator Desktop curl 7 73 0 3 win64 mingw bin gt curl Iv http abc gkmang cn 8081 index php l
  • 使用FastJSON 对Map/JSON/String 进行互转

    前言 Fastjson是一个Java语言编写的高性能功能完善的JSON库 xff0c 由阿里巴巴公司团队开发的 1 主要特性 高性能 fastjson采用独创的算法 xff0c 将parse的速度提升到极致 xff0c 超过所有json库
  • ai面向分析_2020年面向企业的顶级人工智能平台

    ai面向分析 In the long term artificial intelligence and automation are going to be taking over so much of what gives humans
  • 回答问题人工智能源码_回答21个最受欢迎的人工智能问题

    回答问题人工智能源码 Artificial intelligence sets the stage for a new era of solutions to be made with computers It allows us to s
  • 人工智能药物设计_用AI革新药物安全

    人工智能药物设计 介绍 Introduction Advances in the life sciences have brought about a transformative impact on healthcare with lif
  • 数据集分为训练验证测试_将数据集分为训练集,验证集和测试集

    数据集分为训练验证测试 测试我们的模型 Testing Our Model Supervised machine learning algorithms are amazing tools capable of making predict
  • 深度学习 场景识别_使用深度学习进行自然场景识别

    深度学习 场景识别 Recognizing the environment in one glance is one of the human brain s most accomplished deeds While the tremen
  • 使用Pytorch的LSTM文本分类

    Photo by Christopher Gower on Unsplash Christopher Gower在 Unsplash上的 照片 介绍 Intro Welcome to this tutorial This t