实战:利用Pytorch复现Deep Residual Learning for Image Recognition中的 Resnet34

2023-10-27

残差网络Resnet:Deep Residual Learning for Image Recognition 论文阅读笔记

  • 在学习了resnet的论文之后,为了便于理解,变想要复现文论中34层的resnet模型。

即以下这张图:

在这里插入图片描述
实际上,Resnet和VGGnet的骨干网络相差无几,只是深度上更胜后者,并且增加了恒等映射(identity mapping)。

在复现VGG16的时候,采用的是逐层搭建,显然在此处并不可行,一共34层,每层需要正则化和激活函数,工作量相当大。

首先我们先将34层分为3部分

  • 1 初始部分
  • 2-33 主体部分
  • 34 全连接部分

而主题部分根据通道的改变可以分为4个部分:

  • 2-7 64
  • 8-15 128
  • 16-27 256
  • 27-33 512

将每两层设置为Bottleneck:

class BottleNeck(nn.Module):
    def __init__(self,in_chanels,out_chanels,stride=1,downsample=False):
        super(BottleNeck,self).__init__()

        self.conv1=nn.Conv2d(in_chanels,out_chanels,kernel_size=3,padding=1,stride=stride)
        self.BN=nn.BatchNorm2d(out_chanels)
        self.ReLu=nn.ReLU(inplace=True)
        self.conv2=nn.Conv2d(out_chanels,out_chanels,kernel_size=3,padding=1,stride=1)
        self.downsample = downsample
        self.wi = nn.Sequential(
            nn.Conv2d(in_chanels,out_chanels,kernel_size=1,padding=0,stride=stride),
            nn.BatchNorm2d(out_chanels)
        )

    def forward(self,x):

        identiey = x
        out = self.conv1(x)
        out = self.BN(out)
        out = self.ReLu(out)

        out = self.conv2(out)
        out = self.BN(out)

        if self.downsample == True:
            identiey = self.wi(x)

        out = out + identiey
        out = self.ReLu(out)

        return out

需要注意的是当Bottleneck层的输入和输出不符,需要对idenetity mapping 进行转换(等同于shortcut ),因此设置的参数downsample,看是否需要转换。

代码

import torch
from torch import nn

class BottleNeck(nn.Module):
    def __init__(self,in_chanels,out_chanels,stride=1,downsample=False):
        super(BottleNeck,self).__init__()

        self.conv1=nn.Conv2d(in_chanels,out_chanels,kernel_size=3,padding=1,stride=stride)
        self.BN=nn.BatchNorm2d(out_chanels)
        self.ReLu=nn.ReLU(inplace=True)
        self.conv2=nn.Conv2d(out_chanels,out_chanels,kernel_size=3,padding=1,stride=1)
        self.downsample = downsample
        self.wi = nn.Sequential(
            nn.Conv2d(in_chanels,out_chanels,kernel_size=1,padding=0,stride=stride),
            nn.BatchNorm2d(out_chanels)
        )

    def forward(self,x):

        identiey = x
        out = self.conv1(x)
        out = self.BN(out)
        out = self.ReLu(out)

        out = self.conv2(out)
        out = self.BN(out)

        if self.downsample == True:
            identiey = self.wi(x)

        out = out + identiey
        out = self.ReLu(out)

        return out

