mmdetection入门介绍-train.py解析

2023-11-19

四、train.py解析

同样,上面有单GPU测试和多GPU测试,其实上面的测试是由训练导致的。

单GPU训练

python tools/train.py ${CONFIG_FILE}

如果要在命令中指定工作目录,则可以添加参数–work_dir $ {YOUR_WORK_DIR}。如果没有指定的话就用的就是默认在config/**.py下的work_dir = './work_dirs/faster_rcnn_r50_fpn_1x_voc0712’下
参数解释:

  • –validate(强烈建议):在训练过程中,每隔k个时期(默认值是1,可以像这样修改)执行评估。
  • –work_dir $ {WORK_DIR}:覆盖配置文件中指定的工作目录。
  • –resume_from $ {CHECKPOINT_FILE}:从先前的检查点文件恢复。
  • –gpus:是指使用的GPU数量,默认值为1颗;–launcher:是指分布式训练的任务启动器(job launcher),默认值为none表示不进行分布式训练;

其中有几点需要说明的是:

–validate只支持多GPU训练,不支持单GPU训练,甚至包括后面会遇到的workflow = [(‘train’, 1)(‘val’,1)],即训练一次验证一次对单个GPU的场景也是不适用的;

resume_from和load_from之间的区别:resume_from同时加载模型权重和优化器状态,并且epoch也从指定的检查点继承。它通常用于恢复意外中断的训练过程。 load_from仅加载模型权重,并且训练时期从0开始。通常用于微调。

多GPU训练

./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]

我们可以看一下dist_train.sh的内容

#!/usr/bin/env bash

PYTHON=${PYTHON:-"python"}

CONFIG=$1
GPUS=$2

$PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}

其实还是调用train.py,不过这里配置了–launch来启动分布式训练。

这里需要知道一下,关于学习率的一个计算过程:
配置文件中的默认学习率是8个GPU和2 img/gpu(batch_size= 8 * 2 =16)。根据线性缩放规则,如果使用不同的GPU或img/gpu,则需要按照batch_size大小设置学习率,例如,对于4个GPU,lr = 0.01 * 2 img/gpu;对于16个GPU,lr = 0.08 * 4 img/gpu。

知道如何测试别人或者已下好的模型后,就可以转到训练模型,首先还是打开train.py的主要功能

def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    # 在图片输入尺度固定时开启,可以加速.一般都是关的,只有在固定尺度的网络如SSD512中才开启
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    # update configs according to CLI args
    if args.work_dir is not None:
        # 创建工作目录存放训练文件,如果不设置,会自动按照py配置文件生成对应的目录
        cfg.work_dir = args.work_dir
    if args.resume_from is not None:
        # 断点继续训练的权值文件
        cfg.resume_from = args.resume_from
    cfg.gpus = args.gpus

    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger before other steps
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    # 搭建模型
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    # 将训练配置传入
    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        datasets.append(build_dataset(cfg.data.val))
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmdet_version=__version__,
            config=cfg.text,
            CLASSES=datasets[0].CLASSES)  
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
    train_detector(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)

if __name__ == '__main__':
    main()

同样,还是参数的读取

def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--work_dir', help='the dir to save logs and models')
    parser.add_argument(
        '--resume_from', help='the checkpoint file to resume from')
    parser.add_argument(
        '--validate',
        action='store_true',
        help='whether to evaluate the checkpoint during training')
    parser.add_argument(
        '--gpus',
        type=int,
        default=1,
        help='number of gpus to use '
        '(only applicable to non-distributed training)')
    parser.add_argument('--seed', type=int, default=None, help='random seed')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument(
        '--autoscale-lr',
        action='store_true',
        help='automatically scale lr with the number of gpus')
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    return args

设置训练命令后,cfg就会读取相关的配置信息

    args = parse_args()
    cfg = Config.fromfile(args.config)

可以看到训练模型进行了一定的检查

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    # update configs according to CLI args
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    cfg.gpus = args.gpus

    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger before other steps
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

这里面我们其实可以看到学习率args.autoscale_lr的设置,这里也明确说了是linear scaling rule。

然后从配置文件中读取信息,设置模型和数据集

    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

    datasets = [build_dataset(cfg.data.train)]

设置验证集

    if len(cfg.workflow) == 2:
        datasets.append(build_dataset(cfg.data.val))

设置checkpoint

    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmdet_version=__version__,
            config=cfg.text,
            CLASSES=datasets[0].CLASSES)

设置模型信息

    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
    train_detector(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)

我们看一下这个train_detector函数的定义

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   logger=None):
    if logger is None:
        logger = get_root_logger(cfg.log_level)

    # start training
    if distributed:
        _dist_train(model, dataset, cfg, validate=validate)
    else:
        _non_dist_train(model, dataset, cfg, validate=validate)

可以看到,模型分分布式训练和非分布式训练

可以看到分布式训练配置

def _dist_train(model, dataset, cfg, validate=False):
    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    data_loaders = [
        build_dataloader(
            ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, dist=True)
        for ds in dataset
    ]
    # put model on gpus
    model = MMDistributedDataParallel(model.cuda())

    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)
    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                    cfg.log_level)

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config,
                                             **fp16_cfg)
    else:
        optimizer_config = DistOptimizerHook(**cfg.optimizer_config)

    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)
    runner.register_hook(DistSamplerSeedHook())
    # register eval hooks
    if validate:
        val_dataset_cfg = cfg.data.val
        eval_cfg = cfg.get('evaluation', {})
        if isinstance(model.module, RPN):
            # TODO: implement recall hooks for other datasets
            runner.register_hook(
                CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg))
        else:
            dataset_type = DATASETS.get(val_dataset_cfg.type)
            if issubclass(dataset_type, datasets.CocoDataset):
                runner.register_hook(
                    CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg))
            else:
                runner.register_hook(
                    DistEvalmAPHook(val_dataset_cfg, **eval_cfg))

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

非分布式训练

def _non_dist_train(model, dataset, cfg, validate=False):
    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.imgs_per_gpu,
            cfg.data.workers_per_gpu,
            cfg.gpus,
            dist=False) for ds in dataset
    ]
    # put model on gpus
    model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()

    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)
    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                    cfg.log_level)
    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(
            **cfg.optimizer_config, **fp16_cfg, distributed=False)
    else:
        optimizer_config = cfg.optimizer_config
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

我们可以看到,非分布式训练时没有validate的(这里有个想法,为什么非分布式训练没有加val?如果我把代码强行加进去会怎么样?)

其他参考:
这里是mmdetection入门介绍 前言 部分
这里是mmdetection入门介绍 test.py解析 部分
这里是mmdetection入门介绍 train.py解析 部分
这里是mmdetection入门介绍 模型解析 部分

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

mmdetection入门介绍-train.py解析 的相关文章

  • matlab中国官网下载,首页 - MATLAB中文论坛

    阅读来自MathWorks资深工程师的技术博客 I published Round With Ties to Even a couple of days ago Steve Eddins and Daniel Dolan immediate
  • Android Spider JDAX-GUI 反编译工具下载使用以及相关技术介绍

    文章目录 前言 一 JDAX下载 二 基本使用 2 1 解压zip 2 2 Java环境 2 3 进入Dos命令窗口启动Jdax Gui 2 4 正常使用 三 常见的反编译工具以及简单分析介绍 1 Android Killer 2 Dex2
  • qsort(),sort()排序函数

    一 qsort 函数 功 能 使用快速排序例程进行排序 头文件 stdlib h 用 法 void qsort void base int nelem int width int fcmp const void const void 参数
  • 面试题六道-2022-1-6

    CopyOnWriteArrayList的底层原理是怎样的 1 首先CopyOnWriteArraylist内部也是用过数组来实现的 在向CopyOnWriteArrayLlist添加元素时 会复制一个新的数组 写操作在新数组上进行 读操作
  • 尚筹网-前台-会员系统(springboot,springcloud 实战)

    总目标 环境搭建 会员登录注册 发起众筹项目 展示众筹项目 支持众筹项目 订单 支付 1 会员系统架构 1 1 架构图 1 2 需要创建的工程 父工程 聚合工程 shangcouwang01 member parent 唯一的pom工程 注
  • 强化学习 9 —— DQN 改进算法 DDQN、Dueling DQN 详解与tensorflow 2.0实现

    上篇文章强化学习 详解 DQN 算法介绍了 DQN 算法 但是 DQN 还存在一些问题 本篇文章介绍针对 DQN 的问题的改进算法 一 Double DQN 算法 1 算法介绍 DQN的问题有 目标 Q 值 Q Target 计算是否准确
  • VUE常用UI组件插件及框架-vue前端UI框架收集

    UI组件及框架 element 饿了么出品的Vue2的web UI工具套件 mint ui Vue 2的移动UI元素 iview 基于 Vuejs 的开源 UI 组件库 Keen UI 轻量级的基本UI组件合集 vue material 通
  • Java jar包启动及停止

    此处以SpringBoot maven工程为基础 基于Windows系统 工程开发好后 打包成jar包 一 启动jar包 在包所在的目录下运行cmd命令 执行命令 java jar jar包名 二 停止 1 用管理员打开cmd命令窗口 2
  • Java安全之SSL/TLS

    在前面所讲到的一些安全技术手段如 消息摘要 加解密算法 数字签名和数据证书等 一般都不会由开发者直接地去使用 而是经过了一定的封装 甚至形成了某些安全协议 再暴露出一定的接口来供开发者使用 因为直接使用这些安全手段 对开发者的学习成本太高
  • 初学树莓派——(六)树莓派安装OpenCV及USB摄像头配置

    目录 1 安装OpenCV 1 1前言 1 2换源及源内容更新 1 3安装依赖 1 4下载whl包 1 5安装OpenCV 1 6检查安装 2 USB摄像头配置 同时检查OpenCV安装情况 2 1前言 2 2Python调用cv2库来检查
  • sem_init函数用法

    sem init函数 sem init函数是Posix信号量操作中的函数 sem init 初始化一个定位在 sem 的匿名信号量 value 参数指定信号量的初始值 pshared 参数指明信号量是由进程内线程共享 还是由进程之间共享 如
  • 最优化算法概述以及常见分类

    1 最优化问题概述 通俗的来说 最优化问题就是在一定的条件约束下 使得效果最好 最优化问题是一种数学问题 是研究在给定的约束之下如何求得某些因素的量 来使得某一指标达到最优的学科 工程设计中最优化问题的一般说法是 选择一组参数 在满足一系列
  • 数据结构笔记(六)——散列(Hash Table)之散列函数(1)

    散列表 hash table 的实现叫做散列 hashing 这是以常数平均时间O 1 进行插入 删除和查找的技术 散列表没有顺序 需要元素间排序信息的操作 如findMin findMax不会得到有效支持 就是这东西不是这么用的 你可以实
  • RocketMq顺序发送消息

    错乱消息出现的原因 1 在RocketMq为啥消息不是按照顺序来的呢 首先您需要了解 队列是一个先进先出的一个数据的结构 生产者 您可以将topic理解为里面有一个一个的队列 你将一个消息发送到topic的时候 当前的消息不一定是往当前的这
  • win 10 搭建FTP服务,并使用的FTP进行传输文件(很详细)

    1 安装IIS工具 打开控制面板 点击 程序 点击 启用或关闭Windows功能 找到 internet information services 全部都选上 如下图 点击 确定 会出现以下页面 点击 关闭 即可 2 设置开机启动FTP服务
  • 高光谱图像中的Hughes(休斯)现象

    注解 在高光谱图像的分析中 随着参与运算波段数目的增加 分类精度 先增后降 的现象 场景 高光谱影像 由于维数的大幅度增加 在深度学习中 可以理解成模型提取的特征维数的增加 导致用于参数训练的所需样本数也急剧增加 如果样本数过少 那么估计出
  • Fiddler 详尽教程与抓取移动端数据包

    转载自 http blog csdn net qq 21445563 article details 51017605 阅读目录 1 Fiddler 抓包简介 1 字段说明 2 Statistics 请求的性能数据分析 3 Inspecto
  • C++面试题目集合(持续跟新)

    与我前面写的C语言进阶知识点遥相呼应 这才是C 面试 网上的面试题有些太简单了 C 面试题目最多集中在对象的内存模型 记住了 如果用c c 内存都不清楚 还写个屁的程序 1 C 的虚函数是怎样实现的 C 的虚函数使用了一个虚函数表来存放了每
  • 我的世界服务器物品不掉落指令是什么,我的世界死亡物品怎么不掉落 我的世界物品不掉落指令...

    我的世界死亡不了多指令是gamerulekeepInventorytrue 玩家们要注意我的世界死亡不掉落指令默认是关闭状态的哦 死亡不掉落指令在 我的世界 游戏里面就是当玩家们死亡以后仍然保留其物品栏中的所有物品 包括附魔死亡消失魔咒的物

随机推荐

  • 安装WSL + zsh & Pure (ZSH prompt) 美化【Windows11】

    文章目录 前言 WSL 安装 ZSH 安装ZSH Pure ZSH prompt 安装插件 下载插件 编辑配置文件 插件作用 啊 PS 如果在启动过程中提示 请启用虚拟机平台 windows 功能并确保在 bios 中启用虚拟化 前言 之前
  • 数据分析36计(28):Python 使用 Flask+Docker, 100行代码内实现机器学习实时预测​...

    本文的想法是快速轻松地构建 Docker 容器 Python 以使用 Flask 实现机器学习模型执行在线预测 API 我们将使用 Docker 和 Flask RESTful 实现线性判别分析和多层感知器神经网络模型的实时预测 项目包括的
  • Android中的自绘View的那些事儿(八)之 Paint的高级用法

    我们在 Android中的自绘View的那些事儿 一 中简单介绍过Paint和Canvas的一些常用方法和实例使用 其中 一句话提到Paint中有方法 setStrokeCap setStrokeJoin 和 setPathEffect 今
  • nodejs如何利用libuv实现事件循环和异步

    本文是根据之前在公司内部做的分享整理而成 是早期对nodejs的一个认识 源码版本10 x nodejs是什么 libuv的工作原理 nodejs的工作原理 nodejs如何使用libuv实现事件循环和异步 1 nodejs是什么 Node
  • pyinstaller打包最小体积安装python程序 命令行传参执行

    文章目录 创建虚拟环境 进入虚拟环境安装库 pycharm配置虚拟环境 pycharm 打开terminal进入虚拟环境 运行参数传入 sys argv 是获取运行python文件的时候命令行参数 且以list形式存储参数 打包后的文件运行
  • js记录密码出错次数并锁定账号30分钟

    下面要说的是网站中一个常见的功能 在客人使用抵用券或者其他来支付的时候需要验证密码 如果密码输入错误5次就锁定 不在让客人使用抵用券了 在这里是使用的cookie来实现的 不太严谨 思路很简单 在输入密码错误的时候 使用cookie保存2个
  • 基于vue项目的上拉刷新,下拉加载的效果

    使用插件 better scroll 安装使用教程http ustbhuangyi github io better scroll doc installation html npm 还是看官网比较好 子组件
  • 28_content 阶段的 index 模块

    文章目录 content 阶段的 index 模块 显示目录内容 content 阶段的 autoindex 模块 autoindex 模块的指令 index autoindex 示例配置 content 阶段的 index 模块 ngx
  • 6、基于STM32呼吸灯(PWM)

    之前定时器中有提到输入和输出比较部分 https blog csdn net qq 45764141 article details 125286260 参考有江科大自化协的视频和正电原子的视频 这个文章主要讲输出部分 文章目录 一 OC
  • 全面解析并实现逻辑回归(Python)

    本文以模型 学习目标 优化算法的角度解析逻辑回归 LR 模型 并以Python从头实现LR训练及预测 一 逻辑回归模型结构 逻辑回归是一种广义线性的分类模型且其模型结构可以视为单层的神经网络 由一层输入层 一层仅带有一个sigmoid激活函
  • MiniUI - 快速开发WebUI

    http www miniui com index html
  • GPT4来了?10秒钟做一个网站

    GPT4来了 10秒钟做一个网站 好了 我可以像雪容融一样躺平了 为什么雪容融都会wei gui 言归正传 3月15日 GPT4做一个网站只要十秒 登上热搜 根据视频中的演示 首先在草稿纸上画出一个基本的网页框架 图源视频截图 过了仅仅10
  • 【小沐学C++】C++ 常用命令行开发工具(Linux)

    文章目录 1 简介 2 gcc g 2 1 system 执行shell 命令 2 2 popen 建立管道I O 2 3 vfork exec 新建子进程 3 clang 3 1 下载和安装clang 3 2 clang和gcc比较 3
  • Blow Up 3macOS图片放大锐利的详细使用教程与安装方法

    软件介绍 Blow Up 3 macOS是一个Photoshop和Lightroom插件 亲测有效 适合于Adobe Photoshop CS6和Adobe Photoshop CC 2015或更高版本 Adobe Lightroom 6或
  • VsCode远程调试服务器python代码(解决相对路径相关问题)

    1 首先在本地使用VsCode调试python代码 可参考链接 VSCode启动Debug模式调试Python文件 2 vscode远程连接服务器 调试python文件 可参考链接 一文掌握vscode远程调试python代码 3 调试时
  • Google guava之Multiset简介说明

    转自 Google guava之Multiset简介说明 下文笔者讲述guava中Multiset集合的简介说明 如下所示 guava之Multiset集合简介 Multiset集合 可用于存储重复元素 Multiset是ArrayList
  • 一文1000字彻底搞懂Web测试与App测试的区别

    总结分享一些项目需要结合Web测试和App测试的工作经验给大家 从功能测试区分 Web测试与App测试在测试用例设计和测试流程上没什么区别 而两者的主要区别体现在如下几个方面 1 系统结构方面 Web项目 B S架构 基于浏览器的 Web测
  • Unity编辑器界面概述

    了解界面 如果您对编辑器界面没有非常地了解 那么请花一些时间查看并熟悉 Editor 编辑器 界面 Editor 主窗口由选项卡式窗口组成 这些窗口可重新排列 因此 Editor 的外观可能因项目或者开发者而异 具体取决于个人偏好 Wind
  • GitHub博客搭建

    git官网文档 https git scm com book zh v2 E6 9C 8D E5 8A A1 E5 99 A8 E4 B8 8A E7 9A 84 Git E7 94 9F E6 88 90 SSH E5 85 AC E9
  • mmdetection入门介绍-train.py解析

    四 train py解析 同样 上面有单GPU测试和多GPU测试 其实上面的测试是由训练导致的 单GPU训练 python tools train py CONFIG FILE 如果要在命令中指定工作目录 则可以添加参数 work dir