【Pytorch】利用Pytorch+GRU实现情感分类(附源码)

2023-10-30

在这个实验中,数据的预处理过程以及网络的初始化及模型的训练等过程同前文《利用Pytorch+LSTM实现中文新闻分类》,具体这里就不再重复解释了。如果有读者在对数据集的预处理过程中有疑问,请参考我的其他博客,里面对这些方法均有我的一些个人体会,这里直接贴上源码。

## 导入本章所需要的模块
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import time
import copy


import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torchtext import data
from torchtext.vocab import Vectors
## 使用torchtext库进行数据准备
# 定义文件中对文本和标签所要做的操作

## 定义文本切分方法,直接使用空格切分即可
mytokenize = lambda x: x.split()
TEXT = data.Field(sequential=True, tokenize=mytokenize, 
                  include_lengths=True, use_vocab=True,
                  batch_first=True, fix_length=200)
LABEL = data.Field(sequential=False, use_vocab=False, 
                   pad_token=None, unk_token=None)
## 对所要读取的数据集的列进行处理
train_test_fields = [
    ("label", LABEL), # 对标签的操作
    ("text", TEXT) # 对文本的操作
]
## 读取数据
traindata,testdata = data.TabularDataset.splits(
    path="./data/chap6", format="csv", 
    train="imdb_train.csv", fields=train_test_fields, 
    test = "imdb_test.csv", skip_header=True
)
# ## 加载预训练的词向量和构建词汇表
## Vectors导入预训练好的词向量文件
vec = Vectors("glove.6B.100d.txt", "./data")
# ## 使用训练集构建单词表,导入预先训练的词嵌入
TEXT.build_vocab(traindata,max_size=20000,vectors = vec)
# TEXT.build_vocab(traindata,max_size=20000)
LABEL.build_vocab(traindata)
## 训练集、验证集和测试集定义为迭代器
BATCH_SIZE = 32
train_iter = data.BucketIterator(traindata,batch_size = BATCH_SIZE)
test_iter = data.BucketIterator(testdata,batch_size = BATCH_SIZE)
##  获得一个batch的数据,对数据进行内容进行介绍
for step, batch in enumerate(train_iter):  
    textdata,target = batch.text[0],batch.label
    if step > 0:
        break
class GRUNet(nn.Module):
    def __init__(self, vocab_size,embedding_dim, hidden_dim, layer_dim, output_dim):
        """
        vocab_size:词典长度
        embedding_dim:词向量的维度
        hidden_dim: GRU神经元个数
        layer_dim: GRU的层数
        output_dim:隐藏层输出的维度(分类的数量)
        """
        super(GRUNet, self).__init__()
        self.hidden_dim = hidden_dim ## GRU神经元个数
        self.layer_dim = layer_dim ## GRU的层数
        ## 对文本进行词项量处理
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # LSTM + 全连接层
        self.gru = nn.GRU(embedding_dim, hidden_dim, layer_dim,
                          batch_first=True)
        self.fc1 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            torch.nn.Dropout(0.5),
            torch.nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x):
        embeds = self.embedding(x)
        # r_out shape (batch, time_step, output_size)
        # h_n shape (n_layers, batch, hidden_size) 
        r_out, h_n = self.gru(embeds, None)   # None 表示初始的 hidden state 为0
        # 选取最后一个时间点的out输出
        out = self.fc1(r_out[:, -1, :]) 
        return out
    
