mmdetection用mmclassification的backbone

2023-11-02

接上篇[1],现需要将 backbone 换成 DeiT-tiny[2,3]

MMDetection[4] 不直接支持 DeiT(backbones/ 下没有),但 MMClassification 有实现。参考 [6,7],可以直接在 MMDetection 中调用 MMClassification 的模型。

由于 DeiT 是 transformer 结构,而 MMDetection 直接支持的 DETR[8,9]也是,考虑基于它的配置文件[10]来改。

Configuration

配置文件的介绍见 [11-16]。对照已有的配置文件mmdetection/mmdet/models/ 的类定义来看,配置文件中 model 里配置的域,对应相应模型类的构造函数参数,所以替换 DeiT-tiny 做 backbone 时 model/backbone 要写哪些项,是看 deit.py 及其父类的构造函数有什么参数。

由 [10],它引用的数据集相关的配置文件是 coco_detection.py,类似 [1] 中,两个配置文件分别修改:

scannet_detection.py

  • 修数据集类集的方法见 [17,18],要改 / 加 classesdata/train/dataset/classesdata/val/classesdata/test/classes。类集见 [19]。
# Inherited from: mmdetection/configs/_base_/datasets/coco_detection.py
# fit to ScanNet-frames-25k

dataset_type = 'CocoDataset'
classes = (
    "wall", "floor", "cabinet", "bed", "chair",
    "sofa", "table", "door", "window", "bookshelf",
    "picture", "counter", "desk", "curtain", "refrigerator",
    "shower curtain", "toilet", "sink", "bathtub", "otherfurniture"
)
data_root = 'data/scannet-frames/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'scannet_objdet_train.json',
        img_prefix=data_root + 'train/',
        pipeline=train_pipeline,
        classes=classes),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'scannet_objdet_val.json',
        img_prefix=data_root + 'val/',
        pipeline=test_pipeline,
        classes=classes),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'scannet_objdet_val.json',
        img_prefix=data_root + 'val/',
        pipeline=test_pipeline,
        classes=classes))
evaluation = dict(interval=1, metric='bbox')

detr_deit_tiny_8x1_150e_scannet.py

  • model/backbone
  • model/bbox_head
    • num_classes 参考 [17,18],类数也是数自 [19];
    • in_channels 即 backbone 最后一层输出的特征维度,见 arch_zoo 中的 deit-tiny/embed_dims
    • transformer/encoderdecoder/transformerlayers/attn_cfgs/embed_dimsffn_cfgs/embed_dimsBaseTransformerLayer 中的 __init__/ffn_cfgs/embed_dims)也要相应改成跟 in_channels 同维;
    • positional_encoding/num_feats 要恰好是 in_channels 的一半。
  • data/samples_per_gpu 改成 1,即每块卡 batch_size = 1,否则我这会爆显存。按 [11] 的命名规则,配置文件名中改成 8x1
# Inherited from mmdetection/configs/detr/detr_r50_8x2_150e_coco.py
## Modifications
# 1. use DeiT-tiny as backbone
# 2. use ScanNet-frames-25k

