计算机视觉之VGGNet

2023-11-14

1 VGGNet介绍

        VGGNet是牛津大学视觉几何组(Visual Geometry Group)提出的模型,故简称VGGNet, 该模型在2014年的ILSVRC中取得了分类任务第二、定位任务第一的优异成绩。该模型证明了增加网络的深度能够在一定程度上影响网络最终的性能。

        论文地址:原文链接

        根据卷积核大小与卷积层数目不同,VGG可以分为6种子模型,分别是A、A-LRN、B、C、D、E,分别对应的模型为VGG11、VGG11-LRN(第一层采用LRN)、VGG13、VGG16-1、VGG16-3和VGG19。不同的后缀代表不不同的网络层数。VGG16-1表示后三组卷积块中最后一层卷积采用卷积核尺寸为1*1,VGG16-3为3*3。VGG19位后三组每组多一层卷积,VGG19为3*3的卷积。我们常看到的基本是D、E这两种模型,官方给出的6种结构图如下:

2 VGG16网络结构     

        VGG16的网络结果如上图所示:在卷积层1(conv3-64),卷积层2(conv3-128),卷积层3(conv3-256),卷积层4(conv3-512)分别有64个,128个,256个,512个3X3卷积核,在每两层之间有池化层为移动步长为2的2X2池化矩阵(maxpool)。在卷积层5(conv3-512)后有全连接层,再之后是soft-max预测层。

 处理过程的直观表示:

3 VGG16在pytorch下,基于cifar-10数据集的实现

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import datetime
from torchvision import datasets
from torch.utils.data import DataLoader

VGG_types = {
    "VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "VGG13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "VGG16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512,
              "M", 512, 512, 512, "M"],
    "VGG19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512,
              "M", 512, 512, 512, 512, "M"]
}

VGGType = "VGG16"


class VGGnet(nn.Module):
    def __init__(self, in_channels=3, num_classes=1000):
        super(VGGnet, self).__init__()
        self.in_channels = in_channels
        self.conv_layers = self._create_layers(VGG_types[VGGType])
        self.fcs = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fcs(x)
        return x

    def _create_layers(self, architecture):
        layers = []
        in_channels = self.in_channels

        for x in architecture:
            if type(x) == int:
                out_channels = x
                layers += [
                    nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=(3, 3),
                        stride=(1, 1),
                        padding=(1, 1),
                    ),
                    nn.BatchNorm2d(x),
                    nn.ReLU(),
                ]
                in_channels = x
            elif x == "M":
                layers += [nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))]

        return nn.Sequential(*layers)


transform_train = transforms.Compose(
    [
        transforms.Pad(4),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomGrayscale(),
        transforms.RandomCrop(32, padding=4),
        transforms.Resize((224, 224))
    ])

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        transforms.Resize((224, 224))
    ]
)

train_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transform_train,
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=transform_test,
)


def get_format_time():
    return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')


if __name__ == "__main__":

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = VGGnet(in_channels=3, num_classes=10).to(device)
    print(model)

    optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=5e-3)
    loss_func = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.4, last_epoch=-1)

    epochs = 40
    total = 0
    accuracy_rate = []

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        print(f"{get_format_time()},train epoch: {epoch}/{epochs}")
        for step, (images, labels) in enumerate(train_loader, 0):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).to(device)
            loss = loss_func(outputs, labels).to(device)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            correct = torch.sum(predicted == labels)
            train_correct += correct
            train_total += images.shape[0]
            train_loss += loss.item()
            if step % 1 == 0 and step > 0:
                print(f"{get_format_time()},train epoch = {epoch}, step = {step}, "
                      f"train_loss={train_loss}")
                train_loss = 0.0
                break

        # 在测试集上进行验证
        model.eval()
        test_correct = 0
        test_total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                outputs = model(images).to(device)
                _, predicted = torch.max(outputs, 1)
                test_total += labels.size(0)
                test_correct += torch.sum(predicted == labels)
                break
        accuracy = 100 * test_correct / test_total
        accuracy_rate.append(accuracy)

        print(f"{get_format_time()},test epoch = {epoch}, accuracy={accuracy}")
        scheduler.step()

    accuracy_rate = np.array(accuracy_rate)
    times = np.linspace(1, epochs, epochs)
    plt.xlabel('times')
    plt.ylabel('accuracy rate')
    plt.plot(times, accuracy_rate)
    plt.show()

    print(f"{get_format_time()},accuracy_rate={accuracy_rate}")

模型形状打印输出:

VGGnet(
  (conv_layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU()
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU()
    (13): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (16): ReLU()
    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (19): ReLU()
    (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (22): ReLU()
    (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (26): ReLU()
    (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (29): ReLU()
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (32): ReLU()
    (33): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (36): ReLU()
    (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (39): ReLU()
    (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (42): ReLU()
    (43): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (fcs): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )
)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

计算机视觉之VGGNet 的相关文章

随机推荐

  • < C++ >:C++ 类和对象(上)

    目录 1 面向过程和面向对象的初步认识 2 类的引入 3 类的访问限定符及封装 3 1 访问限定符 3 2 封装 4 类的声明和定义或类的定义 可以理解成声明和定义 5 类的作用域 6 类的实例化 7 类对象模型 7 1 如何计算类对象的大
  • 释放技术的想象-解码腾讯云软件架构与应用

    欢迎大家前往腾讯云社区 获取更多腾讯海量技术实践干货哦 关于腾讯 你可能玩过 王者荣耀 你可能用过 微信 和它的 小程序 你可能看过 腾讯视频 并且曾为之付费 你可能已经是多年的 QQ 老手但还不知道什么是 MQ 作为成立多年的老牌互联网公
  • 为何零信任架构身份管理平台更可靠?

    随着信息技术的不断进步 云计算 物联网以及移动设备的普及 信息泄露等安全问题愈发频繁 近期 一起某高校学生的信息泄露事件引发了大家的广泛讨论 该校学生利用其身份便利 非法获取了大量学生的姓名 学号 照片等隐私信息 这次热门话题的讨论后 人们
  • python实现排列组合代码

    def combination n c com 1 limit 0 per for pos in range limit n t per pos if len set t len t if len t c yield pos else fo
  • 22道常见RocketMQ面试题以及答案

    面试宝典到手 搞定面试 不再是难题 系列文章传送地址 请点击本链接 1 RocketMQ是什么 2 RocketMQ有什么作用 3 RoctetMQ的架构 4 RoctetMQ的优缺点 8 消息过滤 如何实现 9 消息去重 如果由于网络等原
  • 基于WSL2+NVIDIA Docker的开发环境最佳实践

    1 Windows 11 安装WSL2 Ubuntu 22 04 LTS 1 1 安装windows附加功能 点击 设置 gt 应用 gt 可选功能 gt 更多windows功能 弹出的窗口 勾选 适用于Linux的Windows子系统 和
  • 学习dubbo遇到的报错:UnsatisfiedDependencyException: Error creating bean with name ‘us

    学习dubbo直连方式遇到的报错 记录一下 org springframework beans factory UnsatisfiedDependencyException Error creating bean with name use
  • win10企业版更新和安全中没有 “恢复”这个选项_永别了您内,整疯我的Win10自动更新...

    朋友终于买了台微星GP75 用来驾驭重制版剑三的网游史最大客户端 号称几乎无短板 处处贵到痛点的香香75的确值得 游戏流畅丝滑 开强冷模式手托半点不热 不过就在昨天 朋友突然问我 微星笔记本自带的win10系统可以退回win7吗 害 这哥果
  • centos 安装k8s

    第一步 每台机子都做 关闭防火墙 systemctl stop firewalld systemctl disable firewalld 第二步 每台机子都做 永久关闭selinux sed i s enforcing disabled
  • 30-10-010-编译-kylin-on-druid-2.6.0-CDH57编译

    1 视界 1 下载kylin git clone https github com apache kylin kylin 2 安装maven nodejs 1 maven的安装参照百度 这里不再赘述 2 nodejs的安装参考
  • ssh连接慢解决办法

    ssh连接慢解决办法 成功 用真机连接虚拟机卡的话 1 进入虚拟机vim etc ssh sshd config 2 将 UseDNS yes改为UseDNS no即可如下 使用 UseDNS找到地方然后添加 UseDNS yes UseD
  • VS2019+QT5.15.2+QGIS二次开发环境搭建

    VS2019 QT5 15 2 QGIS二次开发环境搭建 1 开发环境 VS2019 QT5 15 2 QGIS 注意 QT 平台的版本与qgis下载的版本有关 以前采用OSGeo4w64来下载qgis时会区分32和64位 但现在官网已经不
  • QT中QWeight与QMainWindow的区别

    在Qt中 QWidget 和 QMainWindow 是两个常用的类 用于创建用户界面 它们之间有一些区别 1 QWidget 是Qt中所有用户界面类的基类 而 QMainWindow 是一个特殊的窗口类 通常用于创建应用程序的主窗口 QM
  • elasticsearch(磁盘删除data后kibana自动进入只读模式)

    Elasticsearch 基于磁盘的碎片分配 向index插入 删除数据时发生报错 index kibana 1 blocked by FORBIDDEN 12 index read only allow delete api clust
  • FreeRTOS笔记(一)简介

    这个笔记主要依据韦东山freertos快速入门系列记录 感谢韦东山老师的总结 什么是实时操作系统 操作系统是一个控制程序 负责协调分配计算资源和内存资源给不同的应用程序使用 并防止系统出现故障 操作系统通过一个调度算法和内存管理算法尽可能把
  • DEDE自动调用轮播图/幻灯片

    备注 以下示例是以自动调取轮播图为例 具体使用时 步骤不变 内容据实调整即可 一 创建 1 新建模型 2 在新模型下依次添加字段 本例字段 datu xiaotu 分别给PC端和手机端用 据实调整即可
  • 雅思词汇表8000词版_考“鸭”干货丨雅思词汇备考技巧!

    点击蓝字 关注我们 考 鸭 干货第3期 雅思词汇备考技巧 语言学家TERREL认为 只要掌握了足够的词汇 即使没有多少语法知识 外语学习者也能较好理解外语和用外语进行表达 语言学家WILKINS有一句经典名言 没有语法只能传达很少的信息 没
  • 【STM32】STM32之timer1产生PWM(互补通道)

    本篇博文最后修改时间 2017年01月14日 23 50 一 简介 本文介绍STM32系列如何使用timer1的第TIM1 CH2N通道 PB14 产生PWM 二 实验平台 库版本 STM32F10x StdPeriph Lib V3 5
  • 反序列化攻击原理及防御措施(已解决)

    反序列化攻击原理及防御措施 已解决 java序列化算法透析 Serialization 序列化 是一种将对象以一连串的字节描述的过程 反序列化deserialization是一种将这些字节重建成一个对象的过程 Java序列化API提供一种处
  • 计算机视觉之VGGNet

    1 VGGNet介绍 VGGNet是牛津大学视觉几何组 Visual Geometry Group 提出的模型 故简称VGGNet 该模型在2014年的ILSVRC中取得了分类任务第二 定位任务第一的优异成绩 该模型证明了增加网络的深度能够