中文新闻文本标题分类(基于飞桨、Text CNN)

2023-10-26

目录

一、设计方案概述

二、具体实现

三、结果及分析 

四、总结

一、设计方案概述

主要网络模型设计:

设计所使用网络模型为TextCNN,由于其本身就适用于短中句子,在标题分类这一方面应该能发挥其优势。

TextCNN是Yoon Kim在2014年提出的模型,开创了用CNN编码n-gram特征的先河

图1-1

模型结构如图,图像中的卷积都是二维的,而TextCNN则使用「一维卷积」,即filter_size * embedding_dim,有一个维度和embedding相等。这样就能抽取filter_size个gram的信息。以1个样本为例,整体的前向逻辑是:

对词进行embedding,得到[seq_length, embedding_dim]

用N个卷积核,得到N个seq_length-filter_size+1长度的一维feature map

对feature map进行max-pooling(因为是时间维度的,也称max-over-time pooling),得到N个1x1的数值,拼接成一个N维向量,作为文本的句子表示

将N维向量压缩到类目个数的维度,过Softmax

网络结构图:

图1-2

在TextCNN的实践中,有很多地方可以优化。

Filter尺寸:这个参数决定了抽取n-gram特征的长度,这个参数主要跟数据有关,平均长度在50以内的话,用10以下就可以了,否则可以长一些。在调参时可以先用一个尺寸grid search,找到一个最优尺寸,然后尝试最优尺寸和附近尺寸的组合

Filter个数:这个参数会影响最终特征的维度,维度太大的话训练速度就会变慢。使用100-600之间即可

CNN的激活函数:选择Identity、ReLU、tanh

正则化:指对CNN参数的正则化,可以使用dropout或L2,但能起的作用很小,可以试下小的dropout率(<0.5),L2限制大一点

Pooling方法:根据情况选择mean、max、k-max pooling,大部分时候max表现就很好,因为分类任务对细粒度语义的要求不高,只抓住最大特征就好了。

Embedding表:中文选择char或word级别的输入,也可以两种都用,会提升些效果。如果训练数据充足(10w+),也可以从头训练

蒸馏BERT的logits,利用领域内无监督数据。

加深全连接:加到3、4层左右效果会更好。

TextCNN是很适合中短文本场景的强baseline,但不太适合长文本,因为卷积核尺寸通常不会设很大,无法捕获长距离特征。同时max-pooling也存在局限,会丢掉一些有用特征。

简单流程图:

图1-3

二、具体实现

完整代码:

import os
from multiprocessing import cpu_count
import numpy as np
import paddle
import paddle.fluid as fluid
import matplotlib.pyplot as plt
paddle.enable_static()
data_root_path='./data/'

#创建数据集
def create_data_list(data_root_path):
   with open(data_root_path + 'test_list.txt', 'w') as f:
       pass
   with open(data_root_path + 'train_list.txt', 'w') as f:
       pass

   with open(os.path.join(data_root_path, 'dict_txt.txt'), 'r', encoding='utf-8') as f_data:
       dict_txt = eval(f_data.readlines()[0])

   with open(os.path.join(data_root_path, 'data/Train.txt'), 'r', encoding='utf-8') as f_data:
       lines = f_data.readlines()
   i = 0
   for line in lines:
       title = line.split('\t')[-1].replace('\n', '')
       l = line.split('\t')[0]
       labs = ""
       if i % 10 == 0:
           with open(os.path.join(data_root_path, 'test_list.txt'), 'a', encoding='utf-8') as f_test:
               for s in title:
                   lab = str(dict_txt[s])
                   labs = labs + lab + ','
               labs = labs[:-1]
               labs = labs + '\t' + l + '\n'
               f_test.write(labs)
       else:
           with open(os.path.join(data_root_path, 'train_list.txt'), 'a', encoding='utf-8') as f_train:
               for s in title:
                   lab = str(dict_txt[s])
                   labs = labs + lab + ','
               labs = labs[:-1]
               labs = labs + '\t' + l + '\n'
               f_train.write(labs)
       i += 1
   print("数据列表生成完成!")

