如何把“自己的”网络中的conv2d替换为dcnv2

2023-10-26

1、dcnv2的实现测试了两种,一种是官方版dcnv2,git链接:https://github.com/CharlesShang/DCNv2.git,编译直接cd到DCNv2,然后./make.sh即可,第二种是mmcv.ops.modulated_deform_conv.ModulatedDeformConv2dPack

2、实验以resnet50、mnist训练为例

from torch_op.common import init_seed
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.models.resnet import resnet50
from torch import nn,optim
import torch

with_dcn = True
if with_dcn:
    #两种dcn导入方法
    # from dcn_v2 import DCN
    from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2dPack as DCN

init_seed(0)

batch_size =64
num_workers = 16
epochs = 20
lr = 0.1

train_dataset = MNIST('torch_op/mnist/data',train=True,download=True,
                      transform=transforms.Compose([
                          transforms.ToTensor(),
                          # transforms.Normalize((0.1307,),(0.3081))
                      ])
                      )

test_dataset = MNIST('torch_op/mnist/data',train=False,download=True,
                      transform=transforms.Compose([
                          transforms.ToTensor(),
                          # transforms.Normalize((0.1307,),(0.3081))
                      ])
                      )

train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers)

#替换原网络中每个bottleneck中第二个conv为dcn
def replace(layers):
    for name,module in layers.named_children():
        if isinstance(module,torch.nn.Conv2d):
            if 'conv2' in name:
                new_module = DCN(module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding,
                                 module.dilation, 1)
                layers.add_module(name,new_module)
        else:
            replace(module)


model = resnet50(False,num_classes=10)

#需要将conv替换成dcn的module列表
if with_dcn:
    replace_list = ['layer2','layer3','layer4']
    for name,module in model.named_children():
        if name not in replace_list:
            continue
        # print(module)
        replace(module)
        # print('*'*20)
        # print(module)
        # print('-'*20)

print(model)
model = model.cuda()

optimizer = optim.SGD(model.parameters(),lr=lr)
criterion = nn.CrossEntropyLoss()

def train():
    model.train()
    for batch_idx,(data,target) in enumerate(train_loader):
        if data.size(1)==1:
            data = torch.cat([data,data,data],dim=1)
        # print('start_train')
        data = data.cuda()
        target = target.cuda()

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output,target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('loss:{}'.format(loss.item()))

def test(e):
    model.eval()
    total_right = 0
    total_wrong = 0
    with torch.no_grad():
        for batch_idx,(data,target) in enumerate(test_loader):
            if data.size(1) == 1:
                data = torch.cat([data, data, data], dim=1)
            data = data.cuda()
            target = target.cuda()
            output = model(data)
            output = output.max(dim=1)[1]
            right = torch.sum(output==target).item()
            wrong = len(data) - right
            total_right += right
            total_wrong += wrong
    print('epoch-{} acc:{}/{}={}'.format(e,total_right,total_right+total_wrong,total_right/(total_wrong+total_right)))


for e in range(epochs):
    train()
    test(e)

3、总结

当把layer2、3、4的conv全部替换为dcn时,训练特别慢,而且acc一直很低,大概只有10%左右

 

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何把“自己的”网络中的conv2d替换为dcnv2 的相关文章

