Pytorch源码学习之二:torchvision.models.vgg

2023-05-16

0. VGG的网络结构

VGG网络结构

一、torchvision源码

这种通过配置文件一次性搭建相似网络的结构的方法十分值得学习和模仿.这也是相对于AlexNet的实现过程不同之处.
我对其做了一丁点修改,源码网址可见torchvision.models.vgg源码网址

'''
VGG的torchvison实现重写,
'''
import torch
import torch.nn as nn
try:
    from torch.hub import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url

__all__ = [
    'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
    'vgg19_bn', 'vgg19',
]
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}

class VGG(nn.Module):

    def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.init_weights= init_weights
        self.avgpool = nn.AdaptiveAvgPool2d((7,7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if self.init_weights:
            self._initialize_weights()

        def forward(self, x):
            x = self.features(x)
            x = torch.avgpool(x)
            x = torch.flatten(x, start_dim=1)
            x = self.classifier(x)
            return x

        def _initialize_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, model='fan_out',
                                            nonlinearity='relu')
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                    elif isinstance(m, nn.BatchNorm1d):
                        nn.init.constant_(m.weight, 1)
                        nn.init.constant_(m.bias, 0)
                    elif isinstance(m, nn.Linear):
                        nn.init.normal_(m.weight, 0, 0.01)
                        nn.init.constant_(m.bias, 0)

def make_layer(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels=in_channels, out_channels=v, kernel_size=3, stride=1, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
        in_channels = v
    return nn.Sequential(*layers)
cfgs = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
    }

def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
    '''
    搭建vgg网络
    :param arch:网络名称,用来加载预训练模型
    :param cfg: 配置,用来搭建网络
    :param batch_norm: bool,是否采用BN
    :param pretrained: bool,是否采用Pretrained
    :param progress: bool,下载时是否显示进度条
    :param kwargs:其它参数
    :return:返回搭建的vgg网络
    '''
    if pretrained:
        kwargs['init_weights']  = False
    model = VGG(make_layer(cfg[cfg], batch_norm=batch_norm), **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

def vgg11(pretrained=False, progress=True, **kwargs):

    return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)

def vgg11_bn(pretrained=False, progress=True, **kwargs):

    return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)

def vgg13(pretrained=False, progress=True, **kwargs):

    return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)

def vgg13_bn(pretrained=False, progress=True, **kwargs):

    return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)

def vgg16(pretrained=False, progress=True, **kwargs):

    return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)

def vgg16_bn(pretrained=False, progress=True, **kwargs):

    return _vgg('vgg16_bn', 'D', True, pretrained, progress)

def vgg19(pretrained=False, progress=True, **kwargs):

    return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)

def vgg19_bn(pretrained=False, progress=True, **kwargs):

    return _vgg('vgg19_bn', 'E', pretrained, progress, **kwargs)

二、一些值得学习的用法笔记

 #将start_dim至end_dim展成一维向量
torch.flatten(tenor, start_dim, end_dim)
x = torch.flatten(x, start_dim=1)
#效果同下
x = x.view(x.size(0), -1)
#使用何大佬在2015年提出的方法
torch.nn.init.kaiming_normal_(tensor, a=0, 
                          model='fan_in', nonlinearity='leaky_relu')
nn.init.kaiming_normal_(m.weight, model='fan_out',
                                            nonlinearity='relu')
#使用均值为mean,标准差为std的正态分布填充输入tensor
torch.nn.init.normal_(tensor, mean=0., std=1.) 
#使用浮点数val填充tensor
nn.init.constant_(tensor, val) 
#搭建网络的一种范式
layer = []
layer += [nn.Conv2d(...), nn.ReLU(inplace=True)]
layer += [nn.BatchNorm2d(...)]
nn.Sequential(*layers)
#从网络上加载参数
torch.hub.load_state_dict_from_url(url,  model_dir=None, map_location=None, progress=True)
#url-下载的目标网址
#model_dir - 保存参数的目录
#map_location - a function or a dict specifying how to remap storeage locations.
state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
#progress - 是否展示下载的进度条
#载入参数到模型
#torch.nn.modules.module.Module
#def load_state_dict(self, state_dict, strict=True)
model.load_state_dict(state_dict)
#*args 和 **kwargs都代表1个或多个参数的意思.*args传入tuple类型的无名参数,而**kwargs传入的参数是dict类型
def myprint(*args):
    print(*args)
myprint(10, 2) #10 2

def mykwprint(**kwargs):
    key = kwargs.keys()
    value = kwargs.values()
    print(key) #dict_keys(['epoch', 'LR'])
    print(value) #dict_values([10, 2])