#把下载得数据生成一个字典
def create_dict(data_path, dict_path):
    dict_set = set()
    # 统计有多少种类别,分别对应的id,并显示出来以供后面的预测应用
    id_and_className = {}

    # 读取已经下载得数据
    with open(data_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    # 把数据生成一个元组
    for line in lines:
        lineList = line.split('\t')
        title = lineList[-1].replace('\n', '')
        classId = lineList[0]
        if classId not in id_and_className.keys():
            id_and_className[classId] = lineList[1]

        for s in title:
            dict_set.add(s)
    # 把元组转换成字典,一个字对应一个数字
    dict_list = []
    i = 0
    for s in dict_set:
        dict_list.append([s, i])
        i += 1
    # 添加未知字符
    dict_txt = dict(dict_list)
    end_dict = {"<unk>": i}
    dict_txt.update(end_dict)
    # 把这些字典保存到本地中
    with open(dict_path, 'w', encoding='utf-8') as f:
        f.write(str(dict_txt))
    print("数据字典生成完成!")
    print('类Id及其类别名称:', id_and_className)
# 获取字典的长度
def get_dict_len(dict_path):
    with open(dict_path, 'r', encoding='utf-8') as f:
        line = eval(f.readlines()[0])

    return len(line.keys())
def data_mapper(sample):
   data, label = sample
   dataList=[]
   for e in data.split(','):
        if e=='':
           print('meet blank')
        else:
            dataList.append(np.int64(e))
   return dataList, int(label)

# 创建数据读取器train_reader
def train_reader(train_list_path):
   def reader():
       with open(train_list_path, 'r') as f:
           lines = f.readlines()
           # 打乱数据
           np.random.shuffle(lines)
           # 开始获取每张图像和标签
           for line in lines:
               data, label = line.split('\t')
               yield data, label
   return paddle.reader.xmap_readers(data_mapper, reader, cpu_count(), 1024)
#  创建数据读取器test_reader
def test_reader(test_list_path):
   def reader():
       with open(test_list_path, 'r') as f:
           lines = f.readlines()
           for line in lines:
               data, label = line.split('\t')
               yield data, label
   return paddle.reader.xmap_readers(data_mapper, reader, cpu_count(), 1024)

#   网络定义
def CNN_net(data, dict_dim, class_dim=14, emb_dim=128, hid_dim=128, hid_dim2=98):
    emb = fluid.layers.embedding(input=data,size=[dict_dim, emb_dim])
    conv_1 = fluid.nets.sequence_conv_pool(
           input=emb,
           num_filters=hid_dim,
           filter_size=3,
           act="tanh",
           pool_type="max")
    conv_2 = fluid.nets.sequence_conv_pool(
           input=emb,
           num_filters=hid_dim,
           filter_size=4,
           act="tanh",
           pool_type="max")
    conv_3 = fluid.nets.sequence_conv_pool(
            input=emb,
            num_filters=hid_dim2,
            filter_size=4,
            act="tanh",
            pool_type="max")
    fc1 = fluid.layers.fc(input=[conv_1, conv_2,conv_3], size=128, act='softmax')
    bn = fluid.layers.batch_norm(input=fc1, act='relu')
    fc2 = fluid.layers.fc(input= bn, size=64, act='softmax')
    bn1 = fluid.layers.batch_norm(input=fc2, act='relu')
    fc3 = fluid.layers.fc(input= bn1, size=class_dim, act='softmax')
    return fc3

# 定义绘制训练过程的损失值和准确率变化趋势的方法draw_train_process
all_train_iter=0
all_train_iters=[]
all_train_costs=[]
all_train_accs=[]
def draw_train_process(title,iters,costs,accs,label_cost,lable_acc):
    plt.title(title, fontsize=24)
    plt.xlabel("iter", fontsize=20)
    plt.ylabel("cost/acc", fontsize=20)
    plt.plot(iters, costs,color='red',label=label_cost)
    plt.plot(iters, accs,color='green',label=lable_acc)
    plt.legend()
    plt.grid()
    plt.show()

# 把生产的数据列表都放在自己的总类别文件夹中
data_path = os.path.join(data_root_path, 'data/Train.txt')
dict_path = os.path.join(data_root_path, "dict_txt.txt")
# 创建数据字典
create_dict(data_path, dict_path)
# 创建数据列表
create_data_list(data_root_path)

words = fluid.layers.data(name='words', shape=[1], dtype='int64', lod_level=1)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# 获取数据字典长度
dict_dim = get_dict_len('data/dict_txt.txt')
# 获取卷积神经网络
# 获取分类器
model = CNN_net(words, dict_dim)
# 获取损失函数和准确率
cost = fluid.layers.cross_entropy(input=model, label=label)
avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=model, label=label)