随机推荐

  • ubuntu22.04安装podman及cockpit并在WEB中管理容器

    目录 前言 一 准备工具 二 安装步骤 1 更新系统到最新版本 2 使用以下命令安装podman 3 使用以下命令安装cockpit及相关插件 三 启动服务 四 登录管理界面 五 使用podman容器管理 1 创建容器 2 管理容器 六 总
  • sqli-labs————Less-33

    Less 33 查看源代码
  • QProcess处理带管道的shell

    代码中需要调用shell 原写法为 QProcess proc new QProcess QString qCmd find name so print0 xargs 0 objdump x grep oE T 0 9 a f A F 4
  • 护网

    在HVV期间 蓝队主要就是通过安全设备看告警信息 后续进行分析研判得出结论及处置建议 在此期间要注意以下内容 内网攻击告警需格外谨慎 可能是进行内网渗透 1 攻击IP是内网IP 攻击行为不定 主要包括 扫描探测行为 爆破行为 命令执行等漏扫
  • 笑脸工具COORD批量转换2000大地到空间坐标

    数据格式txt 1 31 48 14 118687N 119 38 07 130943E 2 32 3 19 06731100008N 119 31 20 422269001200302E 3 31 50 31 89348499992000
  • 变频调速系统c语言编程,基于8098单片机的SPWM变频调速系统

    数字控制的交流调速系统所选用的微处理器 功率器件及产生PWM波的方法是影响交流调速系统性能好坏的直接因素 在介绍了正弦脉宽调制 SPWM 技术的基础上 设计了一种以8098单片机作为控制器 以智能功率模块IPM为开关器件的变频调速系统 通过
  • 小样本学习(Few-shot Learning)综述

    作者丨耿瑞莹 李永彬 黎槟华 单位丨阿里巴巴智能服务事业部小蜜北京团队 分类非常常见 但如果每个类只有几个标注样本 怎么办呢 笔者所在的阿里巴巴小蜜北京团队就面临这个挑战 我们打造了一个智能对话开发平台 Dialog Studio 以赋能第
  • [Flutter]封装了个Toast组件

    Flutter官方插件市场上已经有了很多成熟的Toast组件 如 fluttertoast 等等 使用了一年多的Flutter框架 一时兴起 自己封装了一个简单的Toast组件 注 本人觉得 自动关闭的时候 不宜使用 Navigator p
  • 西门子PLC S7-1200的硬件中断组织块简介

    西门子PLC S7 1200系列是一款中小型西门子PLC 可以在各种自动化项目中进行应用 S7 1200系列设计较为紧凑 经济性较好 而且指令功能较为强大 因此在各种自动化控制解决方案中有较广泛的应用 作为西门子PLC S7 200系列的升
  • [1218]hive之Map Join使用方法

    文章目录 介绍 mapjoin的使用方法 介绍 MAPJION会把小表全部加载到内存中 在map阶段直接拿另外一个表的数据和内存中表数据做匹配 由于在map端是进行了join操作 省去了reduce运行的时间 算是hive中的一种优化 如上
  • 开放原子训练营(第三季)inBuilder低代码开发实验室之探秘

    一 活动介绍 以开放原子训练营为主办方的inBuilder低代码实验室活动现已开启 参与者无论身居计算机业界 偏好低代码开发抑或是普通用户 均可在社区版inBuilder低代码开发平台 一款基于UBML开源项目的广泛适用的发行版 中尝试向导
  • ECMAScript2020 可选链操作符(?.)的应用

    一 前言 const programmer user lin department name 技术部 getSite return 在以前的语法中 想要获得深层次的属性或方法 如果不做前置校验的话 那么就很容易出现这种错误 这可能会导致你整
  • MFC 之 重绘按键Cbutton

    上次我们学习了如何美化对话框的界面 这次我们为上次的对话框添加两个按钮 一个是关闭按钮 另一个是最小化按钮 好 现在我们先看一下效果 是不是很难看 因为我们的对话框美化了 所以我们的按钮也要美化 因为采用贴图的方式来美化 所以 我先给出这两
  • 笔试面试算法经典--矩阵的最短路径和(Java)

    题目 给定一个矩阵m 从左上角开始每次只能向右或者向下走 最后到达右下角的位置 路径上所有的数字累加起来就是路径和 返回所有路径中最小的路径和 例子 给定m如下 1 3 5 9 8 1 3 4 5 0 6 1 8 8 4 0 路径1 3 1
  • 信号去噪 - 基于SVD实现数字信号降噪含Matlab源码

    信号去噪 基于SVD实现数字信号降噪含Matlab源码 介绍 信号处理中的一个重要问题是如何降噪 这在各种应用领域中都有非常重要的作用 奇异值分解 SVD 是一种广泛使用的信号处理技术 可以用于有效地降低信号噪声 本文将介绍如何使用SVD进
  • Elasticsearch 安装及启动【Windows】

    一 下载 Elasticsearch 官网下载地址 https www elastic co cn downloads past releases elasticsearch 选择自己所需版本进行下载 这里以Elasticsearch 8
  • 【操作系统】王道考研 p64-66 IO软件层次结构、IO核心子系统、假脱机技术(SPOOLing技术)

    IO软件层次结构 IO核心子系统 假脱机技术 SPOOLing技术 以下是IO软件层次结构的内容 知识总览 用户层软件 实现了与用户交互的接口 将用户的请求翻译为格式化的IO请求 并通过 系统调用 请求操作系统内核的服务 设备独立性软件 又
  • PyQt5 QTableWidget内容复制功能

    为了更快速的将QTableWidget的内容复制到剪贴板 只需重写这个控件的keyPressEvent event 废话不多说 直接上代码 复制功能 def keyPressEvent self event Ctrl C复制表格内容 if
  • 大语言模型浅探一

    目录 1 前言 2 GPT模型解码 3 InstructGPT 4 基于RWKV微调模型 4 1 RWKV简介 4 2 增量预训练 4 3 SFT微调 4 4 RM和PPO 5 测试 6 总结 1 前言 近来 人工智能异常火热 ChatGP
  • 如何把“自己的”网络中的conv2d替换为dcnv2

    1 dcnv2的实现测试了两种 一种是官方版dcnv2 git链接 https github com CharlesShang DCNv2 git 编译直接cd到DCNv2 然后 make sh即可 第二种是mmcv ops modulat