mmsegmentation之tools/train.py文件解析(部分,持续更新)

2023-05-16

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import os
import os.path as osp
import time
import warnings

import mmcv
import torch
import torch.distributed as dist
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import Config, DictAction, get_git_hash

from mmseg import __version__
from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.utils import collect_env, get_root_logger, setup_multi_processes


def parse_args():
    parser = argparse.ArgumentParser(description='Train a segmentor')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--work-dir', help='the dir to save logs and models')
    parser.add_argument(
        '--load-from', help='the checkpoint file to load weights from')
    parser.add_argument(
        '--resume-from', help='the checkpoint file to resume from')
    parser.add_argument(
        '--no-validate',
        action='store_true',
        help='whether not to evaluate the checkpoint during training')
    group_gpus = parser.add_mutually_exclusive_group()
    group_gpus.add_argument(
        '--gpus',
        type=int,
        help='(Deprecated, please use --gpu-id) number of gpus to use '
        '(only applicable to non-distributed training)')
    group_gpus.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='(Deprecated, please use --gpu-id) ids of gpus to use '
        '(only applicable to non-distributed training)')
    group_gpus.add_argument(
        '--gpu-id',
        type=int,
        default=0,
        help='id of gpu to use '
        '(only applicable to non-distributed training)')
    parser.add_argument('--seed', type=int, default=None, help='random seed')
    parser.add_argument(
        '--diff_seed',
        action='store_true',
        help='Whether or not set different seeds for different ranks')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
    parser.add_argument(
        '--options',
        nargs='+',
        action=DictAction,
        help="--options is deprecated in favor of --cfg_options' and it will "
        'not be supported in version v0.22.0. Override some settings in the '
        'used config, the key-value pair in xxx=yyy format will be merged '
        'into config file. If the value to be overwritten is a list, it '
        'should be like key="[a,b]" or key=a,b It also allows nested '
        'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
        'marks are necessary and that no white space is allowed.')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument(
        '--auto-resume',
        action='store_true',
        help='resume from the latest checkpoint automatically.')
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    if args.options and args.cfg_options:
        raise ValueError(
            '--options and --cfg-options cannot be both '
            'specified, --options is deprecated in favor of --cfg-options. '
            '--options will not be supported in version v0.22.0.')
    if args.options:
        warnings.warn('--options is deprecated in favor of --cfg-options. '
                      '--options will not be supported in version v0.22.0.')
        args.cfg_options = args.options

    return args


def main():
    args = parse_args() # 解析命令行参数

    cfg = Config.fromfile(args.config) # 读入cfg参数
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.load_from is not None:
        cfg.load_from = args.load_from
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.gpus is not None:
        cfg.gpu_ids = range(1)
        warnings.warn('`--gpus` is deprecated because we only support '
                      'single GPU mode in non-distributed training. '
                      'Use `gpus=1` now.')
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids[0:1]
        warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
                      'Because we only support single GPU mode in '
                      'non-distributed training. Use the first GPU '
                      'in `gpu_ids` now.')
    if args.gpus is None and args.gpu_ids is None:
        cfg.gpu_ids = [args.gpu_id]

    cfg.auto_resume = args.auto_resume

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)
        # gpu_ids is used to calculate iter when resuming checkpoint
        _, world_size = get_dist_info()
        cfg.gpu_ids = range(world_size)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) # 检测work_dir,若不存在创建work_dir
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) # dump cfg参数
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) # time记录
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log') # 创建log文件
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) #

    # set multi-process settings
    setup_multi_processes(cfg)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env() # 收集环境信息
    env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info

    # log some basic info
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config:\n{cfg.pretty_text}')

    # set random seeds
    seed = init_random_seed(args.seed)
    seed = seed + dist.get_rank() if args.diff_seed else seed
    logger.info(f'Set random seed to {seed}, '
                f'deterministic: {args.deterministic}')
    set_random_seed(seed, deterministic=args.deterministic) # 设置随机种子
    cfg.seed = seed
    meta['seed'] = seed
    meta['exp_name'] = osp.basename(args.config)

    model = build_segmentor(
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg'))
    model.init_weights()

    # SyncBN is not support for DP
    if not distributed:
        warnings.warn(
            'SyncBN is only supported with DDP. To be compatible with DP, '
            'we convert SyncBN to BN. Please use dist_train.sh which can '
            'avoid this error.')
        model = revert_sync_batchnorm(model)

    logger.info(model)

    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmseg version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
            config=cfg.pretty_text,
            CLASSES=datasets[0].CLASSES,
            PALETTE=datasets[0].PALETTE)
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
    # passing checkpoint meta for saving best checkpoint
    meta.update(cfg.checkpoint_config.meta)
    train_segmentor(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=(not args.no_validate),
        timestamp=timestamp,
        meta=meta)


if __name__ == '__main__':
    main()

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