mykwprint(epoch=10, LR=2)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch源码学习之二:torchvision.models.vgg 的相关文章

  • 【大陆ARS408毫米波雷达】一种利用串口解析雷达数据的方法

    硬件平台 xff1a ARS408毫米波雷达 can转485转换器 485转串口转换器 软件平台 xff1a Windows10 python3 本篇博客实现的功能 xff1a 一 通过两个转换器将毫米波雷达的原始数据传入电脑端的串口中 二
  • ubuntu14.04系统下对SD卡分区

    在ubuntu14 04系统下对SD卡进行分区分为3步 xff1a 注意 xff1a 进行SD卡分区时 xff0c 用户操作权限为root权限 xff01 1 umount SD卡 查看SD卡挂载目录 xff0c 一般在 media目录下
  • docker安装图形化管理界面

    首先看下这个界面的样子 还是比较好看 xff0c 而且在同一个局域网中都可以登录进行管理 说下安装教程吧 首先下载这个镜像 xff1a sudo docker pull portainer portainer 然后创建这个容器 sudo d
  • JSP小脚本学习

    小脚本 可以将任何数量的小脚本包含在页面中 xff0c 小脚本是有效的JAVA语言语句 xff0c 变量或方法声明或表达式 小脚本的语法 xff0c lt code fragment gt 入门示例 xff1b lt 64 page lan
  • 34. 在排序数组中查找元素的第一个和最后一个位置(C语言)

    笨办法 xff0c 先找第一个等于target的位置 xff0c 再找最后一个等于target的位置 Note The returned array must be malloced assume caller calls free int
  • 详解如何将TensorFlow训练的模型移植到Android手机

    前言 本文中出现的TF皆为TensorFlow的简称 先说两句题外话吧 xff0c TensorFlow 前两天热热闹闹的发布了正式版r1 0 xff0c 可感觉自己才刚刚上手 r0 12 xff0c 这个时代发展的太快 xff0c 脚步是
  • ROS实践手册(一)ROS安装教程

    笔者根据 古月居 ROS入门21讲 学习整理 xff0c 并参考 ROS机器人开发实践 一书 虚拟机安装 注 该部分可参考 古月居 ROS入门21讲 第2讲下载并安装 VMware Workstation Pro下载 Ubuntu18 04
  • Shell系统学习之如何执行Shell程序

    系列文章目录 Shell系统学习之什么是Shell Shell系统学习之创建一个Shell程序 Shell系统学习之向Shell脚本传递参数 Shell系统学习之如何执行Shell程序 Shell系统学习之Shell变量和引用 Shell系
  • target_link_libraries接口的使用

    target link libraries需要放在add executable之后 xff0c 用于指明连接进来的库 xff0c 官方推荐使用这个接口 xff0c 而不推荐使用link libraries xff0c link librar
  • TTL和RS232之间的详细对比

    背景 之前就听过TTL xff0c 一直没搞懂其和RS232的区别 最近 xff0c 打算去买个USB转RS232的芯片 xff0c 结果找到此产品 xff1a 六合一多功能USB转UART串口模块CP2102 usb TTL485 232
  • STL 解算法题目例子

    STL解算法题目例子
  • 双子天蝎,爱情是不老的传说

    双子天蝎 xff0c 爱情是不老的传说 自注 此文章乃双子座所写 定有主观上的个人倾向 转帖者 xff1a 就是我啦 xff0c 一个双子座的帅哥 xff08 自封 xff09 关于双子和天蝎 xff0c 我是很想很完整的写一些 xff0c
  • ASP2.0-130道ASP.NET面试题

    1 简述 private protected public internal 修饰符的访问权限 答 private 私有成员 在类的内部才可以访问 protected 保护成员 xff0c 该类内部和继承类中可以访问 public 公共成员
  • opencv 图像去噪学习总结

    OpenCV图像处理篇之图像平滑 图像平滑算法 程序分析及结果 图像平滑算法 图像平滑与图像模糊是同一概念 xff0c 主要用于图像的去噪 平滑要使用滤波器 xff0c 为不改变图像的相位信息 xff0c 一般使用线性滤波器 xff0c 其
  • Ubuntu18.04 装系统、cuda、cudnn,主要是Ubuntu的内核版本不能太高,亲测很成功

    一 装系统 简单的我就不说了 xff0c 之说要点 1 选择为图形或无线硬件 安装第三方软件 2 在安装类型中 xff0c 选择其他选项 3 分区 xff0c 我选择分区2 3个 EFI分区 xff0c 主空间 xff0c 空间起始位置 x
  • 被透明元素遮挡的元素还可以被点击到吗?

    遮挡 关于是否被遮挡的判断 xff0c 可以从对层叠级别的判断而确定 见 xff1a 说说标准 CSS核心可视化格式模型 visual formatting model 之十三 xff1a 分层的显示 Layered presentatio
  • 闲谈两句windows,linux

    今天无意在一个群里说了一句 34 我觉得ubuntu比vista还好用 34 马上引来一帮人的反击 xff0c 所用伎俩仍然没有新意 1偷换概念 xff0c 开始用winxp说事 2游戏 xff0c 网银 3windows的系统很稳定 xf
  • python类的基本操作

    本节给出类的基本操作函数 xff0c 方法查阅备用 0 定义类 span class token keyword class span span class token class name student span span class
  • python的异常类型

    1 内建的异常类 异常类含义Exception所有异常的基类AttributeError特性引用或赋值失败引发IOError试图打开不存在文件 包括其他情况 时引发IndexError使用序列中不存在的索引时引发KeyError在使用映射时
  • TFLearn代码示例

    span class token keyword import span tflearn span class token keyword from span tflearn span class token punctuation spa

随机推荐