_base_ = [
    '../_base_/datasets/scannet_detection.py',
    '../../mmdetection/configs/_base_/default_runtime.py'
]
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False)
model = dict(
    type='DETR',
    backbone=dict(
        # _delete_=True, # Delete the backbone field in _base_
        # from: mmclassification/configs/deit/deit-tiny_pt-4xb256_in1k.py
        type='mmcls.VisionTransformer',
        arch='deit-tiny',
        img_size=224,
        patch_size=16,
        with_cls_token=False,
        output_cls_token=False,
        out_indices=-1,
        # norm_cfg=dict(type='BN', requires_grad=False),
        # norm_eval=True,
        # style='pytorch',
        init_cfg=dict(
            type='Pretrained',
            # from: mmclassification/configs/deit/README.md -> DeiT-tiny
            checkpoint='https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth',
            prefix='backbone.',
        ),
    ),
    bbox_head=dict(
        type='DETRHead',
        num_classes=20,  # from: convert-scannet-coco-objdet.py
        # from: mmclassification/mmcls/models/backbones/vision_transformer.py
        #   -> arch_zoo["deit-tiny"]["embed_dims"]
        in_channels=192,
        transformer=dict(
            type='Transformer',
            encoder=dict(
                type='DetrTransformerEncoder',
                num_layers=6,
                transformerlayers=dict(
                    type='BaseTransformerLayer',
                    attn_cfgs=[
                        dict(
                            type='MultiheadAttention',
                            embed_dims=192,#256,
                            num_heads=8,
                            dropout=0.1)
                    ],
                    feedforward_channels=2048,
                    ffn_dropout=0.1,
                    # from: mmcv/mmcv/cnn/bricks/transformer.py
                    #    -> BaseTransformerLayer/__init__/ffn_cfgs
                    ffn_cfgs=dict(
                        embed_dims=192,
                        # feedforward_channels=2048,
                        # ffn_drop=0.1,
                    ),
                    operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
            decoder=dict(
                type='DetrTransformerDecoder',
                return_intermediate=True,
                num_layers=6,
                transformerlayers=dict(
                    type='DetrTransformerDecoderLayer',
                    attn_cfgs=dict(
                        type='MultiheadAttention',
                        embed_dims=192,#256,
                        num_heads=8,
                        dropout=0.1),
                    feedforward_channels=2048,
                    ffn_dropout=0.1,
                    # from: mmcv/mmcv/cnn/bricks/transformer.py
                    #    -> BaseTransformerLayer/__init__/ffn_cfgs
                    ffn_cfgs=dict(
                        embed_dims=192,
                        # feedforward_channels=2048,
                        # ffn_drop=0.1,
                    ),
                    operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
                                     'ffn', 'norm')),
            )),
        positional_encoding=dict(
            # type='SinePositionalEncoding', num_feats=128, normalize=True),
            type='SinePositionalEncoding', num_feats=int(192 // 2), normalize=True),
        loss_cls=dict(
            type='CrossEntropyLoss',
            bg_cls_weight=0.1,
            use_sigmoid=False,
            loss_weight=1.0,
            class_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=5.0),
        loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
    # training and testing settings
    train_cfg=dict(
        assigner=dict(
            type='HungarianAssigner',
            cls_cost=dict(type='ClassificationCost', weight=1.),
            reg_cost=dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
            iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0))),
    test_cfg=dict(max_per_img=100))
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(
        type='AutoAugment',
        policies=[[
            dict(
                type='Resize',
                img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
                           (608, 1333), (640, 1333), (672, 1333), (704, 1333),
                           (736, 1333), (768, 1333), (800, 1333)],
                multiscale_mode='value',
                keep_ratio=True)
        ], [
            dict(
                type='Resize',
                img_scale=[(400, 1333), (500, 1333), (600, 1333)],
                multiscale_mode='value',
                keep_ratio=True),
            dict(
                type='RandomCrop',
                crop_type='absolute_range',
                crop_size=(384, 600),
                allow_negative_crop=True),
            dict(
                type='Resize',
                img_scale=[(480, 1333), (512, 1333), (544, 1333),
                            (576, 1333), (608, 1333), (640, 1333),
                            (672, 1333), (704, 1333), (736, 1333),
                            (768, 1333), (800, 1333)],
                multiscale_mode='value',
                override=True,
                keep_ratio=True)
        ]]),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=1),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
# test_pipeline, NOTE the Pad's size_divisor is different from the default
# setting (size_divisor=32). While there is little effect on the performance
# whether we use the default setting or use size_divisor=1.
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=1),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]
data = dict(
    samples_per_gpu=1,#2,
    workers_per_gpu=2,
    train=dict(pipeline=train_pipeline),
    val=dict(pipeline=test_pipeline),
    test=dict(pipeline=test_pipeline))