mmsegmentation之tools/train.py文件解析(部分,持续更新) 的相关文章

  • Clion-安装

    Clion安装 1 注册Jetbraions账号 https www jetbrains com 2 学生认证使用 baipiao 一年 https www jetbrains com shop eform students 3 下载Min
  • Tools --vscode配置verilog环境:语法检测,自动补全,生成testbench

    vscode配置verilog环境 语法检测 自动补全 生成testbench 语法高亮 关键词 数字等 支持Verilog和System Verilog 支持自动补全 包括关键词和定义的变量等 语法纠错 Vivado逻辑仿真 xvlog
  • 69个网盘搜索引擎资源(最全)

    呵呵 今天博主今天整理了一个下午 把网上的能找到69个网盘搜索引擎都放在这了 希望能帮到有需要的小伙伴 1 盘多多 http www panduoduo net 2 Bdsola http www 3134 c 3 探索云盘搜索 http
  • XNA是什么?

    Software will be the single most important force in digital entertainment over the next decade XNA underscores Microsoft
  • Hexagon GDB Debugger介绍(47)

    Hexagon GDB Debugger介绍 47 4 5 2 8 Python 中的命令 4 5 2 9 Writing new convenience functions 4 5 2 8 Python 中的命令 新的调试器命令可以在Py
  • 最好用的五个黑科技搜索引擎推荐

    一 数据搜 http data chongbuluo com 数据搜 这个网站就是搜索一些热词和数据指数的 包括百度指数 阿里指数 微博指数 微信指数 搜狗指数等等 当然 还有一些汽车数据 腾讯大数据 票房数据相关数据查询网站 估计很多人经
  • UBOOT命令总结(转)

    UBOOT命令总结 转 很好的UBOOT命令总结 我在起步时就是看的这篇东西 熟悉了以后就用 看自带帮助就行 Printenv 打印环境变量 Uboot gt printenvbaudrate 115200ipaddr 192 168 1
  • qtp的基本使用方法(1)

    1 action qtp为每一个action生成相应的测试文件和目录 对象库也是和action绑定的 用action 来划分和组织测试流程 编辑action 修改action的名字 action properties 增加action in
  • MD5,SHA1,SHA256,NTLM,LM等Hash在线破解网站收集

    MD5 http hashchecker de find html http paste2 org p 441222 http r0ot podzemlje net x md5 http hashkiller com index php a
  • [笔记] Unikernel原型:Docker镜像秒变虚拟机镜像以及无ssh开terminal调试

    现如今docker run各种流行 可是虚拟机并没有消退 有没有办法把镜像部署成一个虚拟机来个暂时的转换呢 有没有好办法在没有ssh的container上可以开个terminal进行调试呢 其实很简单 docker save就可以把整个镜像
  • 主机与VMware的Linux虚拟机之间共享交换文件

    搭建环境 主机系统 Windows7 Ultimate VM软件 VMware Workstation 7 1 3 虚拟机系统 Linux Ubuntu 10 10 操作步骤 1 在主机上新建一个共享路径 用于将来和虚拟机之间进行共享文件
  • 找不到本地组策略编辑器找不到gpedit.msc 的解决方法

    通常打开本地组策略编辑器 只需要 win R 在运行里输入 gpedit msc 就可以打开 但是 在windows 家庭版和学生版里 会提示找不到路径 可以用以下办法解决 新建一个文本文档 名字随便取 编辑以下批处理内容 保存后将后缀名改
  • git clone 下载 其他分支

    总是记不住 可能是因为用得少 如果 已经 clone了 master分支 方法 1 那么 本地 git pull 然后执行 git checkout b 本地分支名 origin 远程分支名 这样就能下载 到远程分支 并建立本地关联 方法2
  • Linux安装Phabricator

    背景 公司管理需要 为开发团队寻找一款代码审查工具 最终选择了Phabricator ubuntu14 04 LTS 安装 使用官网上提供的install sh即可 参考链接 https secure phabricator com div
  • MMSegmention官方文档阅读系列之三(MMSegmentation 算法库目录结构、了解配置文件信息)

    1 MMSegmentation 算法库目录结构的主要部分 1 mmsegmentation configs 配置文件 base 基配置文件 datasets 数据集相关配置文件 models 模型相关配置文件 schedules 训练日程
  • VS Code-此时不应有 &

    描述 用vs code运行python代码时 报 此时不应有 解决方式 ctrl 调出终端 将默认终端设置成powershell 退出 重新加载代码
  • 什么是思维导图?6 个开源免费的思维导图软件

    目录 15款思维导图工具推荐 什么是思维导图 6 个开源免费的思维导图软件 当前推荐 Freeplane 离线应用 有免安装版本 跨平台 目前 2023年 还在更新中 下载 https sourceforge net projects fr
  • QCC300x笔记(3) -- QCC3007开发调试经验

    哈喽大家好 这是该系列博文的第三篇 篇 lt lt 系列博文索引 快速通道 gt gt 写在前面 这篇博客主要记录 在使用QCC300x平台中所遇到的问题以及解决方法 会不定时更新 1 使用的堆栈空间大小超出或者全局变量超出 会报以下错误
  • token保活设计.md

    如果我们要使用token机制用以标识用户登录状态 以获得请求相关资源接口的权限 让你来设计一套方案 以为怎么设计呢 通常有两种思路 1 使用refreshtoken获取新的accesstoken 登录成功之后 返回一个返回refreshto
  • AES加密及解密

    public class AesUtil static Security addProvider new BouncyCastleProvider private static final String ALGORITHM AES ECB

随机推荐