深度学习:将新闻报道按照不同话题性质进行分类

2023-11-17

深度学习的广泛运用之一就是对文本按照其内容进行分类。例如对新闻报道根据其性质进行划分是常见的应用领域。在本节,我们要把路透社自1986年以来的新闻数据按照46个不同话题进行划分。网络经过训练后,它能够分析一篇新闻稿,然后按照其报道内容,将其归入到设定好的46个话题之一。深度学习在这方面的应用属于典型的“单标签,多类别划分”的文本分类应用。

我们这里采用的数据集来自于路透社1986年以来的报道,数据中每一篇新闻稿附带一个话题标签,以用于网络训练,每一个话题至少含有10篇文章,某些报道它内容很明显属于给定话题,有些报道会模棱两可,不好确定它到底属于哪一种类的话题,我们先把数据加载到机器里,代码如下:

from keras.datasets import reuters
(train_data, train_label), (test_data, test_labels) = reuters.load_data(num_words=10000)

keras框架直接附带了相关数据集,通过执行上面代码就可以将数据下载下来。上面代码运行后结果如下:
这里写图片描述
从上面运行结果看,它总共有8982条训练数据和2246条测试数据。跟我们上节数据类型一样,数据里面对应的是每个单词的频率编号,我们可以通过上一节类似的代码,将编号对应的单词从字典中抽取出来结合成一篇文章,代码如下:

word_index = reuters.get_word_index()
reverse_word_index = dict([value, key] for (key, value) in word_index.items())
decoded_newswire = ' '.join([reverse_word_index.get(i-3, '?') for i in train_data[0]])
print(decoded_newswire)

上面代码运行后结果如下:
这里写图片描述

如同上一节,我们必须要把训练数据转换成数据向量才能提供给网络进行训练,因此我们像上一节一样,对每条新闻创建一个长度为一万的向量,先把元素都初始为0,然后如果某个对应频率的词在文本中出现,那么我们就在向量中相应下标设置为1,代码如下:

import numpy as np
def vectorize_sequences(sequences, dimension=10000):
    results = np.zeros((len(sequences), dimension))
    for i, sequence in enumerate(sequences):
        results[i, sequence] = 1.
    return results

x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)

print(x_train[0])

上面代码运行后,我们就把训练数据变成含有1或0的向量了:

这里写图片描述

其实我们可以直接调用keras框架提供的接口一次性方便简单的完成:

from keras.utils.np_utils import to_categorical

one_hot_train_labels = to_categorical(train_label)
one_hot_test_labels = to_categorical(test_labels)

接下来我们可以着手构建分析网络,网络的结构与上节很像,因为要解决的问题性质差不多,都是对文本进行分析。然而有一个重大不同在于,上一节我们只让网络将文本划分成两种类别,而这次我们需要将文本划分为46个类别!上一节我们构造网络时,中间层网络我们设置了16个神经元,由于现在我们需要在最外层输出46个结果,因此中间层如果只设置16个神经元那就不够用,由于输出的信息太多,如果中间层神经元数量不足,那么他就会成为信息过滤的瓶颈,因此这次我们搭建网络时,中间层网络节点扩大为6个,代码如下:

from keras import models
from keras import layers

model = models.Sequential()
model.add(layers.Dense(64, activation='relu', input_shape=(10000,)))
model.add(layers.Dense(64, activation='relu'))
#当结果是输出多个分类的概率时,用softmax激活函数,它将为46个分类提供不同的可能性概率值
model.add(layers.Dense(46, activation='softmax'))

#对于输出多个分类结果,最好的损失函数是categorical_crossentropy
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

像上一节一样,在网络训练时我们要设置校验数据集,因为网络并不是训练得次数越多越好,有了校验数据集,我们就知道网络在训练几次的情况下能够达到最优状态,准备校验数据集的代码如下:

x_val = x_train[:1000]
partial_x_train = x_train[1000:]

y_val = one_hot_train_labels[:1000]
partial_y_train = one_hot_train_labels[1000:]

有了数据,就相当于有米入锅,我们可以把数据输入网络进行训练:

history = model.fit(partial_x_train, partial_y_train, epochs=20, batch_size=512, 
                   validation_data=(x_val, y_val))

代码进行了20个周期的循环训练,由于数据量比上一节小,因此速度快很多,与上一节一样,网络的训练并不是越多越好,它会有一个拐点,训练次数超出后,效果会越来越差,我们把训练数据图形化,以便观察拐点从哪里开始:

import matplotlib.pyplot as plt
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(loss) + 1)

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()

上面代码运行后结果如下:

