Pytorch学习笔记(16)———预训练模型微调

2023-11-15

完整工程

  • 工程目录结构
    在这里插入图片描述
  • Code
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import copy


# ---------------------------------------------------------
# 载入预训练的AlexNet模型
model = models.alexnet(pretrained=True)
# 修改输出层,2分类
model.classifier[6] = nn.Linear(in_features=4096, out_features=2)


# -------------------------数据集----------------------------------------------------

transform = transforms.Compose([transforms.Resize((227,227)),
                                transforms.ToTensor()])

train_dataset = ImageFolder(root='./data/train', transform=transform)
val_dataset = ImageFolder(root='./data/val', transform=transform)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, num_workers=4, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)


# ------------------优化方法,损失函数--------------------------------------------------
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss_fc = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, 20, 0.1)


# --------------------判断是否支持GPU--------------------------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

# -------------------训练-------------------------------------------------------------

epoch_nums = 50
best_model_wts = model.state_dict()
best_acc = 0
for epoch in range(epoch_nums):
    scheduler.step()
    running_loss = 0.0
    epoch_loss = 0.0
    correct = 0
    total = 0

    for i, sample_batch in enumerate(train_dataloader):
        inputs = sample_batch[0]
        labels = sample_batch[1]

        inputs.to(device)
        labels.to(device)

        model.train()
        optimizer.zero_grad()
        # forward
        outputs = model(inputs)
        # loss
        loss = loss_fc(outputs, labels)

        loss.backward()
        optimizer.step()

        #
        running_loss += loss.item()
        if i % 10 == 9:
            correct = 0
            total = 0
            for images_test, labels_test in val_dataloader:
                model.eval()
                images_test = images_test.to(device)
                labels_test = labels_test.to(device)
                outputs_test = model(images_test)
                _, prediction = torch.max(outputs_test, 1)
                correct += ((prediction == labels_test).sum()).item()
                total += labels_test.size(0)
            accuracy = correct/total
            print('[{}, {}] running loss={:.5f}, accuracy={:.5f}'.format(epoch + 1, i + 1, running_loss/10, accuracy))
            running_loss = 0.0
            if accuracy > best_acc:
                best_acc = accuracy
                best_model_wts = copy.deepcopy(model.state_dict())


print('Train finish')
torch.save(best_model_wts, './models/model_50.pth')

https://www.jianshu.com/p/2e5a9bd5ad36

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch学习笔记(16)———预训练模型微调 的相关文章

随机推荐