【CV with Pytorch】第 8 章 :图像超分辨率

2023-10-27

随着高分辨率图像捕获代理的出现,图像中捕获的信息是巨大的。技术已经从超高清转向 4K 和 8K 分辨率。如今,电影正在使用高分辨率帧;但是,在某些情况下,他们需要将低分辨率图像增强为高分辨率图像。想象这样一个场景,电影的主角正试图确定从一张超速行驶的汽车的照片中捕捉到的车牌。超分辨率现在可以帮助我们在不扭曲图像的情况下高度放大图像。该行业发生了一些有趣的进步,我们将通过一些例子来讨论这些进步。

图像中的现有信息不能从最初存在的任何信息中增加。在计算机科学中,我们有“垃圾输入,垃圾输出”,这是一个类似的概念。我们不能指望找到图像中不存在的东西。因此,在某种程度上,超分辨率似乎很牵强,而且受到信息论的限制。即便如此,积极的研究表明这个问题是可以解决的。

让我们深入研究手头的问题。到目前为止,我们已经处理了一种有监督的学习形式,其中总是有一个与基本事实相关的损失函数。该模型从定义的输入 (X) 和预期输出 (Y) 中学习。训练模型的全部本质是帮助将输入映射到输出。但这在无监督学习中不会发生。无监督方式帮助模型在没有映射输出的情况下学习输入数据中的模式。该模型学习数据中的模式并围绕它构建权重,然后识别数据中的异同。与监督学习方法不同,无监督学习没有纠正措施。ground truth 方面缺失了,但优化的概念仍然存在。

让我们深入探讨判别模型和生成模型的概念。在生成模型中,学习输入和输出的联合概率。数据的分布是学习的,通常是训练模型的更通用的方法。这些模型能够在输入空间中生成合成数据点。另一方面,判别模型专门创建从输入空间到输出的映射函数。生成模型的例子有线性判别分析、朴素贝叶斯和高斯模型。

我们为什么要介绍生成模型并讨论学习数据分布的思想?我们回过头来看逻辑,它可以帮助我们实现超分辨率。

  • 使用最近邻的概念放大图像

  • 双线性插值/双三次插值

  • 傅里叶变换

  • 神经网络

我们详细探讨了所有这些可能的方法。但在此之前,让我们探索用于放大低分辨率图像的基本技术,从最近邻缩放开始。图8-1a显示了一个基本图像,可以将其调整为更大的图像(参见图8-1b),但请记住图像中的信息保持不变。只是表示形式发生了变化。

图 8-1a 一张 3x3 的图片

图 8-1b 一个 3x3 的图像扩展到6x6

使用最近邻概念进行放大

需要更快解决方案更改的问题也需要更快的操作。我们知道使用卷积神经网络或任何接近神经网络的东西都需要大量计算,所以是时候使用一些简单的技术了。如果我们需要更快的技术,使用最近邻概念放大图像是最有竞争力的方法之一。

8-2a显示需要放大到 8x8 图像的 4x4 图像,如图8-2b所示。我们最初在图像上有 16 个像素,然后当它被拉伸到 64 个像素时,我们剩下 48 个需要填充的空缺。最近邻的概念可以从直线单元来理解。考虑一条直线数字线,从 0 开始到 4 结束。如果我们将它分成四个相等的部分或在本例中为像素,每个部分将获得 25% 的信息。现在,如果将同一条线拉长为 8,则单元的长度保持不变,但每个单元的权重变为 12.5%。但是,图像中携带的信息是相同的。

图 8-2a 示例图片

图 8-2b 放大的示例图像

可以使用以下公式将相同的概念用于空缺的最近邻居实现:

该公式为我们提供了放大图像像素的坐标值。

了解双线性放大

要了解双线性图像放大的概念,我们仍然必须介绍线性插值。插值建议放大的一维扩展。考虑这样一种情况:一条直线在两端用两个值(x 1和 x 2)标记,并且被赋予相同的值。如果我们必须插入位于两端之间的第三个值,我们如何进行?

该算法建议我们可以使用加权平均值的概念来对未知值进行插值。可以根据点 x 1和 x 2之间的比例距离找到权重。当我们在二维中调整图像大小时可以使用此逻辑。

要获得(宽度,高度)维度中一个坐标的值,我们可以计算出每个维度中的线性插值。这将本质上有助于图像在二维中的大小调整。

继续讨论图像放大中最期待的概念,即神经网络。放大的方法可能看起来非常粗鲁和残酷,没有任何技巧。使用一些久经考验的公式来产生价值的逻辑已经存在,并且被一遍又一遍地使用。在某些情况下,重复可以产生奇迹,但在此期间不会发生任何变化或学习。因此,我们进入神经网络的学习部分。在我们编写代码和使用模型架构之前,我们需要讨论基础模块——VAE 和GAN。

变分自动编码器

深度学习领域最具革命性的改进之一是编码器-解码器架构。神经网络可以检索图像中存在的信息,并根据它们的理解重新创建图像。自动编码器架构是一种神经网络架构,可以学习数据中的模式并将其减少到更小的维度。这些尺寸再次可用于将图像重新创建回原始图像。重要的是要注意,从理论上讲,我们可以创建一个无损架构,完全重新创建图像。在现实中,这样的例子并不多见。