这里写图片描述

通过上图观察我们看到,以蓝点表示的是网络对训练数据的判断准确率,该准确率一直在不断下降,但是蓝线表示的是网络对校验数据判断的准确率,仔细观察发现,它一开始是迅速下降的,过了某个点,达到最低点后就开始上升,这个点大概是在epochs=9那里,所以我们把前面对网络训练的循环次数减少到9:

from keras import models
from keras import layers

model = models.Sequential()
model.add(layers.Dense(64, activation='relu', input_shape=(10000,)))
model.add(layers.Dense(64, activation='relu'))
#当结果是输出多个分类的概率时,用softmax激活函数,它将为46个分类提供不同的可能性概率值
model.add(layers.Dense(46, activation='softmax'))

#对于输出多个分类结果,最好的损失函数是categorical_crossentropy
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

history = model.fit(partial_x_train, partial_y_train, epochs=9, batch_size=512, 
                   validation_data=(x_val, y_val))

完成训练后,我们把结果输出看看:

results = model.evaluate(x_test, one_hot_test_labels)
print(results)

上面两句代码运行结果为:
这里写图片描述
右边0.78表示,我们网络对新闻进行话题分类的准确率达到78%,差一点到80%。我们从测试数据集中拿出一条数据,让网络进行分类,得到结果再与其对应的正确结果比较看看是否一致:

predictions = model.predict(x_test)
print(predictions[0])
print(np.sum(predictions[0]))
print(np.argmax(predictions[0]))
print(one_hot_test_labels[0])

我们让网络对每一条测试数据一一进行判断,并把它对第一条数据的判断结果显示出来,最后我们打印出第一条测试数据对应的分类,最后看看网络给出去的结果与正确结果是否一致,上面代码运行后结果如下:

这里写图片描述

从上面运行结果看到,网络对第一条数据给出了属于46个分类的概率,其中下标为3的概率值最大,也就是第一条数据属于分类4的概率最大,最后打印出来的测试数据对应的正确结果来看,它也是下标为3的元素值为1,也就是说数据对应的正确分类是4,由此我们网络得到的结果是正确的。

前面提到过,由于网络最终输出结果包含46个元素,因此中间节点的神经元数目不能小于46,因为小于46,那么有关46个元素的信息就会遭到挤压,于是在层层运算后会导致信息丢失,最后致使最终结果的准确率下降,我们试试看是不是这样:

from keras import models
from keras import layers

model = models.Sequential()
model.add(layers.Dense(64, activation='relu', input_shape=(10000,)))
model.add(layers.Dense(4, activation='relu'))
#当结果是输出多个分类的概率时,用softmax激活函数,它将为46个分类提供不同的可能性概率值
model.add(layers.Dense(46, activation='softmax'))

#对于输出多个分类结果,最好的损失函数是categorical_crossentropy
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

history = model.fit(partial_x_train, partial_y_train, epochs=9, batch_size=512, 
                   validation_data=(x_val, y_val))
results = model.evaluate(x_test, one_hot_test_labels)
print(results)

上面代码运行后,输出的results结果如下:
[1.4625472680649796, 0.6705253784505788]

从上面结果看到,我们代码几乎没变,致使把第二层中间层神经元数量改成4,最终结果的准确率就下降10个点,所以中间层神经元的减少导致信息压缩后,最后计算的准确度缺失。反过来你也可以试试用128个神经元的中间层看看准确率有没有提升。

到这里不知道你发现没有,神经网络在实际项目中的运用有点类似于乐高积木,你根据实际需要,通过选定参数,用几行代码配置好基本的网络结构,把训练数据改造成合适的数字向量,然后就可以输入到网络中进行训练,训练过程中记得用校验数据监测最优训练次数,防止过度拟合。

在网络的设计过程中,其背后的数学原理我们几乎无需了解,只需要凭借经验,根据项目的性质,设定网络的各项参数,最关键的其实在根据项目数据性质对网络进行调优,例如网络设置几层好,每层几个神经元,用什么样的激活函数和损失函数等等,这些操作与技术无关,取决以个人经验,属于“艺术”的范畴。

更详细的讲解和代码调试演示过程,请点击链接

