亚像素卷积网络(ESPCN)学习与Pytorch复现

2023-05-16

论文内容 

论文地址:Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network | IEEE Conference Publication | IEEE Xplore

或者:[1609.05158] Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network (arxiv.org)

ESPCN是2016年提出的,是一篇经典的超分辨率重建算法文章,虽然它的效果和现在的文章相比不算好,但是它所提出的Efficient Sub-pixel Convolution,也叫亚像素卷积/子像素卷积为后面网络PSNR的提升做出了很大贡献,关键这个Sub-pixel Convolution比插值,反卷积,反池化这些上采样方法计算量要更少,因此网络的运行速度会有很大提升,如下图所示。

 那么接下来看看这个Sub-pixel Convolution的结构,正常情况下,卷积操作会使feature map的高和宽变小,但当stride=\frac{1}{r}<1时,可以让卷积后的feature map的高和宽变大,就实现了分辨率的提升也就是超分辨重建,这个操作叫做sub-pixel convolution。

 对于sub-pixel convolution,作者将一个H × W的低分辨率输入图像(Low Resolution)作为输入,低分辨率图像特征提取完毕后,生成n1个特征图,然后经过中间一堆操作等,不管有多少,只要到该上采样的时候,在最后一个卷积调整成r^{2}C就可以通过Sub-pixel操作将其变为rH x rW的高分辨率图像(High Resolution)。但是其实现过程不是直接通过插值等方式产生这个高分辨率图像,而是通过卷积先得到r^{2}C个通道的特征图(特征图大小和输入低分辨率图像一致),然后通过周期筛选(periodic shuffing)的方法得到这个高分辨率的图像,其中r rr为上采样因子(upscaling factor),也就是图像的扩大倍率。

Pytorch复现

sub-pixel convolution这个操作在pytorch里面提供的有接口,只需要调用就可以了。

输入:  (n,channels\cdot upscalefactor^{2},height,width)
输出: (n,channels,upscalefactor\cdot height,upscalefactor\cdot width)
比如:

sub= nn.PixelShuffle(4)
input = torch.tensor(1, 4**2, 4, 4)
output = sub(input)
torch.Size为[1, 1, 16, 16]

SRCNN使用sub-pixel convolution

主干网络:

import torch.nn as nn
import torch.nn.init as init


class Net(nn.Module):
    def __init__(self, upscale_factor):
        super(Net, self).__init__()

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, kernel_size=3, stride=1, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.conv4(x)
        x = self.pixel_shuffle(x)
        return x

训练:

from __future__ import print_function

from math import log10

import torch
import torch.backends.cudnn as cudnn

from SubPixelCNN.model import Net
from progress_bar import progress_bar


class SubPixelTrainer(object):
    def __init__(self, config, training_loader, testing_loader):
        super(SubPixelTrainer, self).__init__()
        self.CUDA = torch.cuda.is_available()
        self.device = torch.device('cuda' if self.CUDA else 'cpu')
        self.model = None
        self.lr = config.lr
        self.nEpochs = config.nEpochs
        self.criterion = None
        self.optimizer = None
        self.scheduler = None
        self.seed = config.seed
        self.upscale_factor = config.upscale_factor
        self.training_loader = training_loader
        self.testing_loader = testing_loader

    def build_model(self):
        self.model = Net(upscale_factor=self.upscale_factor).to(self.device)
        self.criterion = torch.nn.MSELoss()
        torch.manual_seed(self.seed)

        if self.CUDA:
            torch.cuda.manual_seed(self.seed)
            cudnn.benchmark = True
            self.criterion.cuda()

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[50, 75, 100], gamma=0.5)  # lr decay

    def save(self):
        model_out_path = "model_path.pth"
        torch.save(self.model, model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

    def train(self):
        self.model.train()
        train_loss = 0
        for batch_num, (data, target) in enumerate(self.training_loader):
            data, target = data.to(self.device), target.to(self.device)
            self.optimizer.zero_grad()
            loss = self.criterion(self.model(data), target)
            train_loss += loss.item()
            loss.backward()
            self.optimizer.step()
            progress_bar(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1)))

        print("    Average Loss: {:.4f}".format(train_loss / len(self.training_loader)))

    def test(self):
        self.model.eval()
        avg_psnr = 0

        with torch.no_grad():
            for batch_num, (data, target) in enumerate(self.testing_loader):
                data, target = data.to(self.device), target.to(self.device)
                prediction = self.model(data)
                mse = self.criterion(prediction, target)
                psnr = 10 * log10(1 / mse.item())
                avg_psnr += psnr
                progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))

        print("    Average PSNR: {:.4f} dB".format(avg_psnr / len(self.testing_loader)))

    def run(self):
        self.build_model()
        for epoch in range(1, self.nEpochs + 1):
            print("\n===> Epoch {} starts:".format(epoch))
            self.train()
            self.test()
            self.scheduler.step(epoch)
            if epoch == self.nEpochs:
                self.save()

 