神经网络表示架构如图8-3所示,它演示了图像被互连层理解并转换为嵌入的情况。这种嵌入是根据已建立的模型架构来表示来自图像的信息。

图 8-3 编码器-解码器架构

网络的初始部分,通常称为编码器,学习输入图像中的数据分布或模式。它不仅要了解自己,配对解码器架构也需要解码嵌入。因此,特征提取和理解需要使解码器能够以非常小的损失从嵌入中破译有问题的原始图像。

另一方面,解码器端在嵌入层之后立即开始,并尝试以最小的信息损失将嵌入层转换为原始图像。这种信息压缩然后通过神经网络重新生成图像是自动编码器网络的概念。

传输信息时,传输带宽会影响图像的分辨率。压缩有助于将低分辨率图像传递到目的地。一旦图像到达目的地,解码器层就会开始行动,放大并恢复原始图像。

压缩和解压缩的概念可以进一步实现为放大图像。现在我们已经探索了编码器-解码器架构的基础知识,接下来将转向一个有趣的概念,称为变分编码器.

正如我们目前所见,传统的自动编码器架构使用来自输入的表征信息创建了一个潜在空间,以便解码器网络可以生成输出。但是想象一下,单个属性对离散值有贡献,并且在重新创建时将仅限于一个值。这个限制不会帮助模型从分布中生成新的东西,只会重复。如果我们想将潜在空间中的表示作为分布而不是离散值怎么办?我们可以做到,但是现在会出现两个不同的特征:

  • 随机过程

  • 确定性过程

它已经通过多个深度学习概念建立,无论何时网络接受训练,它都需要有一组可以学习和适应损失的过程。在给定模型参数的情况下,将有一个前向传播来计算损失(实际与输出之间的差异)。还会有反向传播,会根据预期输出和ground truth的差异产生的损失改变权重。因此,基本上我们只能在确定性的情况下训练网络。变分自编码器如图8-4所示.

图 8-4 VAE代表

从图8-4所示的代表性图像中,我们可以看到变分自动编码器如何尝试将数据分布映射到潜在空间。由 φ 参数化的编码器网络使用训练数据来学习从训练数据或 X 空间到潜在或 Z 空间的随机映射。

编码器或推理模型学习数据中的模式。可以证明,X空间的经验分布很复杂,而潜在空间很简单。由 θ 参数化的生成网络学习由 P(X|Z) 给出的分布。解码器部分从先验分布(通常是正态高斯分布)和确定性过程中学习。为了提出与自动编码器的区别,添加了一个额外的随机过程。图8-5显示了变分自动编码器网络的表示。它显示了添加到现有自动编码器架构中的随机性。

图 8-5 VAE 网络表示

因此,虽然早些时候我们更专注于寻找潜在空间的向量或离散值的嵌入,但现在我们将寻找均值和标准差的向量空间。

潜在分布为我们提供了过程的随机性。最终,我们将不得不反向传播来训练模型。为了克服这个训练问题,我们将均值视为固定向量。为了保持随机性并保持模型中注入的先验分布,我们将标准偏差视为受高斯先验分布的随机常数影响的固定向量。这个采样过程并不像看起来那么简单,因为我们的损失函数将是重建损失和另一个正则化损失。我们使用重新参数化技巧,其中 € 从先验正态高斯分布中采样,并按潜在分布的均值移动,然后按标准差缩放。公式将是:

Z = 平均值 + 标准差 * € ----- (i)

从标准随机节点,我们得到这个等式:

Z = Q(Z|X) 由 φ 参数化 ---------- (ii)

我们也可以图形化地可视化这个技巧,以便清除重新参数化的概念并将学习路径中的随机过程转换为确定性节点。

8-6a显示了由反向传播或本质上学习潜在空间的模型引起的问题。图8-6b显示了重新参数化的过程,其中反向传播可以通过实箭头线的通道进行。虚线箭头所示为随机过程,不妨碍训练过程,不直接参与反向传播。它没有学习任何东西,也没有根据损失函数调整权重。值得注意的是,过程类型如何随着随机过程从反向传播路径的偏移而在 Z 空间中发生变化。

图 8-6a VAE问题

图 8-6b 重新参数化技巧

公式 (i) 可以被认为是对图8-6b的粗略估计,而公式 (ii) 是对图8-6a的估计。

因此,我们已经建立了变分自动编码器的概念,它可以有多种用途。这个小的随机过程可以帮助生成从相同概率分布中提取的相似图像。有助于图像再生或图像生成,并且总是需要这两种类型。采样过程使生成器模型或解码器模型能够重新创建具有细微变化的相同分布的图像。在某些情况下,它有助于插入信号或图像。此插值概念可用于调整图像大小。现在我们简要介绍了变分自动编码器,在深入研究图像大小调整代码之前,让我们看看另一种生成形式的算法,称为生成对抗网络。

生成对抗网络

2014 年,Ian Goodfellow 将生成对抗网络引入深度学习。该网络能够创建与原始样本非常接近的更新样本。它们还广泛用于图像中的风格转换。

