Pytorch训练流程

2023-05-16

调试了很久YOLO的标准Pytorch实现,将Train代码进行了重新整理,抽出了标准的Pytorch训练框架。现在整理下来,以备后用。整理后的代码分为三个大部分,每个大部分分为三个小部分:

1、初始化(Init):训练之前先分别创建Model、Dataset&Dataloader、Optimizer;

2、轮次内部(Epoch):分别进行:Dataloader遍历训练、Save模型(间隔)、Eval模型(间隔);

3、训练(Train):其实隶属于Epoch中的Dataloader遍历,最核心的训练步骤:Forward、Backward、Optimize参数;

官方YOLO的Pytorch训练代码整理以后,再简化之后就是下面这样。

其中一些小地方需要注意,例如:在模型进行训练之前,一定要调成训练模式,评估时要调成评估模式,以固定BN层和Dropout层的参数。优化器在定义时要指定需要优化的模型参数。封装输入图像和标签时,标签不需要梯度。优化器使用之后需要清零。

其他注意事项:按照惯例,一些项目上的设定参数都是需要通过argparse传入工程的,为了项目的清晰,我把全部的工程参数设定放到了"__main__"部分,核心的训练部分做为一个独立的函数存在于文件中,这样的安排可以增加代码的可读性,方便整理。

def Quan_train(opt, logger):
    ### Init Step 1: Create Model
    model, device, start_epoch = create_model(opt)

    ### Init Step 2: Create Dataset
    dataloader, train_path, valid_path, class_names = create_dataset(opt)

    ### Init Step 3: Create Optimizer
    optimizer = torch.optim.Adam(model.parameters())

    # Epoch
    for epoch in range(start_epoch, opt.epochs):
        # Set model in train.
        model.train()

        ### Epoch Step 1: Train
        for batch_i, (_, imgs, targets) in enumerate(dataloader):
            batches_done = len(dataloader) * epoch + batch_i

            # Load input and target
            imgs = Variable(imgs.to(device))
            targets = Variable(targets.to(device), requires_grad=False)

            ### Train Step 1: Forward pass, get loss
            loss, outputs = model(imgs, targets)

            ### Train Step 2: Backward pass, get gradient
            loss.backward()

            ### Train Step 3: Optimize params
            if batches_done % opt.gradient_accumulations:  # Accumulates gradient before each step
                optimizer.step()
                optimizer.zero_grad()

        ### Epoch Step 2: Save
        if epoch % opt.checkpoint_interval == 0:
            torch.save(model.state_dict(), f"checkpoints/yolov3-tiny_quan_ckpt_%d.pth" % epoch)

        ### Epoch Step 3: Eval
        if epoch % opt.evaluation_interval == 0:
            print("\n---- Evaluating Model ----")
            # Evaluate the model on the validation set
            precision, recall, AP, f1, ap_class, IoU_total = evaluate()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Super-Params
    parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
    parser.add_argument("--batch_size", type=int, default=64, help="size of each image batch")
    parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
    parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
    parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")
    # ......Other Params
    opt = parser.parse_args()

    # Set Logger
    logger = Logger("logs")

    # Set env GPU
    os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)

    # Train
    Quan_train(opt, logger)

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch训练流程 的相关文章