## 初始化网络
vocab_size = len(TEXT.vocab)
embedding_dim = vec.dim #  词向量的维度
# embedding_dim = 128 #  词向量的维度
hidden_dim = 128
layer_dim = 1
output_dim = 2
grumodel = GRUNet(vocab_size, embedding_dim, hidden_dim, layer_dim, output_dim)
## 将导入的词项量作为embedding.weight的初始值
grumodel.embedding.weight.data.copy_(TEXT.vocab.vectors)
## 将无法识别的词'<unk>', '<pad>'的向量初始化为0
UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]
grumodel.embedding.weight.data[UNK_IDX] = torch.zeros(vec.dim)
grumodel.embedding.weight.data[PAD_IDX] = torch.zeros(vec.dim)
## 定义网络的训练过程函数
def train_model(model,traindataloader, testdataloader,criterion, 
                optimizer,num_epochs=25):
    """
    model:网络模型;traindataloader:训练数据集;valdataloader:验证数据集;
    criterion:损失函数;optimizer:优化方法;
    num_epochs:训练的轮数,scheduler:学习率变化器
    """
    train_loss_all = []
    train_acc_all = []
    test_loss_all = []
    test_acc_all = []
    learn_rate = []
    since = time.time()
    ## 设置等间隔调整学习率,每隔step_size个epoch,学习率缩小10倍
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 
    for epoch in range(num_epochs):
        learn_rate.append(scheduler.get_lr()[0])
        print('-' * 10)
        print('Epoch {}/{},Lr:{}'.format(epoch, num_epochs - 1,learn_rate[-1]))
        # 每个epoch有两个阶段,训练阶段和验证阶段
        train_loss = 0.0
        train_corrects = 0
        train_num = 0
        test_loss = 0.0
        test_corrects = 0
        test_num = 0
        model.train() ## 设置模型为训练模式
        for step,batch in enumerate(traindataloader):
            textdata,target = batch.text[0],batch.label
            out = model(textdata)
            pre_lab = torch.argmax(out,1) # 预测的标签
            loss = criterion(out, target) # 计算损失函数值
            optimizer.zero_grad()        
            loss.backward()       
            optimizer.step()  
            train_loss += loss.item() * len(target)
            train_corrects += torch.sum(pre_lab == target.data)
            train_num += len(target)
        ## 计算一个epoch在训练集上的损失和精度
        train_loss_all.append(train_loss / train_num)
        train_acc_all.append(train_corrects.double().item()/train_num)
        print('{} Train Loss: {:.4f}  Train Acc: {:.4f}'.format(
            epoch, train_loss_all[-1], train_acc_all[-1]))
        scheduler.step()  ## 更新学习率
        ## 计算一个epoch的训练后在验证集上的损失和精度
        model.eval() ## 设置模型为训练模式评估模式 
        for step,batch in enumerate(testdataloader):
            textdata,target = batch.text[0],batch.label
            out = model(textdata)
            pre_lab = torch.argmax(out,1)
            loss = criterion(out, target)   
            test_loss += loss.item() * len(target)
            test_corrects += torch.sum(pre_lab == target.data)
            test_num += len(target)
        ## 计算一个epoch在训练集上的损失和精度
        test_loss_all.append(test_loss / test_num)
        test_acc_all.append(test_corrects.double().item()/test_num)
        print('{} Test Loss: {:.4f}  Test Acc: {:.4f}'.format(
            epoch, test_loss_all[-1], test_acc_all[-1]))
        
    train_process = pd.DataFrame(
        data={"epoch":range(num_epochs),
              "train_loss_all":train_loss_all,
              "train_acc_all":train_acc_all,
              "test_loss_all":test_loss_all,
              "test_acc_all":test_acc_all,
              "learn_rate":learn_rate})  
    return model,train_process
# 定义优化器
optimizer = optim.RMSprop(grumodel.parameters(), lr=0.003)  
loss_func = nn.CrossEntropyLoss()  # 交叉熵作为损失函数
## 对模型进行迭代训练,对所有的数据训练EPOCH轮
grumodel,train_process = train_model(
    grumodel,train_iter,test_iter,loss_func,optimizer,num_epochs=10)
## 输出结果保存和数据保存
torch.save(grumodel,"data/chap7/grumodel.pkl")
## 导入保存的模型
grumodel = torch.load("data/chap7/grumodel.pkl")
grumodel
## 保存训练过程
train_process.to_csv("data/chap7/grumodel_process.csv",index=False)
## 可视化模型训练过程中
plt.figure(figsize=(18,6))
plt.subplot(1,2,1)
plt.plot(train_process.epoch,train_process.train_loss_all,
         "r.-",label = "Train loss")
plt.plot(train_process.epoch,train_process.test_loss_all,
         "bs-",label = "Test loss")
plt.legend()
plt.xlabel("Epoch number",size = 13)
plt.ylabel("Loss value",size = 13)
plt.subplot(1,2,2)
plt.plot(train_process.epoch,train_process.train_acc_all,
         "r.-",label = "Train acc")
plt.plot(train_process.epoch,train_process.test_acc_all,
         "bs-",label = "Test acc")
