了解VGG网络结构特点,利用VGG完成图像分类

2023-11-05

学习目标

  • 知道VGG网络结构的特点
  • 能够利用VGG完成图像分类

2014年,牛津大学计算机视觉组(Visual Geometry Group)和Google DeepMind公司的研究员一起研发出了新的深度卷积神经网络:VGGNet,并取得了ILSVRC2014比赛分类项目的第二名,主要贡献是使用很小的卷积核(3×3)构建卷积神经网络结构,能够取得较好的识别精度,常用来提取图像特征的VGG-16和VGG-19

1.VGG的网络架构

VGG可以看成是加深版的AlexNet,整个网络由卷积层和全连接层叠加而成,和AlexNet不同的是,VGG中使用的都是小尺寸的卷积核(3×3),其网络架构如下图所示:

VGGNet使用的全部都是3x3的小卷积核和2x2的池化核,通过不断加深网络来提升性能。VGG可以通过重复使用简单的基础块来构建深度模型

在tf.keras中实现VGG模型,首先来实现VGG块,它的组成规律是:连续使用多个相同的填充为1、卷积核大小为3\times 3的卷积层后接上一个步幅为2、窗口形状为2\times 2的最大池化层。卷积层保持输入的高和宽不变,而池化层则对其减半。我们使用vgg_block函数来实现这个基础的VGG块,它可以指定卷积层的数量num_convs和每层的卷积核个数num_filters:

# 定义VGG网络中的卷积块:卷积层的个数,卷积层中卷积核的个数
def vgg_block(num_convs, num_filters):
    # 构建序列模型
    blk = tf.keras.models.Sequential()
    # 遍历所有的卷积层
    for _ in range(num_convs):
        # 每个卷积层:num_filter个卷积核,卷积核大小为3*3,padding是same,激活函数是relu
        blk.add(tf.keras.layers.Conv2D(num_filters,kernel_size=3,
                                    padding='same',activation='relu'))
    # 卷积块最后是一个最大池化,窗口大小为2*2,步长为2
    blk.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2))
    return blk

VGG16网络有5个卷积块,前2块使用两个卷积层,而后3块使用三个卷积层。第一块的输出通道是64,之后每次对输出通道数翻倍,直到变为512。

# 定义5个卷积块,指明每个卷积块中的卷积层个数及相应的卷积核个数
conv_arch = ((2, 64), (2, 128), (3, 256), (3, 512), (3, 512))

因为这个网络使用了13个卷积层和3个全连接层,所以经常被称为VGG-16,通过制定conv_arch得到模型架构后构建VGG16:

# 定义VGG网络
def vgg(conv_arch):
    # 构建序列模型
    net = tf.keras.models.Sequential()
    # 根据conv_arch生成卷积部分
    for (num_convs, num_filters) in conv_arch:
        net.add(vgg_block(num_convs, num_filters))
    # 卷积块序列后添加全连接层
    net.add(tf.keras.models.Sequential([
        # 将特征图展成一维向量
        tf.keras.layers.Flatten(),
        # 全连接层:4096个神经元,激活函数是relu
        tf.keras.layers.Dense(4096, activation='relu'),
        # 随机失活
        tf.keras.layers.Dropout(0.5),
        # 全连接层:4096个神经元,激活函数是relu
        tf.keras.layers.Dense(4096, activation='relu'),
        # 随机失活
        tf.keras.layers.Dropout(0.5),
        # 全连接层:10个神经元,激活函数是softmax
        tf.keras.layers.Dense(10, activation='softmax')]))
    return net
# 网络实例化
net = vgg(conv_arch)

我们构造一个高和宽均为224的单通道数据样本来看一下模型的架构:

# 构造输入X,并将其送入到net网络中
X = tf.random.uniform((1,224,224,1))
y = net(X)
# 通过net.summay()查看网络的形状
net.summay()

网络架构如下:

Model: "sequential_15"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_16 (Sequential)   (1, 112, 112, 64)         37568     
_________________________________________________________________
sequential_17 (Sequential)   (1, 56, 56, 128)          221440    
_________________________________________________________________
sequential_18 (Sequential)   (1, 28, 28, 256)          1475328   
_________________________________________________________________
sequential_19 (Sequential)   (1, 14, 14, 512)          5899776   
_________________________________________________________________
sequential_20 (Sequential)   (1, 7, 7, 512)            7079424   
_________________________________________________________________
sequential_21 (Sequential)   (1, 10)                   119586826 
=================================================================
Total params: 134,300,362
Trainable params: 134,300,362
Non-trainable params: 0
__________________________________________________________________

2.手写数字势识别

因为ImageNet数据集较大训练时间较长,我们仍用前面的MNIST数据集来演示VGGNet。读取数据的时将图像高和宽扩大到VggNet使用的图像高和宽224。这个通过tf.image.resize_with_pad来实现。

2.1 数据读取

首先获取数据,并进行维度调整:

import numpy as np
# 获取手写数字数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 训练集数据维度的调整:N H W C
train_images = np.reshape(train_images,(train_images.shape[0],train_images.shape[1],train_images.shape[2],1))
# 测试集数据维度的调整:N H W C
test_images = np.reshape(test_images,(test_images.shape[0],test_images.shape[1],test_images.shape[2],1))

由于使用全部数据训练时间较长,我们定义两个方法获取部分数据,并将图像调整为224*224大小,进行模型训练:

# 定义两个方法随机抽取部分样本演示
# 获取训练集数据
def get_train(size):
    # 随机生成要抽样的样本的索引
    index = np.random.randint(0, np.shape(train_images)[0], size)
    # 将这些数据resize成22*227大小
    resized_images = tf.image.resize_with_pad(train_images[index],224,224,)
    # 返回抽取的
    return resized_images.numpy(), train_labels[index]
# 获取测试集数据 
def get_test(size):
    # 随机生成要抽样的样本的索引
    index = np.random.randint(0, np.shape(test_images)[0], size)
    # 将这些数据resize成224*224大小
    resized_images = tf.image.resize_with_pad(test_images[index],224,224,)
    # 返回抽样的测试样本
    return resized_images.numpy(), test_labels[index]

调用上述两个方法,获取参与模型训练和测试的数据集:

# 获取训练样本和测试样本
train_images,train_labels = get_train(256)
test_images,test_labels = get_test(128)

为了让大家更好的理解,我们将数据展示出来:

# 数据展示:将数据集的前九个数据集进行展示
for i in range(9):
    plt.subplot(3,3,i+1)
    # 以灰度图显示,不进行插值
    plt.imshow(train_images[i].astype(np.int8).squeeze(), cmap='gray', interpolation='none')
    # 设置图片的标题:对应的类别
    plt.title("数字{}".format(train_labels[i]))

结果为:

我们就使用上述创建的模型进行训练和评估。

2.2 模型编译

# 指定优化器,损失函数和评价指标
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.0)

net.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

2.3 模型训练

# 模型训练:指定训练数据,batchsize,epoch,验证集
net.fit(train_images,train_labels,batch_size=128,epochs=3,verbose=1,validation_split=0.1)

训练输出为:

Epoch 1/3
2/2 [==============================] - 34s 17s/step - loss: 2.6026 - accuracy: 0.0957 - val_loss: 2.2982 - val_accuracy: 0.0385
Epoch 2/3
2/2 [==============================] - 27s 14s/step - loss: 2.2604 - accuracy: 0.1087 - val_loss: 2.4905 - val_accuracy: 0.1923
Epoch 3/3
2/2 [==============================] - 29s 14s/step - loss: 2.3650 - accuracy: 0.1000 - val_loss: 2.2994 - val_accuracy: 0.1538

2.4 模型评估

# 指定测试数据
net.evaluate(test_images,test_labels,verbose=1)

输出为:

4/4 [==============================] - 5s 1s/step - loss: 2.2955 - accuracy: 0.1016
[2.2955007553100586, 0.1015625]

如果我们使用整个数据集训练网络,并进行评估的结果:

[0.31822608125209806, 0.8855]

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

