【GAN】基础原理讲解及代码实践

2023-11-18

首先什么是GAN:

 

 

 

 

 

 

 

 

GAN的模型结构

 

设计GAN模型的关键:

 

 

 GAN的算法原理:

 

 

 

这里输入噪声的随机性就可以带来生成图像的多样性

 

 

 

 

 GAN公式讲解:

 

 

 

 D(x)表示判别器对真实图片的判别,取对数函数后我们希望其值趋于0,也就是D(x)趋于1,也就是放大损失。

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 GAN代码实践(基于jupyter,顺序执行即可):

导包

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

torch.__version__

数据准备

# 对数据做归一化 (-1, 1)对gan的输入数据全部规范化到(-1,1)之间
transform = transforms.Compose([   #transform做变形
    transforms.ToTensor(),         # ToTensor会将图像像素值转换为0-1; channel, high, witch,
    transforms.Normalize(0.5, 0.5) #然后我们通过均值为0.5,方差为0.5将数据规范化到(-1,1)
])


train_ds = torchvision.datasets.MNIST('data',
                                      train=True,
                                      transform=transform,
                                      download=True)#定义MNIST数据集


dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)#加载数据集,打乱,batch_size设置为64
#%%
imgs, _ = next(iter(dataloader))#加载一个批次的图片(64张)
#%%
imgs.shape

 

定义生成器

# 输入是长度为 100 的 噪声(符合正态分布的随机数)
# 输出为(1, 28, 28)的图片
#linear 1 :   100----256
#linear 2:    256----512
#linear 2:    512----28*28
#reshape:     28*28----(1, 28, 28)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
                                  nn.Linear(100, 256),
                                  nn.ReLU(),
                                  nn.Linear(256, 512),
                                  nn.ReLU(),
                                  nn.Linear(512, 28*28),
                                  nn.Tanh()                     # 对于-1, 1之间的数据分布,Tanh效果最好。输出的取值范围是-1,1之间
        )
    def forward(self, x):              # 前向传播,x 表示长度为100 的noise输入
        img = self.main(x)#将x输入到main模型中 得到img
        img = img.view(-1, 28, 28)#通过view函数reshape成(-1,28,28,1)
        return img

 

定义判别器

## 输入为(1, 28, 28)的图片  输出为二分类的概率值,输出使用sigmoid激活 0-1
# BCEloss计算交叉熵损失

# nn.LeakyReLU   f(x) : x>0 输出 x, 如果x<0 ,输出 a*x  a表示一个很小的斜率,比如0.1
# 判别器中一般推荐使用 LeakyReLU,RELU激活函数在小于0没有任何梯度,会非常难以训练


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()#继承父类的属性
        self.main = nn.Sequential(
                                  nn.Linear(28*28, 512),#输入一张图片28,8,然后展平成28*28,再卷积到256
                                  nn.LeakyReLU(),
                                  nn.Linear(512, 256),
                                  nn.LeakyReLU(),
                                  nn.Linear(256, 1),
                                  nn.Sigmoid()
        )
    def forward(self, x):#x输入的是28,28的图片
        x = x.view(-1, 28*28)#展平
        x = self.main(x)
        return x

初始化模型、优化器及损失计算函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'#默认使用cuda,否则cpu
#%%
gen = Generator().to(device)#初始化Generator模型
dis = Discriminator().to(device)#初始化Discriminator模型
#%%
d_optim = torch.optim.Adam(dis.parameters(), lr=0.0001)#定义优化器,学习率
g_optim = torch.optim.Adam(gen.parameters(), lr=0.0001)
#%%
loss_fn = torch.nn.BCELoss()#二分类判别模型

绘图函数

def gen_img_plot(model, test_input):#每次都给一个同样的test_input正态分布随机数
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())#detach用来截断梯度,放到cpu上,转换为numpy,squeeze用于去掉维度为一的值,鲁棒性更高===>28*28的数组
    fig = plt.figure(figsize=(4, 4))#绘制16张图片
    for i in range(16):#循环
        plt.subplot(4, 4, i+1)#四行四列的第一张
        plt.imshow((prediction[i] + 1)/2)#转换成0,1之间的数值(预测的结果恢复到0,1之间
        plt.axis('off')#关闭
    plt.show()
