好像还挺好玩的GAN8——SRGAN实现图像的分辨率提升

2023-11-16

注意事项

该博客已经有重置版啦,重制版代码更清晰,效果更好一些:
https://blog.csdn.net/weixin_44791964/article/details/110630254

学习前言

SRGAN可以提升图像分辨率,俺很感兴趣,有必要了解一下。
在这里插入图片描述

什么是SRGAN

SRGAN出自论文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。

其主要的功能就是输入一张低分辨率图片,生成高分辨率图片。

文章提到,普通的超分辨率模型训练网络时只用到了均方差作为损失函数,虽然能够获得很高的峰值信噪比,但是恢复出来的图像通常会丢失高频细节。

SRGAN利用感知损失(perceptual loss)和对抗损失(adversarial loss)来提升恢复出的图片的真实感。

其中感知损失是利用卷积神经网络提取出的特征,通过比较生成图片经过卷积神经网络后的特征和目标图片经过卷积神经网络后的特征的差别,使生成图片和目标图片在语义和风格上更相似

对抗损失由GAN提供,根据图像是否可以欺骗过判别网络进行训练。

代码与训练数据的下载

这是我的github连接,代码可以在上面下载:
https://github.com/bubbliiiing/GAN-keras

这个是DIV高清图:
https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
大家也可以用其它的数据集进行训练:
人脸重建可以试试这些人脸数据集:
https://www.cnblogs.com/haiyang21/p/11208293.html

神经网络组成

1、生成网络

生成网络的构成如下图所示:
在这里插入图片描述
此图从左至右来看,我们可以知道:
SRGAN的生成网络由三个部分组成。
1、低分辨率图像进入后会经过一个卷积+RELU函数
2、然后经过B个残差网络结构,每个残差网络内部包含两个卷积+标准化+RELU,还有一个残差边。
3、然后进入上采样部分,将长宽进行放大,两次上采样后,变为原来的4倍,实现提高分辨率。

前两部分用于特征提取,第三部分用于提高分辨率。

def build_generator(self):

    def residual_block(layer_input, filters):
        d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
        d = BatchNormalization(momentum=0.8)(d)
        d = Activation('relu')(d)
        d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
        d = BatchNormalization(momentum=0.8)(d)
        d = Add()([d, layer_input])
        return d

    def deconv2d(layer_input):
        u = UpSampling2D(size=2)(layer_input)
        u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
        u = Activation('relu')(u)
        return u

    img_lr = Input(shape=self.lr_shape)
    # 第一部分,低分辨率图像进入后会经过一个卷积+RELU函数
    c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
    c1 = Activation('relu')(c1)

    # 第二部分,经过16个残差网络结构,每个残差网络内部包含两个卷积+标准化+RELU,还有一个残差边。
    r = residual_block(c1, 64)
    for _ in range(self.n_residual_blocks - 1):
        r = residual_block(r, 64)

    # 第三部分,上采样部分,将长宽进行放大,两次上采样后,变为原来的4倍,实现提高分辨率。
    c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
    c2 = BatchNormalization(momentum=0.8)(c2)
    c2 = Add()([c2, c1])
    u1 = deconv2d(c2)
    u2 = deconv2d(u1)
    gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)

    return Model(img_lr, gen_hr)

2、判别网络

在这里插入图片描述
此图从左至右来看,我们可以知道:
SRGAN的判别网络由不断重复的 卷积+LeakyRELU和标准化 组成。

def build_discriminator(self):

    def d_block(layer_input, filters, strides=1, bn=True):
        """Discriminator layer"""
        d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if bn:
            d = BatchNormalization(momentum=0.8)(d)
        return d
    # 由一堆的卷积+LeakyReLU+BatchNor构成
    d0 = Input(shape=self.hr_shape)

    d1 = d_block(d0, 64, bn=False)
    d2 = d_block(d1, 64, strides=2)
    d3 = d_block(d2, 64*2)
    d4 = d_block(d3, 64*2, strides=2)
    d5 = d_block(d4, 64*4)
    d6 = d_block(d5, 64*4, strides=2)
    d7 = d_block(d6, 64*8)
    d8 = d_block(d7, 64*8, strides=2)

    d9 = Dense(64*16)(d8)
    d10 = LeakyReLU(alpha=0.2)(d9)
    validity = Dense(1, activation='sigmoid')(d10)

    return Model(d0, validity)