了解VGG网络结构特点,利用VGG完成图像分类 的相关文章

  • 总结的快速排序

    很多时候对快速排序的具体格式记得都不太清楚 在网上搜了一下 加上自己的理解就摆到了这里 先声明一下 头文件至少要包括以下几个 最好都写上 写上不扣分 include
  • 2023年超实用的27个VSCode插件推荐

    Visual Studio Code 或者称作VS Code 是一个广为人知且评价很高的代码编辑器 它有许多特性和扩展功能 以增强开发体验 使用VS Code的主要好处之一是它的灵活性 允许开发人员根据自己的特定需求进行自定义 此外 VS
  • H2介绍 – Java嵌入式数据库

    H2是一个用Java开发的嵌入式数据库 这里指的嵌入式不是手持设备之类的 而是H2数据库作为一个类库 直接嵌入到上层的应用程序中 与应用运行在同一个进程中 最大的优势在于可以同应用程序打包在一起发布 对于客户端应用来说 非常方便 比如说腾讯
  • 计算机中CPU的运行到函数的调用过程

    以下内容是摘抄博客 https www cnblogs com liunlls p cpu html CPU的内部结构 我们都知道CPU是一台电脑的核心部件 所有的程序都是通过它运行的 那么CPU是如何让一个程序跑起来的呢 我们今天就来一起
  • 测多少数据量?几个G?多少reads?如何换算?

    关键词 lncRNA表达量低 所以要看lncRNA的表达量变化 就要比普通RNA seq多测一些 要兼顾SNP和低表达量的lncRNA 要测得更深一些 到底需要测多少数据量呢 我们看看权威的ENCODE对RNA seq的测序深度是如何评价的
  • vue自学笔记(1)

    环境配置 vue官网 编写一个helloworld程序 我使用的是vscode 你可以使用官方网站推荐的hbuilderx 项目结构 导入vue js html中的代码 h1 hello world h1 hr div message di