# 获取预测程序
test_program = fluid.default_main_program().clone(for_test=True)
# 定义优化方法
optimizer = fluid.optimizer.AdagradOptimizer(learning_rate=0.002)
opt = optimizer.minimize(avg_cost)

# 创建一个执行器,CPU训练速度比较慢
# 定义使用CPU还是GPU,使用CPU时use_cuda = False,使用GPU时use_cuda = True
use_cuda =  False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
#place = fluid.CPUPlace()
#place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
# 进行参数初始化
exe.run(fluid.default_startup_program())
train_reader = paddle.batch(reader=train_reader('./data/train_list.txt'), batch_size=128)
test_reader = paddle.batch(reader=test_reader('./data/test_list.txt'), batch_size=128)
feeder = fluid.DataFeeder(place=place, feed_list=[words, label])

EPOCH_NUM=5
model_save_dir = './infer_model/'
# 开始训练

for pass_id in range(EPOCH_NUM):
   # 进行训练
   for batch_id, data in enumerate(train_reader()):
       train_cost, train_acc = exe.run(program=fluid.default_main_program(),
                            feed=feeder.feed(data),
                            fetch_list=[avg_cost, acc])
       all_train_iter = all_train_iter + 100
       all_train_iters.append(all_train_iter)
       all_train_costs.append(train_cost[0])
       all_train_accs.append(train_acc[0])
       if batch_id % 100 == 0:
           print('Train Pass:%d, Batch:%d, Cost:%0.5f, Acc:%0.5f' % (pass_id, batch_id, train_cost[0], train_acc[0]))
   # 进行测试
   test_costs = []
   test_accs = []
   for batch_id, data in enumerate(test_reader()):
       test_cost, test_acc = exe.run(program=test_program,
                                             feed=feeder.feed(data),
                                             fetch_list=[avg_cost, acc])
       test_costs.append(test_cost[0])
       test_accs.append(test_acc[0])
   # 计算平均预测损失在和准确率
   test_cost = (sum(test_costs) / len(test_costs))
   test_acc = (sum(test_accs) / len(test_accs))
   print('Test:%d, Cost:%0.5f, ACC:%0.5f' % (pass_id, test_cost, test_acc))

if not os.path.exists(model_save_dir):
   os.makedirs(model_save_dir)
fluid.io.save_inference_model(model_save_dir,
                           feeded_var_names=[words.name],
                           target_vars=[model],
                           executor=exe)
print('训练模型保存完成!')
draw_train_process("training", all_train_iters, all_train_costs, all_train_accs, "trainning cost", "trainning acc")

说明:本次网络仿照TextCNN样式设计,相关参数可自行调试选出最优数值,预测模块已被删除,训练集的选择是经处理的训练集

训练集地址:中文新闻文本标题分类 - 飞桨AI Studio (baidu.com)

三、结果及分析 

5轮训练:

图2-1

图2-2

图2-3

图2-4

图2-5

图2-6

图2-7

分析:从数值变化上来看,从0到0.9所花的训练较少,趋向很快,由于数据集为THUCNews(740000多条数据)新闻标题,作为输入数据的它并没有图片大,所以CNN网络处理的速度一般较快,cpu运行所用时间约1.5-2.5个小时,改用GPU的话速度应该会更快。测试的数值呈线性,并没有发现随着训练的增加而出现数值下降倾向。

从图像(图2-7)上看,一目了然,随着训练量的增加,准确率上升、误差减少,趋向稳定后,准确率在0.9之间波动,而误差(损失值)0.1-0.4之间波动。

四、总结