EDSR使用sub-pixel convolution

import math

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self, num_channels, base_channel, upscale_factor, num_residuals):
        super(Net, self).__init__()

        self.input_conv = nn.Conv2d(num_channels, base_channel, kernel_size=3, stride=1, padding=1)

        resnet_blocks = []
        for _ in range(num_residuals):
            resnet_blocks.append(ResnetBlock(base_channel, kernel=3, stride=1, padding=1))
        self.residual_layers = nn.Sequential(*resnet_blocks)

        self.mid_conv = nn.Conv2d(base_channel, base_channel, kernel_size=3, stride=1, padding=1)

        upscale = []
        for _ in range(int(math.log2(upscale_factor))):
            upscale.append(PixelShuffleBlock(base_channel, base_channel, upscale_factor=2))
        self.upscale_layers = nn.Sequential(*upscale)

        self.output_conv = nn.Conv2d(base_channel, num_channels, kernel_size=3, stride=1, padding=1)

    def weight_init(self, mean=0.0, std=0.02):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    def forward(self, x):
        x = self.input_conv(x)
        residual = x
        x = self.residual_layers(x)
        x = self.mid_conv(x)
        x = torch.add(x, residual)
        x = self.upscale_layers(x)
        x = self.output_conv(x)
        return x


def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        if m.bias is not None:
            m.bias.data.zero_()


class ResnetBlock(nn.Module):
    def __init__(self, num_channel, kernel=3, stride=1, padding=1):
        super(ResnetBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_channel, num_channel, kernel, stride, padding)
        self.conv2 = nn.Conv2d(num_channel, num_channel, kernel, stride, padding)
        self.bn = nn.BatchNorm2d(num_channel)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        x = self.bn(self.conv1(x))
        x = self.activation(x)
        x = self.bn(self.conv2(x))
        x = torch.add(x, residual)
        return x


class PixelShuffleBlock(nn.Module):
    def __init__(self, in_channel, out_channel, upscale_factor, kernel=3, stride=1, padding=1):
        super(PixelShuffleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_channel * upscale_factor ** 2, kernel, stride, padding)
        self.ps = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = self.ps(self.conv(x))
        return x

EDSR这个网络后面得单独写一篇。

实验结果

论文中给出的结果如下表,我实际跑出来比原文要略低,应该是因为训练不到位和训练数据集不一样的原因,论文的训练数据集是Image,我是用BSD300训练的。

从左到右分别是原图,Bicubic,ESPCN,效果还说的过去。

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

