0. VGG的网络结构
一、torchvision源码
这种通过配置文件一次性搭建相似网络的结构的方法十分值得学习和模仿.这也是相对于AlexNet的实现过程不同之处.
我对其做了一丁点修改,源码网址可见torchvision.models.vgg源码网址
'''
VGG的torchvison实现重写,
'''
import torch
import torch.nn as nn
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
__all__ = [
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
'vgg19_bn', 'vgg19',
]
model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
class VGG(nn.Module):
def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()
self.features = features
self.init_weights= init_weights
self.avgpool = nn.AdaptiveAvgPool2d((7,7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
if self.init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = torch.avgpool(x)
x = torch.flatten(x, start_dim=1)
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, model='fan_out',
nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def make_layer(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels=in_channels, out_channels=v, kernel_size=3, stride=1, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
cfgs = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
'''
搭建vgg网络
:param arch:网络名称,用来加载预训练模型
:param cfg: 配置,用来搭建网络
:param batch_norm: bool,是否采用BN
:param pretrained: bool,是否采用Pretrained
:param progress: bool,下载时是否显示进度条
:param kwargs:其它参数
:return:返回搭建的vgg网络
'''
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layer(cfg[cfg], batch_norm=batch_norm), **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def vgg11(pretrained=False, progress=True, **kwargs):
return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
def vgg11_bn(pretrained=False, progress=True, **kwargs):
return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
def vgg13(pretrained=False, progress=True, **kwargs):
return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
def vgg13_bn(pretrained=False, progress=True, **kwargs):
return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
def vgg16(pretrained=False, progress=True, **kwargs):
return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
def vgg16_bn(pretrained=False, progress=True, **kwargs):
return _vgg('vgg16_bn', 'D', True, pretrained, progress)
def vgg19(pretrained=False, progress=True, **kwargs):
return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
def vgg19_bn(pretrained=False, progress=True, **kwargs):
return _vgg('vgg19_bn', 'E', pretrained, progress, **kwargs)
二、一些值得学习的用法笔记
torch.flatten(tenor, start_dim, end_dim)
x = torch.flatten(x, start_dim=1)
x = x.view(x.size(0), -1)
torch.nn.init.kaiming_normal_(tensor, a=0,
model='fan_in', nonlinearity='leaky_relu')
nn.init.kaiming_normal_(m.weight, model='fan_out',
nonlinearity='relu')
torch.nn.init.normal_(tensor, mean=0., std=1.)
nn.init.constant_(tensor, val)
layer = []
layer += [nn.Conv2d(...), nn.ReLU(inplace=True)]
layer += [nn.BatchNorm2d(...)]
nn.Sequential(*layers)
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True)
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
def myprint(*args):
print(*args)
myprint(10, 2)
def mykwprint(**kwargs):
key = kwargs.keys()
value = kwargs.values()
print(key)
print(value)
mykwprint(epoch=10, LR=2)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)