#%%
test_input = torch.randn(16, 100, device=device)#生成长度为100的一个批次16张的随机噪声输入

 

GAN的训练

D_loss = []
G_loss = []#定义空列表用来放两个模型生成的loss
#%%
# 训练循环
for epoch in range(20):#训练20轮
    d_epoch_loss = 0
    g_epoch_loss = 0#初始化损失函数为0
    count = len(dataloader)#返回批次数,len(dataset)返回样本数
    for step, (img, _) in enumerate(dataloader):#_表示标签,这里生成模型用不到,enumerate用于对dataloader迭代
        img = img.to(device)#将照片上传到设备上
        size = img.size(0)#获批次大小根据这个大小来输入我们随机噪声的输入大小
        random_noise = torch.randn(size, 100, device=device)#生成噪声随机数,大小个数是size
        
        d_optim.zero_grad()#将梯度归0
        
        real_output = dis(img)      # 判别器输入真实的图片,real_output对真实图片的预测结果 真实图片为1,假图片为0
        d_real_loss = loss_fn(real_output, 
                              torch.ones_like(real_output))      # 得到判别器在真实图像上的损失  ones_like:全1数组
        d_real_loss.backward()#反向传播,计算梯度
        
        gen_img = gen(random_noise)
        # 判别器输入生成的图片,fake_output对生成图片的预测
        fake_output = dis(gen_img.detach()) #这里阶段梯度是因为,这里通过对判别器输入生成图片去计算损失是用来优化判别器的。对生成器的参数暂时不做优化。所以梯度不用再传递到生成器模型当中了,我们希望fake_output被判定为0
        d_fake_loss = loss_fn(fake_output,
                              torch.zeros_like(fake_output))      # 得到判别器在生成图像上的损失,zeros_like:全0数组
        d_fake_loss.backward()#同样计算梯度
        #以上是用来优化判别器
        d_loss = d_real_loss + d_fake_loss#判别器的总损失(两部分)
        d_optim.step()#进行优化
        
        g_optim.zero_grad()#梯度归零
        fake_output = dis(gen_img)#将生成图片放到判别器当中--不要梯度截断
        g_loss = loss_fn(fake_output, #我们这里就希望fake_output被判定为1用来优化生成器
                         torch.ones_like(fake_output))      # 生成器的损失
        g_loss.backward()#计算梯度
        g_optim.step()#权重优化
        
        with torch.no_grad():#两个模型的损失函数做累加(不需要计算梯度)---每个批次累加==一个epoch
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss
            
    with torch.no_grad():#得到平均loss
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss.item())
        G_loss.append(g_epoch_loss.item())#这样列表当中会保存每个epoch的平均loss
        print('Epoch:', epoch)#打印当前epoch
        gen_img_plot(gen, test_input)#绘图