# optimizer
optimizer = dict(
    type='AdamW',
    lr=0.0001,
    weight_decay=0.0001,
    paramwise_cfg=dict(
        custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[100])
runner = dict(type='EpochBasedRunner', max_epochs=150)

Training

代码结构类似 [1],这里只展示必要的部分:

my-project/
|- mmdetection/
|- configs/
|  |- _base_/
|  |  `- datasets/
|  |     `- mstrain_3x_scannet.py
|  `- detr/
|     `- detr_deit_tiny_8x1_150e_scannet.py
`- scripts/
   |- find_gpu.sh
   `- train-scannet-frames.sh

其中,训练脚本:

#!/bin/bash
# train-scannet-frames.sh
clear

# run `conda activate openmmlab` first

config=configs/detr/detr_deit_tiny_8x2_150e_scannet.py

. scripts/find_gpu.sh -1 14787

echo begin: $(date) > scripts/RUN-`basename $0`.txt

PATH=/usr/local/cuda/bin:$PATH \
PYTHONPATH=mmdetection/mmdet:$PYTHONPATH \
CUDA_VISIBLE_DEVICES=${gpu_id} \
MMDET_DATASETS=`pwd`/data/scannet-frames/ \
bash mmdetection/tools/dist_train.sh \
    $config ${n_gpu_found}
# python mmdetection/tools/train.py \
#     $config

echo end: $(date) >> scripts/RUN-`basename $0`.txt

References

  1. MMDetection在ScanNet上训练
  2. (ICLR 2021) Training data-efficient image transformers & distillation through attention
  3. facebookresearch/deit
  4. open-mmlab/mmdetection
  5. open-mmlab/mmclassification
  6. How to change the model of mmclassification to mmdetection? #7761
  7. Use backbone network implemented in MMClassification / Use backbone network implemented in MMClassification
  8. (ECCV 2020) End-to-End Object Detection with Transformers - paper, supplementary
  9. facebookresearch/detr
  10. open-mmlab/mmdetection/configs/detr/detr_r50_8x2_150e_coco.py
  11. Tutorial 1: Learn about Configs
  12. Config
  13. MMDetection框架入门教程(二):快速上手教程
  14. MMDetection框架入门教程(三):配置文件详细解析
  15. 【MMDetection-学习记录】config配置文件说明
  16. mmdetection的config配置文件参数介绍
  17. AssertionError: The num_classes (3) in Shared2FCBBoxHead of MMDataParallel does not matches the length of CLASSES 80) in CocoDataset #4828
  18. Prepare a config / Prepare a config
  19. ScanNet/BenchmarkScripts/convert2panoptic.py
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

mmdetection用mmclassification的backbone 的相关文章

  • 1380. 矩阵中的幸运数

    class Solution public vector
  • oracle聚合函数

    1 COUNT 计算元组的个数 2 COUNT DISTINCT ALL col 对一列中的值计算个数 distinct去重复 缺省时是ALL 3 SUM DISTINCT ALL lt 列名 gt 求某一列值的总和 数值型 4 AVG D