class ResNet34(nn.Module):
    def __init__(self,num_classes):
        super(ResNet34,self).__init__()
        # 最开始的7x7卷积部分和最大池化,即第一层卷积
        self.start = nn.Sequential(
            # 使用7*7的卷积核,卷积步长为2,使得维度/2,padding长度根据公式计算得为3
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            # 使用3*3得卷积核,卷积步长为2,使得维度/2,padding长度根据公式计算得为1
            nn.MaxPool2d(kernel_size=3, stride=2,padding=1)
        )
        # 2-33层卷积
        self.layers = nn.Sequential(
            self._make_layer(64,64,False,3),
            self._make_layer(64,128,True,4),
            self._make_layer(128,256,True,6),
            self._make_layer(256,512,True,3)
        )
        # 第34层全连接
        self.fc = nn.Sequential(
            # 自适应平均池化层,输出大小(1,1),将起变成1*1*512
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(start_dim=1,end_dim=-1),
            nn.Linear(512, num_classes)
        )

    def forward(self,x):
        return self.fc(self.layers(self.start(x)))

    def _make_layer(self,in_chanels,out_chanels,downsample,num_blocks):
        layes = []
        layes.append(BottleNeck(in_chanels,out_chanels,downsample=downsample))
        for _ in range(1,num_blocks):
            layes.append(BottleNeck(out_chanels,out_chanels))
        return nn.Sequential(*layes)

if __name__ =='__main__':
    inputs = torch.rand((8,3,224,224)).cpu()
    model = ResNet34(num_classes=1000).cpu().train()
    outputs = model(inputs)
    print(outputs.shape)

  • 考虑到其他部分(train,test)都与之前复现的VGG16类似,就不再重复
  • 相较于VGG的串行,Resnet更加考验代码的复用性,如何合理设置模块能够减少码量,事实上官方版本为我们展现了良好代码构建能力。

代码参考

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

实战:利用Pytorch复现Deep Residual Learning for Image Recognition中的 Resnet34 的相关文章