plt.xlabel("Epoch number",size = 13)
plt.ylabel("Acc",size = 13)
plt.legend()
plt.show()
## 对测试集进行预测并计算精度
grumodel.eval() ## 设置模型为训练模式评估模式 
test_y_all = torch.LongTensor()
pre_lab_all = torch.LongTensor()
for step,batch in enumerate(test_iter):
    textdata,target = batch.text[0],batch.label.view(-1)
    out = grumodel(textdata)
    pre_lab = torch.argmax(out,1)
    test_y_all = torch.cat((test_y_all,target)) ##测试集的标签
    pre_lab_all = torch.cat((pre_lab_all,pre_lab))##测试集的预测标签

acc = accuracy_score(test_y_all,pre_lab_all)
print("在测试集上的预测精度为:",acc)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【Pytorch】利用Pytorch+GRU实现情感分类(附源码) 的相关文章

  • 使用 python requests 模块时出现 HTTP 503 错误

    我正在尝试发出 HTTP 请求 但当前可以从 Firefox 浏览器访问的网站响应 503 错误 代码本身非常简单 在网上搜索一番后我添加了user Agent请求参数 但也没有帮助 有人能解释一下如何消除这个 503 错误吗 顺便说一句
  • 安装了 32 位的 Python,显示为 64 位

    我需要运行 32 位版本的 Python 我认为这就是我在我的机器上运行的 因为这是我下载的安装程序 当我重新运行安装程序时 它会将当前安装的 Python 版本称为 Python 3 5 32 位 然而当我跑步时platform arch
  • 处理 Python 行为测试框架中的异常

    我一直在考虑从鼻子转向行为测试 摩卡 柴等已经宠坏了我 到目前为止一切都很好 但除了以下之外 我似乎无法找出任何测试异常的方法 then It throws a KeyError exception def step impl contex
  • Pandas 日期时间格式

    是否可以用零后缀表示 pd to datetime 似乎零被删除了 print pd to datetime 2000 07 26 14 21 00 00000 format Y m d H M S f 结果是 2000 07 26 14
  • 使用Python请求登录Google帐户

    在多个登录页面上 需要谷歌登录才能继续 我想用requestspython 中的库以便让我自己登录 通常这很容易使用requests库 但是我无法让它工作 我不确定这是否是由于 Google 做出的一些限制 也许我需要使用他们的 API 或
  • 如何使用 Pandas、Numpy 加速 Python 中的嵌套 for 循环逻辑?

    我想检查一下表的字段是否TestProject包含了Client端传入的参数 嵌套for循环很丑陋 有什么高效简单的方法来实现吗 非常感谢您的任何建议 def test parameter a list parameter b list g
  • datetime.datetime.now() 返回旧值

    我正在通过匹配日期查找 python 中的数据存储条目 我想要的是每天选择 今天 的条目 但由于某种原因 当我将代码上传到 gae 服务器时 它只能工作一天 第二天它仍然返回相同的值 例如当我上传代码并在 07 01 2014 执行它时 它
  • 从Python中的字典列表中查找特定值

    我的字典列表中有以下数据 data I versicolor 0 Sepal Length 7 9 I setosa 0 I virginica 1 I versicolor 0 I setosa 1 I virginica 0 Sepal
  • Python,将函数的输出重定向到文件中

    我正在尝试将函数的输出存储到Python中的文件中 我想做的是这样的 def test print This is a Test file open Log a file write test file close 但是当我这样做时 我收到
  • 在Python中检索PostgreSQL数据库的新记录

    在数据库表中 第二列和第三列有数字 将会不断添加新行 每次 每当数据库表中添加新行时 python 都需要不断检查它们 当 sql 表中收到的新行数低于 105 时 python 应打印一条通知消息 警告 数量已降至 105 以下 另一方面
  • 如何使用 Mysql Python 连接器检索二进制数据?

    如果我在 MySQL 中创建一个包含二进制数据的简单表 CREATE TABLE foo bar binary 4 INSERT INTO foo bar VALUES UNHEX de12 然后尝试使用 MySQL Connector P
  • 如何通过 TLS 1.2 运行 django runserver

    我正在本地 Mac OS X 机器上测试 Stripe 订单 我正在实现这段代码 stripe api key settings STRIPE SECRET order stripe Order create currency usd em
  • 不同编程语言中的浮点数学

    我知道浮点数学充其量可能是丑陋的 但我想知道是否有人可以解释以下怪癖 在大多数编程语言中 我测试了 0 4 到 0 2 的加法会产生轻微的错误 而 0 4 0 1 0 1 则不会产生错误 两者计算不平等的原因是什么 在各自的编程语言中可以采
  • 从 NumPy ndarray 中选择行

    我只想从 a 中选择某些行NumPy http en wikipedia org wiki NumPy基于第二列中的值的数组 例如 此测试数组的第二列包含从 1 到 10 的整数 gt gt gt test numpy array nump
  • Pandas 将多行列数据帧转换为单行多列数据帧

    我的数据框如下 code df Car measurements Before After amb temp 30 268212 26 627491 engine temp 41 812730 39 254255 engine eff 15
  • 为什么 Pickle 协议 4 中的 Pickle 文件是协议 3 中的两倍,而速度却没有任何提升?

    我正在测试 Python 3 4 我注意到 pickle 模块有一个新协议 因此 我对 2 个协议进行了基准测试 def test1 pickle3 open pickle3 wb for i in range 1000000 pickle
  • 如何在 pygtk 中创建新信号

    我创建了一个 python 对象 但我想在它上面发送信号 我让它继承自 gobject GObject 但似乎没有任何方法可以在我的对象上创建新信号 您还可以在类定义中定义信号 class MyGObjectClass gobject GO
  • 如何解决 PDFBox 没有 unicode 映射错误?

    我有一个现有的 PDF 文件 我想使用 python 脚本将其转换为 Excel 文件 目前正在使用PDFBox 但是存在多个类似以下错误 org apache pdfbox pdmodel font PDType0Font toUnico
  • 模拟pytest中的异常终止

    我的多线程应用程序遇到了一个错误 主线程的任何异常终止 例如 未捕获的异常或某些信号 都会导致其他线程之一死锁 并阻止进程干净退出 我解决了这个问题 但我想添加一个测试来防止回归 但是 我不知道如何在 pytest 中模拟异常终止 如果我只
  • 使用 z = f(x, y) 形式的 B 样条方法来拟合 z = f(x)

    作为一个潜在的解决方案这个问题 https stackoverflow com questions 76476327 how to avoid creating many binary switching variables in gekk