该网络是两个模型的组合——生成器模型和鉴别器模型。结合起来,这些模型形成了一种监督学习形式。

  • 生成器:该模型尝试根据域或问题集生成样本。这些最好是来自固定分布的样本。生成器接受随机输入(在大多数情况下,高斯分布用于帮助它处理输入)。在训练过程中,这些随机或无意义的点将被视为来自域分布。生成器应该能够从输入数据分布中生成表示。正如我们之前看到的,数据分布很复杂,编码器试图映射到一个更简单但高度压缩的信息块。这个空间通常被称为潜在空间,自动编码器的生成器块从中生成输出。这些模型可以理解数据分布的复杂性并创建一个表示,从中获取样本。这可以而且应该能够欺骗鉴别器或分类器。

  • 鉴别器:一旦生成器模型创建了它认为与原始数据分布非常相似的假样本,它们就会被传递给鉴别器模型进行验证和分类。它本质上是一个分类模型。它的工作是对生成器生成的图像进行分类,是假的还是真的。分类器区分真假图像。

我们已经确定必须同时训练生成器和鉴别器。这被称为生成对抗网络,因为生成模型和判别模型相互对抗。他们试图在一场零和博弈中让彼此变得更好。理论上,一个人不会打败另一个人。我们让生成器网络尝试尽可能逼真地创建假图像,以便鉴别器无法将其识别为假图像。另一方面,鉴别器模型正在尝试进行训练,以便图像中的任何错误都会被它捕获。在一个完美的世界中,生成器最终会生成鉴别器无法识别为假或真(50% 假/真)的图像。最终,生成器从网络中移除并用于其他目的。

模型代码

我们已经讨论了生成对抗网络背后的基本概念。这导致了它的众多用途之一,即超分辨率。它有各种应用,其中样式转换、图像生成和超分辨率很少。处理超分辨率的模型是SRGAN。它的前身之一,称为 SRResNet,在 SSIM 和 PSNR 方面取得了不错的结果。

让我们看看通常在超分辨率问题中确定的指标:

  • 结构相似性指数 (SSIM):该指标试图为由于从源图像到目标图像的变化而发生的退化量赋值。它检查图像各部分之间的感知相似性。它基于所选窗口的平均值和标准偏差。

  • 峰值信噪比:这是另一个重要的指标,用于衡量图像的重建损失或与原始图像的变化图像。它充其量可以通过均方误差计算来定义。它可以通过取以 10 为底的对数刻度来形成。

  • 平均意见得分 (MOS):这是由单个数字定义的,范围为 1 到 5。1 是最低的感知质量,5 是最高的感知质量。

现在我们已经了解了定义和衡量差异的指标,让我们看看我们将要开发代码的数据。

我们将使用 DIV2K 数据集,它有 1000 张高清图像,在训练、验证和测试数据方面按 800-100-100 划分。该数据可以从 CVPR 2017 中介绍的源论文中下载,网址为https://data.vision.ee.ethz.ch/cvl/DIV2K/

代码的设置需要遵循应用程序的标准构建。通常这意味着需要有一个模型文件、一些实用程序脚本、一个训练文件和一个验证文件。在少数情况下,此模型需要是托管在服务器中的应用程序,并且还需要有一个安装文件。一步一个脚印,我们可以从模型文件开始。

模型开发

该代码库具有生成模型块、判别模型块、残差块和内容损失计算块。

Imports

我们将为整个代码块使用Torch 框架。如果开发是在本地环境中完成的,那么我们必须确保 Torch 及其依赖项已安装在环境中并且可以正常工作。Torch 和 TorchVision 是需要设置的两个重要包。如果提供带有 CUDA 内核的 GPU,我们应该安装最新的 CUDA 包以帮助 PyTorch 利用并行 GPU 内核进行计算。对于模型脚本,我们导入了 Torch 和 TorchVision 相关的函数。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch import Tensor

接下来,我们定义生成器类来帮助重新生成图像。

class Generator(nn.Module):
    ## 定义生成器模型
    ## 扩展类
    ## 初始化顺序 - 网络期望 64x3
    def __init__(self) -> None:
        super(Generator, self).__init__()
        self.convolutional_block1 = nn.Sequential(
            nn.Conv2d(3, 64, (9, 9), (1, 1), (4, 4)),
            nn.PReLU()
        )
        ## 添加 16 个 resnet 转换块
        res_trunk = []
        for _ in range(16):
            res_trunk.append(ResidualConvBlock(64))
        self.res_trunk = nn.Sequential(*res_trunk)
        self.convolutional_block2 = nn.Sequential(
            nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(64)
        )
        self.upsampling = nn.Sequential(
            nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
            nn.PixelShuffle(2),
            nn.PReLU(),
            nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
            nn.PixelShuffle(2),
            nn.PReLU()
        )
        self.convolutional_block3 = nn.Conv2d(64, 3, (9, 9), (1, 1), (4, 4))
        self._initialize_weights()
    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)
    def _forward_impl(self, x: Tensor) -> Tensor:
        ## 定义前向传播 -> 3 个卷积块
        out1 = self.convolutional_block1(x)
        out = self.res_trunk(out1)
        out2 = self.convolutional_block2(out)
        output = out1 + out2
        output = self.upsampling(output)
        output = self.convolutional_block3(output)
        return output
    def _initialize_weights(self) -> None:
        ## 初始化权重
        ## 添加批量标准化的规定
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                m.weight.data *= 0.1
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                m.weight.data *= 0.1

