VGG16训练RAF-DB

2023-05-16

使用VGG16对本地数据集RAF-DB中的basic图片进行训练,官方已经在图片命名时分好了train与test,train和test的label在同一个txt文件里,方便起见,把这两种label分成两个txt;需要自行重写导入数据集的函数。
注意:data.DataLoader后输出的batch中label要变成一维;load训练数据时要打乱顺序
进行50个epoch,最后在测试集中得到83.1%的准确率。

import torch, cv2, os, random
from torch.utils import data
from torchvision import transforms
import torchvision.models as models
import torch.nn as nn

class TxtImage(data.Dataset):
    def __init__(self,label,dataRoot,transform=None,size=(120,120),index=0):
        self.tranform=transform
        self.dataRoot=dataRoot
        self.size=size
        self.imgList=[]
        self.labelList=[]
        self.index=index
        for xx in label:
            x =xx.split(' ')
            self.imgList.append(x[0])
            self.labelList.append(int(x[1])-1)

    def __getitem__(self,index):
        imgName=self.imgList[index]
        imgName_=list(imgName)
        imgName_.insert(self.index,'_aligned')
        imgName="".join(imgName_)
        imgPath=os.path.join(self.dataRoot,imgName)
        img = cv2.imread(imgPath)
        if self.size is not None:
            img=cv2.resize(img,self.size)
        if self.tranform is not None:
            img=self.tranform(img)
        label=torch.IntTensor([self.labelList[index]])
        return img,label

    def __len__(self):
        return len(self.imgList)

if __name__=='__main__':
    transform=transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
    TestlabelPath = os.path.join('/home/msy/FaceData/RAF-DB/basic/alignByMyself/test.txt')
    TrainlabelPath = os.path.join('/home/msy/FaceData/RAF-DB/basic/alignByMyself/train.txt')
    ImageRoot='/home/msy/FaceData/RAF-DB/basic/Image/aligned'#aligned

    with open(TestlabelPath,'r') as testf:
        Testlabels=testf.readlines()
    TestDataset=TxtImage(Testlabels,ImageRoot,transform,index=9)
    TestLoader=data.DataLoader(TestDataset,batch_size=20,num_workers=2)
    TestDataLen=TestDataset.__len__()

    with open(TrainlabelPath,'r') as trainf:
        Trainlabels=trainf.readlines()
    TrainDataset=TxtImage(Trainlabels,ImageRoot,transform,index=11)
    TrainLoader=data.DataLoader(TrainDataset,batch_size=20,num_workers=2,shuffle=True)

    vgg16 = models.vgg16(pretrained=True)
    vgg16.classifier[6] = nn.Linear(4096, 7) #输出改为7个基础表情
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    vgg16.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(vgg16.parameters(), lr=0.0001)
    model_name = "vgg16"
    save_path = './{}Net.pth'.format(model_name)
    best_acc = 0.0

    print("start to train...")

    for epoch in range(50):
        vgg16.train()
        running_loss=0.0
        for step,(img,label) in enumerate(TrainLoader):

            optimizer.zero_grad()
            outputs = vgg16(img.to(device))

            label = label.squeeze()   # if parameters are none ,tensor is squeezed into one dimension
            label = torch.as_tensor(label, dtype=torch.int64, device=device)
            loss = loss_function(outputs, label)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            # print train process
            rate = (step + 1) / len(TrainLoader)
            a = "*" * int(rate * 50)
            b = "." * int((1 - rate) * 50)
            print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
        print()

        # validate
        vgg16.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            for data_test in TestLoader:
                test_images, test_labels = data_test
                optimizer.zero_grad()
                outputs = vgg16(test_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                test_labels = test_labels.squeeze()  # if parameters are none ,tensor is squeezed into one dimension
                
                acc += (predict_y == test_labels.to(device)).sum().item()
            accurate_test = round(acc / TestDataLen, 4)
            print()
            print(acc)
            if accurate_test >= best_acc:
                best_acc = accurate_test
                torch.save(vgg16.state_dict(), save_path)
                print("save weights")
            print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
                  (epoch + 1, running_loss / step, accurate_test))

    print('Finished Training')

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

