PytorchCNN项目搭建 6--- 训练、验证CNN

2023-11-05

整体的代码在我的github上面可以查阅


上几次的实验已经下载了数据集,并且写好了models,并进行了一些基础的配置,这次的主要目标是写好训练过程。主要的流程如下:

  • 配置args, cfg, log等
  • 将之前的数据集datasetset经过DataLoader变成data_loader
  • 加载网络net
  • 选择损失函数和优化器
  • 训练网络,得到损失值loss

import os
import pdb
import argparse
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
from mmcv import Config
import numpy as np
from log.logger import Logger
from dataset.dataset import Cifar10Dataset
from utils.get_net import get_network
from utils.visualization import plot_acc_loss
from loss import MyCrossEntropy

1. 配置args,cfg,log,等

def parser():
    parse = argparse.ArgumentParser(description='Pytorch Cifar10 Training') 
    parse.add_argument('--config','-c',default='./config/config.py',help='config file path') # 配置config
    parse.add_argument('--net','-n',type=str,required=True,help='input which model to use') # 配置网络net
    parse.add_argument('--pretrain','-p',action='store_true',help='Location pretrain data') # 网络是否进行预训练?默认为否
    parse.add_argument('--resume','-r',action='store_true',help='resume from checkpoint') # 是否进行断点续训?默认为否
    parse.add_argument('--epoch','-e',default=None,help='resume from epoch') # 断点续训从哪个epoch开始
    parse.add_argument('--gpuid','-g',type=int,default=0,help='GPU ID') # 是否指定进行训练的 GPU_ID
    args = parse.parse_args()
    return args

2. 将数据集dataset经DataLoader变成dataloader

def DataLoad(cfg):
    trainset = Cifar10Dataset(txt=cfg.PARA.cifar10_paths.train_data_txt, transform='for_train')
    validset = Cifar10Dataset(txt=cfg.PARA.cifar10_paths.valid_data_txt, transform='for_valid')
    train_loader = DataLoader(dataset=trainset, batch_size=cfg.PARA.train.batch_size, drop_last=True, shuffle=True, num_workers=cfg.PARA.train.num_workers)
    valid_loader = DataLoader(dataset=validset, batch_size=cfg.PARA.train.batch_size, drop_last=True, shuffle=True, num_workers=cfg.PARA.train.num_workers)
    return train_loader, valid_loader

3. 加载网络net

4. 选择损失函数和优化器

net = get_network(args, cfg).cuda(args.gpuid)
criterion = MyCrossEntropy().cuda(args.gpuid)
optimizer = optim.SGD(net.parameters(), lr=cfg.PARA.train.lr, momentum=cfg.PARA.train.momentum)

5. 训练网络,得到损失值

def train(net,criterion,optimizer, train_loader, valid_loader, args, log, cfg):
    for epoch in range(cfg.PARA.train.epochs): #在一个epoch下,既进行训练,也进行验证
        net.train()
        train_loss = 0.0
        train_total = 0.0
        for i, data in enumerate(train_loader, 0):
            length = len(train_loader) #length = 47500 / batch_size
            inputs, labels = data
            inputs, labels = Variable(inputs.cuda(args.gpuid)), Variable(labels.cuda(args.gpuid))
            # pdb.set_trace()
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs,labels)
            # pdb.set_trace()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            if (i+1+epoch*length)%100==0:
                log.logger.info('[Epoch:%d, iter:%d] Loss: %.5f '
                            %(epoch+1, (i+1+epoch*length), train_loss/ (i+1)))
        with open(cfg.PARA.utils_paths.visual_path + args.net + '_train.txt', 'a') as f:
            f.write('epoch=%d,loss=%.5f\n' % (epoch + 1, train_loss / length))


        net.eval()
        valid_loss = 0.0
        valid_total = 0.0
        with torch.no_grad():  # 强制之后的内容不进行计算图的构建,不使用梯度反传
            for i, data in enumerate(valid_loader, 0):
                length = len(valid_loader)
                inputs, labels = data
                inputs, labels = Variable(inputs.cuda(args.gpuid)), Variable(labels.cuda(args.gpuid))
                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                valid_total += labels.size(0)
                # correct += (predicted == labels).sum()
                loss = criterion(outputs, labels)
                valid_loss += loss.item()
            log.logger.info('Validation | Loss: %.5f' % (valid_loss / length))
            with open(cfg.PARA.utils_paths.visual_path+args.net+'_valid.txt','a') as f:
                f.write('epoch=%d,loss=%.5f\n' %(epoch+1, valid_loss/length))

        '''save model's net & epoch to checkpoint'''
        log.logger.info('Save model to checkpoint ' )
        checkpoint = { 'net': net.state_dict(),'epoch':epoch}
        if not os.path.exists(cfg.PARA.utils_paths.checkpoint_path+args.net):os.makedirs(cfg.PARA.utils_paths.checkpoint_path+args.net)
        torch.save(checkpoint, cfg.PARA.utils_paths.checkpoint_path+args.net+'/'+str(epoch+1)+'ckpt.pth')


