vggNet网络学习(网络架构及代码搭建)

2023-11-15

原论文 翻译链接:VERY DEEP CONVOLUTIONAL NETWORKSFOR LARGE-SCALE IMAGE RECOGNITION(VGGnet论文翻译(附原文))_机器学习我不学习的博客-CSDN博客

 网络架构

        vggnet使用了更小的卷积核和更深的卷积神经网络在分类任务中获得了更好的效果。vggnet的输入是固定大小为224×224的RGB图像,预处理是从每个像素上减去在训练集上计算出的平均RGB值。和过去的神经网络不同,vggnet卷积层没有使用大的卷积核,而是使用了非常小的卷积核进行卷积操作。卷积核大小为3×3,步长为1,padding为1。也使用到了1×1的卷积核,这可以看作是输入通道的线性变换。有五个最大池化层,池化窗口大小为2×2,步长为2。之后跟着三个全连接层,其中前两个层每个有4096个通道,第三个层执行1000路ILSVRC分类,因此包含1000个通道(每个类一个),后一层是softmax层。所有的隐层都使用ReLU作为激活函数。

        3×3的卷积核是能包含图像中一个像素点上下左右的最小尺度。

        该网络的亮点:通过堆叠多个3×3的卷积核来替代大尺度卷积核,这样操作可以减少训练中的参数。由两个3×3 卷积层堆叠一起具有5×5的有效感受野;三个这样的3×3的卷积层具有7×7的有效感受野。通过堆叠两个3×3的卷积核可以替代5×5的卷积核,通过堆叠三个3×3的卷积核可以替代7×7的卷积核。

在这里插入图片描述

        如上图所示,两个3×3的卷积核可以堆叠成一个5×5的卷积核(左图),三个3×3的卷积核可以堆叠成一个7×7的卷积核(右图)。

使用7×7的卷积核所需要的参数:7×7×C×C = 49C^{2}   (假设输入输出chan'n)

使用3个3×3的卷积核堆叠成7×7的卷积核所需要的参数:3×3×C×C+3×3×C×C+3×3×C×C=27C^{2}

         vggnet训练了几种深度不同的网络,从网络A~网络E的网络架构如下图所示:

        下图是vgg-16的网络架构: 

         输入的是一个224×224×的RGB图像,所有卷积核的stride为1,padding为1;maxpooling的size为2,stride为2。

        ①:经过两个3×3的卷积层,每个卷积层中卷积核的数量为64,所以得到的输出尺寸为224×224×64;再经过一个maxpooling,尺寸变成112×112×64。

        ②:经过两个3×3的卷积层,每个卷积层中卷积核的数量为128,所以得到的输出尺寸为112×112×128;再经过一个maxpooling,尺寸变成56×56×128。

        ③:经过三个3×3的卷积层,每个卷积层中卷积核的数量为256,所以得到的输出尺寸为56×56×256;再经过一个maxpooling,尺寸变成28×28×256。

        ④:经过三个3×3的卷积层,每个卷积层中卷积核的数量为512,所以得到的输出尺寸为28×28×512;再经过一个maxpooling,尺寸变成14×14×512。

        ⑤:经过三个3×3的卷积层,每个卷积层中卷积核的数量为512,所以得到的输出尺寸为14×14×512;再经过一个maxpooling,尺寸变成7×7×512。

        ⑥:是三层全连接层,其中第一层和第二层全连接层的节点个数都是4096个,第三个全连接层有1000个节点,对应着ImageNet中的1000个分类。其中前两个全连接层后跟着Relu激活函数,但是第三个全连接层不需要加Relu激活函数,因为最后会有一层softmax层。

        ⑦softmax层,将预测结果转化为概率分布。

vgg-16网络搭建与训练

train.py

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()

    model_name = "vgg16"
    net = vgg(model_name=model_name, num_classes=5, init_weights=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)

    epochs = 30
    best_acc = 0.0
    save_path = './{}Net.pth'.format(model_name)
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

model.py

import torch.nn as nn
import torch

# official pretrain weights
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}


class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.features(x)
        # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1)
        # N x 512*7*7
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def make_features(cfg: list):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)


cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def vgg(model_name="vgg16", **kwargs):
    assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    cfg = cfgs[model_name]

    model = VGG(make_features(cfg), **kwargs)
    return model

predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)
    
    # create model
    model = vgg(model_name="vgg16", num_classes=5).to(device)
    # load model weights
    weights_path = "./vgg16Net.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

vggNet网络学习(网络架构及代码搭建) 的相关文章