随机推荐

  • 知道创宇研发技能列表v3.0

    Expand Collapse 知道创宇研发技能表v3 0 2015 8 21 发布 by 知道创宇 www knownsec com 余弦 404团队 后续动态请关注微信公众号 Lazy Thought 说明 关于知道创宇 知行合一 守正
  • go语言入门详细教程

    文章目录 一 前言 1 Go语言的创始人 2 go语言的发展 3 go语言优缺点 4 使用go语言的项目 5 学习go语言可以做什么 一 前言 1 Go语言的创始人 Go 语言的创始人是 Robert Griesemer Rob Pike
  • 全球第二大成人网站,也要“自宫”了。。

    兄弟们 一直以全球第二大成人网站自居的O站 全称 OnlyFans 可能又要搞事情了 众所周知 这个O站一直都是一个有梦想的成人网站 他们的目标从来都不只是单纯的做大做强 它一直都没有放弃过 想要上市的 梦想 只不过吧 成人网站想要上市 这
  • 调试cube生成的f107+lan8720代码

    之前用的w5500 无奈芯片越来越贵了 正好手头上有100来颗lan8720a 直接将方案改了吧 以前在深圳工作时公司的网关正好用的这个方案 直接抄吧 硬件设计网口无晶振 由mcu的mco脚输出 50Mhz模式 其他都是通用连接方式 接下来
  • ubuntu设置ssh登陆

    默认请况下 ubuntu是不允许远程登陆的 因为服务没有开 可以这么理解 想要用ssh登陆的话 要在需要登陆的系统上启动服务 即 安装ssh的服务器端 sudo apt get install openssh server 然后 启动服务
  • graphviz安装及使用、决策树生成

    一 graphviz下载安装 下载网址 http www graphviz org download 选择合适版本下载 1 1 双击安装 1 2 点击下一步 1 3 点击我接受 1 4 添加至系统路径 勾选添加至当前用户的系统路径 创建桌面
  • 诛仙服务器获取角色信息失败,架设诛仙提示游戏服务器正在维护中

    架设诛仙提示游戏服务器正在维护中 内容精选 换一换 一 系统信息相关命令本节内容主要是为了方便通过远程终端维护服务器时 查看服务器上当前 系统日期和时间 磁盘空间占用情况 程序执行情况本小结学习的终端命令基本都是查询命令 通过这些命令对系统
  • Interactive Image Segmentation

    FocalClick Towards Practical Interactive Image Segmentation 阿里巴巴 CVPR2022 Interactive segmentation allows users to extra
  • 大学物理绝不挂科期末考试复习

    大学物理 第一章走近物理 第二章 质点运动学 三角形法则 矢量平移不变性 V V0 at X V0t 1 2a t 2 变速运动 积分 建立自然坐标系比较好 微分积分 你
  • sobol灵敏度分析matlab_sobol全局灵敏性分析

    最近在研究全局敏感性分析方法中的 Sobol 方法 看了一些国内的论文 发现一个通病 就是公 式一挂就可以得出结果了 真心觉得这种论文很 恶心 主要原因是自己看不太懂 直到在维基百 科上面找到了这种方法的详细解释 今天我们就根据网页上的步骤
  • Sqli-Labs Less1-16关介绍

    Sqli Labs Less1 16关介绍 一 Http 请求方法 Get 对比 Post Get传输方式 Less1 10 Less1 4 Union Select注入 Less5 6 报错型注入 Less 7 写入数据 闭合符 Less
  • 设计模式学习(五):State状态模式

    目录 一 什么是State模式 二 State模式示例程序 2 1 伪代码 2 1 1 不使用State模式的伪代码 2 1 2 使用State模式的伪代码 2 2 各个类之间的关系 2 3 State接口 2 4 DayState类 2
  • 一文带你弄懂 JVM 三色标记算法

    最近和一个朋友聊天 他问了我 JVM 的三色标记算法 我脑袋一愣发现竟然完全不知道 于是我带着疑问去网上看了几天的资料 终于搞清楚啥事三色标记算法 它是用来干嘛的 以及它和 CMS 回收器和 G1 回收器的关系了 今天 就让树哥带着大家一起
  • npm 无法将“npm”项识别为 cmdlet、函数、脚本文件或可运行程序的名称

    1 问题描述 在vscode运行命令 npm run dev报错 2 分析解决 问题原因 npm环境变量配置问题 在cmd窗口输出node 回车后弹出信息node不是内部或外部命令 也不是可运行的程序 这时候就是环境变量配置的问题 方法一
  • 线性表——顺序表(含代码)

    线性是一种逻辑结构 表示元素与元素之间一对一的相邻关系 顺序表和链表是指存储结构 本文首先讨论的是顺序表 要构造顺序表首先要了解其结构 顺序表用一组连续地址一次存放线性表中元素 使得逻辑上相邻的元素物理上也相邻 顺序表使用数组来描述顺序存储
  • vue中使用element-tiptap

    安装 npm install save element tiptap或者yarn add element tiptap main js文件引入 全局引入 引入element tiptap import ElementTiptapPlugin
  • 在 Compose 中使用 Koin 进行依赖注入

    The pragmatic Kotlin Kotlin Multiplatform Dependency Injection framework 实用的Kotlin和Kotlin多平台依赖注入框架 Android Studio环境为 And
  • 浮动IP(FLOAT IP)

    主要谈一谈关于浮动IP的东西 介绍下浮动IP是什么 1 为什么要有浮动IP这个东西 现在有一个场景 在一台Linux上部署一个web应用 应用跑在tomcat里面 linux网卡上的ip是115 239 100 120 大致就是如下的部署关
  • 狂学数据库之关系模式的设计问题及数据的函数依赖

    关系模式的设计问题及数据的函数依赖 一 关系模式的设计问题 1 1 数据依赖 1 2 数据依赖对关系模式的影响 二 数据的函数依赖 2 1 函数依赖 2 1 1 函数依赖的定义 2 1 2 函数依赖的3种基本情形 2 2 函数依赖和码 关键
  • mmdetection用mmclassification的backbone

    接上篇 1 现需要将 backbone 换成 DeiT tiny 2 3 MMDetection 4 不直接支持 DeiT backbones 下没有 但 MMClassification 有实现 参考 6 7 可以直接在 MMDetect