主函数

def main():
    args = parser()
    cfg = Config.fromfile(args.config)
    log = Logger(cfg.PARA.utils_paths.log_path+ args.net + '_trainlog.txt',level='info')
    start_epoch = 0

    log.logger.info('==> Preparing dataset <==')
    train_loader, valid_loader = DataLoad(cfg)

    log.logger.info('==> Loading model <==')
    if args.pretrain:
        log.logger.info('Loading Pretrain Data')

    net = get_network(args, cfg).cuda(args.gpuid)
    criterion = MyCrossEntropy().cuda(args.gpuid)
    optimizer = optim.SGD(net.parameters(), lr=cfg.PARA.train.lr, momentum=cfg.PARA.train.momentum)

    # net = resnet18().cuda(args.gpuid)
    log.logger.info('==> SUM NET Params <==')
    get_model_params(net,args,cfg)

    # if torch.cuda.device_count()>1:#DataParallel is based on Parameter server
    #     net = nn.DataParallel(net, device_ids=cfg.PARA.train.device_ids)
    torch.backends.cudnn.benchmark = True

    '''断点续训否'''
    if args.resume:
        log.logger.info('Resuming from checkpoint')
        checkpoint = torch.load(cfg.PARA.utils_paths.checkpoint_path+args.net+'/'+args.epoch + 'ckpt.pth')
        net.load_state_dict(checkpoint['net'])
        start_epoch = checkpoint['epoch']

    log.logger.info('==> Waiting Train <==')
    train(net=net,criterion=criterion,optimizer=optimizer,
          train_loader=train_loader,valid_loader=valid_loader,args=args,log=log,cfg=cfg)
    log.logger.info('==> Finish Train <==')

    log.logger.info('==> Plot Train_Vilid Loss & Save to Visual <==')
    plot_acc_loss(args, cfg=cfg)
    log.logger.info('*'*25)
    
if __name__ == '__main__':
    main()

说明:

主函数在后台运行,并指定net


参考文献

pytorch官网教程文档