此代码片段定义了能够重新生成图像的卷积块的类。重要的是,这个代码块由三个卷积块和一个上采样块组成。第一个卷积块之后是一个残差块,它作为整个生成器网络的主干。接下来是第二个卷积块。放大块由一对卷积层和像素混洗组成。最终,添加最终的卷积块以生成输出。该块提供批归一化层和 3x3 卷积层的组合。

正向传递有助于在函数正向实现中构建顺序模型。还有另一个函数来初始化权重。介绍完基本的生成器类后,我们将转到鉴别器的下一个类。

鉴别器模块在八层卷积之后扩展了标准的nn.module 。他们在每一层之后使用批量归一化来深入运行。模型结构使用 leaky ReLU作为激活函数。该模型以torch.flatten层结束,这有助于它对事物进行分类。

class Discriminator(nn.Module):
    ## 定义鉴别器
    def __init__(self) -> None:
        super(Discriminator, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=True),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True)
        )
        self.classifier = nn.Sequential(
            nn.Linear(512 * 6 * 6, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
    def forward(self, x: Tensor) -> Tensor:
        ## 定义正向传播
        output = self.features(x)
        output = torch.flatten(output, 1)
        output = self.classifier(output)
        return output
The model establishes the discriminator class in the architecture. Let’s look at the ContentLoss class.
class ContentLoss(nn.Module):
    ## 定义内容丢失类
    ## 特征提取器 - 直到 36
    def __init__(self) -> None:
        super(ContentLoss, self).__init__()
    ## 使用预训练的 VGG 模型提取特征
        vgg19_model = models.vgg19(pretrained=True, num_classes=1000).eval()
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:36])
        for parameters in self.feature_extractor.parameters():
            parameters.requires_grad = False
        self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
        self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
    def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
        hr = (hr - self.mean) / self.std
        sr = (sr - self.mean) / self.std
        mse_loss = F.mse_loss(self.feature_extractor(sr), self.feature_extractor(hr))
        return mse_loss

该课程使用预训练的 VGG 网络来提取特征以计算内容损失。在此之后,我们看一下残差卷积块。

class ResidualConvBlock(nn.Module):
    ## 获取残差块
    def __init__(self, channels: int) -> None:
        super(ResidualConvBlock, self).__init__()
        self.rc_block = nn.Sequential(
            nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(channels)
        )
    def forward(self, x: Tensor) -> Tensor:
        identity = x
        output = self.rc_block(x)
        output = output + identity
        return output

模型脚本到此结束。在此之后,我们将查看一些辅助函数,从创建数据集开始。

def main():
    r""" 训练和测试 """
    image_list = os.listdir(os.path.join("train", "input"))
    test_img_list = random.sample(image_list,
                                     int(len(image_list) / 10))
    ## 遍历测试文件
    for test_img_file in test_img_list:
        filename = os.path.join("train", "input", test_img_file)
        logger.info(f"Process: `{filename}`.")
        shutil.move(os.path.join("train", "input", test_img_file),
                    os.path.join("test", "input", test_img_file))
        shutil.move(os.path.join("train", "target", test_img_file),
                    os.path.join("test", "target", test_img_file))

该功能有助于定义训练测试分离并定位它以供训练工作进行。在线的另一个重要功能是裁剪功能,我们接下来可以检查它。它有助于返回裁剪后的图像。

def crop_image(img, crop_sizes: int):
    assert img.size[0] == img.size[1]
    crop_num = img.size[0] // crop_sizes
    box_list = []
    for width_index in range(0, crop_num):
        for height_index in range(0, crop_num):
            box_info = ( (height_index + 0)*crop_sizes,(width_index + 0) * crop_sizes,
                   (height_index + 1)*crop_sizes,(width_index + 1) * crop_sizes)
            box_list.append(box_info)
    cropped_images = [img.crop(box_info) for box_info in box_list]
    return cropped_images

下一个要处理的重要功能之一是数据集类。数据集类根据配置和可用性向训练函数提供批量信息。

class BaseDataset(Dataset):
    ## 基础数据集类从 pytorch 扩展数据集类
    ## 应用随机裁剪、旋转等增强技术
    ## 水平翻转和张量
    ## 调整大小和中心裁剪也被使用
    ## 最终转换为张量
    def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None:
        super(BaseDataset, self).__init__()
        self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)]
        lr_img_size = (image_size // upscale_factor, image_size // upscale_factor)
        hr_img_size = (image_size, image_size)
        if mode == "train":
            self.hr_transforms = transforms.Compose([
                transforms.RandomCrop(hr_img_size),
                transforms.RandomRotation(90),
                transforms.RandomHorizontalFlip(0.5),
                transforms.ToTensor()
            ])
        else:
            self.hr_transforms = transforms.Compose([
                transforms.CenterCrop(hr_img_size),
                transforms.ToTensor()
            ])
        self.lr_transforms = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(lr_img_size, interpolation=IMode.BICUBIC),
            transforms.ToTensor()
        ])
    def __getitem__(self, index) -> Tuple[Tensor, Tensor]:
        hr = Image.open(self.filenames[index])
        temp_lr = self.lr_transforms(hr)
        temp_hr = self.hr_transforms(hr)

