ResNet网络加入CBAM注意力机制

2023-11-03

首先定义resnet_cbam.py:如图

import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo


__all__ = ['ResNet', 'resnet18_cbam', 'resnet34_cbam', 'resnet50_cbam', 'resnet101_cbam',
           'resnet152_cbam']


model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)

        self.ca = ChannelAttention(planes * 4)
        self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def resnet18_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model


def resnet34_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet34'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model


def resnet50_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model


def resnet101_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet101'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model


def resnet152_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet152'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model

然后你需要什么就引用什么,如图:

from resnet_cbam import resnet50_cbam 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

ResNet网络加入CBAM注意力机制 的相关文章

随机推荐

  • 在ubuntu下安装vscode

    ubuntu22 04下通过命令安装vscode 1 为什么不用应用市场直接下载 最近下载ubuntu22 04版本 不知道为啥里面的应用软件下载不了vscode 尝试在网上解决 gt 卸载自带的应用市场 安装另外的一种 结果失败了 导致原
  • 【路由指令】

    一 linux route add net 192 0 0 0 netmask 255 0 0 0 gw 192 180 30 1 sudo route add net 192 180 0 0 netmask 255 255 0 0 gw
  • python爬取今日头条后台数据_爬虫爬取今日头条数据代码实现

    课程链接 讲师的公众号文章 今日头条数据抓取及持久化 完整代码版 含IP和用户代理 mp weixin qq com 课程代码 抓取并持久化user agent工具utils py 对于爬虫工具 需要设置发起请求的user agent im
  • Spring Boot 框架基础

    Spring Boot 框架基础 基础案例 pom xml
  • BUCK LX_OUT Snubber电路

    1 问题 开关节点振铃 过冲 开关节点过冲会导致LX OUT管脚的电压过高 如果超过datasheet上的maximum值 就有可能影响DCDC芯片寿命 2 产生振铃 过冲的原因 2 1 输入电容摆放不正确 2 2 输出电感 电容摆放不正确
  • STM32F103滴答计时器之delay函数

    如果使用FreeRTOS void delay us u32 nus u32 ticks u32 told tnow tcnt 0 u32 reload SysTick gt LOAD ticks nus fac us tcnt 0 del
  • k8s删除deployment_k8s灾备指南(Velero)

    最近验证了使用velero对k8s进行灾难恢复 操作验证步骤如下 1 下载verlero 解压 tar xvf
  • java案例15:模拟订单号生成

    思路 模拟订单号生成 超市购物时 小票上都会有一个订单号 且订单号唯一 编写程序模拟订单系统中订单号的生成 生成订单号时 使用年月日和毫秒值组合生成唯一订单号 例如 给一个包括年月日和毫秒值的数组arr 2023 0401 1100 将其拼
  • git解决代码冲突、合并代码

    共同开发时提交代码会遇到代码冲突 第一次遇到就手足无措的我 打算写一篇博客记录下来 下次遇到稳如老狗 一 远程代码已有更新记录 忘记拉取远程代码 直接提交 单人开发时 我没有先拉远程代码再提交的习惯 千万不要学习 一定要先拉代码再提交 导致
  • 关于odoo条码显示问题处理

    这里分几种情况 1 第一种情况 打印的单据不显示条码 这种情况比较常见 一般是 没有对应字体导致 不能正常显示条码 单据打印的条码 一片空白 无条码的情况 这种情况是因为 条码的字体没有安装 需要安装一下 这里我会把资源上传 大家可以下载
  • Mac texlive+texstudio 如何手动安装宏包

    1 从CTAN上搜索自己需要的package 2 以subfigure为例 选择第一个结果 下载下来 3 解压下载后的文件 这时候发现里面并没有 sty文件 在终端中输入latex subfigure即可生成需要的sty文件 将整个文件夹保
  • nuxt.js-------koa2项目,环境错误一次性解决

    nuxt js虽然好用但是自己的脚手架安装完全是坑 cnpm run dev 报错确实main js node环境nuxt版本不匹配 在网上找了很多解决方法没有解决 就一次性把所有脚手架和环境都升级到最新版本 npm install bac
  • Nginx 负载均衡 - fair

    学习在 Nginx 中使用 fair 模块 第三方 来实现负载均衡 fair 采用的不是内建负载均衡使用的轮换的均衡算法 而是可以根据页面大小 响应时间智能的进行负载均衡 1 准备工作 nginx upstream fair 官方下载地址
  • 使用Nuxt.js框架开发(SSR)服务端渲染项目

    SSR 服务端渲染的优缺点 优点 1 前端耗时少 首屏加载速度快 因为后端拼接完了html 浏览器只需要直接渲染出来 2 有利于SEO 因为在后端有完整的html页面 所以爬虫更容易爬取获得信息 更有利于seo 3 无需占用客户端资源 即解
  • flutter 文字从左到右轮播滚动,跑马灯

    参考 description begin 跑马灯 根据滚动方向可以分为 横向滚动 和 纵向滚动 此页实现的跑马灯是 横向滚动的 description end import dart async import package flutter
  • Shell—变量、字符串和数组

    本文主要讲解Shell变量 字符串和数组的相关知识 一 Shell变量 1 变量的定义 运行shell时 会同时存在三种变量 1 局部变量 局部变量在脚本或命令中定义 仅在当前shell实例中有效 其他shell启动的程序不能访问局部变量
  • 区块链+教育:区块链是底层技术,教育才是本质!

    建国君民 教育为先 国愚是智可以强国 国智则力可以强人 依教建国 以智强国 是中国古代先贤就教育强国思想的重要体现 每次的教育改革也牵动着无数人的心 也有很多人关心未来的教育会走向哪里 百年大计 教育为先 随着人们对教育行业的关注度逐渐提升
  • HECO使用docker部署单节点的开发网

    文章目录 一 编写说明 1 1 文档说明 1 2 配置信息 1 3 部署文档信息 二 heco开发网镜像生成 三 heco主链容器生成 3 1 配置文件编写 3 2 预先创建一个账户地址 3 3 创建genesis json 3 4 初始化
  • Android加密 看雪,Android加密与解密入门两题

    写在最前面 本次题目来自看雪2w班9月题 密码学一直是安全的基础 Android安全也不例外 这次9月份的题分别从java层和C层考察了密码学中常用的对称加密 hash函数以及一些基础的编码 但是不是单纯的算法分析题 可以说是很好的练习题了
  • ResNet网络加入CBAM注意力机制

    首先定义resnet cbam py 如图 import torch import torch nn as nn import math import torch utils model zoo as model zoo all ResNe