PyTorch学习(9):实战
- Pytorch官方文档: https://pytorch-cn.readthedocs.io/zh/latest/
- Pytorch学习文档: https://github.com/tensor-yu/PyTorch_Tutorial
- Pytorch模型库: https://github.com/pytorch/vision/tree/master/torchvision/models
文章目录
- PyTorch学习(9):实战
- 前言
- 1.ShuffleNet系列
- (1)数据处理
- (2)模型设计
- (3)训练配置
- (4)训练过程
- (5)模型保存
- 2.EfficientNet系列
-
- 总结
前言
Pytorch模型库已经提供了相当多可用的模型了。
本文从两个比较好的开源项目入手:ShuffleNet 和 EfficientNet。
数据集采用flowers17:
链接: https://pan.baidu.com/s/1UxcW2AFE2OjWfgsLIHtCYA
提取码: brxm
1.ShuffleNet系列
Github链接: https://github.com/megvii-model/ShuffleNet-Series
megvii开源的ShuffleNet系列模型
以ShuffleNetV2为例: 学习深度学习模型搭建的整体流程。
(1)数据处理
train_dataset = datasets.ImageFolder(
args.train_dir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
transforms.RandomHorizontalFlip(0.5),
ToBGRTensor(),
])
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=1, pin_memory=use_gpu)
train_dataprovider = DataIterator(train_loader)
(2)模型设计
architecture = [0, 0, 3, 1, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 2, 0, 2, 1, 3, 2]
model = ShuffleNetV2(model_size=args.model_size)
(3)训练配置
optimizer = torch.optim.SGD(get_parameters(model),
lr=args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay)
criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)
(4)训练过程
output = model(data)
loss = loss_function(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
prec1, prec5 = accuracy(output, target, topk=(1, 5))
(5)模型保存
def save_checkpoint(state, iters, tag=''):
if not os.path.exists("./models"):
os.makedirs("./models")
filename = os.path.join("./models/{}checkpoint-{:06}.pth.tar".format(tag, iters))
torch.save(state, filename)
至此,基于PyTorch框架的ShuffleNetV2模型的训练框架搭建完成。具体细节参考官方code。
2.EfficientNet系列
Github链接: https://github.com/lukemelas/EfficientNet-PyTorch
非官方开源的EfficientNet系列模型,目前仅有V1版本,坐等开源V2版本。
作者的训练代码在examples文件夹内,整理流程同ShuffleNet系列一致。
以Efficientnet-b0模型为基础,训练flowers17数据集:
(1)数据集划分
Flowers17数据由17种不同的花构成,本文以每种花的0.8作为训练集,其余0.2作为测试集。
(2)模型训练
训练结果:
无数据增强 | +数据增强 | +预训练模型 |
---|
59.56% | 84.92% | 96.32% |
大致训了一下,采用预训练模型进行迁移训练的效果最好。
总结
至此,基于PyTorch框架的入门已基本完成。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)