数据集基类提供随机裁剪、中心裁剪、随机旋转、水平翻转、调整大小等增强功能。最终,它将数据转换为 PyTorch 框架的张量。它还具有长度和获取项目功能。

在开发所需的所有重要功能之后,我们归结为训练序列。训练序列训练生成器。代码块如下:

def train_generator(train_dataloader, epochs) -> None:
    ## 从训练生成器开始
    ## 定义数据加载器
    ## 定义损失函数
    batch_count = len(train_dataloader)
    ## 开始训练生成器块
    generator.train()
    for index, (lr, hr) in enumerate(train_dataloader):
        ## 获取 hr 到 cuda 或 cpu
        hr = hr.to(device)
        ## 获取 lr 到 cuda 或 cpu
        lr = lr.to(device)
        ## 将生成器初始化为零梯度以避免梯度累积
        ## 仅在时基模型的情况下建议累加
        generator.zero_grad()
        sr = generator(lr)
        ## 定义像素损失
        pixel_losses = pixel_criterion(sr, hr)
        ## 从优化器获取阶跃函数
        pixel_losses.backward()
        ## 生成器的 adam 优化器
        p_optimizer.step()
        iteration = index + epochs * batch_count + 1
        writer.add_scalar(" computing train generator Loss", pixel_losses.item(), iteration)

同样,对抗块的训练如下。

def train_adversarial(train_dataloader, epoch) -> None:
    ## 用于训练对抗网络
    batches = len(train_dataloader)
    ## 训练判别器和生成器
    discriminator.train()
    generator.train()
    for index, (lr, hr) in enumerate(train_dataloader):
        hr = hr.to(device)
        lr = lr.to(device)
        label_size = lr.size(0)
        fake_label = torch.full([label_size, 1], 0.0, dtype=lr.dtype, device=device)
        real_label = torch.full([label_size, 1], 1.0, dtype=lr.dtype, device=device)
        ## 初始化零梯度因为我们想避免梯度累积
        discriminator.zero_grad()
        output_dis = discriminator(hr)
        dis_loss_hr = adversarial_criterion(output_dis, real_label)
        dis_loss_hr.backward()
        dis_hr = output_dis.mean().item()
        sr = generator(lr)
        output_dis = discriminator(sr.detach())
        dis_loss_sr = adversarial_criterion(output_dis, fake_label)
        dis_loss_sr.backward()
        dis_sr1 = output_dis.mean().item()
        dis_loss = dis_loss_hr + dis_loss_sr
        d_optimizer.step()
        generator.zero_grad()
        output = discriminator(sr)
        pixel_loss = pixel_weight * pixel_criterion(sr, hr.detach())
        perceptual_loss = content_weight * content_criterion(sr, hr.detach())
        adversarial_loss = adversarial_weight * adversarial_criterion(output, real_label)
        gen_loss = pixel_loss + perceptual_loss + adversarial_loss
        gen_loss.backward()
        g_optimizer.step()
        dis_sr2 = output.mean().item()
        iteration = index + epoch * batches + 1
        writer.add_scalar("Train_Adversarial/D_Loss", dis_loss.item(), iteration)
        writer.add_scalar("Train_Adversarial/G_Loss", gen_loss.item(), iteration)
        writer.add_scalar("Train_Adversarial/D_HR", dis_hr, iteration)
        writer.add_scalar("Train_Adversarial/D_SR1", dis_sr1, iteration)
        writer.add_scalar("Train_Adversarial/D_SR2", dis_sr2, iteration)

最终,我们将研究一个验证块,将生成器和对抗网络放在一起。

以下代码将所有内容放在主函数中并运行整个训练序列:

如果恢复:

## 用于恢复训练

如果 resume_p_weight != "":

generator.load_state_dict(torch.load(resume_p_weight))

别的:

鉴别器.load_state_dict(torch.load(resume_d_weight))

generator.load_state_dict(torch.load(resume_g_weight))

best_psnr_val = 0.0

对于范围内的纪元(start_p_epoch,p_epochs):

train_generator(train_dataloader,纪元)

psnr_val = validate(valid_dataloader, epoch, "generator")

best_condition = psnr_val > best_psnr_val

best_psnr_val = max(psnr_val, best_psnr_val)

torch.save(generator.state_dict(), os.path.join(exp_dir1, f"p_epoch{epoch + 1}.pth"))

如果 best_condition:

torch.save(generator.state_dict(), os.path.join(exp_dir2, "p-best.pth"))

## 保存最佳模型

torch.save(generator.state_dict(), os.path.join(exp_dir2, "p-last.pth"))

best_psnr_val = 0.0

generator.load_state_dict(torch.load(os.path.join(exp_dir2, "p-best.pth")))

对于范围内的纪元(start_epoch,纪元):

train_adversarial(train_dataloader,纪元)

psnr_val = validate(valid_dataloader, epoch, "adversarial")