处理NLP任务首先需要选择合适的网络模型,比如TextCNN、TextRNN、LSTM、GRU、BiLSTM、RCNN、EntNet等等。当然,有些NLP任务也可以用机器学习方法去解决,至于哪种任务用哪种方法,需要根据实际情况去选择。解决一个NLP的任务可能有多种方案,但是哪一个方案更合适需要我们不断地去分析尝试。比如,二次文本分类,可以尝试着去组合多种网络,以求达到最优效果。确定NLP任务后,首先需要对数据进行分析,任务具体是干什么需要什么功能,并且要深入地分析理解数据,知道数据的含义,这样可以帮助制定解决方案,同时也有利于进行数据预处理。数据要采取什么样的处理方式,需要对数据进行深入地分析后才能知道。确定好处理方式后,预处理数据。这一步和前面的数据分析关系很强,很多预处理操作都是基于对数据的分析而来,一般对文本预处理包含分词、去除停用词、训练词向量、文本序列化等等,当然,对于有的任务还包含同义词替换、训练词权重等等。再接着就是搭建模型,具体使用什么模型得根据具体任务来定。最后就是优化模型,常用的操作有调参、更改网络结构、针对评价指标优化等等。

参考:文本分类算法总结 - 知乎 (zhihu.com)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

中文新闻文本标题分类(基于飞桨、Text CNN) 的相关文章