VGG16训练RAF-DB 的相关文章

  • Mavlink地面站编写之一--协议分析

    阿木社区 转载注意出处 http www amovauto com Pixhawk APM都是采用MAVLINK协议实现的飞控的数据链路传输 先简单介绍下mavlink协议 Mavlink协议最早由 苏黎世联邦理工学院 计算机视觉与几何实验
  • 航模基础知识之电机,电调,电池的选择

    在设计一款小型无人车 xff0c 无人船 xff0c 多旋翼 xff0c 固定翼的时候 一定会面对的一个问题是怎么选择合适的电机 xff0c 电池 xff0c 电调 这三者配合的好 xff0c 机器才可以发挥优秀的性能 xff0c 否则会有
  • APM2.8 Rover 自动巡航车设计(硬件连接)

    APM PIX4系类飞控是美国3DR公司的开源项目 xff0c 是目前在开源无人机领域使用最多人数最多的开源控制板 整个项目开源 xff0c 从硬件到软件 有非常优秀的地面站系统和适应多旋翼 xff0c 无人车 xff0c 无人船 xff0
  • APM2.8 Rover 自动巡航车设计(固件安装和设置)

    1 2 APM2 8软件安装与固件下载 下载Mission Planner这个地面基站软件 xff0c 这里介绍的是windoews平台下的 xff0c 在MAC或者linux下可以使用QGroundCont基于QT编写的地面站软件 xff
  • 树莓派4b安装ubuntu18.04并安装ros

    ubuntu官方已经支持树莓派了 xff0c 官方镜像如下 https ubuntu com download raspberry pi 这是一个为树莓派准备的arm架构的ubuntu xff0c 不是ubuntu mate xff0c 功
  • 利用Session完成用户的登录和注销

    用户的登录和注销是最常见的Web应用案例 xff0c 当一个应用的客户登录了以后 xff0c 其他所有的会话都得知道这个用户已经登录还很有可能得提取用户的昵称予以显示等等 xff0c 所以 xff0c 只有把登录成功的用户的信息放入到Ses
  • Mavlink地面站编写之五-Mission Planner中ProgressReporterDialogue和读串口线程serialreaderthread的分析

    转载请注明出处 http www amovauto com p 61 660 阿木社区 xff0c 玩也要玩的专业 QQ群 526221258 ProgressReporterDialogue 这个对话框很有意思 xff0c 在MP中连接阶
  • Mavlink地面站编写之六---MP源码多线程读写框架分析

    转载请注明出处 xff01 阿木开源社区 玩也要玩的专业 http www amovauto com p 61 743 more 743 对于MissionPlanner这种多任务的程序 xff0c 我们知道要采用多线程的方式来实现 xff
  • APM/PIXhawk 最全资料总汇(欢迎补充更新)

    转载请注明出处 http www amovauto com cat 61 11 AMOV社区玩也要玩的专业 xff01 欢迎加入社区QQ群 APM PIX UAV 相关中英文网站链接总汇 xff1a 1 国内外知名论坛 无人机开源基金会 D
  • PIXHAWK开发环境建立(固件编译)

    阿木社区 xff1a 玩也要玩的专业 xff01 http www amovauto com p 61 842 QQ群 526221258 目前有很多基于PIXHAWK的开发的无人机 xff0c 业界对于APM PIXHAWK也比较认可 我
  • Mavlink地面站编写之七—发送控制指令

    转载请注明出处 http www amovauto com cat 61 19 阿木UAV社区 好久没更新MAVLINK系列文章了 xff0c 最近事情比较多 xff0c 中间去了趟深圳 见了老朋友顺便去了趟华强北溜了圈 所以中间耽误更新的

随机推荐