好像还挺好玩的GAN重制版2——Keras搭建SRGAN平台进行图片超分辨率提升

2023-11-17

学习前言

我又死了我又死了我又死了!
在这里插入图片描述

源码下载地址

https://github.com/bubbliiiing/srgan-keras

喜欢的可以点个star噢。

网络构建

一、什么是SRGAN

SRGAN出自论文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。

如果将SRGAN看作一个黑匣子,其主要的功能就是输入一张低分辨率图片,生成高分辨率图片。
在这里插入图片描述
该文章提到,普通的超分辨率模型训练网络时只用到了均方差作为损失函数,虽然能够获得很高的峰值信噪比,但是恢复出来的图像通常会丢失高频细节

SRGAN利用感知损失(perceptual loss)和对抗损失(adversarial loss)来提升恢复出的图片的真实感

二、生成网络的构建

在这里插入图片描述
生成网络的构成如上图所示,生成网络的作用是输入一张低分辨率图片,生成高分辨率图片。

SRGAN的生成网络由三个部分组成。
1、低分辨率图像进入后会经过一个卷积+RELU函数
2、然后经过B个残差网络结构,每个残差结构都包含两个卷积+标准化+RELU,还有一个残差边。
3、然后进入上采样部分,在经过两次上采样后,原图的高宽变为原来的4倍,实现分辨率的提升

前两个部分用于特征提取,第三部分用于提高分辨率。

def residual_block(inputs, filters):
    x = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(inputs)
    x = layers.BatchNormalization(momentum=0.5)(x)
    x = layers.advanced_activations.PReLU(shared_axes=[1,2])(x)

    x = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(x)
    x = layers.BatchNormalization(momentum=0.5)(x)
    x = layers.Add()([x, inputs])
    return x