best_condition = psnr_val > best_psnr_val

best_psnr_val = max(psnr_val, best_psnr_val)

torch.save(discriminator.state_dict(), os.path.join(exp_dir1, f"d_epoch{epoch + 1}.pth"))

torch.save(generator.state_dict(), os.path.join(exp_dir1, f"g_epoch{epoch + 1}.pth"))

如果 best_condition:

torch.save(discriminator.state_dict(), os.path.join(exp_dir2, "d-best.pth"))

torch.save(generator.state_dict(), os.path.join(exp_dir2, "g-best.pth"))

d_scheduler.step()

g_scheduler.step()

torch.save(discriminator.state_dict(), os.path.join(exp_dir2, "d-last.pth"))

torch.save(generator.state_dict(), os.path.join(exp_dir2, "g-last.pth"))

有了这个,我们总结了代码并可以看看如何运行它。代码块应如图8-7所示完成后,我们可以转到下一部分,其中包括运行应用程序。

图 8-7

代码开发模板

运行应用程序

要运行该应用程序,我们需要首先将数据集下载到正确的目录,或者通过配置脚本将数据目录映射到训练函数。配置脚本很重要,因为它将所有脚本和位置绑定在一起。它帮助应用程序了解需要什么。

要下载数据,我们可以使用 bash 访问下载脚本。

!bash ./data/download_dataset.sh

安装后,我们只需运行训练脚本。

!蟒蛇火车.py

一旦生成器训练完成,对抗训练就会开始。我们可以快速浏览一下时代的样子。

训练纪元[0016/0020](00010/00050) 损失:0.008974。

训练纪元[0016/0020](00020/00050) 损失:0.009684。

训练纪元[0016/0020](00030/00050) 损失:0.004455。

训练纪元[0016/0020](00040/00050) 损失:0.008851。

训练纪元[0016/0020](00050/00050) 损失:0.008883。

有效阶段:生成器 Epoch[0016] 平均 PSNR:21.19。

训练纪元[0017/0020](00010/00050) 损失:0.005397。

训练纪元[0017/0020](00020/00050) 损失:0.006351。

训练纪元[0017/0020](00030/00050) 损失:0.007704。

训练纪元[0017/0020](00040/00050) 损失:0.007926。

训练纪元[0017/0020](00050/00050) 损失:0.005559。

有效阶段:生成器 Epoch[0017] 平均 PSNR:21.37。

训练纪元[0018/0020](00010/00050) 损失:0.006054。

训练纪元[0018/0020](00020/00050) 损失:0.008028。

训练纪元[0018/0020](00030/00050) 损失:0.006164。

训练纪元[0018/0020](00040/00050) 损失:0.006737。

训练纪元[0018/0020](00050/00050) 损失:0.007716。

有效阶段:生成器 Epoch[0018] 平均 PSNR:21.36。

训练纪元[0019/0020](00010/00050) 损失:0.009527。

训练纪元[0019/0020](00020/00050) 损失:0.004672。

训练纪元[0019/0020](00030/00050) 损失:0.004574。

训练纪元[0019/0020](00040/00050) 损失:0.005196。

训练纪元[0019/0020](00050/00050) 损失:0.007712。

有效阶段:生成器 Epoch[0019] 平均 PSNR:21.64。

训练纪元[0020/0020](00010/00050) 损失:0.006843。

训练纪元[0020/0020](00020/00050) 损失:0.007701。

训练纪元[0020/0020](00030/00050) 损失:0.005366。

训练纪元[0020/0020](00040/00050) 损失:0.004797。

训练纪元[0020/0020](00050/00050) 损失:0.008607。

有效阶段:生成器 Epoch[0020] 平均 PSNR:21.53。

训练阶段:adversarial Epoch[0001/0005](00010/00050) D Loss:0.051520 G Loss:0.574723 D(HR):0.970196 D(SR1)/D(SR2):0.019971/0.003046.

训练阶段:adversarial Epoch[0001/0005](00020/00050) D Loss:0.001356 G Loss:0.528222 D(HR):0.998656 D(SR1)/D(SR2):0.000007/0.000005。

训练阶段:adversarial Epoch[0001/0005](00030/00050) D Loss:0.004768 G Loss:0.574079 D(HR):0.999959 D(SR1)/D(SR2):0.004646/0.000619。

训练阶段:adversarial Epoch[0001/0005](00040/00050) D Loss:0.000339 G Loss:0.557449 D(HR):0.999820 D(SR1)/D(SR2):0.000159/0.000527。

训练阶段:adversarial Epoch[0001/0005](00050/00050) D Loss:0.009615 G Loss:0.531170 D(HR):0.990858 D(SR1)/D(SR2):0.000000/0.000000。

有效阶段:对抗性 Epoch[0001] 平均 PSNR:11.47。

训练阶段:adversarial Epoch[0002/0005](00010/00050) D Loss:0.000002 G Loss:0.488294 D(HR):0.999998 D(SR1)/D(SR2):0.000000/0.000000。

训练阶段:adversarial Epoch[0002/0005](00020/00050) D Loss:0.114398 G Loss:0.568630 D(HR):0.947419 D(SR1)/D(SR2):0.000000/0.000000。