随机推荐

  • Python虚拟环境——virtualenv

    林野哥推荐的虚拟环境 xff0c 这个跟Conda虚拟环境有点像 xff0c 但是和conda最大的区别就是virtualenv会创建一个单独的文件夹存放python环境 xff0c 感觉隔离程度更高 使用方法如下 xff1a 1 安装vi
  • 洛桑联邦理工 TPAMI-2008 MTMC 概率占用图POM建模过程推导 笔记

    一切都要从2019年9月的那个秋天讲起 xff0c 林野哥向我推荐了这篇洛桑联邦理工的2008年TPAMI论文 xff0c 于是一个半月的时间都花在了这上面 Multi Camera People Tracking with a Proba
  • 知识图谱笔记(小象学院课程)

    2018年寒假看小象学院课程的时候写的笔记 xff0c 一共写了10页 xff0c 记得比较乱 因为纸质笔记不容易保存 xff0c 所以把它扫成了PDF以备后用 希望大家能够指出不足和错误
  • 隐马尔可夫模型HHM重要公式推导

    我终于把HMM看完了 xff0c 这些笔记都是看的过程中自己对推导过程写的注释 xff0c 还有知识框架 原来卡尔曼和粒子滤波都是属于HMM模型里面的 笔记结构如下 xff1a 1 HMM简介 xff1a 知识体系 43 一个模型 43 两
  • MOT指标笔记《CLEAR Metrics-MOTA&MOTP》2008年·卡尔斯鲁厄大学

    搞了这么久的MOT xff0c 到头来发现最基本的MOTA和MOTP还没有搞懂 xff0c 实在有点说不过去 今天花了一上午的时间阅读2008年卡尔斯鲁厄大学的 Evaluating Multiple Object Tracking Per
  • 概率图模型-知识结构

    两周多 xff0c 终于把概率图模型这一章看完了 xff0c 由于只是看了知识框架 xff0c 很多具体细节都还不理解 内容真的是好多啊 xff0c 而且都是理论 xff0c 没有实践 希望日后用到的时候能回忆的起来这些内容吧
  • 软件工程概论-课后作业1

    需要网站系统开发需要掌握的技术 1 网页设计 xff1a Photoshop Flash max Dreamweaver 2 网站程序 xff1a Dreamweaver Visual Studio NET 会asp asp net php
  • 《强化学习》——CH2 多臂赌博机 笔记

  • 相机几何学——投影矩阵P的构成(实验报告版)

    最近在可视化WildTrack数据集 xff0c 由于要对棋盘格点进行映射和绘制 xff0c 涉及到了P矩阵的计算 现在对P的来源进行了系统的整理 xff0c 以备后忘 在最后对场地端点映射产生的问题进行了讨论 xff08 事情开始变得有意
  • 约束优化方法_2_——Frank-Wolfe方法

    Frank Wolfe方法属于约束优化中可行方向法的一种 上一篇博文对同类型的Zoutendijk可行性方法进行了介绍 xff0c 这一部分着重关注Frank Wolfe方法 Frank Wolfe方法的基本思想是 xff1a 每次迭代中使
  • 二次规划_1_——Lagrange方法

    二次规化是非线性规化中的一种特殊情形 xff0c 其目标函数是二次实函数 xff0c 约束是线性的 考试中会考到四种方法 xff0c 分别为 xff1a Lagrange方法 起作用集方法 直接消去法和广义消去法 前两种在教材上有详细描述
  • 二次规划_2_——起作用集方法

    这个算法很反人类 xff0c 迭代过程相当复杂 xff0c 最优化老师说 xff1a 明确地告诉你要考的 起作用集方法适用于消元法和Lagrange方法无法处理的不等式约束二次规化问题 其主要思想是 xff1a 以已知点为可行点 xff0c
  • 约束非线性优化:几何意义&对偶形式

    借助老师的PPT对约束非线性优化问题的几何意义 和对偶形式 进行阐述 一 几何意义 xff08 1 xff09 等式约束 考虑只有等式约束h x 的非线性优化问题 xff0c 形式为 xff1a 可视化结果如下图所示 xff0c 红色曲线为
  • 转载篇:优秀博文汇总

    1 Pytorch中堆网络语法 xff1a nn moduleList 和Sequential由来 用法和实例 写网络模型 https blog csdn net e01528 article details 84397174 2 CNN中
  • 批量归一化:Batch Normalization层 原理+Pytorch实现

    一 BN层概念明晰 BN层存在的意义 xff1a 让较深的神经网络的训练变得更加容易 BN层的工作内容 xff1a 利用小批量数据样本的均值和标准差 xff0c 不断调整神经网络中间输出 xff0c 使整个神经网络在各层的中间输出的数值更加
  • 模型量化——基础知识 & LSQ论文阅读

    感谢陈老师给的机会 xff0c 有幸能够参加2020年的DAC比赛 我在其中负责目标检测器的调试和量化 自己第一次接触量化这个任务 xff0c 很多东西都是一点一点学 一 量化基础 对于一个全精度的值 v v v xff0c 若量化步长为
  • python3安装tensorflow遇到的问题

    1 使用命令 xff1a sudo pip3 install upgrade https storage googleapis com tensorflow linux cpu tensorflow 1 1 0rc2 cp35 cp35m
  • argparse模块使用说明

    深度学习的工程中 xff0c py文件中的main函数一开始总会有大量的参数传入 xff0c 而通常使用的方法就是使用argparse通过命令行传入 xff0c 这篇博文旨在记录argparse的常用方法 一 语句简介 1 载入argpar
  • Tensorboard在网络训练中的Logger使用方法

    做为神经网络训练中最常用的可视化工具 xff0c Tensorboard经常在Pytorch和Tensorflow框架中做为可视化工具 但是其使用的确是有点繁琐 xff0c 所以开设了一个这样的专题 xff0c 用于总结见过的工程里面的数据
  • Pytorch训练流程

    调试了很久YOLO的标准Pytorch实现 xff0c 将Train代码进行了重新整理 xff0c 抽出了标准的Pytorch训练框架 现在整理下来 xff0c 以备后用 整理后的代码分为三个大部分 xff0c 每个大部分分为三个小部分 x