更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:
这里写图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习:将新闻报道按照不同话题性质进行分类 的相关文章

  • Android 反编译Apk,修改资源,重新打包,签名发布

    本文简单介绍apk是如何修改logo ic launcher 类似的资源文件修改也可以通过此方式 不过要修改class的话就要涉及到smali的学习了 这里就暂且不谈 后续有需要再做更新 一 工具介绍 apktool 用来反编译apk ap
  • 【华为OD机试真题 JAVA】最多的连续胡杨棵树

    标题 最多的连续胡杨棵树 时间限制 1秒 内存限制 262144K 语言限制 不限 近些年来 我国防沙治沙取得显著成果 某沙漠新种植N棵胡杨 编号1 N 排成一排 一个月后 有M棵胡杨未能成活 现可补种胡杨K棵 请问如何补种 只能补种 不能
  • 【JS基础】通俗易懂的讲清楚去抖/防抖、节流。外加手写深度比较

    文章目录 去抖 防抖 思路解析 节流 两者在vue中结合计算属性使用 深度比较 去抖 防抖 去抖也叫防抖 为了照顾JS初学者的理解和记忆 我就简单的说明一下 我们生活中很多出现抖动的现象 都是没有规律的 例如人的发抖 树叶在风中的抖动 海浪
  • java mysql 断开连接_mysql java连接异常及断开解决秘籍

    3 The last packet sent successfully to the server was 0 milliseconds ago The driver has not received any packets from th
  • 前端一年的经验,面试官都会问一些什么问题呢?都是这样一些的问题

    面试准备阶段 学习以及复习基础知识 这一定是第一步需要做的事情 先制定规划 然后按照这一条既定的规划去学习以及复习 可分为六部分去准备 css部分 像 css这一部分 面试必问 但是它的东西很杂很多 我不知道有多少人和我感觉一样 学习前端最
  • Oracle中Delete和Commit操作的流程分析

    以后还会陆续加入其他各种操作的实现机制 1 删除 Delete 流程 Oracle读Block 数据块 到Buffer Cache 缓冲区 如果该Block在Buffer中不存在 在Redo Log Buffer 重做日志缓冲区 中记录De
  • Leetcode【DFS BFS】

    Leetcode 200 岛屿数量 题目 解题 思路 DFS解法 BFS解法 题目 给你一个由 1 陆地 和 0 水 组成的的二维网格 请你计算网格中岛屿的数量 岛屿总是被水包围 并且每座岛屿只能由水平方向和 或竖直方向上相邻的陆地连接形成
  • ES6 method写法与TypeError: is not a constructor

    公司前端最近开始强推ESlint 很多文件需要逐步修改为符合ESlint规则的形式 结果遇到了一个神奇的问题 有一段类似这样的代码 let obj init function el 此处ESlint检查提示 Expect method sh
  • k8s部署tomcat及web应用_在k8s部署tomcat

    小试牛刀 准备编排文件tomcat yaml 包含两部分 副本rc和service配置可为两个文件 不过我们此处合并为一个 rc副本相关 apiVersion extensions v1beta1 表示Deployment调度配置 kind