亚像素卷积网络(ESPCN)学习与Pytorch复现 的相关文章

  • EBYTE E103-W02 WIFI模块配置总结(TCP+UDP+HTTP+云透传)

    目录 1 硬件配置 1 1 原理图 1 2 管脚配置 2 AT指令集 3 AP模式配置 3 1AP介绍 3 2 AP配置TCP通信 3 3 AP配置UDP通信 4 STA模式配置 4 1STA介绍 4 2配置过程 4 3网页配置 5 基于亿
  • JavaSE-自定义单链表

    目录 1 自定义链表实现 2 基础操作 2 1 链表打印操作 2 2 链表逆序打印 2 3 链表逆置 3 进阶操作 3 1查找倒数第K个结点 3 2 不允许遍历链表 xff0c 在pos结点之前插入 3 3两个链表相交 xff0c 输出相交
  • SRGAN实现超分辨率图像重建之模型复现

    1 论文介绍 1 1简介 论文名称 Photo Realistic Single Image Super Resolution Using a Generative Adversarial Ledig C Theis L Huszar F
  • JavaSE-自定义队列+两栈实现队列+两队列实现栈

    1 顺序队列实现 与栈一样 xff0c 队列也是一种操作受限制的线性表 xff0c 但与栈不同的是 xff0c 栈是后进先出 xff0c 队列的特点是先进先出 实现与栈类似 xff0c 队列有一个队头指针和一个队尾指针 xff0c 入队的时
  • JavaSE-八大经典排序算法及优化算法思路与实现

    目录 1 冒泡排序 1 1算法思想 1 2 算法实现 1 3 算法优化 1 4 算法分析 2 简单选择排序 2 1 算法思想 2 2 算法实现 2 3 算法优化 2 4 算法分析 3 直接插入排序 3 1 算法思想 3 2 算法实现 3 3
  • 快速超分辨率重建卷积网络-FSRCNN

    1 网路结构图 2 改进点 SRCNN缺点 xff1a SRCNN将LR送入网络前进行了双三次插值上采样 xff0c 产生于真实图像大小一致的图像 xff0c 会增加计算复杂度 xff0c 因为插值后图像更大 xff0c 输入网络后各个卷积
  • PriorityQueue(优先级队列)的解读与底层实现

    目录 1 什么是优先级队列 xff1f 2 优先级队列的特性 3 如何使用优先级队列 4 JDK源码分析 4 1类的继承关系 4 2类的属性 4 3常用的方法 5 自定义优先级队列实现 5 1 优先级队列类定义 5 2 add 方法 5 3
  • HashMap的使用与底层结构剖析

    目录 一 基础概念 二 先使用再了解 三 底层结构 1 HashMap结构 xff08 JDK1 8以前 xff09 2 HashMap结构 xff08 JDK1 8以后 xff09 四 HashMap实现 1 成员变量 2 put实现 3
  • 线程基础与使用测试

    目录 一 进程和线程 二 线程的创建 1 继承Thread类 xff0c 重写run 方法 2 实现Runnable接口 xff0c 重写run方法 3 匿名线程 xff0c 匿名内部类 4 实现Callable接口 xff0c 重写cal
  • 线程生命周期及常用方法的使用

    一 守护线程 守护线程是什么 xff1f 守护线程是一类比较特殊的线程 xff0c 一般用于处理后台的工作 xff0c 比如JDK的垃圾回收线程 守护线程的作用 xff1f JVM xff08 Java Virtual Machine xf
  • git合并被fork的仓库的远程分支

    如果你 fork 了一个仓库并在自己的 forked 仓库中进行了更改 xff0c 而原始仓库也有一些更新 xff0c 此时想将原始仓库的更新合并到你的 forked 仓库 xff0c 可以按照以下步骤 xff1a 1 将原始仓库添加为远程
  • Linux-基础知识及常见操作命令汇总

    目录 1 终端操作 2 命令手册 3 关机重启 4 runlevel 5 目录结构 6 文件属性 7 Linux文件属主和属组 8 目录常用命令 9 VIM命令 10 进程管理命令 1 进程状态 2 ps命令 3 pstree命令 jobs
  • 关键字synchronized与volatile详解

    在多线程并发编程中synchronized和volatile都扮演着重要的角色 xff0c synchronized一直是元老级角色 xff0c 很多人都会称呼它为重量级锁 但是 xff0c 随着Java SE 1 6对synchroniz
  • 迁移学习与Transformer架构

    迁移学习 迁移学习干什么的 xff1f 迁移学习是通过从已学习的相关任务中转移知识来改进学习的新任务 Eg xff1a 学习识别苹果可能有助于识别梨 xff0c 学习骑自行车可能有助于学习骑摩托车 xff0c 学习打羽毛球可能有助于学习打网
  • 生产者消费者模型分析与实现

    生产者消费者模式就是通过一个容器来解决生产者和消费者的强耦合问题 生产者和消费者彼此之间不直接通讯 xff0c 而通过阻塞队列来进行通讯 xff0c 所以生产者生产完数据之后不用等待消费者处理 xff0c 直接扔给阻塞队列 xff0c 消费
  • ConcurrentHashMap优点与源码剖析

    哈希表是中非常高效 xff0c 复杂度为O 1 的数据结构 xff0c 在Java开发中 xff0c 我们最常见到最频繁使用的就是HashMap和HashTable xff0c 但是在线程竞争激烈的并发场景中使用都不够合理 HashMap
  • IO-字节流

    文件 amp File类的使用 1 文件的概念 文件可以分为文本文件 二进制文件 2 IO流的概念 流是有顺序 有起点和终点的集合 xff0c 是对数据传输的总称 流的本质就是用来对数据进行操作 IO是我们实现的目的 xff0c 实现这个目
  • STM32F407的TCP编程出现客户端无法连接上服务器,DHCP获取IP失败,服务器重启客户端无法自动重连问题解决方案

    单写一篇文章记录这些问题 xff0c 因为有的问题实在是困扰了我太久太久了 xff0c 终于解决了 xff01 xff01 xff01 1 STM32F407的TCP编程 xff0c TCP SERVER测试完全正常 xff0c TCP C
  • SQL练习汇总(查询“01“课程比“02“课程成绩高的学生的信息及课程分)

    1 学生表 Student SID Sname Sage Ssex SID 学生编号 Sname 学生姓名 Sage 年龄 Ssex 学生性别 编号 姓名 年龄 性别 1 赵雷 20 男 2 钱电 20 男 3 孙风 21 男 4 吴兰 1
  • JDBC编程,SQL注入与连接池

    JDBC概念 JDBC Java Data Base Conection 是java中提供的一套标准的应用编程接口 xff0c 用来连接Java编程语言和数据库 JDBC常用组件 xff1a DriverManger xff1a 数据库驱动