随机推荐

  • MyBatis resultMap collection标签 返回基本类型集合 如:List<Long> List<String> List<Integer>等

    class xxDTO private Long id private Set
  • vc 判断某个盘符是否为移动硬盘盘符

    在使用GetDriveType获取磁盘类型时 一般小容量的U盘直接返回DRIVE REMOVABLE 倒是不用再进行下一步的判断 而大容量U盘和移动硬盘的盘符返回值和本地硬盘盘符返回值都是DRIVE FIXED 需要再进行判断 如果是IDE
  • 【paddlepaddle】一键人物抠图

    效果 环境准备 win11 python3 8 pip install paddlepaddle i https pypi tuna tsinghua edu cn simple pip install paddlehub i https
  • animation 动画的定义和使用

    keyframes 定义动画 keyframes myname 0 50 100 调用动画 div animation name myfirst animation duration 5s animation timing function
  • Unity 开发人员转CGE(castle Game engine)城堡游戏引擎指导手册

    Unity 开发人员的城堡游戏引擎概述 一 简介 2 Unity相当于什么GameObject 3 如何设计一个由多种资产 生物等组成的关卡 4 在哪里放置特定角色的代码 例如生物 物品 Unity 中 向 GameObject 添加 Mo
  • U盘启动盘制作(步骤详细)

    U盘启动盘制作 在制作启动盘之前我们需要先准备一个8G以上的U盘 和一台能上网的电脑 一 下载系统镜像 根据自己需要的系统版本去下载官方的镜像文件 记得要下载纯净的镜像 否则在后续安装好系统后会出现捆绑的现象 可以直接去下面这个网站下载 下
  • rsync实现文件服务器间文件同步

    rsync介绍 rsync命令工具可以实现服务器间的文件同步 全量或者增量 比如使用 size only来检查源端文件和目标端文件大小是否一致决定是否需要同步 由此同步的功能扩展 可以实现本机不同目录文件拷贝 快速删除海量文件等功能 但要注
  • MySQL隔离级别

    表结构和表数据如下 id 自增主键 uid 唯一索引 name price 普通索引 pictures 33 a Apple 12 NULL 34 b banana 5 NULL 35 c cherry 51 NULL 36 d date
  • Python语言 :关于使用装饰器的技巧介绍

    转自 微点阅读 https www weidianyuedu com 导语 装饰器 Decorator 是 Python 里的一种特殊工具 它为我们提供了一种在函数外部修改函数的灵活能力 它有点像一顶画着独一无二 符号的神奇帽子 只要将它戴
  • 抽象,内部类,接口,多态

    final 最终的 不能改变的 单独应用几率低 修饰变量 变量不能被改变 修饰方法 方法不能被重写 修饰类 类不能被继承 static final常量 应用率高 必须声明同时初始化 常常通过类名点来访问 不能被改变 建议 常量名所有字母都大
  • android — NDK生成so文件

    我们在安装环境的时候安装了NDK 可以在eclipse下直接生成so文件 NDK的压缩包里面自带了一些sample工程 NDK的文件直接解压到某个目录下即可 第一次生成so文件的时候 我们先使用NDK的sample下的hello jni的例
  • 【栈】逆波兰计算器

    文章目录 前言 一 基本概念 1 1 介绍 1 2 入栈和出栈示意图 1 3 栈的应用场景 二 使用数组模拟栈 2 1 思路分析 2 2 代码实现 2 3 测试 三 使用栈模拟中缀表达式计算器 3 1 整体思路 3 2 验证3 2 6 2
  • Qt基本窗口

    窗口类 1 Qt中最经常被使用的窗口类是QWidget QDialog 其中QDialog是继承于QWidget 它是一个顶级窗口 不能附着在其他QDialog上面 一般情况下QDialog基本都是用 在弹出窗口需求中被使用 而QWidge
  • OpenGL中光源的三种移动区别

    1 光源不动 需要在设置完视图模型变换之后 然后再设置光源的位置并且开启 伪代码如下 glmatrixmode gl projection glloadidentity xxxxxxxxxx glmatrixmode gl modelvie
  • Vue 移动端、PC 端适配

    Vue 移动端 PC 端适配可以使用 lib flexible amfe flexible postcss pxtorem postcss px2rem 和 postcss px to viewport 这几个插件 lib flexible
  • BLE蓝牙协议 — BLE连接建立过程梳理(一)

    文章出处 枫之星雨 转载文章 如有不妥 通知后我会立即删除 连接建立 应付比广播更为复杂的数据传输 或者要在设备之间实现可靠的数据交付 这些都要依赖于连接 连接使用数据信道在两个设备之间可靠地发送信息 它采取了自适应跳频增强鲁棒性 同时使用
  • Idea:applicationcontext in module file is included in 5 contexts

    今天使用IDEA做项目的时候出现了这个东西 经过查询资料 应该是编译器自动导入配置文件的时候发生了某些错误 提示修正 解决方法 依次打开 Project Settings gt Modules gt Spring 按减号删除右侧所有文件 然
  • 国产ChatGpt、AI模型盘点

    个人中心 DownLoad 一 百度 文心一言 百度的文心一言是一款基于深度学习技术的自然语言生成模型 能够生成各种类型的文本 包括新闻 小说 诗歌等 它采用了Transformer模型和GPT 2模型 能够生成高质量的文本 并且速度非常快
  • 2022-1-12 java运算符的学习记录

    2022 1 12 java运算符的学习记录 算数运算符 在java中有i 和 i俩种操作 前一种是先使用变量再自增 后一种是先自增再使用变量 因为java是强运算符号 所以不同的类型的变量加减 最终会趋向于高等级的类型的运算类型 是取整符
  • vggNet网络学习(网络架构及代码搭建)

    原论文 翻译链接 VERY DEEP CONVOLUTIONAL NETWORKSFOR LARGE SCALE IMAGE RECOGNITION VGGnet论文翻译 附原文 机器学习我不学习的博客 CSDN博客 网络架构 vggnet