VGGNet网络详解与模型搭建

2023-11-20

1 模型介绍

​ VGGNet是由牛津大学视觉几何小组(Visual Geometry Group, VGG)提出的一种深层卷积网络结构,他们以7.32%的错误率赢得了2014年ILSVRC分类任务的亚军(冠军由GoogLeNet以6.65%的错误率夺得)和25.32%的错误率夺得定位任务(Localization)的第一名(GoogLeNet错误率为26.44%),网络名称VGGNet取自该小组名缩写。VGG网络原论文《Very Deep Convolutional Networks For Large-Scale Image Recognition》发表于ICLR-2015,VGGNet所提出的 3 × 3 3\times3 3×3卷积核的思想为后来许多模型所沿用。

2 模型结构

​ 在原论文中,作者尝试了不同深度的配置(11层,13层,16层,19层),是否使用LRN(Local Response Normalization)以及卷积核1x1与卷积核3x3的差异,VGGNet尝试使用了6种不同的模型结构,分别对应VGG11、VGG11-LRN、VGG13、VGG16-1、VGG16-3和VGG19,不同的后缀数值表示不同的网络层数(VGG11-LRN表示在第一层中采用了LRN的VGG11,VGG16-1表示后三组卷积块中最后一层卷积采用卷积核尺寸为 1 × 1 1\times1 1×1,相应的VGG16-3表示卷积核尺寸为 3 × 3 3\times3 3×3)。下表是从原论文中截取的几种VGG模型的配置表,VGGNet网络模型结构非常工整,其卷积层全部都采用了大小为3x3,步距为1,padding为1的卷积操作(即same卷积,经过卷积后不会改变特征矩阵的高和宽);最大池化下采样层全部都是池化核大小为2,步距为2的池化操作,每次通过最大池化下采样后特征矩阵的高和宽都会缩减为原来的一半。

在这里插入图片描述

​ 我们通常使用的VGG模型是表格中的VGG16(D)配置,根据表格中的配置信息以及上文所讲的卷积层和池化层的详细参数,可以搭建如下图所示的feature map大小的变化图。在VGG模型中,卷积操作不会改变feature map的大小,池化操作会使feature map大小减小为原来的一半。
在这里插入图片描述

3 模型特性

(1)通过堆叠多个3x3的卷积核来替代大尺度卷积核

​ 论文中提到,可以通过堆叠两层 3 × 3 3\times 3 3×3的卷积核替代一层 5 × 5 5\times 5 5×5的卷积核,堆叠三层 3 × 3 3\times3 3×3的卷积核替代一层 7 × 7 7\times7 7×7的卷积核。这样的连接方式使得网络参数量更小(见下例),而且多层的激活函数令网络对特征的学习能力更强。

  • 如果使用一层卷积核大小为7的卷积所需参数(第一个C代表输入特征矩阵的channel,第二个C代表卷积核的个数也就是输出特征矩阵的深度): 7 × 7 × C × C = 49 C 2 7\times 7\times C\times C=49C^2 7×7×C×C=49C2

  • 如果使用三层卷积核大小为3的卷积所需参数: 3 × 3 × C × C + 3 × 3 × C × C + 3 × 3 × C × C = 27 C 2 3\times 3\times C\times C + 3\times 3\times C\times C + 3\times 3\times C\times C=27C^2 3×3×C×C+3×3×C×C+3×3×C×C=27C2

    经过对比明显使用3层大小为3x3的卷积核比使用一层7x7的卷积核参数更少

(2)整个网络都使用了同样大小的卷积核尺寸 3 × 3 3\times3 3×3和最大池化尺寸 2 × 2 2\times2 2×2,模型十分工整。

(3)VGGNet在训练时有一个小技巧,先训练浅层的的简单网络VGG11,再复用VGG11的权重来初始化VGG13,如此反复训练并初始化VGG19,能够使训练时收敛的速度更快。

4 Pytorch模型搭建代码

注:由于LRN层对训练结果影响不大,故代码中去除了LRN层

import torch
import torch.nn as nn


class VGG(nn.Module):
    def __init__(self, features, num_classes=1000):
        super().__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )

    def forward(self, inputs):
        x = self.features(inputs)  # [N, 3, 224, 224]  --> [N, 512, 7, 7]
        x = torch.flatten(x, start_dim=1)  # [N, 512, 7, 7]  --> [N, 512 * 7 * 7]
        outputs = self.classifier(x)  # [N, 512 * 7 * 7]  --> [N, num_classes]
        return outputs


# VGGNet的配置文件,数字表示卷积层输出的feature map大小,'M'表示最大池化下采样
cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}

def make_features(cfg: list):
  	"""根据cfgs配置制作vgg的特征提取层"""
    layers = []
    in_channels = 3
    for v in cfg:
        if v == "M":
            maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2)
            layers.append(maxpool2d)
        else:
            conv2d = nn.Conv2d(in_channels=in_channels, out_channels=v, kernel_size=3, padding=1)
            layers.append(conv2d)
            in_channels = v
    return nn.Sequential(*layers)


def vgg(model_name="vgg16", **kwargs):
    assert model_name in cfgs, "Warning: {} not in config dict!".format(model_name)

    cfg = cfgs[model_name]
    model = VGG(features=make_features(cfg), **kwargs)
    return model
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

VGGNet网络详解与模型搭建 的相关文章

