Pytorch实战笔记(1)——BiLSTM 实现情感分析

2023-10-27

本文展示的是使用 Pytorch 构建一个 BiLSTM 来实现情感分析。本文的架构是第一章详细介绍 BiLSTM,第二章粗略介绍 BiLSTM(就是说如果你想快速上手可以跳过第一章),第三章是核心代码部分。

1. BiLSTM的详细介绍

坦白的说,其实我也不懂 LSTM,但是我这里还是尽我最大的可能解释这个模型。这里我就盗个图 [1](懒得自己画了,而且感觉好像他也是盗的李宏毅老师课件的图)。
LSTM
简单来说,LSTM 在每个时刻的输入都是由该时刻输入的序列信息 X t X^t Xt 与上一时刻的隐藏状态 h t − 1 h^{t-1} ht1 通过四种不同的非线性变化映射而成,分别为:

  1. 遗忘门控信号:遗忘门控信号 z f z^f zf 的计算公式如下:
    z f = s i g m o i d ( W f [ X t ; h t − 1 ] ) , z^f = {\rm sigmoid}(W^f\left[ X^t; h^{t-1} \right]), zf=sigmoid(Wf[Xt;ht1]),
    其中, [ X t ; h t − 1 ] [X^t;h^{t-1}] [Xt;ht1] 是将 X t X^t Xt h t − 1 h^{t-1} ht1 拼接起来; W f W^f Wf 是权重; S i g m o i d ( ⋅ ) {\rm Sigmoid}(\cdot) Sigmoid() 是 Sigmoid 激活函数,用于将数据映射到 (0, 1) 的区间范围内。
  2. 记忆门控信号:记忆门控信号 z i z^i zi 的计算公式如下:
    z i = s i g m o i d ( W i [ X t ; h t − 1 ] ) . z^i={\rm sigmoid}(W^i\left[ X^t; h^{t-1} \right]). zi=sigmoid(Wi[Xt;ht1]).
  3. 输出门控信号:输出门控信号 z o z^o zo 的计算公式如下:
    z o = s i g m o i d ( W o [ X t ; h t − 1 ] ) . z^o = {\rm sigmoid}(W^o\left[ X^t; h^{t-1} \right]). zo=sigmoid(Wo[Xt;ht1]).
  4. 当前时刻的信息:当前时刻的信息 z z z 的计算公式如下:
    z = t a n h ( W [ X t ; h t − 1 ] ) , z = {\rm tanh}(W\left[ X^t; h^{t-1} \right]), z=tanh(W[Xt;ht1]),
    其中, t a n h ( ⋅ ) {\rm tanh}(\cdot) tanh() 是将数据放缩到 (-1, 1) 的区间内。

通过以上的公式,我们可以发现, z f , z i , z o z^f, z^i, z^o zf,zi,zo 都是 (0, 1) 区间的值,而 z z z(-1, 1) 区间的值。

