深度学习 从零开始 —— 神经网络(四),二分类问题,IMDB数据集使用

2023-11-08

IMDB数据集

互联网电影数据,包含50000条严重两极分化的评论。正面和负面评论各占50%。
而该数据集也同样被内置于Keras库中了。

其中的评论数据已经经过了预处理,评论(单词)被转化为了整数序列,每个整数都对应词典里面的一个单词。

加载数据集

from keras.datasets import imdb

(train_data,train_labels),(test_data,test_labels) = imdb.load_data(num_words=10000)
#第一条评论的单词索引列表
print(train_data[0])
#1表示正面品论,0表示负面评论
print(train_labels[0])
#取所有测试单词所有的最大的索引值
print(max([max(sequence) for sequence in train_data]))

num_words=10000是取最常出现的10000个单词,舍弃低频词语。减小向量数据量。
train_data和test_data都是评论的列表(单词索引组成)。
train_labels和test_labels都是0和1组成的列表。
在这里插入图片描述可以看到最大9999就是说索引不超过10000

第一次加载时会下载数据集
在这里插入图片描述

解码为英语

#某条评论解码为英文
word_index = imdb.get_word_index()
reverse_word_index = dict(
    [(value,key) for (key,value) in word_index.items()]
)
decoded_review = ' '.join(
    [reverse_word_index.get(i-3,'?') for i in train_data[0]]
)
print(decoded_review)

imdb.get_word_index()这个是单词映射整数的字典,将其键值颠倒,变为根据索引查单词。
其中i-3是因为0,1,2是为填充,序列开始,未知词保留的索引
在这里插入图片描述

处理数据

import numpy as np

#转换为10000维的向量,索引的位置是1,其他位置是0
def vectorize_sequences(sequences,dimension=10000):
    results = np.zeros((len(sequences),dimension))
    for i, sequence in enumerate(sequences):
        results[i,sequence] = 1.
    return results
#数据向量化
x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)
print(x_train[0])
#标签向量化
y_train = np.asarray(train_labels).astype('float32')
y_test = np.asarray(test_labels).astype('float32')

将每一条评论的索引值转化到10000维向量上。有该索引值的设为1,其余位置设为0.
例如一条评论为[1,3,5],则变为[0,1,0,1,0,1,0,0,0,0…,0]。
在这里插入图片描述

建立网络

from keras import models
from keras import layers

#模型定义
model = models.Sequential()
model.add(layers.Dense(16,activation='relu',input_shape=(10000,)))
model.add(layers.Dense(16,activation='relu'))
model.add(layers.Dense(1,activation='sigmoid'))
#定义优化器,损失函数,指标
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])

模型为三层。

  • 一二层都是16个隐藏单元,输入10000维的向量。做的操作除了张量的操作外,relu是激活函数之一,将线性的变换集合变成非线性。
  • 第三层输出1维,sigmod是将输出归一化,变成[0,1]的概率。

二分类问题,输出是一个概率值(正面评价的概率)。
因此选用交叉熵(crossentropy)来计算损失。
而优化器使用rmsprop。这是前任总结的最优于此示例的,至于为什么选这两个,后面章节在做学习。

验证集

为了在训练过程中监控模型对于未见过的数据上的精度,从数据中去出10000个样本用来做验证集。

#取10000用于验证集
x_val = x_train[:10000] #验证集
partial_x_train = x_train[10000:]

y_val = y_train[:10000] #验证集
partial_y_trail = y_train[10000:]

训练模型

#训练模型
history = model.fit(partial_x_train,
                    partial_y_trail,
                    epochs=20,
                    batch_size=512,
                    validation_data=(x_val,y_val))
                    
history_dict = history.history
print(history_dict.keys())

我们输入训练数据集partial_x_train,训练标签集partial_y_trail。
全部数据训练次数*20。
一次取512的个数据。
history 中包含了训练过程中的所有数据。
validation_data是前一步取出来的验证集。

history包含训练过程中的所有数据。打印history_dict所包含的指标。
在这里插入图片描述
可以看到其中包括验证精度,训练精度,验证损失,训练损失。

我们来用matplot绘制出这几个指标的走势:

损失变化图:

loss_values = history_dict['loss']
val_loss_values = history_dict['val_loss']

