PyTorch学习(9):实战

2023-05-16

PyTorch学习(9):实战

  • Pytorch官方文档: https://pytorch-cn.readthedocs.io/zh/latest/
  • Pytorch学习文档: https://github.com/tensor-yu/PyTorch_Tutorial
  • Pytorch模型库: https://github.com/pytorch/vision/tree/master/torchvision/models

文章目录

  • PyTorch学习(9):实战
  • 前言
    • 1.ShuffleNet系列
      • (1)数据处理
      • (2)模型设计
      • (3)训练配置
      • (4)训练过程
      • (5)模型保存
    • 2.EfficientNet系列
      • (1)数据集划分
      • (2)模型训练
  • 总结


前言

Pytorch模型库已经提供了相当多可用的模型了。
本文从两个比较好的开源项目入手:ShuffleNetEfficientNet
数据集采用flowers17:
  链接: https://pan.baidu.com/s/1UxcW2AFE2OjWfgsLIHtCYA
  提取码: brxm


1.ShuffleNet系列

Github链接: https://github.com/megvii-model/ShuffleNet-Series
megvii开源的ShuffleNet系列模型
在这里插入图片描述
以ShuffleNetV2为例: 学习深度学习模型搭建的整体流程。

(1)数据处理

    train_dataset = datasets.ImageFolder(
        args.train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(0.5),
            ToBGRTensor(),
        ])
    )
    train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=1, pin_memory=use_gpu) #锁页内存
    train_dataprovider = DataIterator(train_loader)

(2)模型设计

    architecture = [0, 0, 3, 1, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 2, 0, 2, 1, 3, 2]
    model = ShuffleNetV2(model_size=args.model_size)

(3)训练配置

   optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)

(4)训练过程

    output = model(data)
    loss = loss_function(output, target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() # 对参数更新
    prec1, prec5 = accuracy(output, target, topk=(1, 5))

(5)模型保存

def save_checkpoint(state, iters, tag=''):
	if not os.path.exists("./models"):
		os.makedirs("./models")
	filename = os.path.join("./models/{}checkpoint-{:06}.pth.tar".format(tag, iters))
	torch.save(state, filename)

至此,基于PyTorch框架的ShuffleNetV2模型的训练框架搭建完成。具体细节参考官方code


2.EfficientNet系列

Github链接: https://github.com/lukemelas/EfficientNet-PyTorch
非官方开源的EfficientNet系列模型,目前仅有V1版本,坐等开源V2版本。
在这里插入图片描述
作者的训练代码在examples文件夹内,整理流程同ShuffleNet系列一致。
以Efficientnet-b0模型为基础,训练flowers17数据集:

(1)数据集划分

Flowers17数据由17种不同的花构成,本文以每种花的0.8作为训练集,其余0.2作为测试集。

(2)模型训练

训练结果:

无数据增强+数据增强+预训练模型
59.56%84.92%96.32%

大致训了一下,采用预训练模型进行迁移训练的效果最好

总结

  至此,基于PyTorch框架的入门已基本完成。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch学习(9):实战 的相关文章

随机推荐