随机推荐

  • 希尔排序

    目录 一 原理 二 示例代码 三 算法分析 希尔排序又称为缩小增量排序 是直接插入排序算法的一种更高效的改进版本 希尔排序是基于插入排序的以下两点性质而提出改进方法的 插入排序在对几乎已经排好序的数据操作时 效率高 即可以达到线性排序的效率
  • WebGL加载跨域纹理的解决方法

    本人一直对WebGL很有兴趣 也试着尝试用osgjs写了个DEMO 很成功的出现了效果 可是当自己用ASP net写了个服务端 想用自己写的服务器提供的数据来用做纹理 可是怎么也不出来 还报错 跟了下代码 发现是用作纹理的Image对象的问
  • 华为 5G、阿里检测病毒算法、腾讯 AI 一分钟诊断,国内抗疫科技大阅兵!

    作者 马超 责编 王晓曼 伍杏玲 出品 CSDN ID CSDNnews 近期以来 国际风云不断变换 而在2020年初疫情肆虐期间 也成为我国科技实力的 大检阅 近期人民网官微致敬我们中国科技企业的排头兵 可以说掌握硬核科技成了全民的共识
  • Metasploit(MSF)基础超级详细版

    MSF基础学习看这一篇就够了 Metasploit 常见名词解释 MSF简介 MSF框架结构 MSF配置数据库 内网主机发现 MSF命令查询 常用命令 数据库管理命令 核心命令 模块命令 进程命令 资源脚本命令 后台数据库命令 后端凭证命令
  • 假如“唐僧团队”裁员,你会先裁掉哪一位

    相信很多人看过水煮三国 大话西游 文中去西天取经的4人被影射成一个团队 其中 唐僧是TeamLeader 性格坚韧 目的明确 讲原则 懦弱没主意 孙悟空是团队中那个创意员工 业绩突出却个性极强 屡屡得罪人 猪八戒就好比那为人圆滑 偏偏干活时
  • 【ESP系列】AT指令案例

    前言 ESP系列芯片具有高性价比的联网功能 广受大家的认可 然而 在开发过程中 有时候我们想要使用ESP系列芯片的联网功能 却又不想为此编写繁杂的联网逻辑 串口交互逻辑等等 此时 我们可以运用AT指令来实现简洁的联网控制 本文将介绍这种基于
  • linux命令之查看jvm内存使用情况

    linux命令之查看jvm内存使用情况 1 使用 ps ef grep java 查询java的进程ID 2 使用jstat命令查看堆内存的使用情况 1 垃圾回收统计 jstat gc 进程ID 参数解释 S0C 第一个幸存区的大小 S1C
  • 【vue-treeselect】数据量大的时候懒加载并且可以搜索,树懒加载+搜索

    这两天快被这个懒加载加搜索搞崩溃了 今天小有收获 后面优化了再更新 主要说一下一棵树如何懒加载和搜索 1 ref不解释了 和本次代码无关 2 normalizer格式化内容不重要 3 load options很关键 4 search cha
  • Qt5 C++源码中使用中文的简单步骤

    本文不讲任何道理 当你在Qt5的C 源文件内使用中文时 你只需按顺序简单照做即可 不止是中文 其实你完全可以在代码中使用日韩法俄语等等各国语言 0 通用 源文件保存为带BOM的UTF 8格式 如果你准备跨平台 保存为带BOM的UTF 8是必
  • 计算机迭代步数英语,迭代计算

    迭代法是数值计算中一类典型方法 应用于方程求根 方程组求解 矩阵求特征值等方面 其基本思想是逐次逼近 先取一个粗糙的近似值 然后用同一个递推公式 反复校正此初值 直至达到预定精度要求为止 1 迭代计算次数指允许公式反复计算的次数 在Exce
  • 毕业设计记录(二):基于VUE框架与ECharts和Axios技术结合的Web移动高校实验室管理系统设计与实现

    目录 点击即跳转 参考文献阅读笔记 空间信息与规划系实验室情况统计表 毕业设计进度 前端设计 登陆界面 未美术优化 参考文献 总 参考文献阅读笔记 2 甄翠明 李克 基于Web的高校计算机实验室预约管理系统的研究与设计 J 现代信息科技 2
  • 【速度↑20%模型尺寸↓36%】极简开源人脸检测算法升级

    经过一年的各种尝试 调试 测试以及无数次失败 我们的开源人脸检测算法再次升级 我们团队专注人脸检测优化十几年 一直持续优化 向着最简单的算法努力 新版本提升 计算量更小 速度提升约20 模型尺寸精简36 85K参数降低至54K 准确率有所提
  • so库的反编译,反汇编

    Linux APP SO的反汇编工具 ida Pro 可以反汇编app和SO库 有函数名 但是不能反编译到code这一级别 下载最强的反编译工具 ida Pro 6 4 Plus rar 还有这个反汇编工具 没用过 转自 http bbs
  • protobuf的序列化和反序列化的分析

    一 protobuf的optional 数据类型序列化分析 1 optional 的protobuf的文件 格式 syntax proto2 message test proto optional int32 proto1 1 option
  • thinkphp5.0.24反序列化漏洞分析

    thinkphp5 0 24反序列化漏洞分析 文章目录 thinkphp5 0 24反序列化漏洞分析 具体分析 反序列化起点 toArray getRelationData分析 modelRelation生成 进入 call前的两个if c
  • 初步学习Oracle之PL/SQL

    PL SQL简介 PL SQL Procedure Language SQL 程序语言是 Oracle 对 sql 语言的过程化扩展 指在 SQL 命令语言中增加了过程处理语句 如分支 循环等 使 SQL 语言具有过程处理能力 把SQL 语
  • 【满分】【华为OD机试真题2023 JS】最差产品奖

    华为OD机试真题 2023年度机试题库全覆盖 刷题指南点这里 最差产品奖 知识点滑窗 时间限制 1s 空间限制 256MB 限定语言 不限 题目描述 A公司准备对他下面的N个产品评选最差奖 评选的方式是首先对每个产品进行评分 然后根据评分区
  • 在Android Studio中使用vulkan

    首先要确定手机是否支持Vulkan 可以下载一个aida64 在设备中如果能找到vulkan设备 说明支持 否则不支持 严格按照官方介绍的步骤一步步执行 就能获得官方推荐的可执行的例子 有很多 可以都试一试 那怎么在自己的工程中使用vulk
  • Vue模版语法&2种数据绑定方式

    Vue模板语法有2大类 1 插值语法 功能 用于解析标签体内容 写法 xx 其中xx是js表达式 且可以直接读取到data中的所有属性 p value p 在双大括号中 除了可以简单的传值外 还可以使用表达式 每个绑定都只能包含单个表达式
  • 中文新闻文本标题分类(基于飞桨、Text CNN)

    目录 一 设计方案概述 二 具体实现 三 结果及分析 四 总结 一 设计方案概述 主要网络模型设计 设计所使用网络模型为TextCNN 由于其本身就适用于短中句子 在标题分类这一方面应该能发挥其优势 TextCNN是Yoon Kim在201