def deconv2d(inputs):
    x = layers.Conv2D(256, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(inputs)
    x = SubpixelConv2D(scale=2)(x)
    x = layers.advanced_activations.PReLU(shared_axes=[1,2])(x)
    return x

def build_generator(lr_shape, scale_factor, num_residual=16):
    #-----------------------------------#
    #   获得进行上采用的次数
    #-----------------------------------#
    upsample_block_num = int(math.log(scale_factor, 2))
    img_lr = layers.Input(shape=lr_shape)

    #--------------------------------------------------------#
    #   第一部分,低分辨率图像进入后会经过一个卷积+PRELU函数
    #--------------------------------------------------------#
    x = layers.Conv2D(64, kernel_size=9, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(img_lr)
    x = layers.advanced_activations.PReLU(shared_axes=[1,2])(x)

    short_cut = x
    #-------------------------------------------------------------#
    #   第二部分,经过num_residual个残差网络结构。
    #   每个残差网络内部包含两个卷积+标准化+PRELU,还有一个残差边。
    #-------------------------------------------------------------#
    for _ in range(num_residual):
        x = residual_block(x, 64)

    x = layers.Conv2D(64, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(x)
    x = layers.BatchNormalization(momentum=0.5)(x)
    x = layers.Add()([x, short_cut])

    #-------------------------------------------------------------#
    #   第三部分,上采样部分,将长宽进行放大。
    #   两次上采样后,变为原来的4倍,实现提高分辨率。
    #-------------------------------------------------------------#
    for _ in range(upsample_block_num):
        x = deconv2d(x)

    gen_hr = layers.Conv2D(3, kernel_size=9, strides=1, padding='same', activation='tanh')(x)

    return Model(img_lr, gen_hr)

三、判别网络的构建

在这里插入图片描述
判别网络的构成如上图所示:

SRGAN的判别网络由不断重复的 卷积+LeakyRELU和标准化 组成。
对于判断网络来讲,它的目的是判断输入图片的真假,它的输入是图片,输出是判断结果

判断结果处于0-1之间,利用接近1代表判断为真图片,接近0代表判断为假图片。

判断网络的构建和普通卷积网络差距不大,都是不断的卷积对图片进行下采用,在多次卷积后,最终接一次全连接判断结果。

实现代码如下:

def d_block(inputs, filters, strides=1):
    x = layers.Conv2D(filters, kernel_size=3, strides=strides, padding='same', kernel_initializer = random_normal(stddev=0.02))(inputs)
    x = layers.BatchNormalization(momentum=0.5)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    return x

def build_discriminator(hr_shape):
    inputs = layers.Input(shape=hr_shape)

    x = layers.Conv2D(64, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(inputs)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = d_block(x, 64, strides=2)
    x = d_block(x, 128)
    x = d_block(x, 128, strides=2)
    x = d_block(x, 256)
    x = d_block(x, 256, strides=2)
    x = d_block(x, 512)
    x = d_block(x, 512, strides=2)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(1024, kernel_initializer = random_normal(stddev=0.02))(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    validity = layers.Dense(1, activation='sigmoid', kernel_initializer = random_normal(stddev=0.02))(x)
    return Model(inputs, validity)

训练思路

SRGAN的训练可以分为生成器训练和判别器训练:
每一个step中一般先训练判别器,然后训练生成器。

一、判别器的训练

训练判别器的时候我们希望判别器可以判断输入图片的真伪,因此我们的输入就是真图片、假图片和它们对应的标签

因此判别器的训练步骤如下:

1、随机选取batch_size个真实高分辨率图片。
2、利用resize后的低分辨率图片,传入到Generator中生成batch_size个虚假高分辨率图片。
3、真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练。

二、生成器的训练

训练生成器的时候我们希望生成器可以生成极为真实的假图片。因此我们在训练生成器需要知道判别器认为什么图片是真图片。

因此生成器的训练步骤如下:

1、将低分辨率图像传入生成模型,得到虚假高分辨率图像,将虚假高分辨率图像获得判别结果与1进行对比得到loss。(与1对比的意思是,让生成器根据判别器判别的结果进行训练)。
2、将真实高分辨率图像和虚假高分辨率图像传入VGG网络,获得两个图像的特征,通过这两个图像的特征进行比较获得loss

在这里插入图片描述

利用SRGAN生成图片

SRGAN的库整体结构如下:
在这里插入图片描述

一、数据集的准备

在训练前需要准备好数据集,数据集保存在datasets文件夹里面。
在这里插入图片描述

二、数据集的处理

打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。
此时生成根目录下面的train_lines.txt。
在这里插入图片描述

三、模型训练

在完成数据集处理后,运行train.py即可开始训练。
在这里插入图片描述
训练过程中,可在results文件夹内查看训练效果:
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

好像还挺好玩的GAN重制版2——Keras搭建SRGAN平台进行图片超分辨率提升 的相关文章

  • 听说,你想做大模型时代的应用层创业!

    亲爱的科技探险家们和代码魔法师们 未来的钟声已经敲响 预示着一场极度炫酷的虚拟现实游戏即将展开 从初期简单的智能识别 到设计师级别的图纸设计 生成式AI技术 Generative AI 以其独特理念和创新模式重塑了传统内容生产效率和交互模式
  • Unity手机端3档震动

    using System Collections using System Collections Generic using UnityEngine public class VibrateHelper MonoBehaviour sta

随机推荐

  • elementui dialog组件固定高度

    弹窗高度过大 想设置个自适应的高度 固定头尾 deep el dialog margin 5vh auto important deep el dialog body height 70vh overflow auto margin hei
  • ChatGPT救命!4岁男孩3年求医17位专家无果,大模型精准揪出病因

    克雷西 萧箫 发自 凹非寺量子位 公众号 QbitAI 怪病 缠身3年求医无果 最终竟然被ChatGPT成功诊断 这是发生在一名4岁男孩身上的真实经历 某次运动后 他身体开始剧痛 母亲前后带她看了17名医生 从儿科 骨科到各种专家 先后进行
  • DeferredResult的使用场景及用法

    场景 假设我们现在要实现这样一个功能 浏览器要实时展示服务端计算出来的数据 一种可能的实现是 浏览器频繁 例如定时1秒 向服务端发起请求以获得服务端数据 但定时请求并不能 实时 反应服务端的数据变化情况 若定时周期为S 则数据延迟周期最大即
  • 小程序支付完整过程。足够详细!

    1 注册小程序 拿到App id 和 AppSecret 小程序密钥 取得商户的微信支付商户号 MCHID 和 微信支付密钥 APIKEY 2 流程 微信用户在微信端选商品下单 后台响应下单生成单号 产生第一次签名数据 提交微信统一支付接口
  • IDA工具介绍

    往期推荐 IDA工具安装 分享 ARM处理器寻址方式 ARM指令集 ARM汇编语言程序结构 昨天给大家概括性的了解了IDA工具 今天分享IDA工具的实际应用 一 IDA打开so文件 1 首先打开IDA工具 进入选项界面 直接选中Go选项 如
  • 初识 libcurl multi:实现一个 http 请求处理客户端,一个线程玩命发一个线程吐血收

    一 引言 最近在工作中 遇到了这么一个需求 我们希望拥有一个高性能的 http 请求处理客户端程序 这个客户端要求有这样的架构 它拥有两个线程 一个线程接收业务程序通过消息队列发来的批量的 http 请求信息 进行批量的 http 请求 另
  • 深入了解scratch中的“移动10步”和(你真的了解scratch吗?scratch初学者值得一看)

    scratch中的 移动10步 是scratch运动类积木中的第一个积木 也是大多数初学者使用scratch的时候用到的第一个积木命令 当我们运行 移动10步 积木时 小猫会向右移动10步 目测其实也就一点点距离 那么 移动10步 究竟在s
  • 正交多项式-勒让德多项式,两类切比雪夫多项式及零点,拉盖尔多项式,埃尔米特多项式

    1 正交多项式 设 n x 是 a
  • Linux 云服务器运维(操作及命令)

    1 什么是linux服务器load average Load是用来度量服务器工作量的大小 即计算机cpu任务执行队列的长度 值越大 表明包括正在运行和待运行的进程数越多 2 如何查看linux服务器负载 可以通过w top uptime p
  • Tomcat多实例部署

    文章目录 一 Tomcat多实例的操作步骤 1 关闭防火墙 将安装 Tomcat 所需软件包传到 opt目录下 2 安装JDK 3 安装 tomcat 4 配置 tomcat 环境变量 5 修改 tomcat2 中的 server xml
  • Qt设计模式与运行界面有偏差 Qt自适应高清屏

    原因Qt对高分辨率屏幕支持的问题 设置下属性 注意在应用程序实例之前设置 int main int argc char argv if QT VERSION gt QT VERSION CHECK 5 9 0 QApplication se
  • 安装NVIDIA CUDA失败最简单详细解决方法

    针对于这样的情况直接下载显卡驱动卸载工具 进入网站 点击下载 网页下拉 会出现如下图所示的内容 点击官方下载 下载软件 运行程序 下载完之后是一个压缩包的形式 解压缩 之后点击运行 如果不是最新版本会跳出如下弹出框 程序是否为最新版本没有影
  • 服务器怎么修改mac,服务器如何修改MAC地址

    服务器如何修改MAC地址 内容精选 换一换 如果要自定义裸金属服务器的DNS服务器信息 需要将裸金属服务器网络设置为静态IP 若将动态DHCP改为静态IP设置 IP和网关等网络信息必须和裸金属服务器下发时保持一致 否则可能会引起网络不通 以
  • Flask框架学习整理——从零开始入门Flask

    文章目录 Flask框架 一 简介 二 概要 三 知识点 附代码 1 Flask基础入门 1 路由route的创建 2 endpoint的作用 3 request对象的使用 4 请求钩子before after request 5 redi
  • Day6:浅谈useState

    目标 持续输出 每日分享关于web前端常见知识 面试题 性能优化 新技术等方面的内容 Day6 今日话题 谈谈react hooks中的useState 将从以下七个角度介绍 用法 参数 返回值 作用 工作原理 优缺点 注意点 用法 use
  • Map BSTMap映射代码(java)

    public interface Map
  • DDR3学习总结(一)

    简介 DDR3 SDRAM常 简称 DDR3 是当今较为常见的一种储存器 在计算机及嵌入式产品中得到广泛应用 特别是应用在涉及到大量数据交互的场合 比如电脑的内存条 对DDR3的读写操作大都借助IP核来完成 本次实验将采用 Xilinx公司
  • 报错 blocked by: [FORBIDDEN/12/index read-only / allow delete (api)];ElasticSearch 报错

    报错 Caused by org elasticsearch cluster block ClusterBlockException blocked by FORBIDDEN 12 index read only allow delete
  • Kali Linux 从入门到精通(一)-概论

    Kali Linux 从入门到精通 一 概论 欢迎关注 https github com Wheeeeeeeeels 基本介绍 1 安全目标 先于攻击者发现和防止漏洞出现 攻击型安全 防护型安全 2 渗透测试 尝试挫败安全防御机制 发现系统
  • 好像还挺好玩的GAN重制版2——Keras搭建SRGAN平台进行图片超分辨率提升

    好像还挺好玩的GAN重制版2 Keras搭建SRGAN平台进行图片超分辨率提升 学习前言 源码下载地址 网络构建 一 什么是SRGAN 二 生成网络的构建 三 判别网络的构建 训练思路 一 判别器的训练 二 生成器的训练 利用SRGAN生成