随机推荐

  • Keras默认权值初始化方式

    20230117 在最初使用Keras进行神经网络编程的时候 除了设置神经元个数 层数 或者激活函数之后 基本上对神经网络内部就不怎么管了 所以最后很多参数都是默认的 这种情况一般遇到的数据集问题 都能轻易解决 一般不是层数非常深的神经网络
  • 【华为OD统一考试A卷

    华为OD统一考试A卷 B卷 新题库说明 2023年5月份 华为官方已经将的 2022 0223Q 1 2 3 4 统一修改为OD统一考试 A卷 和OD统一考试 B卷 你收到的链接上面会标注A卷还是B卷 请注意 根据反馈 目前大部分收到的都是
  • Kali系统(Debian 10.3) 遇到的问题

    目录 问题一 Kali系统 相关技术网站 博客 文章 论坛 工具包 包跟踪 提交BUG 问题二 黑客入门 手痒地方 问题三 Kali系统 MySQL问题Can t connect to local MySQL server through
  • 边缘计算操作系统安装及测试实验报告

    边缘计算操作系统安装及测试 一 实验目的 二 实验环境 三 实验原理 1 系统组成部分 2 总体数据流程 四 实验步骤及结果 1 安装 Docker 和 Docker Compose 2 下载 EdgeX compose 文件 3 运行Ed
  • qt中clicked(bool checked)和toggled(bool checked)的区别

    先来看qt文档的解释 上面看出 共同点是 当点击按钮时 状态信号都会被发送 不同点 clicked this signal is not emitted if you call setDown setChecked or toggle to
  • 5年测试面试要20K,面试三个问题把我打发走了···

    都说金三银四 金九银十跳槽涨薪季 我是着急忙慌的准备简历 5年软件测试经验 可独立测试大型产品项目 熟悉项目测试流程 薪资要求 5年测试经验起码能要个20K吧 我加班肝了一页半简历 投出去一周 面试电话倒是不少 自信满满去面试 现场被问了这
  • Nmap源码分析(服务与版本扫描)

    Nmap源码分析 服务与版本扫描 2012年8月23日 在进行端口扫描后 Nmap可以进一步探测出运行在端口上的服务类型及应用程序的版本 目前Nmap可以识别几千种服务程序的签名 Signature 覆盖了180多种应用协议 比如 端口扫描
  • java写后端接口中mapper的一些操作

    文章目录 Mybatis Mapper的动态SQL语句问题 一 if 二 choose when otherwise 三 where 四 trim 元素来定制 where 元素的功能 五 set 动态地在行首插入 SET 关键字 六 for
  • PTA 7-4 统计学生平均成绩与及格人数 (15 分)

    本题要求编写程序 计算学生们的平均成绩 并统计及格 成绩不低于60分 的人数 题目保证输入与输出均在整型范围内 输入格式 输入在第一行中给出非负整数N 即学生人数 第二行给出N个非负整数 即这N位学生的成绩 其间以空格分隔 输出格式 按照以
  • C语言函数大全-- y 开头的函数

    y 开头的函数 1 yperror 1 1 函数说明 1 2 演示示例 2 yp match 2 1 函数说明 2 2 演示示例 3 y0 零阶第二类贝塞尔函数 3 1 函数说明 3 2 演示示例 3 3 运行结果 4 y1 一阶第二类贝塞
  • 在Vue中使用flex布局 echarts多图标不能自适应缩放问题

    前言 最近有个项目需要用到echarts绘制多个图表 需求是要支持大屏展示 还有需要支持不同比例的缩放和任意手动缩放 因此 深入学习了echarts和flex布局 虽然遇到很多问题 但都一一解决了收获良多 故此写下遇到的问题与坑 与之共勉
  • go 进阶 多路复用支持: 二. Accept/Read/Write

    目录 一 通过httpServer服务端引用Accept 二 Listener Accept 等待连接 三 Conn Read读数据 Conn Write写数据 四 gopark 阻塞 五 netpoll 唤醒等待队列中挂起的协程 什么时候
  • C#桌面应用程序打包

    使用微软的技术开发windows桌面应用程序是很快捷方便的 开发完之后肯定要打包安装才能发布 以前有做过但过长时间没有打包一下子还真有些遗忘 今天专门又重温了一些 干脆写下来算是加深些印象 以后需要时也好有个参考 感觉有很多技术上手都没有太
  • std::bind可以绑定成员变量

    include
  • java student类_java定义一个Student类,包含内容如下

    展开全部 public class Student 成员变量 学号 姓名 性别 班干部否 数学 语文 外语 成62616964757a686964616fe58685e5aeb931333337613166员方法 输入 总分 平均分 编程实
  • MeterSphere实践指南汇总,搬砖党

    闲来无事 整理了MeterSphere实践指南 我司用了MeterSphere一段时间 感觉挺好用 百度网盘下载链接 链接 https pan baidu com s 1s8sAuz31lgnvTRTLkWZuiQ pwd 98bg 提取码
  • 我的算法笔记(1)——冒泡排序

    我的算法笔记 1 冒泡排序 排序是指将一个无序序列按某个规则进行有序排列 而冒泡排序是排序算法中最基础的一种 现给出一个序列a 其中元素的个数为n 要求将他们按从小到大的顺序排序 冒泡排序的本质在于交换 即每次通过交换的方式把当前剩余元素的
  • BP神经网络阈值

    在基于神经网络的数据融合文章里 有改进权值和阈值 但是没有说清阈值到底是什么 神经网络是模仿大脑的神经元 当外界刺激达到一定的阈值时 神经元才会受到刺激 影响下一个神经元 简单来说 超过阈值 就会引起某一变化 不超过阈值 无论是多少 都不产
  • 【数据库实验】sql总结

    首先说明 以下大部分针对的是标准sql 目录 结构 关键词 关于模式 创建模式 删除模式 关于表 创建表 修改表 删除表 关于索引 建立索引 修改索引 删除索引 关于查询 几个点 指定列 全部列 经过计算的值 列的别名 方便查看 以及聚集函
  • 深度学习:将新闻报道按照不同话题性质进行分类

    深度学习的广泛运用之一就是对文本按照其内容进行分类 例如对新闻报道根据其性质进行划分是常见的应用领域 在本节 我们要把路透社自1986年以来的新闻数据按照46个不同话题进行划分 网络经过训练后 它能够分析一篇新闻稿 然后按照其报道内容 将其