训练思路

1、对判别模型进行训练

将真实的高分辨率图像和虚假的高分辨率图像传入判别模型中。
将真实的高分辨率图像的判别结果与1对比得到loss。
将虚假的高分辨率图像的判别结果与0对比得到loss。
利用得到的loss进行训练。

2、对生成模型进行训练

将低分辨率图像传入生成模型,得到高分辨率图像,利用该高分辨率图像获得判别结果与1进行对比得到loss。
将真实的高分辨率图像和虚假的高分辨率图像传入VGG网络,获得两个图像的特征,通过这两个图像的特征进行比较获得loss。

在这里插入图片描述

全部代码

1、data_loader全部代码

该部分用于对数据进行加载:

import scipy
from glob import glob
import numpy as np
import matplotlib.pyplot as plt

class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "test"
        
        path = glob('./datasets/%s/train/*' % (self.dataset_name))

        batch_images = np.random.choice(path, size=batch_size)

        imgs_hr = []
        imgs_lr = []
        for img_path in batch_images:
            img = self.imread(img_path)

            h, w = self.img_res
            low_h, low_w = int(h / 4), int(w / 4)

            img_hr = scipy.misc.imresize(img, self.img_res)
            img_lr = scipy.misc.imresize(img, (low_h, low_w))

            # If training => do random flip
            if not is_testing and np.random.random() < 0.5:
                img_hr = np.fliplr(img_hr)
                img_lr = np.fliplr(img_lr)

            imgs_hr.append(img_hr)
            imgs_lr.append(img_lr)

        imgs_hr = np.array(imgs_hr) / 127.5 - 1.
        imgs_lr = np.array(imgs_lr) / 127.5 - 1.

        return imgs_hr, imgs_lr


    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)

2、主函数全部代码

训练代码

from __future__ import print_function, division
import scipy

from keras.datasets import mnist
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.applications import VGG19
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os

import keras.backend as K