接着就是 LSTM 的内部计算公式,即图上所示的那几个,分别为:

  1. 当前时刻的细胞状态 c t c^t ct 的计算公式如下:
    c t = z f ⊙ c t − 1 + z i ⊙ z , c^t = z^f \odot c^{t-1} + z^i \odot z, ct=zfct1+ziz,
    其中, ⊙ \odot 是哈达玛积,即矩阵元素对位相乘,但是需要注意的是,哈达玛积数学上不可解释,但是跑出来效果好
  2. 当前时刻的隐藏状态 h t h^t ht 的计算公式如下:
    h t = z o ⊙ t a n h ( c t ) . h^t = z^o \odot {\rm tanh} (c^t). ht=zotanh(ct).
  3. 当前时刻的输出 y t y^t yt 的计算公式如下:
    y t = σ ( W ′ h t ) . y^t = \sigma (W'h^t). yt=σ(Wht).

公式列举完后,这里说一下我对这些公式的理解(不一定是对的哈)。

  • 首先是 c t c^t ct 的计算。我们看到 c t c^t ct 的计算分为了两部分。一部分是 z f ⊙ c t − 1 z^f \odot c^{t-1} zfct1,这一部分是 LSTM 的遗忘过程,由于刚刚提到, z f z^f zf 是 (0, 1) 区间范围内的值,同时,sigmoid 函数是一个无限趋近于 0 或者 1 的函数,也就是说, c t − 1 c^{t-1} ct1 无论怎样,都会有些数据被遗弃,始终不会完全保留下来,这也就模拟了一个遗忘的过程。同理,对于记忆部分 z i ⊙ z z^i \odot z ziz,这一步也是只会保留部分 z z z 的信息,也就模拟了人的记忆是由些许失真的过程。同时,两者相加后,那么就代表了当前细胞状态 c t c^t ct 中保留的是没有被遗忘掉的过去的信息和当前时刻被记忆下来的信息
  • 接着是 h t h^t ht 的计算。首先是为什么要先对 c t c^t ct 做一次 t a n h ( ⋅ ) {\rm tanh}(\cdot) tanh(),这是因为由于 c t c^t ct 的区间范围不是 (-1, 1),因为 z i ⊙ z z^i \odot z ziz 的区间范围是 (-1, 1),再与 z f ⊙ c t − 1 z^f \odot c^{t-1} zfct1 相加,那么 c t c^t ct 的范围就有可能超出 (-1, 1),所以先用一个 tanh 将数值给放缩到 (-1, 1) 内。接着再与 z o z^o zo 做一次哈达玛积后,得到的隐藏状态就是 (-1, 1) 的数据,那么该数据放到后续模块中,就可以代表当前时刻的输入是正的还是负的,同时有多大。
  • 最后就是 y t y^t yt 的计算,实际上这就是个全连接层,将隐藏状态进行一次映射,再通过一个非线性变化的激活函数。

2. BiLSTM 的简单介绍

当然,其实你没看懂上面的部分也不重要,从使用的角度上来讲,会用就行了,就像你用手机,你不会去搞懂里面每个元器件是怎么做出来的,每个 APP 是怎么写出来的;就像你去打篮球,也不用梳个中分,穿个背带裤。

那么对于 BiLSTM,你需要了解的是什么?

  • 首先,这是一个序列模型,它接受一个序列的输入,并且输出这个序列的信息。对于序列中每个位置的输出,它会包含该位置的信息以及之前的信息。就是说 LSTM 能够捕获到位置 t t t 及其之前位置的信息。而对于 BiLSTM 的话,则能捕获到 t t t 的双向信息。
  • 如果是 BiLSTM,它的每个位置的输出,是前向 L S T M → \overrightarrow{LSTM} LSTM 的输出 y → \overrightarrow{y} y 与反向 L S T M ← \overleftarrow{LSTM} LSTM 的输出 y ← \overleftarrow{y} y 拼接在一起的, [ y → ; y ← ] [\overrightarrow{y}; \overleftarrow{y}] [y ;y ]。所以假设你设置 LSTM 的隐藏层维度为 128,那么单向 LSTM 的输出维度是 128,但是双向就是 256 (128*2).
  • 但是虽然说 LSTM 好像大概可能也许 maybe possibly 能够捕获长距离依赖信息哈,毕竟 LSTM 的全称都是 Long Short-Term Memory,但是实际上这是 LSTM 的骗局,LSTM 并没有捕获长距离依赖信息的能力!LSTM 并没有捕获长距离依赖信息的能力!LSTM 并没有捕获长距离依赖信息的能力! 从数学上说,你经过这么多次 sigmoid,还能保留个啥?当然,在《An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling》这篇论文[2]中,作者用了大量的实验来说明了,LSTM 不仅并行计算能力差(因为要上一个时间步的信息才能计算下一个时间步,所以 LSTM 不是个并行系统),同时在它最吹嘘的长距离信息捕获能力上,都不如 CNN,所以以后在跑实验的时候,可以尝试使用 TextCNN 来试试,说不定效果比 BiLSTM 好(反正我做过的实验中 TextCNN 性能一般比 BiLSTM 高8-10个点)。

3. BiLSTM 实现情感分析

在本博客中仅介绍模型部分,详细代码见 github。

模型图如图所示:
模型图
具体而言,就是输入序列输入到一个双向 LSTM 中,并将双向 LSTM 的最后一个隐藏状态(即句向量)输入到一个全连接层(也可以说是分类器)中,输出最后的分类结果,具体模型的代码如下:

import torch.nn as nn

class BiLSTM_SA(nn.Module):

    def __init__(self, embed, config):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(embed, freeze=False)
        self.LSTM = nn.LSTM(config.embed_size, config.lstm_hidden_size,
                            num_layers=config.num_layers, batch_first=True,
                            bidirectional=True)
        # 因为是双向 LSTM, 所以要乘2
        self.ffn = nn.Linear(config.lstm_hidden_size * 2,
                             config.dense_hidden_size)
        self.relu = nn.ReLU()
        self.classifier = nn.Linear(config.dense_hidden_size,
                                    config.num_outputs)

    def forward(self, inputs):
        # shape: (batch_size, max_seq_length, embed_size)
        embed = self.embedding(inputs)
        # shape: (batch_size, max_seq_length, lstm_hidden_size * 2)
        lstm_hidden_states, _ = self.LSTM(embed)
        # LSTM 的最后一个时刻的隐藏状态, 即句向量
        # shape: (batch, lstm_hidden_size * 2)
        lstm_hidden_states = lstm_hidden_states[:, -1, :]
        # shape: (batch, dense_hidden_size)
        ffn_outputs = self.relu(self.ffn(lstm_hidden_states))
        # shape: (batch, num_outputs)
        logits = self.classifier(ffn_outputs)

        return logits

全连接层我采用了两个全连接层,一个将维度从 256 压缩到 128,另外一个是分类器。

这里有个小细节要注意一下,通常在论文的公式里面,我们都会看到别人写的分类器的公式如下: y ^ = S o f t m a x ( W h + b ) \hat{y} = {\rm Softmax}(Wh+b) y^=Softmax(Wh+b),有个 softmax 的激活函数,但是在 pytorch 中实际不需要,就比如我代码里面是写的:

logits = self.classifier(ffn_outputs)

而不是:

y_hat = self.softmax(self.classifier(ffn_outputs))

这是因为如果你后面选用交叉熵作为损失函数,而且调用的是torch中的 nn.CrossEntropyLoss(),那么就没必要在输出的时候用 softmax,这是因为 nn.CrossEntropyLoss() 中自带有 softmax 操作,虽然这样对你的分类结果不会产生任何影响,但是你得损失会变得很大。

最后的测试集的实验结果为:

test loss 0.419664 | test accuracy 0.813760 | test precision 0.804267 | test recall 0.829360 | test F1 0.816621

参考

[1] 陈诚. 人人都能看懂的LSTM[EB/OL]. https://zhuanlan.zhihu.com/p/32085405, 2018
[2] Shaojie Bai, J. Zico Kolter, Vladlen Koltun. An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling [EB/OL]. https://arxiv.org/abs/1803.01271, 2018

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch实战笔记(1)——BiLSTM 实现情感分析 的相关文章

随机推荐

  • 嵌入式数据结构(查找)(哈希表)

    嵌入式自学第十二天 1 2 代码实现 list c define CRT SECURE NO WARNINGS include list h include stdlib h include string h include stdio h
  • 2023版ChatGPT 能用来帮助谈恋爱吗,如果用 ChatGPT 来谈恋爱会发生什么?

    大家好啊 有没有和ChatGPT聊过天的 5G高手 们呢 ChatGPT是美国AI公司OpenAI开发出来的一款人工智能聊天机器 会通过学习和理解自然语言来跟我们聊天 不管你想聊啥 从诗歌到哲学 它都可以让你感觉像在跟一个超水平牛逼闪耀的老
  • 静态Web服务器-返回固定页面数据

    学习目标 能够写出组装固定页面数据的响应报文 1 开发自己的静态Web服务器 实现步骤 编写一个TCP服务端程序 获取浏览器发送的http请求报文数据 读取固定页面数据 把页面数据组装成HTTP响应报文数据发送给浏览器 HTTP响应报文数据
  • RISC-V IDE MRS使用笔记(九):使用WCH-LinkW实现无线下载、调试

    RISC V IDE MRS使用笔记 九 使用WCH LinkW实现无线下载 调试 1 硬件环境 WCH LinkW无线仿真调试器2块 CH32V307开发板1块 2 软件环境 MRS V185版本 3 无线仿真调试器配对与连接 通过WCH
  • Beam技术

    一 简介 在大数据处理中 流计算技术包括Storm Spark Streaming和Flink 实际应用中还包括Storm Trident Samura以及Google MillWhell和亚马逊的Kinesis等技术 离线处理基本上都基于
  • vue3 ts页面赋值发现不生效

    在 onMounted 周期中 dataForm value route params data 赋值不生效
  • MOOC浙大数据结构课后题记录——PTA数据结构题目集(全)

    目录 第一周 最大子列和算法 二分查找 01 复杂度1 最大子列和问题 20分 01 复杂度2 Maximum Subsequence Sum 25分 01 复杂度3 二分查找 20分 第二周 线性结构 02 线性结构1 两个有序链表序列的
  • element 中 表格设置滚动条

    element 中 表格设置滚动条 表格设置滚动条 1 使用header 直接设个表格的高度 就会为表填加上表格 2 表格自定义使用css样式添加滚动条 样式一 deep el table body wrapper height 200px
  • 进行页面跳转时,不将请求参数显示在url的方法

    在SSM项目中 ajax不能实现跳转 反正我是不知道 href会将传参显示在url上 但有些人不想在页面跳转时 将传参显示在url中 反正我不想 就比如这种 有以下两种方法 将传参数放在session中 用js创建动态form表单 页面跳转
  • html2canvas给指定区域添加满屏水印

    效果图如下 直接贴上代码 下载插件 npm i html2canvas
  • 劲爆!java架构师百度网盘

    第一份资料 Kafka实战笔记 Kafka入门 为什么选择Kafka Karka的安装 管理和配置 Kafka的集群 第一个Kafka程序 afka的生产者 Kafka的消费者 深入理解Kafka 可靠的数据传递 Spring和Kalka的
  • 2021-06-20

    conda换源后安装包报错 只搜索第一个源 为了安装qiskit包 首先给conda增加了多个源 如下图 而后在安装qiskit包时 conda报错 但是提示只有第一个404 其他源没有提示 所以问题是 conda安装时是否遍历了所有已添加
  • 英文数字汇总,KMGT,毫微纳

    以5MB为例 现在的习惯是读作 五兆 可是 兆的本意是万亿 在这里却成了百万 5MB的标准读法应该是 五百万字节 网络的带宽 100M 常读作 一百兆 若读作 一百百万 会有人反对 可1000 不是也读作 一千千米 吗 还有气象预报的 五百
  • 讯飞星火大模型申请及测试:诚意满满

    大家好 我是可夫小子 关注AIGC 读书和自媒体 解锁更多ChatGPT AI绘画玩法 加 keeepdance 备注 chatgpt 拉你进群 最近国产大模型跟下饺子似 隔几天就发布一个 厂家发布得起劲 大家看多了也麻木了 而且无一例外都
  • 计算图像帧的平均灰度值

    2016 7 15 在处理视频中 需要对视频流中的图像帧进行区分 分离出其中的亮暗帧图像 区分亮暗图像 是依据图像的平均灰度值来实现的 我们知道 对于一幅灰度图像 每个像素点的灰度值可以通过指针来访问 i j 处的灰度值 img gt im
  • 运行项目出现java.lang.ClassNotFoundException: org.springframework.web.util.IntrospectorCleanupListener

    java lang ClassNotFoundException org springframework web util IntrospectorCleanupListener at org apache catalina loader
  • spyder 出现ValueError: PyCapsule_GetPointer called with incorrect name

    我太难了 经过一堆试验 终于出了坑 总的来说 1 卸载pyqt5 命令 pip uninstall pyqt5 2 重新安装低版本的pyqt5 命令 pip install PyQt5 5 10 1 如果出现pip vendor urlli
  • meta-compilation

    RPython GraalVM 转载于 https my oschina net crcc blog 2239743
  • k8s的pv和pvc创建

    NFS使用PV和PVC 1 配置nfs存储 2 定义PV 实现 下图的pv和pvc测试 pv的定义 这里定义5个PV 并且定义挂载的路径以及访问模式 还有PV划分的大小 vim pv yaml apiVersion v1 kind Pers
  • Pytorch实战笔记(1)——BiLSTM 实现情感分析

    本文展示的是使用 Pytorch 构建一个 BiLSTM 来实现情感分析 本文的架构是第一章详细介绍 BiLSTM 第二章粗略介绍 BiLSTM 就是说如果你想快速上手可以跳过第一章 第三章是核心代码部分 目录 1 BiLSTM的详细介绍