随机推荐

  • SQLi LABS Less-19

    第19关使用POST请求提交参数 后端对用户名和密码进行了特殊字符转译 难度较大源码如下 但后面插入HTTP Referer时 并没有对参数进行过滤 我们可以从Referer入手 首先 输入正确的账号和密码 只有账号和密码都正确 才能操作R
  • MySQL8数据库原理与应用(微课版)课后笔记-实训7

    最近学习笔记记录 仅供学习参考 在完成课后实训7前所需的建表语句如下 CREATE TABLE bmdmb bmh varchar 10 NOT NULL COMMENT 部门号 bmmc char 50 NOT NULL COMMENT
  • 理解ConvNeXt网络(结合代码)

    目录 1 简介 2 ConvNeXt的设计与实验 2 1 macro design 大的结构上的设计 2 1 1 Changing stage compute ratio 改变每个stage的堆叠次数 2 1 2 Changing stem
  • 简单制作后台系统页面(含菜单)

    第一步 制作数据库表 我个人喜欢在PowerDesigner先建好数据库模型 然后导入到mysql里 导入方式 在PowerDesigner导航栏点开Datebase选择Datebase Generation再选择最后的preview 然后
  • Day.js 常用用法

    我认为克服恐惧最好的办法理应是 面对内心所恐惧的事情 勇往直前地去做 直到成功为止 罗斯福 Day js 时间戳转换 const nowTime dayjs format console log 获取当前时间 nowTime const n
  • GPT4来了!微软云能否反超亚马逊夺冠,就靠它了

    文 光锥智能 作者 刘雨琦 Azure 微软云 能否反超AWS 亚马逊云 夺冠 就靠ChatGPT了 今天凌晨 GPT4横空出世 支持图像输入和混合输入 多模态大模型的出现 将对算力产生更高的需求 一场由ChatGPT引发的算力革命 即将给
  • TCP的三次握手(一个好男人追女孩的故事)一看必懂系列

    网络世界如情场 女生指代服务端 在网络协议内 和TCP是纯情男的作风 UDP作风则称为 渣男 理由非常的简单 由于UDP的行为就是从来不会和任何女人产生感情 不建立连接 因此追女生的效率 具有高效率的特性 就比TCP作风高的多 从来不付出
  • 通过U盘向服务器拷贝文件

    目录 完整操作流程 检查U盘是否被识别 gt 挂载U盘 gt 拷贝文件 gt 卸载U盘 检查U盘是否被识别 挂载U盘 拷贝文件 卸载U盘 完整操作流程 检查U盘是否被识别 gt 挂载U盘 gt 拷贝文件 gt 卸载U盘 检查U盘是否被识别
  • 数据结构算法设计——深搜DFS(走迷宫)

    一 什么是深搜 深搜就是 深度搜索 也就是 深度优先的搜索 那什么是 深度优先 呢 我们拿最常见的迷宫问题举例 深度优先就是你照着一条路死命的走 有个形象的说法叫 不撞南墙不回头 一直到这条路走不通了 再返回上一步选择其他的方向 在算法中我
  • Java8 Streams用法总结大全 之 Collector用法详解

    1 前言 在 Java8 Streams用法总结大全 之 Stream中的常见操作 中 我们已经学习了Stream中的常用操作 其中也提到了collect 的用法 当时只是通过入参Collectors toList 实现了把Stream转为
  • [SQL]postgreSQL中如何查找无主键的sql语句

    查找postgreSQL数据库中 查找无主键的表 可以通下面语句查找 select from pg tables where hasindexes is false and schemaname public
  • 新编法学概论--吴祖谋

    新编法学概论 吴祖谋 2007 pdf 介绍法学概论的书籍 但是写的太官僚了 什么阶级论 之类的开头 让我读着那样的不理解 能不能有本写的比较通俗易懂的法学概论 这样的书籍 真心的不喜欢看 但是没办法 还是看一看吧 1 宪法 三次完全的更新
  • 什么是代码区、常量区、静态区(全局区)、堆区、栈区?

    前言 和作者有同样的感觉 对代码区 常量区 静态区 全局区 堆区 栈区没有较深刻的认识 通过查找网络找到本篇比较优秀的文章 特此转发 摘自CSDN https blog csdn net u014470361 article details
  • oracle中translate与replace的使用

    1 translate 语法 TRANSLATE char from to 用法 返回将出现在from中的每个字符替换为to中的相应字符以后的字符串 若from比to字符串长 那么在from中比to中多出的字符将会被删除 三个参数中有一个是
  • OpenCV中的图像腐蚀和膨胀操作有哪些?

    在OpenCV中 图像腐蚀 Erosion 和膨胀 Dilation 是常用的图像形态学操作 它们可以用于去除噪声 填充空洞 提取图像中的结构等 下面是几种常见的腐蚀和膨胀操作 腐蚀操作 图像腐蚀可以通过函数cv2 erode 来实现 腐蚀
  • [Linux Audio Driver] 音频POP音问题归纳总结

    1 板级电容 电感发声 情况就是你设备开机之后 啥也没干 然后听到呲啦刺啦的声音 这种情况我遇到过一次 这个是 不合理的结构设计或者走线导致的 硬件实力挖坑 需要改版解决 2 播放声音长时间有杂音 这个锅我们送给硬件 这个是芯片之间有干扰
  • SVN 启动模式

    首先 在服务端进行SVN版本库的相关配置 手动新建版本库目录 mkdir opt svn 利用svn命令创建版本库 svnadmin create opt svn runoob 使用命令svnserve启动服务 svnserve d r 目
  • java tomcat远程调试端口_tomcat开发远程调试端口以及利用eclipse进行远程调试

    一 tomcat开发远程调试端口 方法1 WIN系统 在catalina bat里 SET CATALINA OPTS server Xdebug Xnoagent Djava compiler NONE Xrunjdwp transpor
  • LeetCode2.两数相加

    给你两个 非空 的链表 表示两个非负的整数 它们每位数字都是按照 逆序 的方式存储的 并且每个节点只能存储 一位 数字 请你将两个数相加 并以相同形式返回一个表示和的链表 你可以假设除了数字 0 之外 这两个数都不会以 0 开头 示例 1
  • 了解VGG网络结构特点,利用VGG完成图像分类

    学习目标 知道VGG网络结构的特点 能够利用VGG完成图像分类 2014年 牛津大学计算机视觉组 Visual Geometry Group 和Google DeepMind公司的研究员一起研发出了新的深度卷积神经网络 VGGNet 并取得