随机推荐

  • Pytorch搭建基于SRCNN图像超分辨率重建模型及论文总结

    SRCNN xff08 Super Resolution Convolutional Neural Network xff09 论文出处 xff1a Learning a Deep Convolutional Network for Ima
  • 技术领域的面试总结

    在当今互联网中 xff0c 虽然互联网行业从业者众多 xff0c 不断崛起的互联网公司也会很多 xff0c 仍然是很多同学想要进入的企业 那么本篇文章将会为大家很直白的讲解面试流程以及侧重点 仔细阅读本篇文章 xff0c 绝对会有所收获 x
  • Mybatis基于XML与基于注解实现CRUD

    数据库 实体类Student package com pojo Description Created by Resumebb Date 2021 3 26 public class Student 和数据库中的STudent表对应 pri
  • Spring-IOC容器进行对象管理

    目录 IOC概念 IOC思想 Spring管理对象 集成依赖 spring的配置文件 xff08 Applicationcontext xml xff09 创建实体类User Spring对Bean的实例化方式 基于配置形式 1 通过无参构
  • Spring-AOP原理及实现

    Spring AOP AOP Aspect Oriented Programing 面向切面编程 xff1a 扩展功能不通过修改源代码实现 AOP采用横向抽取机制 xff0c 取代传统纵向继承体系实现响应的功能 xff08 性能监控 事务
  • Spring&Mybatis整合及Spring中JDBCTemplate的使用

    Spring和Mybatis整合 在mybatis中 xff0c 操作数据库需要获取到SQLSession对象 xff0c 而该对象的实例过程在mybatis是通过SQLSessionFactoryBuilder读取全局配置文件来实例化一个
  • SpringMVC设计模式

    什么是MVC MVC是模型 Model 视图 View 控制器 Controller 的简写 xff0c 是一种软件设计规范 是将业务逻辑 数据 显示分离的方法来组织代码 MVC主要作用是降低了视图与业务逻辑间的双向偶合 MVC不是一种设计
  • SSM框架整合

    整合思路 主要分为Controller xff0c service层 xff0c dao层 整合dao mybatis和spring的整合 xff0c 通过spring来管理mapper接口 xff0c 数据源 xff0c 使用mapper
  • SSM框架实战-搭建自己的个人博客1-基础架构搭建

    前言 本系列文章主要通过从零开始搭建自己的个人博客 xff0c 来加深对SSM框架的学习与使用 xff0c 了解一个系统从提出到设计 到开发 到测试 部署运行的过程 xff0c 并记录在搭建过程中的学习心得 遇见的错误及解决方式 代码放在g
  • SSM框架实战-搭建自己的个人博客2-UEditor编辑器的使用

    目录 UEditor 博客内容提交与展示功能测试 Controller开发 新增博客页面add ueditor jsp 博客详情界面detail jsp 博客新增和展示详情功能开发 博客存储 博客标题开发 标签POJO类 TagMapper
  • SSM框架实战-搭建自己的个人博客3-登录实现及前端界面设计

    目录 后台登录功能 前端页面 后端开发 前端界面设计 详情首页 js脚本 SSM整体设计 Dao层 Service层 Mapper xml Controller 子博文界面 部署至服务器 后台登录功能 登录页面 xff1a 用户名和密码 通
  • 超分辨率重建-PNSR与SSIM的计算(RGB、YUV和YCbCr互转)

    RGB YUV和YCbCr 自己复现的网络跑出来的模型进行预测的时候 xff0c 不知道为啥算出来的结果比论文中要低好多 不论scale factor为多少 xff0c 我算出来的结果均普遍低于论文中给出的 xff0c PSNR大概低个1
  • 如何写简历

    注意点 xff1a 篇幅 校招一页 社招二页 谨慎使用精通 精通 gt 熟悉 xff08 推荐使用 xff09 gt 掌握 xff08 推荐使用 xff09 gt 了解 xff08 推荐使用 xff09 拿不准的不要写在简历上 突出自己技能
  • SSM框架实战-搭建自己的个人博客4-文章管理与展示

    实现功能 主要实现上图所示的功能 xff0c 从数据库中查询到所有文章数据 xff0c 并进行显示如标题 xff0c 栏目等信息 xff0c 可以通过分类查询文章 xff0c 通过标签查询文章 xff0c 也可以通过搜索进行模糊查询 xff
  • Pytorch加载与保存模型(利用pth的参数更新h5预训练模型)

    前言 以前用Keras用惯了 xff0c fit和fit generator真的太好使了 xff0c 模型断电保存搞个checkpoint回调函数就行了 近期使用pytorch进行训练 xff0c 苦于没有类似的回调函数 xff0c 写完网
  • 如何用pyplot优雅的绘制loss,acc,miou,psnr变化曲线

    前言 TensorFlowBoard过于强大 xff0c 导致我对它依赖性很强 xff0c 今年转手使用pytorch进行开发 xff0c 本以为没了TensorFlowBoard xff0c 后来发现人家Tensorflow封装了个Ten
  • Pytorch实现CA,SA,SE注意力机制

    通道注意力CA class ChannelAttention nn Module def init self in planes ratio 61 16 super ChannelAttention self init self avg p
  • Python使用OpenCV按自定义帧率提取视频帧并保存

    在做室外语义分割 视觉导航与定位的时候 xff0c 通常会用对一个连续的视频帧进行测试 xff0c 除去常用数据集外 xff0c 也经常会自己制作一些数据集 xff0c 这个工具类可以按需求对视频进行分帧提取 xff0c 封装好了直接可以使
  • 悲观锁与乐观锁详解

    悲观锁 悲观锁顾名思义是从悲观的角度去思考问题 xff0c 解决问题 它总是会假设当前情况是最坏的情况 xff0c 在每次去拿数据的时候 xff0c 都会认为数据会被别人改变 xff0c 因此在每次进行拿数据操作的时候都会加锁 xff0c
  • 亚像素卷积网络(ESPCN)学习与Pytorch复现

    论文内容 论文地址 xff1a Real Time Single Image and Video Super Resolution Using an Efficient Sub Pixel Convolutional Neural Netw