随机推荐

  • PDF 的各种操作,我用 Python 来实现(附网站和操作指导)

    导言 PDF 处理是日常工作中的常见需求 包括 PDF 合并 删除 提取等 更复杂的任务如 将 PDF 转换成 图像 下面通过几个简单的例子和一份代码 帮助大家解决上面的需求 操作非常简单 在文末我会提供一份源码和一个神奇的 PDF 处理网
  • outside of class is not definition

    有一种可能的情况 You have semicolons at the end of all your function definitions making the compiler think they re declarations
  • 解决Base64报java.lang.IllegalArgumentException: Illegal base64 character 20

    报错 java lang IllegalArgumentException Illegal base64 character 20 原因 base64编码时使用加号 在URL传递时加号会被当成空格让base64字符串更改 服务器端解码出错
  • ROS主从机配置,并实现远程登陆

    第一步 主从机配置 首先确保主从机在同一个局域网中 1 编辑主机的bashrc文件 机器人平台 gedit bashrc 主机的bashrc文件添加如下的内容 export ROS MASTER URI http 主机的ip 11311 e
  • stm32F4 IAP实现原理讲解以及中断向量表的偏移

    一 IAP原理 IAP即是在应用编程 IAP 是用户自己的程序在运行过程中对User Flash 的部分区域进行烧写 目的是为了在产品发布后可以方便地通过预留的通信口对产 品中的固件程序进行更新升级 通常实现IAP 功能时 即用户程序运行中
  • 生命在于磨炼——连续两年参加4C大赛心得

    一 4C大赛简介 1 大赛简介 中国大学生计算机设计大赛 下面简称 大赛 是由教育部高等学校计算机类专业教学指导委员会 教育部高等学校软件工程专业教学指导委员会 教育部高等学校大学计算机课程教学指导委员会 教育部高等学校文科计算机基础教学指
  • 操作系统笔记五(Linux存储管理)

    1 Buddy内存管理算法 内部碎片就是已经被分配出去 能明确指出属于哪个进程 却不能被利用的内存空间 外部碎片指的是还没有被分配出去 不属于任何进程 但由于太小了无法分配给申请内存空间的新进程的内存空闲区域 目的 努力让内存分配与相邻内存
  • Task2_MySQL_basic

    MySQL表数据类型 用SQL语句创建表 创建MySQL数据表需要以下信息 表名 表字段名 定义每个表字段 语句解释 设定列类型 大小 约束 设定主键 用SQL语句向表中添加数据 语句解释 多种添加方式 指定列名 不指定列名 用SQL语句删
  • Ubuntu16.04下搭建LAMP环境

    Ubuntu16 04下搭建LAMP环境 Ubuntu16 04下搭建LAMP环境 1 安装 Apache2 2 重启 apache2 3 测试apache2是否安装成功 4 安装php7 5 测试php是否安装成功 6 安装mysql数据
  • 序列化与反序列化之Flatbuffers(一):初步使用

    序列化与反序列化之Flatbuffers 一 初步使用 一 前言 在MNN中 一个训练好的静态模型是经过Flatbuffers序列化之后保存在硬盘中的 这带来两个问题 1 为什么模型信息要序列化不能直接保存 2 其他框架如caffe和onn
  • 深度学习在目标视觉检测中的应用进展与展望

    前言 文章综述了深度学习在目标视觉检测中的应用进展与展望 首先对目标视觉检测的基本流程进行总结 并介绍了目标视觉检测研究常用的公共数据集 然后重点介绍了目前发展迅猛的深度学习方法在目标视觉检测中的最新应用进展 最后讨论了深度学习方法应用于目
  • ORAN专题系列-0: O-RAN快速索引

    专题一 O RAN的快速概述 ORAN专题系列 1 什么是开放无线接入网O RAN ORAN专题系列 1 什么是开放无线接入网O RAN 文火冰糖的硅基工坊的博客 CSDN博客 什么是oran ORAN专题系列 2 O RAN的系统架构 O
  • C和C++安全编码笔记:动态内存管理

    4 1 C内存管理 C标准内存管理函数 1 malloc size t size 分配size个字节 并返回一个指向分配的内存的指针 分配的内存未被初始化为一个已知值 2 aligned alloc size t alignment siz
  • Spring Aop自定义注解用在Controller层

    前提项目用的框架是SpringMVC 切面类 Aspect Component 把这个注掉是为了不让Spring中扫描 应该让SpringMVC扫描 public class SysLogAop Pointcut annotation co
  • 图像识别毕业设计 opencv实现植物识别算法系统 - python 深度学习

    文章目录 0 前言 2 相关技术 2 1 VGG Net模型 2 2 VGG Net在植物识别的优势 1 卷积核 池化核大小固定 2 特征提取更全面 3 网络训练误差收敛速度较快 3 VGG Net的搭建 3 1 Tornado简介 1 优
  • Maven项目的jdk版本修改

    Maven项目的jdk版本修改 修改的办法有以下三种 一 选择项目 gt 右键 gt build path Configure build path 选择旧的jre 1 5 gt remove删除 gt add Library 添加新的jr
  • Activity 工作流引擎

    Activiti工作流引擎使用详解 http blog csdn net m0 37327416 article details 71743368 Activity用户手册 http www mossle com docs activiti
  • SpringBoot笔记:SpringBoot 集成 Dataway(一)

    文章目录 1 什么是 Dataway 2 主打场景 3 技术架构 4 整合SpringBoot 4 1 maven 依赖 4 2 初始化脚本 4 3 整合 SpringBoot 5 Dataway 接口管理 6 Mybatis 语法支持 7
  • Kafka3.0.0版本——文件清理策略

    目录 一 文件清理策略 1 1 文件清理策略的概述 1 2 文件清理策略的官方文档 1 3 日志超过了设置的时间如何处理 1 3 1 delete日志删除 将过期数据删除 1 3 2 compact日志压缩 一 文件清理策略 1 1 文件清
  • 【Pytorch】利用Pytorch+GRU实现情感分类(附源码)

    在这个实验中 数据的预处理过程以及网络的初始化及模型的训练等过程同前文 利用Pytorch LSTM实现中文新闻分类 具体这里就不再重复解释了 如果有读者在对数据集的预处理过程中有疑问 请参考我的其他博客 里面对这些方法均有我的一些个人体会