epochs = range(1,len(loss_values)+1)

plt.plot(epochs,loss_values,'bo',label='Training loss') #bo是蓝色圆点
plt.plot(epochs,val_loss_values,'b',label='Validation loss') #b是蓝色实线
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

精度变化图:

plt.clf() #清空图表

acc_values = history_dict['acc']
val_acc_values  = history_dict['val_acc']

plt.plot(epochs,acc_values,'bo',label='Training acc') #bo是蓝色圆点
plt.plot(epochs,val_acc_values,'b',label='Validation acc') #b是蓝色实线
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

在这里插入图片描述
在这里插入图片描述可以看出训练损失持续降低,训练精度持续增长。这满足梯度下降的优化预期。
但是可以看到验证损失验证精度,可以从图中看到,他们在第三到第四个周期达到了最佳。

总之,在训练数据上表现变好,但是在没有见过的验证数据上表现有变动。这就是过拟合(overfit)。
在第2轮后,对训练数据过度优化,最终学得的结果仅针对训练数据。无法泛华到训练集之外的数据。

就需要一种降低过拟合的方案。在后面的章节再来学习。

这里先用一种粗糙简单的方案。
我们只训练四轮。
修改训练参数epochs=4

#训练模型
history = model.fit(partial_x_train,
                    partial_y_trail,
                    epochs=4,
                    batch_size=512,
                    validation_data=(x_val,y_val))

在调用 fit() 训练之后添加测试集的测试代码,并输出结果:

result = model.evaluate(x_test,y_test)
print(result)

在这里插入图片描述
可见这种粗略的策略达到88%的精度。后面再来研究更好的降低过拟合,多做训练的策略。

训练好的网络,使用predict来对评论进行正面的可能性做预测。

predictResult = model.predict(x_test)
print(predictResult)

在这里插入图片描述
可以看到有非常确定的(>0.99或者<0.01),也有不确信的(0.4~0.6)。

这一节学习就到这。之后可以自行尝试
1.使用1层或者3层隐藏层
2.隐藏单元换成32或者64个
3.用损失函数mse
4.用激活函数tanh

整合上面全部代码:(epochs次数自行修改)

from keras.datasets import imdb
import numpy as np
from keras import models
from keras import layers
import matplotlib.pyplot as plt

(train_data,train_labels),(test_data,test_labels) = imdb.load_data(num_words=10000)
# #第一条评论的单词索引列表
# print(train_data[0])
# #1表示正面品论,0表示负面评论
# print(train_labels[0])
# #取所有测试单词所有的最大的索引值
# print(max([max(sequence) for sequence in train_data]))

# #某条评论解码为英文
# word_index = imdb.get_word_index()
# reverse_word_index = dict(
#     [(value,key) for (key,value) in word_index.items()]
# )
# decoded_review = ' '.join(
#     [reverse_word_index.get(i-3,'?') for i in train_data[0]]
# )
# print(decoded_review)

#转换为10000维的向量,索引的位置是1,其他位置是0
def vectorize_sequences(sequences,dimension=10000):
    results = np.zeros((len(sequences),dimension))
    for i, sequence in enumerate(sequences):
        results[i,sequence] = 1.
    return results
#数据向量化
x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)
print(x_train[0])
#标签向量化
y_train = np.asarray(train_labels).astype('float32')
y_test = np.asarray(test_labels).astype('float32')

#模型定义
model = models.Sequential()
model.add(layers.Dense(16,activation='relu',input_shape=(10000,)))
model.add(layers.Dense(16,activation='relu'))
model.add(layers.Dense(1,activation='sigmoid'))
#定义优化器,损失函数,指标
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])

#取10000用于验证集
x_val = x_train[:10000] #验证集
partial_x_train = x_train[10000:]

y_val = y_train[:10000] #验证集
partial_y_trail = y_train[10000:]

#训练模型
history = model.fit(partial_x_train,
                    partial_y_trail,
                    epochs=4,
                    batch_size=512,
                    validation_data=(x_val,y_val))

history_dict = history.history
print(history_dict.keys())

loss_values = history_dict['loss']
val_loss_values = history_dict['val_loss']

epochs = range(1,len(loss_values)+1)