随机推荐

  • 2023面试问答-计算机网络

    OSI 的七层模型分别是 各自的功能是什么 简要概括 物理层 底层数据传输 如网线 网卡标准 数据链路层 定义数据的基本格式 如何传输 如何标识 如网卡MAC地址 网络层 定义IP编址 定义路由功能 如不同设备的数据转发 传输层 端到端传输
  • 【ES实战】ES中关于segment的小结

    文章目录 ES中关于segment的小结 ES中segment相关的原理 在Lucene中的产生segment的过程 Lucene commit过程 ES为了实现近实时可查询做了哪些 缩短数据可被搜索的等待时长 增加数据的可靠性 优化seg
  • mysql更新一张表的字段来自另一张表的某个字段

    UPDATE tba a LEFT JOIN tbb b on a id b id set a xxx b xxxx where a id b id
  • 对于opencv摄像头调用与现实方向相反的问题怎么解决?

    可以对原始图像进行水平翻转 使用opencv自带的flip函数 例如 读取图像帧 ret frame cap read 水平翻转图像 frame cv2 flip frame 1 这样就可以了 后面的参数1代表水平翻转图像 而0代表垂直翻转
  • node.js与elasticsearch交互

    参考elasticsearch 以下简称es 官方javascript的API https www elastic co guide en elasticsearch client javascript api 6 x api refere
  • Sqli-Labs靶场(6--10)题详解

    目录 六 Less 6 GET Double Injection Double Quotes string GET 双重注入 双引号 字符串 七 Less 7 GET Dump into outfile string GET 导出文件 字符
  • Altium designer自动布线设置GND或其他网络不布线的方法

    1 在导航栏里面找到设计栏 找到类选项打开2 在Net Classes选项下 右击鼠标 找到添加类选项 会创建一个New Class 3 设置好需要布线的网络 以及不需要布线的网络 如下图 4 找到自动布线菜单栏下的网络类 点击进去如下图
  • Android下自定义的jar库文件编译和调用

    主要为了解决如下问题 项目中使用了Android未公开的API 在Eclipse下会有红叉显示 不同的项目抽出相同部分的代码共用 必需的前提条件 需要有Android源代码 编译的库文件主要是封装未公开API或者共用代码 工程1 Java库
  • h5单页面埋点问题(undefine)

    需求 商城页面里调用第三方资源埋点 代码实现 主要解决资源未加载就被调用问题 备注 把调用函数作为参数传进去 控制保证在资源加载完成后调用 let COLLECTURL http collect trc com index js 动态创建j
  • java的特点

    一 简单易学 1 java的风格类似于c 因而许多c c程序员初次接触java语言时会感到熟悉 从某种意义来说c 语言是从c语言继承而来 java语言是c 语言的一个变种 因此 学过c或c 的程序员可以更快速的掌握java编程技术 附图 编
  • 【mysql timeStamp默认值0000-00-00 00:00:00 报错:Invalid default value for ‘end_time’】

    mysql timeStamp默认值0000 00 00 00 00 00 报错 Invalid default value for end time 运行其中的sql文件时报错 nvalid default value for end t
  • python猜拳游戏编程代码_用python实现“猜拳"游戏

    原标题 用python实现 猜拳 游戏 用python实现 猜拳 游戏 先来练习一道用python编写的小程序 这道题是用for in 循环实现输入10个数并求和 这里用到了append 方法 append 方法 是一个很重要的方法 它是向
  • 计算机翻译的汉字,计算机系外文翻译(中英对照3000汉字左右).doc

    文档介绍 毕业设计 论文 外文资料翻译系别计算机信息与技术系专业计算机科学与技术班级姓名学号外文出处附件1 原文 2 putingMainarticle puter wasrecordedin1613 referringtoapersonw
  • 拓扑排序,广度优先

    使用一个队列来进行广度优先搜索 初始时 所有入度为 0 的节点都被放入队列中 它们就是可以作为拓扑排序最前面的节点 并且它们之间的相对顺序是无关紧要的 在广度优先搜索的每一步中 取出队首的节点 u 将 u 放入答案中 移除 u 的所有出边
  • hample滤波器的原理及其Python实现

    hample滤波器 1 作用及原理 2 Python实现 1 作用及原理 功能 检测并删除异常值 用一个一维向量 x x 1
  • 利用云原生数仓 Databend 构建 MySQL 的归档分析服务

    MySQL 常用 OLTP 业务环境 一般会使用比较好的硬件资源来提供对外服务 现在 MySQL 数据对外提供的数据动不动好几个 T 也是正常的 在很多业务中 数据有较强的生命周期 在线一段时间后 可能就是失去业务意义 如 某个业务下线 业
  • C语言通讯录

    主要知识 结构体 枚举 指针 递归 冒泡排序等 文章目录 一 前言 1 菜单 2 结构体创建 3 初始化通讯录 4 增加联系人 4 删除联系人 5 修改联系人信息 6 搜索联系人 7 显示联系人 8 联系人排序 三 代码展示 contect
  • 单片机FLASH操作

    FLASH 操作 查看程序已经占用的FLASH的扇区 剩余的扇区就是可以操作而不会使程序发生错乱的区域 找到listing文件夹下面的 map文件 搜索Memory Map of the image 查看占用的内存 起始地址是 0x8000
  • kafka数据丢包原因及解决方案

    数据丢失是一件非常严重的事情事 针对数据丢失的问题我们需要有明确的思路来确定问题所在 针对这段时间的总结 我个人面对kafka 数据丢失问题的解决思路如下 是否真正的存在数据丢失问题 比如有很多时候可能是其他同事操作了测试环境 所以首先确保
  • VGGNet网络详解与模型搭建

    文章目录 1 模型介绍 2 模型结构 3 模型特性 4 Pytorch模型搭建代码 1 模型介绍 VGGNet是由牛津大学视觉几何小组 Visual Geometry Group VGG 提出的一种深层卷积网络结构 他们以7 32 的错误率