最后感谢我的师兄,是他手把手教我搭建了整个项目,还有实验室一起学习的小伙伴~ 希望他们万事胜意,鹏程万里!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PytorchCNN项目搭建 6--- 训练、验证CNN 的相关文章

  • vue3 setup + ts + vite 项目问题解决:Cannot find module ... or its corresponding type declarations.(ts2307)

    昨日我尝试使用vue3 setup ts vite进行vue3项目的实现 遇到此问题 Cannot find module or its corresponding type declarations ts2307 文件报错类型以及ts官方
  • 转载:CCNP学习考试心得

    CCNP学习考试心得 当计算机屏幕上显示 Congralation时 我不禁长出一口气 心中想 终于考完了 我所说的终于考完是指 我终于完成了CCNP考试 四个月的学习 对于某些人来说可能太长了 但是要真正掌握ccnp的内容我感觉四个月还只
  • 手把手教你使用python发送邮件

    前言 发送电子邮件是个很常见的开发需求 平时如果有什么重要的信息怕错过 就可以发个邮件到邮箱来提醒自己 使用 Python 脚本发送邮件并不复杂 不过由于各家邮件的发送机制和安全策略不同 常常会因为一些配置问题造成发送失败 今天我们来举例讲
  • 混合模型简介与高斯混合模型

    高斯混合模型 混合模型概述 In statistics a mixture model is a probabilistic model for representing the presence of subpopulations wit
  • C++primer 阅读随记

    目 录 一 C 基础 1 变量和基本类型 2 字符串 向量和数组 3 表达式 4 语句 5 函数 6 类 二 C 标准库 1 IO库 2 顺序容器 3 泛型算法 4 关联容器 5 动态内存 三 类设计者的工具 1 拷贝控制 2 重载运算与类
  • 实施Microsoft Dynamics 365 CE-5. 配置Dynamics 365 CE组织,包括配置不同的Dynamics 365 CE设置。

    本章将帮助您了解Dynamics 365 CE中为个人和管理员提供的Dynamics 365配置选项 您将了解哪些选项可以为单个用户配置 哪些是管理员用户可以完成的配置 您将了解业务管理和服务管理设置下提供的不同配置选项 您还将了解Dyna
  • RobotFramework之高级API

    一 窗口跳转 跳转页面的时候需要获取句柄 Get Window Handles 获取窗口的句柄 Select Window By Handle 切换到新窗口 但是在seleniumLibrary中只有Select window 所以我们进入
  • Top K问题的两种解决思路

    Top K问题在数据分析中非常普遍的一个问题 在面试中也经常被问到 比如 从20亿个数字的文本中 找出最大的前100个 解决Top K问题有两种思路 最直观 小顶堆 大顶堆 gt 最小100个数 较高效 Quick Select算法 Lee