运行效果

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【GAN】基础原理讲解及代码实践 的相关文章

  • PyTorch:如何使用 DataLoaders 自定义数据集

    如何利用torch utils data Dataset and torch utils data DataLoader根据您自己的数据 不仅仅是torchvision datasets 有没有办法使用内置的DataLoaders他们使用的
  • 检查 PyTorch 张量在 epsilon 内是否相等

    如何检查两个 PyTorch 张量在语义上是否相等 考虑到浮点错误 我想知道元素是否仅相差一个小的 epsilon 值 在撰写本文时 这是最新稳定版本 0 4 1 中的一个未记录的函数 但文档位于master unstable branch
  • MNIST、torchvision 中的输出和广播形状不匹配

    在 Torchvision 中使用 MNIST 数据集时出现以下错误 RuntimeError output with shape 1 28 28 doesn t match the broadcast shape 3 28 28 这是我的
  • 推导 pytorch 网络的结构

    对于我的用例 我需要能够采用 pytorch 模块并解释模块中的层序列 以便我可以以某种文件格式在层之间创建 连接 现在假设我有一个简单的模块 如下所示 class mymodel nn Module def init self input
  • pytorch - “conv1d”在哪里实现?

    我想看看 conv1d 模块是如何实现的https pytorch org docs stable modules torch nn modules conv html Conv1d https pytorch org docs stabl
  • Pytorch 数据加载器:错误的文件描述符和 EOF > 0

    问题描述 在使用由自定义数据集制作的 Pytorch 数据加载器进行神经网络训练期间 我遇到了奇怪的行为 数据加载器设置为workers 4 pin memory False 大多数时候 训练都顺利完成 有时 训练会随机停止 并出现以下错误
  • 如何使用 torch.stack?

    我该如何使用torch stack将两个张量与形状堆叠a shape 2 3 4 and b shape 2 3 没有就地操作 堆叠需要相同数量的维度 一种方法是松开并堆叠 例如 a size 2 3 4 b size 2 3 b torc
  • 通过 Conda 安装 PyTorch

    目标 使用 pytorch 和 torchvision 创建 conda 环境 Anaconda 导航器 1 8 3 python 3 6 MacOS 10 13 4 我尝试过的 在Navigator中 创建了一个新环境 尝试安装 pyto
  • 预训练 Transformer 模型的配置更改

    我正在尝试为重整变压器实现一个分类头 分类头工作正常 但是当我尝试更改配置参数之一 config axis pos shape 即模型的序列长度参数时 它会抛出错误 Reformer embeddings position embeddin
  • 我可以使用逻辑索引或索引列表对张量进行切片吗?

    我正在尝试使用列上的逻辑索引对 PyTorch 张量进行切片 我想要与索引向量中的 1 值相对应的列 切片和逻辑索引都是可能的 但是它们可以一起吗 如果是这样 怎么办 我的尝试不断抛出无用的错误 类型错误 使用 ByteTensor 类型的
  • 在pytorch中使用tensorboard,但得到空白页面?

    我在pytorch 1 3 1中使用tensorboard 并且我在张量板的 pytorch 文档 https pytorch org docs stable tensorboard html 运行后tensorboard logdir r
  • 为什么 PyTorch nn.Module.cuda() 不将模块张量移动到 GPU,而仅将参数和缓冲区移动到 GPU?

    nn Module cuda 将所有模型参数和缓冲区移动到 GPU 但为什么不是模型成员张量呢 class ToyModule torch nn Module def init self gt None super ToyModule se
  • torch.mm、torch.matmul 和 torch.mul 有什么区别?

    阅读完 pytorch 文档后 我仍然需要帮助来理解之间的区别torch mm torch matmul and torch mul 由于我不完全理解它们 所以我无法简明地解释这一点 B torch tensor 1 1207 0 3137
  • 尝试理解 Pytorch 的 LSTM 实现

    我有一个包含 1000 个示例的数据集 其中每个示例都有5特征 a b c d e 我想喂7LSTM 的示例 以便它预测第 8 天的特征 a 阅读 nn LSTM 的 Pytorchs 文档 我得出以下结论 input size 5 hid
  • 从打包序列中获取每个序列的最后一项

    我试图通过 GRU 放置打包和填充的序列 并检索每个序列最后一项的输出 当然我的意思不是 1项目 但实际上是最后一个 未填充的项目 我们预先知道序列的长度 因此应该很容易为每个序列提取length 1 item 我尝试了以下方法 impor
  • Pytorch CUDA 错误:没有内核映像可用于在带有 cuda 11.1 的 RTX 3090 设备上执行

    如果我运行以下命令 import torch import sys print A sys version print B torch version print C torch cuda is available print D torc
  • Pytorch 损失为 nan

    我正在尝试用 pytorch 编写我的第一个神经网络 不幸的是 当我想要得到损失时遇到了问题 出现以下错误信息 RuntimeError Function LogSoftmaxBackward0 returned nan values in
  • 如何从已安装的云端硬盘文件夹中永久删除?

    我编写了一个脚本 在每次迭代后将我的模型和训练示例上传到 Google Drive 以防发生崩溃或任何阻止笔记本运行的情况 如下所示 drive path drive My Drive Colab Notebooks models if p
  • PyTorch 中的连接张量

    我有一个张量叫做data形状的 128 4 150 150 其中 128 是批量大小 4 是通道数 最后 2 个维度是高度和宽度 我有另一个张量叫做fake形状的 128 1 150 150 我想放弃最后一个list array从第 2 维
  • 如何使用 pytorch 同时迭代两个数据加载器?

    我正在尝试实现一个接收两张图像的暹罗网络 我加载这些图像并创建两个单独的数据加载器 在我的循环中 我想同时遍历两个数据加载器 以便我可以在两个图像上训练网络 for i data in enumerate zip dataloaders1

随机推荐