plt.plot(epochs,loss_values,'bo',label='Training loss') #bo是蓝色圆点
plt.plot(epochs,val_loss_values,'b',label='Validation loss') #b是蓝色实线
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

plt.clf() #清空图表

acc_values = history_dict['accuracy']
val_acc_values  = history_dict['val_accuracy']

plt.plot(epochs,acc_values,'bo',label='Training acc') #bo是蓝色圆点
plt.plot(epochs,val_acc_values,'b',label='Validation acc') #b是蓝色实线
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

result = model.evaluate(x_test,y_test)
print(result)

predictResult = model.predict(x_test)
print(predictResult)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习 从零开始 —— 神经网络(四),二分类问题,IMDB数据集使用 的相关文章

  • 使用单个文件的 Python 日志记录(函数名、文件名、行号)

    我正在尝试了解应用程序的工作原理 为此 我将调试命令插入作为每个函数主体的第一行 目的是记录函数的名称以及向日志输出发送消息的行号 代码内 最后 由于这个应用程序由许多文件组成 我想创建一个日志文件 以便我可以更好地理解应用程序的控制流 这
  • Python Numpy Reshape错误[关闭]

    Closed 这个问题是无法重现或由拼写错误引起 help closed questions 目前不接受答案 我在尝试重塑 3D numpy 数组时遇到一个奇怪的错误 数组 x 的形状为 6 10 300 我想将其重塑为 6 3000 我正
  • 高效地将大型 Pandas 数据帧写入磁盘

    我正在尝试找到使用 Python Pandas 高效地将大型数据帧 250MB 写入磁盘或从磁盘写入的最佳方法 我已经尝试了所有方法Python 数据分析 但表现却非常令人失望 这是一个更大项目的一部分 该项目探索将我们当前的分析 数据管理
  • 检查子字符串是否在字符串列表中?

    我之前已经找到了这个问题的一些答案 但它们对于当前的Python版本来说似乎已经过时了 或者至少它们对我不起作用 我想检查字符串列表中是否包含子字符串 我只需要布尔结果 我找到了这个解决方案 word to check or wordlis
  • 如何将脚本作为 pytest 测试运行

    假设我有一个用简单脚本表示的测试assert 陈述 请参阅背景了解原因 例如 import foo assert foo 3 4 我如何以一种好的方式将该脚本包含在我的 pytest 测试套件中 我尝试了两种有效但不太好的方法 一种方法是将
  • 如何调试 numpy 掩码

    这个问题与this one https stackoverflow com q 73672739 11004423 我有一个正在尝试矢量化的函数 这是原来的函数 def aspect good angle float planet1 goo
  • 当我从本地计算机更改为虚拟主机时,从 python 脚本调用 pdftotext 不起作用

    我编写了一个小的 python 脚本来解析 提取 PDF 中的信息 我在本地机器上测试了它 我有 python 2 6 2 和 pdftotext 版本 0 12 4 我正在尝试在我的虚拟主机服务器 dreamhost 上运行它 它有 py
  • 样本()和r样本()有什么区别?

    当我从 PyTorch 中的发行版中采样时 两者sample and rsample似乎给出了类似的结果 import torch seaborn as sns x torch distributions Normal torch tens
  • 一个类似 dict 的 Python 类

    我想编写一个自定义类 其行为类似于dict 所以 我继承自dict 不过 我的问题是 我是否需要创建一个私有的dict我的成员 init 方法 我不明白这个有什么意义 因为我已经有了dict如果我只是继承自的行为dict 谁能指出为什么大多
  • 错误:尝试使用 scrappy 登录时出现 raise ValueError("No element found in %s" % response)

    问题描述 我想从我大学的bbs上抓取一些信息 这是地址 http bbs byr cn http bbs byr cn下面是我的蜘蛛的代码 from lxml import etree import scrapy try from scra
  • 将 Python Selenium 输出写入 Excel

    我编写了一个脚本来从在线网站上抓取产品信息 目标是将这些信息写入 Excel 文件 由于我的Python知识有限 我只知道如何在Powershell中使用Out file导出 但结果是每个产品的信息都打印在不同的行上 我希望每种产品都有一条
  • 操作错误:尝试在 ubuntu 服务器中写入只读数据库

    我正在使用 FlaskApp 运行mod wsgi and apache2在 Ubuntu 服务器上 我尝试运行烧瓶应用程序localhost成功 然后部署到ubuntu服务器上 但是当我尝试更新数据库时 出现错误 Failed to up
  • 无法将matplotlib安装到pycharm

    我最近开始使用Python速成课程学习Python编程 我陷入困境 因为我无法让 matplotlib 在 pycharm 中工作 我已经安装了pip 我已经通过命令提示符使用 pip 安装了 matplotlib 现在 当我打开 pych
  • 如何使用logging.conf文件使用RotatingFileHandler将所有内容记录到文件中?

    我正在尝试使用RotatingHandler用于 Python 中的日志记录目的 我将备份文件保留为 500 个 这意味着我猜它将创建最多 500 个文件 并且我设置的大小是 2000 字节 不确定建议的大小限制是多少 如果我运行下面的代码
  • 在不同的 GPU 上同时训练多个 keras/tensorflow 模型

    我想在 Jupyter Notebook 中同时在多个 GPU 上训练多个模型 我正在使用 4GPU 的节点上工作 我想将一个 GPU 分配给一个模型并同时训练 4 个不同的模型 现在 我通过 例如 为一台笔记本选择 GPU import
  • 避免“散点/点/蜂群”图中的数据点重叠

    使用绘制点图时matplotlib 我想偏移重叠的数据点以使它们全部可见 例如 如果我有 CategoryA 0 0 3 0 5 CategoryB 5 10 5 5 10 我想要每一个CategoryA 0 数据点并排设置 而不是彼此重叠
  • 为什么我的 PyGame 应用程序根本不运行?

    我有一个简单的 Pygame 程序 usr bin env python import pygame from pygame locals import pygame init win pygame display set mode 400
  • python 日志记录替代方案 [关闭]

    Closed 此问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 蟒蛇记录模块 http docs python org library logging html使用起来
  • PYTHON:从 txt 文件中删除 POS 标签

    我有以下 txt 文件 其中包含 POS 词性 http en wikipedia org wiki Part of speech tagging 每个单词的标签 不用 jj到 说 vb 我 ppss是 bedz愤怒 jj在 在 dt无与伦
  • 防止 Ada DLL 中的名称损坏

    有没有一种简单的方法可以防止在创建 Ada DLL 时 Ada 名称被破坏 这是我的 adb 代码 with Ada Text IO package body testDLL is procedure Print Call is begin