随机推荐

  • 自适应表格中input框输入文字布局被打乱

    我今天在写一个新增用户表单的时候 发现我只要输入文字 input框的高度就会改变 导致布局被打乱 这是正常排列好的样式 这是我输入中文后的样子 后来我发现输入中文后 input的高度被撑开了 我一开始没有给盒子设置固定的高度以及行高 设置完
  • C 语言基础-什么是常量、变量?

    C 语言基础 常量和变量 常量 只读 常量是只读的固定值 在程序运行期间不会改变 不能被程序修改的量 可以是任意类型 定义常量的方式有两种 使用 define 宏定义 使用 const 关键字 常量大体分为 直接常量 字面常量 符号常量 d
  • python练习61:打印出杨辉三角形,包含二维列表的应用

    打印出杨辉三角形 要求打印出10行如下图 yanghui for i in range 10 yanghui append 构造二维列表 for j in range i 1 if j 0 or j i yanghui i append 1
  • CCF-CSP真题-2022-06-1归一化处理讲解

    题目传送门 这是CCF CSP2022 06的第一题 相比较还是比较简单 较难理解的是方差 每个样本值与全体样本值的平均数之差的平方值的平均数方差 数学计算公式是这样的 然而 用代码来写要简洁得多 这里采用暴力的复杂度算法 for int
  • MySQL utf8mb4 字符集,用于存储emoji表情

    最近在做微信相关的项目 其中MySQL 要存储emoji表情 因此发现我们常用的utf8 字符集根本无法存储表情 网上有不少替代方案 本人还是采用了修改MySQL字符集的方案简单快捷 首先将我们数据库默认字符集由utf8 更改为utf8mb
  • Pandas分组与排序

    Grouping and Sorting 分组 agg 排序 经常需要将数据根据某个字段划分为不同的组 group 进行分析 然后对组里的数据进行特定的操作 pandas的 groupby 操作便是实现这一功能 groupby的过程就是将原
  • jquery的两种常用自动加载方法

    一 jquery JavaScript的三种常用自动加载方法 1 function jQuery 2 function 3 window nl ad function 加载的先后顺序 第一步 代码块1加载 是在css html 信息加载完毕
  • Scala环境配置完成,在命令行还是不能运行

    刚开始我以为是版本兼容的问题 所以下载了很多个版本 发现没用 我找了很久都不知道是什么原因 网上也没找到跟我一样的问题 偶然我发现是系统环境变量PATHEXT中缺少东西 在PATHEXT中添加 bat 然后就可以了
  • AIX系统安装

    1 选择安装介质 CD ROM 现有备份的安装系统 网络安装 Token Ring Ethernet FDDI U盘 服务器通电启动系统 在控制台显示器出现keyboard字符时 按对应的按钮 1 进入系统管理服务模式 SMS 2 指定控制
  • C语言中结构体初始化并清零的方法有几种?

    结构体初始化清零方法 在C语言中 结构体初始化并清零的方法有以下几种 手动赋值为0 结构体定义后在函数内手动将每个成员都赋值为0 例如 struct MyStruct int a char b float c struct MyStruct
  • vue页面基本组成

    作为编写过html的人 vue页面的基本组成是什么呢 如何快速入手vue呢 我来讲下自己的思路 简介 vue是一个前端框架 运行它需要下载node js 后台支撑 下载vs code 代码编辑器 来编辑代码 可配合eliment ui 上百
  • nodejs处理图片文件上传

    如果使用express框架的话 其内置模块就可以直接处理文件上传了 而不需要饱含额外的模块 express版本 3 4 4 1 使用bodyParser过滤器 并且指定上传的目录为public upload 注意这里的目录为相对于expre
  • PyQt5学习笔记--GridLayout、FormLayout和StackedLayout布局

    目录 1 GridLayout布局 2 FormLayout布局 3 StackedLayout布局 1 GridLayout布局 import sys from PyQt5 QtWidgets import class MyWindow
  • select、poll、epoll

    因为实际需要所致 我们不得不考虑在现有的开源 商用的应用服务器之外开发一个 有性能要求 有并发要求的服务端应用 从技术要求的角度来分析一下 用Java实现这件事情我们可能关注的知识层面 在整体上 可能需要我们从下面几个层面出发来考虑 1 在
  • windows多个不同java共存

    windows多个不同java共存 如图我电脑存在java1 8和15 使用时 我会存在工具支持的java版本不一样 有的工具要8才能使用有的工具需要11或者15以上java才能正常使用 于是为了方便快捷便写了这个多java版本共存 jav
  • 微服务SpringCloud

    什么是SpringCloud SpringCloud是由Spring提供的一套能够快速搭建微服务架构程序的框架集 SpringCloud本身不是一个框架 而是一系列框架的统称 SpringClound就是为了搭建微服务架构才出现的 有人将S
  • Linux如何查看系统时间

    文章目录 一 使用date命令查看系统时间 二 通过 var log syslog文件查看系统时间 三 通过 proc uptime文件查看系统运行时间 四 通过hwclock命令查看硬件时间 五 通过timedatectl命令设置系统时区
  • Python实现普通二叉树

    Python实现普通二叉树 二叉树是每个节点最多有两个子树的树结构 本文使用Python来实现普通的二叉树 关于二叉树的介绍 可以参考 https blog csdn net weixin 43790276 article details
  • MES系统是什么?MES系统的主要功能是什么?看完本文就知道

    MES系统是什么 MES系统是一套面向制造企业车间执行层的生产信息化管理系统 MES可以为企业提供包括制造数据管理 计划排程管理 生产调度管理 库存管理 质量管理 人力资源管理 工作中心 设备管理 工具工装管理 采购管理 成本管理 项目看板
  • PytorchCNN项目搭建 6--- 训练、验证CNN

    PytorchCNN项目搭建 6 训练 验证CNN 1 配置args cfg log 等 2 将数据集dataset经DataLoader变成dataloader 3 加载网络net 4 选择损失函数和优化器 5 训练网络 得到损失值 主函