随机推荐

  • websocket 发送ping_那些很重要,但是不常用的技术,websocket

    1 为什么会有websocket 2 websocket协议格式 3 协议具体实现 一 为什么需要 WebSocket 初次接触 WebSocket 的人 都会问同样的问题 我们已经有了 HTTP 协议 为什么还需要另一个协议 它能带来什么
  • RocketMQ 部署不当导致磁盘空间不释放

    背景 生产环境采用 RocketMQ 三主三从集群搭建 6 个实例部署在 3 台 Linux 服务器上 节省资源 每台服务器部署一主一从 生产上运行一段时间后 发现磁盘空间报警 发现df与du显示的空间不一致 相差几十G 问题原因 Rock
  • logback异步日志,支持滚动策略

    logback properties error日志保存路径 LOG ERROR HOME logs error info日志保存路径 LOG INFO HOME logs info 最长保存天数 MAX HISTORY 7 日志文件最大
  • 《Openwrt开发》第一章:newifi3 刷自己编译的Openwrt固件

    最近在淘宝入手了一个二手的newifi3 主要是因为它内存大 而且性价比相当高 512M的ddr2和32M的flash买下来才100左右 好了 废话不多说 开始第一章的源码编译征程 1 准备 源码编译宿主机 ubuntu14 04 64位
  • 三个维度看全球半导体格局变迁

    来源 世纪证券 费城半导体指数 SOX 的发展阶段反应了全球半导体的走势与兴衰更替 费半指数涵盖全球半导体设计 设备 制造 材料等方向 其走势可以是衡量全球半导体行业景气程度的主要指标 费城半导体指数发行于在 1993 年12 月 1 日
  • Python基础_如何搭建起一个PyWeb项目(入门篇)

    一 介绍 本文介绍如何从零开始利用pyCharm搭建起一个可用的web项目 基于pychram2020 2版本 二 步骤 1 在开发前我们需要为py工具设置一个python的编译环境 通过 file gt settings gt proje
  • jmeter线程组 bzm - Arrivals Thread Group & 阶梯式压测

    简介 BZM Arrivals Thread Group是jmeter的一个插件 它可以模拟并发到达的用户流量 按时间加压 可以有效地帮助测试人员评估系统在高压力和高并发情况下的性能表现 插件下载地址 jmeter版本不低于 5 2 0 h
  • Mysql-JDBC配置LoadBalance协议

    Mysql JDBC长期以来提供了有效的手段在MySql集群 多主Replication部署的情况下分发读写负载 自从mysql jdbc 5 1 3以来 你可以在不停用服务的情况下动态配置loadBalance连接 进程中的事务不丢失 实
  • 箭头函数(=>)和普通函数(function)的区别

    JavaScript中箭头函数 gt 和普通函数function的区别 2021前端高频面试题 转载自 作者 阮一峰 ECMAScript6 入门和博客园 一 区别 1 箭头函数与普通函数写法不同 箭头函数 var声明变量时 var fn
  • 自学成材的黑客很多,但还是得掌握方法,给你黑客入门与进阶建议

    建议一 黑客七个等级 仅供参考 黑客 对很多人来说充满诱惑力 很多人可以发现这门领域如同任何一门领域 越深入越敬畏 知识如海洋 黑客也存在一些等级 参考知道创宇 CEO ic 世界顶级黑客团队 0x557 成员 的分享如下 Level 1
  • 时间序列分类算法_时间序列分类算法简介

    时间序列分类算法 A common task for time series machine learning is classification Given a set of time series with class labels c
  • 使用ICE建立C++与C#的通讯

    使用ICE建立C 与C 的通讯 版权 三夏健 https www cnblogs com liwei81730 archive 2012 08 21 2649476 html ICE的优势是作为通讯中间件可支持跨平台的通讯 目前支持C C
  • 基于深度学习Seq2Seq框架的技术总结

    随着互联网经济的普及定位技术的快速发展 人们在日常生活中产生了大量的轨迹数据 例如出租车的GPS数据 快递配送员PDA产生的轨迹数据等 轨迹数据是一种典型的时空数据 Spatial Temporal Data 是按照时间顺序索引且空间变化的
  • gitleb+hexo部署搭建博客

    当你想发布自己的想法 或者学习内容时 这个时候可能你的选择就是在各大平台发布 比如说 简书 csdn 掘金等一些公开的平台 但是这样你的数据就是属于别人了 如果有一天那个平台关闭了 那不是你的多年记录的内容都没有了 可想而知你当时的心情是多
  • WIN10环境下配置hadoop+spark并运行实例的教程

    WIN10环境下配置 hadoop spark 并运行开发实例的教程 前期准备 基本环境配置 虚拟机的安装 配置虚拟机中的静态网络 关闭并禁用防火墙 配置主机名 编辑host文件 使用ssh传输文件 SSH免密配置 解压文件 配置文件 配置
  • imx6ul:uboot-2013.10启动过程解析

    1 源码结构分析 首先一个问题 老版本的u boot是没有SPL这个文件的 新版u boot开始包含SPL文件 原来u boot启动比如放到nand中 在cpu内部有一个stepping stone 可以拷贝nand中的u boot到ram
  • python. 创建虚拟环境 conda_python使用conda创建和管理python虚拟环境

    一 背景 前期使用过程中发现使用python3自带的venv创建虚拟环境时 无法指定python版本 也许可以 但我没找到方法 所以打算利用第三方的工具conda来管理python环境 二 Miniconda安装 本文主要是介绍环境管理相关
  • SAP FI 系列 (026) - 增值税的配置

    产品的销售 原料的采购 都要与增值税打交道 SAP 系统对于不同国家的销售和购置税 都提供了基于国家的计税程序 Tax Procedure 项目实施的时候 只需要选择预置的税码或者新增税码 针对这些税码配置记账的会计科目即可 税码包括的最重
  • 2023最新「阿里」Java 高级工程师面试高频题:JVM+Redis+ 并发 + 算法 + 框架

    前言 面对今年的大环境而言 跳槽成功的难度比往年高了很多 很明显的感受就是 对于今年的 java 开发朋友跳槽面试 无论一面还是二面 都开始考验一个 Java 程序员的技术功底和基础 对源码解读和核心原理理解也是成了加分项 特别是对 Jav
  • 实战:利用Pytorch复现Deep Residual Learning for Image Recognition中的 Resnet34

    残差网络Resnet Deep Residual Learning for Image Recognition 论文阅读笔记 在学习了resnet的论文之后 为了便于理解 变想要复现文论中34层的resnet模型 即以下这张图 实际上 Re