《Keras深度学习:入门、实战与进阶》CIFAR-10图像识别

2023-11-17

本文摘自《Keras深度学习:入门、实战与进阶》。
https://item.jd.com/10038325202263.html
在这里插入图片描述
这个数据集由Alex Krizhevsky、Vinod Nair和Geoffrey Hinton收集整理,共包含了60000张32×32的彩色图像,50000张用于训练模型、10000张用于评估模型。可以从其主页(http://www.cs.toronto.edu/~kriz/cifar.html)下载。共有10个类别,它们是:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车。每个分类有6000个图像。
在这里插入图片描述

1、加载CIFAR-10数据

Keras提供了dataset_cifar10()函数用于下载或读取CIFAR-10数据。第一次运行dataset_cifar10()时,程序会检查是否有cifar-10-batches-py.tar.gz文件,如果还没有,就会下载文件,并且解压下载的文件。第一次运行因为需要下载文件,所以运行时间可能会比较长,之后就可以直接从本地加载数据,用于神经网络模型的训练。
如果是Windows环境,文件将存放在C:\Users\用户名\Documents.keras\datasets中。我们来查看解压后的cifar-10-batches-py目录下的内容。

# 查看cifar-10目录下的文件
> file <- 'C:/Users/Daniel/Documents/.keras/datasets/cifar-10-batches-py'
> list.files(file) # 查看目录下文件
[1] "batches.meta" "data_batch_1" "data_batch_2" "data_batch_3" "data_batch_4"
[6] "data_batch_5" "readme.html"  "test_batch"

CIFAR-10数据集分为训练集和测试集两部分。训练集构成了5个训练批次(data_batch_1、data_batch_2、data_batch_3、data_batch_4、data_batch_5),每一批次10000张图。另外用于测试的10000张图单独构成一批(test_batch)。注意一个训练批次中的各类图像数量并不一定相同,总的训练样本包含来自每一类的5000张图。数据导入时,会直接被分割成训练集和测试集两部分,训练和测试数据又由图像数据和标签所组成。

> library(keras)
> c(c(x_train,y_train),c(x_test,y_test)) %<-% dataset_cifar10()
> # 查看数据维度
> dim(x_train);dim(x_test)
[1] 50000    32    32     3
[1] 10000    32    32     3
> dim(y_train);dim(y_test)
[1] 50000     1
[1] 10000     1

train训练数据集有50000项,test测试数据集10000项。x_train和x_test是四维数组,第一维是样本数,第二、三维是指图像大小为32×32,第四维是RGB三原色,所以是3。y_train和y_test是矩阵(二维数组),第一维是样本数,第二维是图像数据的实际真实值。每一个数字代表一种图像类别的名称:0:飞机(airplane)airplane、1:汽车(automobile)automobile、2:鸟(bird)bird、3:猫(cat)、43:鹿(deer)deer、4:dog、5:狗(dog)、6:青蛙(frog)frog、7:马(horse)horse、8:船(ship)ship、9:卡车(truck)truck。
运行以下程序代码,绘制train数据集中前10张图像

> # 绘制前10张图像
> label_dict <- data.frame('label' = 0:9,
+                          'name' = c("airplane","automobile","bird","cat","deer",
+                                     "dog","frog","horse","ship","truck"))
> 
> par(mfrow=c(2,5))
> for(i in 1:10){
+   plot(as.raster(x_train[i,,,],max=255))
+   title(main = paste0(i-1,",",
+                       label_dict[label_dict$label==y_train[i],2]))
+ }
> par(mfrow=c(1,1))

在这里插入图片描述

2、CIFAR-10数据预处理

为了将数据送入卷积神经网络模型进行训练与预测,必须进行数据的预处理。前面的维度分析可知,x_train和x_test的图像数据已经是四维数组,符合卷积神经网络模型的维度要求。

> x_train <- x_train / 255
> x_test <- x_test / 255
> min(x_train);max(x_train)
[1] 0
[1] 1
> min(x_test);max(x_test)
[1] 0
[1] 1

对于CIFAR-10数据集,我们希望预测图像的类型,例如“船”图像的label是8,经过独热编码(One-Hot Encoding)转换为0000000010,10个数字正好对应输出层10个神经元。可以利用to_categorical()函数进行转换。

> y_train_onehot <- to_categorical(y_train,num_classes = 10)
> y_test_onehot <- to_categorical(y_test, num_classes = 10)
> dim(y_train_onehot)
[1] 50000    10
> dim(y_test_onehot)
[1] 10000    10

3、构建简单卷积神经网络识别CIFAR-10图像

首先构建一个简单的卷积神经网络,来验证卷积神经网络在这个数据集上的性能,并以此为基础对网络进行优化,逐步提高模型的准确度。
这个简单的卷积神经网络具有两个卷积层、一个最大值池化层、一个Flatten层和一个全连接层,网络拓扑结构如下:
卷积层,具有32个特征图,卷积核大小为3×3,激活函数为Relu。
Dropout概率为20%的Dropout层。
卷积层,具有32个特征图,卷积核大小为3×3,激活函数为Relu。
Dropout概率为20%的Dropout层。
采样因子(pool_size)为2×2的最大值池化层。
Flatten层。
具有512个神经元和ReLU激活函数的全连接层。
Dropout概率为50%的Dropout层。
具有10个神经元的输出层,激活函数为softmax。
编译模型时,采用RMSProp优化器,categorical_crossentropy作为损失函数,同时采用准确率(accuracy)来评估模型的性能。
构建模型build_simple_cnn ()程序代码如下。

> build_simple_cnn <- function(X=trainx) {
+   model <- keras_model_sequential() %>%
+     layer_conv_2d(filters = 32, 
+                   kernel_size = c(3,3),
+                   activation = 'relu',
+                   input_shape = dim(X)[-1]) %>%
+     layer_dropout(rate = 0.2) %>%
+     layer_conv_2d(filters = 32, 
+                   kernel_size = c(3,3),
+                   activation = 'relu') %>%
+     layer_dropout(rate = 0.2) %>%
+     layer_max_pooling_2d(pool_size = c(2,2)) %>%
+     layer_flatten() %>%
+     layer_dense(units = 512, activation = 'relu') %>% 
+     layer_dropout(rate = 0.5) %>%
+     layer_dense(units = 10, activation = 'softmax')
+   # Compile
+   model %>% compile(
+     loss = 'categorical_crossentropy',
+     optimizer = optimizer_rmsprop(),
+     metrics = 'accuracy')
+   model
+ }

模型构建后,使用fit()函数进行模型训练。将训练周期参数epochs设置为25,batch_size参数为256,validation_split参数为0.2,说明从训练样本中抽取20%作为验证集。`

> simple_cnn_model <- build_simple_cnn(x_train)
> history <- simple_cnn_model %>%
+   fit(x_train,
+       y_train_onehot,
+       epochs = 25,
+       batch_size = 256,
+       validation_split = 0.2)
> plot(history)

在这里插入图片描述
经过30个训练周期后,训练集的准确率为93%,验证集的准确率为70%,出现过拟合现象。可使用当监测值不再改善时将终止训练的callback_early_stopping()回调函数来监控模型,防止出现过拟合现象。
利用训练好的简单卷积神经网络模型对测试进行预测,并查看混淆矩阵。

> pred <- simple_cnn_model %>% predict_classes(x_test)
> t <- table(Actual = y_test,Predicted = pred)
> t
      Predicted
Actual   0   1   2   3   4   5   6   7   8   9
     0 788  24  28  18  27   1  12  10  55  37
     1  23 817   5  13   3   2   8   4  27  98
     2 100  11 470  85 117  68  61  47  22  19
     3  37  18  46 525 102 137  42  38  22  33
     4  34   4  33  63 691  33  47  69  13  13
     5  23   9  38 226  63 535  18  57  14  17
     6  14  12  28  73  57  25 755   8  11  17
     7  27   4  25  43  78  40   4 739   8  32
     8  77  46   7  19   5   2   4   6 795  39
     9  44  98   7  14   4   1   6  13  24 789

模型对汽车(1:automobile)的预测能力最好,有817个样本被正确预测,准确率超过81%;其次是船(8:ship),有795个样本被正确预测。
最后,让我们绘制实际是鸟,但预测错误的50张图像

> ind <- which(as.vector(y_test)==2 & pred != 2) # 提取实际为2,但预测不为2的下标集
> # 绘制预测错误的图像
> par(mfrow=c(5,10)) 
> for(i in 1:50){
+   plot(as.raster(x_test[ind[i],,,]))
+   title(main = paste0(label_dict[label_dict$label==y_test[ind[i]],2],">>",
+                   label_dict[label_dict$label==pred[ind[i]],2]))
+ 
+ }
> par(mfrow=c(1,1))

在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

《Keras深度学习:入门、实战与进阶》CIFAR-10图像识别 的相关文章

  • 保存为 HDF5 的图像未着色

    我目前正在开发一个将文本文件和 jpg 图像转换为 HDF5 格式的程序 用HDFView 3 0打开 似乎图像仅以灰度保存 hdf h5py File Sample h5 img Image open Image jpg data np
  • 为什么从 Pandas 1.0 中删除了日期时间?

    我在 pandas 中处理大量数据分析并每天使用 pandas datetime 最近我收到警告 FutureWarning pandas datetime 类已弃用 并将在未来版本中从 pandas 中删除 改为从 datetime 模块
  • 如何使用 opencv.omnidir 模块对鱼眼图像进行去扭曲

    我正在尝试使用全向模块 http docs opencv org trunk db dd2 namespacecv 1 1omnidir html用于对鱼眼图像进行扭曲处理Python 我正在尝试适应这一点C 教程 http docs op
  • 安装了 32 位的 Python,显示为 64 位

    我需要运行 32 位版本的 Python 我认为这就是我在我的机器上运行的 因为这是我下载的安装程序 当我重新运行安装程序时 它会将当前安装的 Python 版本称为 Python 3 5 32 位 然而当我跑步时platform arch
  • 使用 Python 从文本中删除非英语单词

    我正在 python 上进行数据清理练习 我正在清理的文本包含我想删除的意大利语单词 我一直在网上搜索是否可以使用像 nltk 这样的工具包在 Python 上执行此操作 例如给出一些文本 Io andiamo to the beach w
  • 跟踪 pypi 依赖项 - 谁在使用我的包

    无论如何 是否可以通过 pip 或 PyPi 来识别哪些项目 在 Pypi 上发布 可能正在使用我的包 也在 PyPi 上发布 我想确定每个包的用户群以及可能尝试积极与他们互动 预先感谢您的任何答案 即使我想做的事情是不可能的 这实际上是不
  • 删除flask中的一对一关系

    我目前正在使用 Flask 开发一个应用程序 并且在删除一对一关系中的项目时遇到了一个大问题 我的模型中有以下结构 class User db Model tablename user user id db Column db String
  • Pandas 日期时间格式

    是否可以用零后缀表示 pd to datetime 似乎零被删除了 print pd to datetime 2000 07 26 14 21 00 00000 format Y m d H M S f 结果是 2000 07 26 14
  • 将 python2.7 与 Emacs 24.3 和 python-mode.el 一起使用

    我是 Emacs 新手 我正在尝试设置我的 python 环境 到目前为止 我已经了解到在 python 缓冲区中使用 python mode el C c C c将当前缓冲区的内容加载到交互式 python shell 中 显然使用了什么
  • 如何将张量流模型部署到azure ml工作台

    我在用Azure ML Workbench执行二元分类 到目前为止 一切正常 我有很好的准确性 我想将模型部署为用于推理的 Web 服务 我真的不知道从哪里开始 azure 提供了这个doc https learn microsoft co
  • Python 2:SMTPServerDisconnected:连接意外关闭

    我在用 Python 发送电子邮件时遇到一个小问题 me my email address you recipient s email address me email protected cdn cgi l email protectio
  • Python beautifulsoup 仅限 1 级文本

    我看过其他 beautifulsoup 得到相同级别类型的问题 看来我的有点不同 这是网站 我正试图拿到右边那张桌子 请注意表的第一行如何展开为该数据的详细细分 我不想要那个数据 我只想要最顶层的数据 您还可以看到其他行也可以展开 但在本例
  • 从Python中的字典列表中查找特定值

    我的字典列表中有以下数据 data I versicolor 0 Sepal Length 7 9 I setosa 0 I virginica 1 I versicolor 0 I setosa 1 I virginica 0 Sepal
  • 如何在不丢失注释和格式的情况下更新 YAML 文件 / Python 中的 YAML 自动重构

    我想在 Python 中更新 YAML 文件值 而不丢失 Python 中的格式和注释 例如我想改造 YAML 文件 value 456 nice value to value 6 nice value 界面类似于 y yaml load
  • pyspark 将 twitter json 流式传输到 DF

    我正在从事集成工作spark streaming with twitter using pythonAPI 我看到的大多数示例或代码片段和博客是他们从Twitter JSON文件进行最终处理 但根据我的用例 我需要所有字段twitter J
  • Numpy - 根据表示一维的坐标向量的条件替换数组中的值

    我有一个data多维数组 最后一个是距离 另一方面 我有距离向量r 例如 Data np ones 20 30 100 r np linspace 10 50 100 最后 我还有一个临界距离值列表 称为r0 使得 r0 shape Dat
  • pip 列出活动 virtualenv 中的全局包

    将 pip 从 1 4 x 升级到 1 5 后pip freeze输出我的全局安装 系统 软件包的列表 而不是我的 virtualenv 中安装的软件包的列表 我尝试再次降级到 1 4 但这并不能解决我的问题 这有点类似于这个问题 http
  • 使用特定颜色和抖动在箱形图上绘制数据点

    我有一个plotly graph objects Box图 我显示了箱形 图中的所有点 我需要根据数据的属性为标记着色 如下所示 我还想抖动这些点 下面未显示 Using Box我可以绘制点并抖动它们 但我不认为我可以给它们着色 fig a
  • 在本地网络上运行 Bokeh 服务器

    我有一个简单的 Bokeh 应用程序 名为app py如下 contents of app py from bokeh client import push session from bokeh embed import server do
  • 使用 z = f(x, y) 形式的 B 样条方法来拟合 z = f(x)

    作为一个潜在的解决方案这个问题 https stackoverflow com questions 76476327 how to avoid creating many binary switching variables in gekk

随机推荐

  • Linux中创建文件与文件夹

    一 创建文件夹 命令 mkdir 文件夹名 例 一开始home目录下没有test文件夹 命令创建后生成 二 创建文件 命令 touch 文件名 例 一开始test文件夹下没有boot properties 命令创建后生成 三 注意事项 创建
  • linux下挂载移动硬盘(ntfs格式),Linux下挂载移动硬盘(NTFS格式)

    工作中遇到linux系统 Red Hat Enterprise5 7 挂载希捷ntfs格式移动硬盘 会跳出一个ERROR提示框 The volume EAGET NQH user the ntfs file system which is
  • arcpy导入报错 “ImportRrror: No module named arcpy”

    在使用ArcGIS自带的Python IDLE处理数据的时候 导入arcpy报错 ImportError No module named arcpy 我遍历了各解决方法依然无法成功导入arcpy 后经过查询 探索 通过如下方法得以成功解决
  • aoj1303

    继续python系列 python能够自动推断类型这个太好用了 根本不用声明类型 自己根据运行情况推断出所用的类型 所以在定义函数的时候根本不用声明参数的类型 下面这个题目aoj1303 求2的指数 如下 def gethex a li w
  • 关于飞书的告警通知,这里有个更好的办法

    飞书 是字节跳动于2016年自研的新一代一站式协作平台 是保障字节跳动全球五万人高效协作的办公工具 飞书将即时沟通 日历 云文档 云盘和工作台深度整合 通过开放兼容的平台 让成员在一处即可实现高效的沟通和流畅的协作 全方位提升企业效率 20
  • Git 使用

    Git 一 Git基础 1 Git介绍 Git是目前世界上最先进的分布式版本控制系统 2 Git与Github 2 1 两者区别 Git是一个分布式版本控制系统 简单的说其就是一个软件 用于记录一个或若干文件内容变化 以便将来查阅特定版本修
  • 模板类、模板函数的模板类型显式实例化及其用途(转载)

    转载自 C 11模板隐式实例化 显式实例化声明 定义 简单易懂 云飞扬 Dylan的博客 CSDN博客 模板隐式实例化 1 隐式实例化 在代码中实际使用模板类构造对象或者调用模板函数时 编译器会根据调用者传给模板的实参进行模板类型推导然后对
  • 【LAMMPS系列】LAMMPS软件安装资料包

    大家好 我是粥粥 LAMMPS 是一种经典的分子动力学代码 专注于材料建模 它是大型原子 分子大规模并行模拟器的首字母缩略词 LAMMPS 具有固态材料 金属 半导体 和软物质 生物分子 聚合物 以及粗粒或中等系统的势函数 它可用于模拟原子
  • 自定义多数据源JDBC连接池

    背景 公司需要对各个客户的数据库进行统一管理 故涉及到对多个不同数据库进行连接 传统的数据库连接池无法满足需求 故结合网上的自定义数据库连接池 进行的改进 代码如下 注意 由于代码处于公司环境 有直接使用肯定是会有报错 相信这种简单的修补是
  • android Stopwatch实例

    Stopwatch 实例 package net baisoft stopwatch import java util ArrayList import java util Date import java util HashMap imp
  • electron vue 打开新窗口

    1 主进程 background js文件 const winURL process env NODE ENV development http localhost 8080 file dirname index html 事件名 open
  • 网页设计期末大作业-景点旅游网站(含导航栏,轮播图,样式精美)

    景点旅游网站 资源链接在文末 页设计期末结课的作业 样式很精美 链接基本正常 详细情况入下图所示 资源下载链接 https download csdn net download weixin 43474701 85514120
  • AIX显示版本的最高全包含版本原则

    复杂度2 5 机密度4 5 最后更新2021 05 02 专题其它章节说过AIX对所有程序包管理会检验完整性 并且内置了一个验证列表 包含其所能识别的最新版应当包含的各个程序包的版本 如果当前安装的TL Patch不完整 则只会显示可以实现
  • CSS transform属性的简单应用——双开门动画效果

    1 效果演示 CSS transform属性有许多效果 平移 旋转 缩放等 这里简单应用平移效果 实现双开门动画 以下为效果图 2 设计思路 设置一张居中的需要隐藏的底图 设置封面图 平分成左右两部分 鼠标悬浮在封面图上 触发 开门 效果
  • 在C/C++代码中使用SSE等指令集的指令(4)SSE指令集Intrinsic函数使用

    在http blog csdn net gengshenghong article details 7008682里面列举了一些手册 其中Intel Intrinsic Guide可以查询到所有的Intrinsic函数 对应的汇编指令以及如
  • centos7的安装和创建用户

    1 centos7 2的安装 打开安装包之后解压 然后双击 进入下面的界面 选择语言 点击下一步 2 然后来到了配置页面 可以配置时间 选择中国的时区 3 其他的选择默认就好 重要的是选择安装类型和磁盘分区 4 选择安装类型 一般默认是mi
  • npm开发微信小程序--使用vantui 详解干货

    更新微信开发者工具创建项目 1 创建项目 放在一个合适的文件夹中 没有APPID时 请点击测试号 或去注册一个 2 进入项目的根目录 npm init 一路回车 要先npm init 初始化项目 否则会报错 官方文档中没有提到的东东 里面有
  • 爬虫实战——58同城租房数据爬取

    背景 自己本人在暑期时自学了python 还在中国大学mooc上学习了一些爬虫相关的知识 对requests库 re库以及BeautifulSoup库有了一定的了解 但是没有过爬虫方面的实战 刚好家人有这方面需求 就对58同城上的租房数据进
  • 简单工厂模式

    提示 文章写完后 目录可以自动生成 如何生成可参考右边的帮助文档 文章目录 前言 一 创建头文件 二 创建 c文件 1 cat c 2 dog c 3 person c 三 创建main c 四 运行结果 总结 前言 工厂模式 常用的设计模
  • 《Keras深度学习:入门、实战与进阶》CIFAR-10图像识别

    本文摘自 Keras深度学习 入门 实战与进阶 https item jd com 10038325202263 html 这个数据集由Alex Krizhevsky Vinod Nair和Geoffrey Hinton收集整理 共包含了6