训练阶段:adversarial Epoch[0002/0005](00030/00050) D Loss:3.704494 G Loss:0.580344 D(HR):0.230086 D(SR1)/D(SR2):0.000000/0.000000。

训练阶段:adversarial Epoch[0002/0005](00040/00050) D Loss:0.000804 G Loss:0.557581 D(HR):0.999662 D(SR1)/D(SR2):0.000464/0.000324.

训练阶段:adversarial Epoch[0002/0005](00050/00050) D Loss:0.001132 G Loss:0.459117 D(HR):0.999191 D(SR1)/D(SR2):0.000317/0.000301。

有效阶段:对抗性 Epoch[0002] 平均 PSNR:12.48。

训练阶段:adversarial Epoch[0003/0005](00010/00050) D Loss:0.000187 G Loss:0.488436 D(HR):0.999847 D(SR1)/D(SR2):0.000033/0.000032。

训练阶段:adversarial Epoch[0003/0005](00020/00050) D Loss:0.001537 G Loss:0.444651 D(HR):0.999899 D(SR1)/D(SR2):0.001425/0.001385。

训练阶段:adversarial Epoch[0003/0005](00030/00050) D Loss:0.000169 G Loss:0.493448 D(HR):0.999877 D(SR1)/D(SR2):0.000046/0.000041。

训练阶段:adversarial Epoch[0003/0005](00040/00050) D Loss:0.000285 G Loss:0.465992 D(HR):0.999925 D(SR1)/D(SR2):0.000210/0.000202。

训练阶段:adversarial Epoch[0003/0005](00050/00050) D Loss:0.000720 G Loss:0.567912 D(HR):0.999978 D(SR1)/D(SR2):0.000695/0.000668。

有效阶段:对抗性 Epoch[0003] 平均 PSNR:13.09。

训练阶段:adversarial Epoch[0004/0005](00010/00050) D Loss:0.000293 G Loss:0.479247 D(HR):0.999786 D(SR1)/D(SR2):0.000079/0.000076。

训练阶段:adversarial Epoch[0004/0005](00020/00050) D Loss:0.000064 G Loss:0.492225 D(HR):0.999978 D(SR1)/D(SR2):0.000042/0.000041。

训练阶段:adversarial Epoch[0004/0005](00030/00050) D Loss:0.000030 G Loss:0.444387 D(HR):0.999984 D(SR1)/D(SR2):0.000014/0.000014。

训练阶段:adversarial Epoch[0004/0005](00040/00050) D Loss:0.000108 G Loss:0.387137 D(HR):0.999918 D(SR1)/D(SR2):0.000025/0.000025。

训练阶段:对抗性 Epoch[0004/0005](00050/00050) D 损失:0.000224 G 损失:0.513328 D(HR):0.999825 D(SR1)/D(SR2):0.000049/0.000048。

有效阶段:对抗性 Epoch[0004] 平均 PSNR:13.29。

在这个训练集上,我们使用可配置的时期和其他训练参数,这些都在配置文件中可用。一旦模型准备好下载,我们就可以用它来将图像放大四倍。我们可以在训练时配置放大因子。这样,我们的培训过程就结束了。

概括

本章从与放大图像相关的问题开始,讨论了如何进行放大。我们讨论了当前可用的各种方法和建模技术的优点。讨论并实施了 SRGAN 等最先进的算法。我们还经历了建立项目的培训过程。本章讨论了我们如何使用卷积模型结合生成模型通过某种因素来放大图像。超分辨率是一个不断发展的领域并且被广泛使用,例如从交通摄像头检测车牌或增强旧照片。这是计算机视觉中一个非常重要的领域,并进行了多年的研究。

在接下来的章节中,我们将从静止图像的概念转向移动图像,也称为视频。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【CV with Pytorch】第 8 章 :图像超分辨率 的相关文章