class SRGAN():
    def __init__(self):
        # 低分辨率图的shape
        self.channels = 3
        self.lr_height = 128
        self.lr_width = 128
        self.lr_shape = (self.lr_height, self.lr_width, self.channels)
        # 高分辨率图的shape
        self.hr_height = self.lr_height*4
        self.hr_width = self.lr_width*4
        self.hr_shape = (self.hr_height, self.hr_width, self.channels)
        
        # 16个残差卷积块
        self.n_residual_blocks = 16
        # 优化器
        optimizer = Adam(0.0002, 0.5)
        # 创建VGG模型,该模型用于提取特征
        self.vgg = self.build_vgg()
        self.vgg.trainable = False
        
        # 数据集
        self.dataset_name = 'DIV'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.hr_height, self.hr_width))


        patch = int(self.hr_height / 2**4)
        self.disc_patch = (patch, patch, 1)

        # 建立判别模型
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])
        self.discriminator.summary()
        # 建立生成模型
        self.generator = self.build_generator()
        self.generator.summary()

        # 将生成模型和判别模型结合。训练生成模型的时候不训练判别模型。
        img_lr = Input(shape=self.lr_shape)

        fake_hr = self.generator(img_lr)
        fake_features = self.vgg(fake_hr)

        self.discriminator.trainable = False
        validity = self.discriminator(fake_hr)
        self.combined = Model(img_lr, [validity, fake_features])
        self.combined.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[5e-1, 1],
                              optimizer=optimizer)


    def build_vgg(self):
        # 建立VGG模型,只使用第9层的特征
        vgg = VGG19(weights="imagenet")
        vgg.outputs = [vgg.layers[9].output]

        img = Input(shape=self.hr_shape)
        img_features = vgg(img)

        return Model(img, img_features)

    def build_generator(self):

        def residual_block(layer_input, filters):
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
            d = Activation('relu')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Add()([d, layer_input])
            return d

        def deconv2d(layer_input):
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
            u = Activation('relu')(u)
            return u

        img_lr = Input(shape=self.lr_shape)
        # 第一部分,低分辨率图像进入后会经过一个卷积+RELU函数
        c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
        c1 = Activation('relu')(c1)

        # 第二部分,经过16个残差网络结构,每个残差网络内部包含两个卷积+标准化+RELU,还有一个残差边。
        r = residual_block(c1, 64)
        for _ in range(self.n_residual_blocks - 1):
            r = residual_block(r, 64)

        # 第三部分,上采样部分,将长宽进行放大,两次上采样后,变为原来的4倍,实现提高分辨率。
        c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
        c2 = BatchNormalization(momentum=0.8)(c2)
        c2 = Add()([c2, c1])
        u1 = deconv2d(c2)
        u2 = deconv2d(u1)
        gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)

        return Model(img_lr, gen_hr)

    def build_discriminator(self):

        def d_block(layer_input, filters, strides=1, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d
        # 由一堆的卷积+LeakyReLU+BatchNor构成
        d0 = Input(shape=self.hr_shape)

        d1 = d_block(d0, 64, bn=False)
        d2 = d_block(d1, 64, strides=2)
        d3 = d_block(d2, 128)
        d4 = d_block(d3, 128, strides=2)
        d5 = d_block(d4, 256)
        d6 = d_block(d5, 256, strides=2)
        d7 = d_block(d6, 512)
        d8 = d_block(d7, 512, strides=2)

        d9 = Dense(64*16)(d8)
        d10 = LeakyReLU(alpha=0.2)(d9)
        validity = Dense(1, activation='sigmoid')(d10)

        return Model(d0, validity)
    def scheduler(self,models,epoch):
        # 学习率下降
        if epoch % 20000 == 0 and epoch != 0:
            for model in models:
                lr = K.get_value(model.optimizer.lr)
                K.set_value(model.optimizer.lr, lr * 0.5)
            print("lr changed to {}".format(lr * 0.5))

    def train(self, epochs ,init_epoch=0, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()
        if init_epoch!= 0:
            self.generator.load_weights("weights/%s/gen_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)
            self.discriminator.load_weights("weights/%s/dis_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)

        for epoch in range(init_epoch,epochs):
            self.scheduler([self.combined,self.discriminator],epoch)
            # ---------------------- #
            #  训练判别网络
            # ---------------------- #
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)

            fake_hr = self.generator.predict(imgs_lr)

            valid = np.ones((batch_size,) + self.disc_patch)
            fake = np.zeros((batch_size,) + self.disc_patch)

            d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
            d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------- #
            #  训练生成网络
            # ---------------------- #
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)

            valid = np.ones((batch_size,) + self.disc_patch)

            image_features = self.vgg.predict(imgs_hr)

            g_loss = self.combined.train_on_batch(imgs_lr, [valid, image_features])
            print(d_loss,g_loss)
            elapsed_time = datetime.datetime.now() - start_time
            print ("[Epoch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, feature loss: %05f] time: %s " \
                                                                        % ( epoch, epochs,
                                                                            d_loss[0], 100*d_loss[1],
                                                                            g_loss[1],
                                                                            g_loss[2],
                                                                            elapsed_time))

            if epoch % sample_interval == 0:
                self.sample_images(epoch)
                # 保存
                if epoch % 500 == 0 and epoch != init_epoch:
                    os.makedirs('weights/%s' % self.dataset_name, exist_ok=True)
                    self.generator.save_weights("weights/%s/gen_epoch%d.h5" % (self.dataset_name, epoch))
                    self.discriminator.save_weights("weights/%s/dis_epoch%d.h5" % (self.dataset_name, epoch))

    def sample_images(self, epoch):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 2

        imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=2, is_testing=True)
        fake_hr = self.generator.predict(imgs_lr)

        imgs_lr = 0.5 * imgs_lr + 0.5
        fake_hr = 0.5 * fake_hr + 0.5
        imgs_hr = 0.5 * imgs_hr + 0.5

        titles = ['Generated', 'Original']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for row in range(r):
            for col, image in enumerate([fake_hr, imgs_hr]):
                axs[row, col].imshow(image[row])
                axs[row, col].set_title(titles[col])
                axs[row, col].axis('off')
            cnt += 1
        fig.savefig("images/%s/%d.png" % (self.dataset_name, epoch))
        plt.close()

        for i in range(r):
            fig = plt.figure()
            plt.imshow(imgs_lr[i])
            fig.savefig('images/%s/%d_lowres%d.png' % (self.dataset_name, epoch, i))
            plt.close()

if __name__ == '__main__':
    gan = SRGAN()
    gan.train(epochs=60000,init_epoch = 0, batch_size=1, sample_interval=50)

实现效果在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

好像还挺好玩的GAN8——SRGAN实现图像的分辨率提升 的相关文章

  • C++设计模式---组合模式

    文章目录 使用场景 组合模式的定义 安全组合模式 使用场景 组合模式和类与类之间的组合是不同的概念 组合模式主要用来处理树形结构的数据 如果要表达的数据不是树形结构 就不太适合组合模式 比如我们有一个目录结构 这个目录我们把它绘制成树形结构
  • python中sha1 md5等用法

    import hashlib sha1 hashlib sha1 sha1 update a encode utf 8 sha1 update b encode utf 8 sha1 update c encode utf 8 等同于 sh
  • Linux下more命令高级用法

    我们在 Linux 环境下工作时 每天肯定会跟各种各样的文本文件打交道 这些文件 有时候会非常长 无法在一屏的空间内显示完全 所以 在查看这种文件时 我们需要分页显示 这时 我们就可以使用 more 命令 more 命令基本用法 more
  • 上传文件至svn

    1 软件管理搜索下载tortoiseSVN 2 在D盘新建一个文件夹 重命名 最好与SVN上要上传的目录名字保持一致 自己取也可以 3 选中文件夹右键点击SVN checkout 弹出框标红的填写XXSVN上的地址 拉取远程仓库的文件 完成

随机推荐

  • 游戏贪吃蛇计分c语言,C语言实现贪吃蛇游戏(命令行)

    这是一个纯C语言写的贪吃蛇游戏 供大家参考 具体内容如下 include include include include include define SNAKE LENGTH 100 定义蛇的最大长度 define SCREEN WIDE
  • C++中的STL中map用法详解

    C 中的STL中map用法详解 Map是STL的一个关联容器 它提供一对一 其中第一个可以称为关键字 每个关键字只能在map中出现一次 第二个可能称为该关键字的值 的数据 处理能力 由于这个特性 它完成有可能在我们处理一对一数据的时候 在编
  • 我们规定对一个字符串的shift操作如下:

    shift ABCD 0 ABCD shift ABCD 1 DABC shift ABCD 2 CDAB 换言之 我们把最左侧的N个字符剪切下来 按序附加到了右侧 给定一个长为n的字符串 我们规定最多可以进行n次向左的循环shift操作
  • itextpdf、freemarker和flying-saucer-pdf实现PDF导出功能

    目录 目录 1 导入maven 2 代码结构 编辑 3 纯文本生成方式 JavaToPdfHtml template html simhei ttf 字体文件自行百度下载 4 基础上加了freemarker模板引擎 JavaToPdfHtm
  • k8s英伟达GPU插件(nvidia-device-plugin)

    安装方法 Installation Guide NVIDIA Cloud Native Technologies documentation 1 本地节点添加 NVIDIA 驱动程序 要求 NVIDIA drivers 384 81 先确保
  • Vue开发 常用方法总结

    nextTick this nextTick 将回调延迟到下次 DOM 更新循环之后执行 在修改数据之后立即使用它 然后等待 DOM 更新 使用场景 在一些情况下 变量进行了初始赋值或更新 但是DOM还未更新完成时 使用变量的值是不起作用的
  • tensorflow实践(一) 安装和调试

    人工智能 机器学习 监督学习 无监督学习 深度学习等等一大堆词语对于每个软件开发人员来说 是最近几年听的最多 也最让人觉的自己离IT未来有差距的词汇 虽然将来人工智能是否如现在预测的广泛的取代大部分行业 但是仅仅就各种才露尖尖角的各种场景如
  • [项目管理-30]:项目成员成熟度以及采取的不同的策略

    目录 前言 一 管理的误区 二 帮助部属的四种层次 2 1 四种类型 层次 2 2 四象限法差别化帮助 三 项目经理自身可以使用的权力 资源 3 1 专家权力 3 2 奖励权力 3 3 正式权力 3 4 参照性权力 3 5 惩罚性 强制性权
  • wps(word)复制过来的文字一行字数不足却自动换行解决办法

    问题描述 在wps或者word里粘贴网页上复制过来的文字常常会出现每行字数不足换行要求却提前换行的情况 如下图所示 问题分析 出现这种情况主要是从网页上复制过来的文字保留了原网站的段落标记 手动换行符 打开显示段落标记 出现向下的箭头符号就
  • 闲谈IPv6-源IP地址的选择(RFC3484读后感)

    杭州数月的连续淅淅沥沥的雨 让我感到舒适 但却不知湿了多少人的皮鞋 回想起2014年的一个周末从上海来杭州 我在思考一个关于IPv6的问题 但一切却不是因为IPv6而起 缘起 在多年以前 我被一个看似很简单的问题困扰了很久很久 问题是这样的
  • 【Zotero高效知识管理】(2)Zotero的安装、百度云存储配置及常用插件安装

    Zotero高效知识管理 专栏其他文章 Zotero文献管理软件的系统性教程 包括安装 全面的配置 基于众多插件的文献导入 管理 引用 笔记方法 Zotero高效知识管理 1 Zotero介绍 Zotero高效知识管理 3 Zotero的文
  • 王爽《汇编语言》第3版 实验4 详解 以及个人的一些小疑问

    实验四 1和2编程 向内存0 2000 23F依次传送数据063 3FH 为什么0 200和0020 0表示的是同一段内存地址 0000 X 16 0200 00200 assume cs codes codes segment mov a
  • android组件之DrawerLayout(抽屉导航)-- 侧滑菜单效果

    一 介绍 导航抽 屉显示 在屏幕的最左侧 默认情况下是隐藏的 当用户用手指从边缘向另一个滑动的时候 会出现一个隐藏的面板 当点击面板外部或者向原来的方向滑动的时候 抽屉导航就会消失了 好了 这个抽屉就是DrawerLayout 该类位于V4
  • MyBatis总结(六)--typeAliases属性介绍

    说明 typeAliases别名处理器 是为 Java 类型设置一个短的名字 可以方便我们 引用某个类 正常情况下不推荐使用该别名处理器 因为使用了别明处理器不方便直接观察到所对应的类 在项目维护起来不方便 1对单个类起别名 在mybati
  • windows10专业版使用远程桌面

    windows企业版远程桌面控制方法 服务器端 打开服务器开关 添加新用户 添加其他用户即可 3 点立即查找 添加要控制的用户 客户端 搜windows远程桌面 添加远程的ip地址即可 家庭想被控制的方法 windows10 远程桌面设置
  • Spring Cloud Sentinel(限流、熔断、降级)、SpringBoot整合Sentinel、Sentinel的使用-60

    文章目录 一 Sentinel简介 1 1 官方文档 1 2 项目地址 1 3 特征 1 4 Sentinel 分为两个部分 1 5 基本概念 1 6 主要作用 流量控制 熔断降级 系统负载保护 1 7 Hystrix 与 Sentinel
  • 机器学习——KNN实现

    一 KNN K近邻 概述 KNN一种基于距离的计算的分类和回归的方法 其主要过程为 计算训练样本和测试样本中每个样本点的距离 常见的距离度量有欧式距离 马氏距离等 对上面所有的距离值进行排序 升序 选前k个最小距离的样本 根据这k个样本的标
  • web信息收集----网站指纹识别

    文章目录 一 网站指纹 web指纹 二 CMS简介 三 指纹识别方法 3 1 在线网站识别 3 2 工具识别 3 3 手动识别 3 4 Wappalyzer插件识别 一 网站指纹 web指纹 Web指纹定义 Web指纹是一种对目标网站的识别
  • Stata改变变量label

    我们用dta格式数据时 label栏可能是无法识别的字符 其中一个原因是我们电脑安装的是简体中文版 但数据原来的label是繁体字 只要用 label var命令就可以更改了 具体用法 label var 变量名称 变量新label 如下所
  • 好像还挺好玩的GAN8——SRGAN实现图像的分辨率提升

    好像还挺好玩的GAN8 SRGAN实现图像的分辨率提升 注意事项 学习前言 什么是SRGAN 代码与训练数据的下载 神经网络组成 1 生成网络 2 判别网络 训练思路 1 对判别模型进行训练 2 对生成模型进行训练 全部代码 1 data