随机推荐

  • 蓝牙简单配对(Simple Pairing)协议及代码流程简述

    DESCRIPTION 在BT2 1及之后版本 蓝牙协议有在传统的密码配对 PIN Code Pairing 之外 新增一种简单配对 Simple Pairing 的方式 这种新的配对方式操作更为简单 安全性也更强 目前市面上大部分蓝牙设备
  • 《算法笔记》01

    1 比较交换3个实数值 并按序输出 从键盘输入3个实数a b c 通过比较交换 将最小值存储在变量a中 最大值存储在变量c中 中间值存储在变量b中 并按照从小到大的顺序输出这三个数a b c 末尾输出换行 include
  • kafka客户端连接测试

    客户端代码 package main import fmt github com Shopify sarama kafka 示例代码 func main 配置 config sarama NewConfig 等待服务器所有副本都保存成功后的
  • 【Qt多线程之线程的等待和唤醒】

    QWatiCondition的成员函数 QWaitCondition QWaitCondition bool wait QMutex mutex unsigned long time ULONG MAX void wakeOne void
  • 哪些你朝思暮想的动漫网站-搜嗖工具箱

    AcFun是国内首家弹幕视频网站 这里有全网独家动漫新番 友好的弹幕氛围 有趣的UP主 好玩有科技感的虚拟偶像 年轻人都在用www acfun cn 哔哩哔哩是国内知名的视频弹幕网站 这里有及时的动漫新番 活跃的ACG氛围 有创意的Up主
  • docker pull 设置代理

    简介 你在终端设置代理的时候docker pull的时候是不会走代理的 下面是docker pull设置代理的正确方式 操作 环境是在centos下 如果没有新建下面这个文件夹 sudo mkdir p etc systemd system
  • 毕业设计-基于大数据的新闻推荐系统-python

    目录 前言 课题背景和意义 实现技术思路 实现效果图样例 前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实习为毕业后面临的就业升学做准备 一边要为毕业设计耗费大量精力 近几年各个学校要求的毕设项目越来越难 有不少课题是研究生级别难度
  • 组件的生命周期

    一 组件的生命周期 1 组件的生命周期 至一个组件从 创建 gt 运行 gt 销毁的过程 2 声明周期函数 由Vue提供的内置函数 伴随组件生命周期按次序自动运行 gt 钩子函数 3 生命周期的阶段划分 1 创建阶段 beforeCreat
  • Flutter从入门到放弃之坑的神奇之处?

    坑一 关于环境变量的配置 这里要注意几点 不然你将会在这里卡死 这里只说Mac OS环境变量的配置 因为我是Mac 首先 command shift 打开隐藏文件 如果你是用的是自带的终端 请在这个文件中配置 如果你使用的是zsh请在 这个
  • Pytorch nn.Module的基本使用

    文章目录 nn Module的基本用法 nn Module的其他常用方法 参考资料 nn Module的基本用法 nn Module是所有神经网络的基类 所以你的神经网络类也应该要继承这个基类 当使用时 主要需要实现其两个方法 init 初
  • SPSS问卷数据处理步骤

    SPSS问卷数据处理步骤 一 准备 界面与数据准备工作 1 先处理显示界面问题 改成中文输出 优化操作过程 编辑 选项 2 数据字典 定义变量属性 几个代表性的 复制数据属性 数据 定义变量属性 设好以后 数据 复制数据属性 把几个代表性的
  • SQL 左连接 右表多条数据处理方式

    前言 多表连接经常使用 而一对多的数据处理比较麻烦 用于记录 便于以后使用 正文 SQL server语句 select a t from table1 a left join select from select b ROW NUMBER
  • STL模板库 常用函数 vector向量容器

    STL模板库 STL是Standard Template Library缩写 中文名字叫标准模板库 由惠普实验室提供 共有三类内容 算法 以函数模板形式实现的常用算法 如 max min swap find sort 容器 以类模板形式实现
  • 做一个 加减运算 利用JavaScript

    首先做个运算需要用到三个文本框 显示数字 这里我用input做了3个框 并且赋值他们的属性值 id 并且做了一个button按钮来调动 接着调用button这个函数 接着我们需要获取第一个和第第二个input的值 为什么用 parseInt
  • 【Linux专题_05】Linux统计行数命令

    Linux统计行数几种常用命令 wc l 这是最常用的命令 用于统计文件中的行数 它会输出文件的行数以及文件名 示例 wc l filename txt nl 该命令会给文件中的每一行添加行号 并将结果输出到标准输出 通过查看行号的最后一个
  • LeetCode 220. 存在重复元素 III

    题目链接 点击这里 class Solution public boolean containsNearbyAlmostDuplicate int nums int k int t TreeSet
  • Android Studio解除全局搜索100条限制

    1 点击Help gt Find Action选项 2 输入Registry 并选中进入 3 将ide usages page size的value设置为自己想要的数值即可
  • 修改块的方法+AcGeMatrix3d+AcGeScale3d+两点之间的距离

    开发过程中 当从外部获取了一个 需要修改块中的实体时 有以下几种方法 1 第一个通过explode函数炸开AcDbBlockReference 获取块参照中的实体对象 然后遍历对象 找到修改的实体 完成修改后将所有的实体插入到模型空间 注意
  • 第四章CSS基础

    文章目录 学习CSS的目的 引入的三种方式 内部样式表 行内样式表 外部样式表 选择器的分类 基础选择器 标签选择器 类选择器 id选择器 通配符选择器 复合选择器 后代选择器 子选择器 并集选择器 伪类选择器 盒子模型 不同浏览器下盒子模
  • 深度学习 从零开始 —— 神经网络(四),二分类问题,IMDB数据集使用

    IMDB数据集 互联网电影数据 包含50000条严重两极分化的评论 正面和负面评论各占50 而该数据集也同样被内置于Keras库中了 其中的评论数据已经经过了预处理 评论 单词 被转化为了整数序列 每个整数都对应词典里面的一个单词 加载数据