随机推荐

  • 我在windows10下,使用msys64 mingw64终端

    系列文章目录 文章目录 系列文章目录 前言 一 MSYS2是什么 前言 msys2官网 MSYS2 Minimal SYStem 2 是一个MSYS的独立改写版本 主要用于 shell 命令行开发环境 同时它也是一个在Cygwin POSI
  • JavaDay06

    用户登录 提示用户输入用户名和密码 如果用户名和密码不是 admin 和 123 的话 就提示用户继续输入 最多输入五次 用户登录 提示用户输入用户名和密码 如果用户名和密码不是 admin 和 123 的话 就提示用户继续输入 最多输入五
  • 数据结构4-单链表的删除修改和查找

    1 单链表按照顺序插入节点 package com yin m3LinkedList public class SingleLinkedListDemo public static void main String args TODO Au
  • C语言(关于浮点数比较的学习)

    由于浮点数十进制转化成二进制的机制 会造成精度损失 因此在浮点数的比较中 无法直接令两个浮点数是否相等来判断两个浮点数 如 include
  • 上传报org.apache.tomcat.util.http.fileupload.impl.FileSizeLimitExceededException: The field file exceed

    错误如下 springBoot项目自带的tomcat对上传的文件大小有默认的限制 SpringBoot官方文档中展示 每个文件的配置最大为1Mb 单次请求的文件的总数不能大于10Mb 解决方法 Spring Boot 2 5 6 版本 在
  • 紫禁繁花服务器维护,各种坑的坑。

    该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 最开始玩的小主 建议开菜坑 会拉开一部分势力 前期略微有明显 比如你冲宫斗去 冲榜去来它是首选 特别是国力的234榜 攻略摘要 势力增加快 复仇积分多 宫斗提升 雨露增加快 办差收货多 势力提升
  • C++使用dll的一些探索

    一 动态链接库的加载方式 隐式加载又称载入时加载 指在主程序载入内存时搜索DLL 并将DLL载入内存 使用隐式加载时 使用者需要DLL链接库的 h文件 lib文件和 dll文件 lib文件包含DLL导出的函数声明和变量的符号名 dll 文件
  • 旧电脑改造nas黑群晖_黑群晖教程:旧电脑不吃灰,手把手教你变成千元顶级NAS...

    前言 如果有玩PCDIY 玩摄影 玩PT 那么一定有听说过NAS NAS中群晖的NAS又是使用体验最佳的 群晖NAS系统在功能上十分齐全 人机界面做的也较为出众 但可惜的是 机器本体价格相对来说高昂 很多人在看到售价后只能摇摇头作罢 黑群晖
  • altium designer执行DRC检查+消除绿色错误

    由原理图生成PCB以后 各种显示绿色 也即PCB报错 如下图 绿色的原因是DRC Design Rule Check 检查未通过 解决方法是正确设置规则 但是在此之前 为了观感 我们先掩耳盗铃一下 临时清除绿色 步骤是 菜单栏 gt 工具
  • 线性代数-向量,矩阵,线性变换

    一 向量 向量要求具有两个条件 长度 大小 方向 二维 三维 计算机中 向量可看做列表 图中第一个列表有两行 我们说它是二维向量 第二个列表有四行 我们说他是四维向量 向量的运算 向量加法 向量加法 将对应的行相加 将向量w的起点平移到向量
  • Java基础(2)面向对象的理解

    面向对象学习 面向对象与面向过程的区别 面向过程思想适合简单 不需要协作的任务 面向对象需要很多协作才能完成 面向对象就应运而生了 object 对象 instance 实例 都是解决问题的思维模式 都是代码组织的方式 解决简单问题可以使用
  • 通过文件夹文件获取文件夹大小

    思路就是便利文件夹下的每个文件 碰到子文件夹递归进去继续找文件 所有的文件大小累加起来 int GetFolderSize LPCTSTR szPath TCHAR szFileFilter 512 TCHAR szFilePath 512
  • mysql的sql语句没错但是报错,sql语句可以正常执行,但是报错:【merge sql error, dbType mysql, sql :】...

    错误信息如下 2017 09 06 19 03 41 186 ERROR method com alibaba druid filter stat StatFilter mergeSql StatFilter java 147 merge
  • 基于有道API的命令行词典(golang版)

    Godict 本项目地址 近期一直再使用golang语言开发一些工具 相关的后端技术链 golang orm postgresql gin jwt logrus 和对应前端的技术链 vue iview axios vue router 基本
  • matlab逆变器原理,MATLAB中的单相全桥逆变器电路建模与仿真

    电子技术设计和应用电子设计和应用电子技术O 3969 j issn 1000 0755 201 5 03 020 MATLAB中的单相全桥逆变器电路建模与仿真杨露容军刘凯周雷李仁贵 湖南工学院信息与通信工程学院 湖南岳阳 描述了全桥逆变器电
  • 什么是MapReduce,MapReduce的工作流程和原理是什么

    一 MapReduce的概念 MapReduce是一种编程模型 用于大规模数据集 大于1TB 的并行运算 概念 Map 映射 和 Reduce 归约 和它们的主要思想 都是从函数式编程语言里借来的 还有从矢量编程语言里借来的特性 它极大地方
  • 对拦截器的小小理解

    对于初学架构的 color red 小白 color 来讲 拦截器绝对是一把需要掌握的 color red 利器 color 那么自己从以下几个方面 谈谈对拦截器的小小思考 拦截器的方法在Action执行前或执行后自动执行 从而将通用的操作
  • 吐血解决磁盘占用率100%

    吐血解决磁盘占用率100 问题简述 解决步骤 吐血解决 磁盘利用率高的建议 问题简述 一次偶然使用电脑后 发现每次开机后 磁盘长时间占用率达到100 带来的影响是打开浏览器 打开本地电脑磁盘特别卡 解决步骤 1 尝试了网络上提供的绝大部分方
  • 常用的范数求导

    矢量范数的偏导数 L1范数不可微 但是存在次梯度 即是次微分的 L1范数的次梯度如下 x x 1 sign x begin equation begin aligned frac partial partial mathbf x mathb
  • 【CV with Pytorch】第 8 章 :图像超分辨率

    随着高分辨率图像捕获代理的出现 图像中捕获的信息是巨大的 技术已经从超高清转向 4K 和 8K 分辨率 如今 电影正在使用高分辨率帧 但是 在某些情况下 他们需要将低分辨率图像增强为高分辨率图像 想象这样一个场景 电影的主角正试图确定从一张