GAN+pytorch实现MNIST生成

2023-10-28

背景知识

代码实现

本文实现最简单的例子,利用GAN生成MNIST的数字,代码如下:

导入包

%matplotlib inline

import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

import torchvision
from torchvision import models
from torchvision import transforms

# 如果有gpu就用gpu,如果没有就用cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

导入数据集

batch_size=32

# Compose定义了一系列transform,此操作相当于将多个transform一并执行
transform = transforms.Compose([
    transforms.ToTensor(),    
    # mnist是灰度图,此处只将一个通道标准化
    transforms.Normalize(mean=(0.5), 
                         std=(0.5))
    ])
                         
# 设定数据集
mnist_data = torchvision.datasets.MNIST("./mnist_data", train=True, download=True, transform=transform)

# 加载数据集,按照上述要求,shuffle本意为洗牌,这里指打乱顺序,很形象
dataloader = torch.utils.data.DataLoader(dataset=mnist_data,
                                         batch_size=batch_size,
                                         shuffle=True)
                                         

在线下载MNIST时如果下载速度特别慢可以更改源码,改为本地。

定义模型

image_size = 784
hidden_size = 256

# Discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid() # sigmoid结果为(0,1)
)

# Generator
latent_size = 64 # latent_size,相当于初始噪声的维数
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh() # 转换至(-1,1)
)

# 放到gpu上计算(如果有的话)
D = D.to(device)
G = G.to(device)

# 定义损失函数、优化器、学习率
loss_fn = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

开始训练

# 先定义一个梯度清零的函数,方便后续使用
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

# 迭代次数与计时
total_step = len(dataloader)
num_epochs = 200
start = time.perf_counter() # 开始时间

# 开始训练
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(dataloader): # 当前step
        batch_size = images.size(0) # 变成一维向量
        images = images.reshape(batch_size, image_size).to(device)
        
        # 定义真假label,用作评分
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # 对D进行训练,D的损失函数包含两部分
        # 第一部分,D对真图的判断能力
        outputs = D(images) # 将真图送入D,输出(0,1),应该是越接近1越好
        d_loss_real = loss_fn(outputs, real_labels)
        real_score = outputs # 真图的分数,越大越好
        
        # 第二部分,D对假图的判断能力
        z = torch.randn(batch_size, latent_size).to(device) # 开始生成一组fake images即32*784的噪声经过G的假图
        fake_images = G(z)
        outputs = D(fake_images.detach()) # 将假图片给D,detach表示不作用于求grad
        d_loss_fake = loss_fn(outputs, fake_labels)
        fake_score = outputs # 假图的分数,越小越好
        
        # 开始优化discriminator
        d_loss = d_loss_real + d_loss_fake # 总的损失就是以上两部分相加,越小越好
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # 对G进行训练,G的损失函数包含一部分
        # 可以用前面的z,也可以新生成,因为模型没有改变,事实上是一样的
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = loss_fn(outputs, real_labels) # G想骗过D,故让其越接近1越好
        
        # 开始优化generator
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        # 优化完成,下面进行一些反馈,展示学习进度
        if i % 100 == 0:
            print("Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}"
                  .format(epoch, num_epochs, i, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item()))

# 训练结束,跳出循环,检验成果
end = time.perf_counter() # 结束时间
total = end - start
minutes = total//60
seconds = total - minutes*60
print("利用GPU总用时:{:.2f}分钟{:.2f}秒".format(minutes, seconds))

结果

...上面结果省略

Epoch [199/200], Step [1600/1875], d_loss: 0.9140, g_loss: 1.4904, D(x): 0.64, D(G(z)): 0.25
Epoch [199/200], Step [1700/1875], d_loss: 0.7004, g_loss: 1.8600, D(x): 0.72, D(G(z)): 0.23
Epoch [199/200], Step [1800/1875], d_loss: 0.8012, g_loss: 1.5045, D(x): 0.72, D(G(z)): 0.27
利用GPU总用时:102.00分钟6.29

我的是自己电脑的GPU(NVIDIA 1050ti),经过102分钟,我们已经训练好了网络,接下来看看输出是否满足需求:

检验成果

# 向G输入一个噪声,观察生成的图片
z = torch.randn(1, latent_size).to(device)
fake_images = G(z).view(28, 28).data.cpu().numpy()
plt.imshow(fake_images, cmap = plt.cm.gray)

我们得到输出:

看起来还不错,再对比一下原图:

对比原图

plt.imshow(next(iter(dataloader))[0][0][0], cmap = plt.cm.gray)

原图如下:

还不错啦!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

GAN+pytorch实现MNIST生成 的相关文章

  • torch.unique() 中的参数“dim”如何工作?

    我试图提取矩阵每一行中的唯一值并将它们返回到同一个矩阵中 重复值设置为 0 例如 我想转换 torch Tensor 1 2 3 4 3 3 4 1 6 3 5 3 5 4 to torch Tensor 1 2 3 4 0 0 0 1 6
  • MNIST、torchvision 中的输出和广播形状不匹配

    在 Torchvision 中使用 MNIST 数据集时出现以下错误 RuntimeError output with shape 1 28 28 doesn t match the broadcast shape 3 28 28 这是我的
  • PyTorch 中的截断反向传播(代码检查)

    我正在尝试在 PyTorch 中实现随时间截断的反向传播 对于以下简单情况K1 K2 我下面有一个实现可以产生合理的输出 但我只是想确保它是正确的 当我在网上查找 TBTT 的 PyTorch 示例时 它们在分离隐藏状态 将梯度归零以及这些
  • Cuda和pytorch内存使用情况

    我在用Cuda and Pytorch 1 4 0 当我尝试增加batch size 我遇到以下错误 CUDA out of memory Tried to allocate 20 00 MiB GPU 0 4 00 GiB total c
  • 二维数组的按行 numpy.isin [重复]

    这个问题在这里已经有答案了 我有两个数组 A np array 3 1 4 1 1 4 B np array 0 1 5 2 4 5 2 3 5 是否可以使用numpy isin二维数组按行排列 我想检查一下是否A i j is in B
  • RuntimeError:维度指定为 0 但张量没有维度

    我试图使用 MNIST 数据集实现简单的 NN 但我不断收到此错误 将 matplotlib pyplot 导入为 plt import torch from torchvision import models from torchvisi
  • 我可以使用逻辑索引或索引列表对张量进行切片吗?

    我正在尝试使用列上的逻辑索引对 PyTorch 张量进行切片 我想要与索引向量中的 1 值相对应的列 切片和逻辑索引都是可能的 但是它们可以一起吗 如果是这样 怎么办 我的尝试不断抛出无用的错误 类型错误 使用 ByteTensor 类型的
  • 查找张量中沿轴的非零元素的数量

    我想找到沿特定轴的张量中非零元素的数量 有没有 PyTorch 函数可以做到这一点 我尝试使用非零 http pytorch org docs master torch html highlight nonzero torch nonzer
  • 将 CNN Pytorch 中的预训练权重传递到 Tensorflow 中的 CNN

    我在 Pytorch 中针对 224x224 大小的图像和 4 个类别训练了这个网络 class CustomConvNet nn Module def init self num classes super CustomConvNet s
  • 如何避免 PyTorch 中的“CUDA 内存不足”

    我认为对于 GPU 内存较低的 PyTorch 用户来说 这是一个非常常见的消息 RuntimeError CUDA out of memory Tried to allocate X MiB GPU X X GiB total capac
  • pytorch grad 在 .backward() 之后为 None

    我刚刚安装火炬 1 0 0 on Python 3 7 2 macOS 并尝试tutorial https pytorch org tutorials beginner blitz autograd tutorial html sphx g
  • 在pytorch张量中过滤数据

    我有一个张量X like 0 1 0 5 1 0 0 1 2 0 我想实现一个名为的函数filter positive 它可以将正数据过滤成新的张量并返回原始张量的索引 例如 new tensor index filter positive
  • torch.mm、torch.matmul 和 torch.mul 有什么区别?

    阅读完 pytorch 文档后 我仍然需要帮助来理解之间的区别torch mm torch matmul and torch mul 由于我不完全理解它们 所以我无法简明地解释这一点 B torch tensor 1 1207 0 3137
  • 如何在pytorch中查看DataLoader中的数据

    我在 Github 上的示例中看到类似以下内容 如何查看该数据的类型 形状和其他属性 train data MyDataset int 1e3 length 50 train iterator DataLoader train data b
  • 下载变压器模型以供离线使用

    我有一个训练有素的 Transformer NER 模型 我想在未连接到互联网的机器上使用它 加载此类模型时 当前会将缓存文件下载到 cache 文件夹 要离线加载并运行模型 需要将 cache 文件夹中的文件复制到离线机器上 然而 这些文
  • BatchNorm 动量约定 PyTorch

    Is the 批归一化动量约定 http pytorch org docs master modules torch nn modules batchnorm html 默认 0 1 与其他库一样正确 例如Tensorflow默认情况下似乎
  • 如何有效地对一个数组中某个值在另一个数组中的位置出现的次数求和

    我正在寻找一种有效的 for 循环 避免解决方案来解决我遇到的数组相关问题 我想使用一个巨大的一维数组 A gt size 250 000 用于一维索引的 0 到 40 之间的值 以及用于第二维索引的具有 0 到 9995 之间的值的相同大
  • Pytorch GPU 使用率低

    我正在尝试 pytorch 的例子https pytorch org tutorials beginner blitz cifar10 tutorial html https pytorch org tutorials beginner b
  • PyTorch 中的连接张量

    我有一个张量叫做data形状的 128 4 150 150 其中 128 是批量大小 4 是通道数 最后 2 个维度是高度和宽度 我有另一个张量叫做fake形状的 128 1 150 150 我想放弃最后一个list array从第 2 维
  • 在Pytorch中计算欧几里得范数..理解和实现上的麻烦

    我见过另一个 StackOverflow 线程讨论计算欧几里德范数的各种实现 但我很难理解特定实现的原因 如何工作 该代码可以在 MMD 指标的实现中找到 https github com josipd torch two sample b

随机推荐

  • 蓝桥杯单片机学习日记4-串口接收与发送,解决串口引脚与按键引脚冲突

    此片文章用于记录蓝桥杯单片机的学习 串口的发送与接收较为简单 主要是字节和字符串的发送与接收 直接上程序 串口初始化 void UartInit void 9600bps 11 0592MHz SCON 0x50 8位数据 可变波特率 AU
  • MySQL数据库是非关系_MySQL(数据库)基础知识、关系型数据库yu非关系型数据库、连接认证...

    什么是数据库 数据库 Database 存储数据的仓库 高效地存储和处理数据的介质 介质主要是两种 磁盘和内存 数据库系统 DBS Database System 是一种虚拟系统 将多种内容关联起来的称呼 DBS DBMS DB DBMS
  • 汇编语言(王爽)第四版学习1

    第一章 机器语言 0 1 简单语句 mov ax bx 汇编语言组成 1 汇编指令 机器码的助记符 有对应的机器码 2 伪指令 没有对应的机器码 由编译器执行 计算机并不执行 3 其他符号 如 等 由编译器识别 没有对应的机器码 存储器 内
  • SpringCloud微服务---Nacos配置中心

    1 Nacos Config 服务配置 1 1 服务配置中心介绍 首先我们来看一下 微服务架构下关于配置文件的一些问题 1 配置文件相对分散 在一个微服务架构下 配置文件会随着微服务的增多变的越来越多 而且分散在各个微服务中 不好统一配置和
  • 007 数据结构_堆——“C”

    前言 本文将会向您介绍关于堆Heap的实现 具体步骤 tips 本文具体步骤的顺序并不是源代码的顺序 typedef int HPDataType typedef struct Heap HPDataType a int size int
  • 04Python爬虫:retrying模块

    代码 结果 None 转载于 https www cnblogs com jumpkin1122 p 11521013 html
  • 勇担重任从不放弃——一个阿里P7的内部求职故事

    Java开发程序员在互联网行业中名声在外 同时也意味着竞争特别激烈 当然 在众多从业者中 并不是每个人都可以经历从Java外包到成为阿里P7这样的成功故事 不过 这个同志通过自己坚定的信仰和勤奋的努力 不仅完成了自己的进步增值 而且分享了成
  • 泰勒图(Taylor Diagrams)和常用模型评价指标小结

    文章内容仅用于自己知识学习和分享 如有侵权 还请联系并删除 一 泰勒图 1 原理 1 1 定义 泰勒图 Taylor diagram 可以简单的理解为一种的可同时展示相关系数 their correlation 中心均方根误差 their
  • win11上的虚拟机安装Ubuntu 16.04和基础环境配置教程

    1 安装 VM 17 win11最好装VM16以后的 2 下载 ubuntu 的iso文件 可以在国内的镜像站下载更快 如下是阿里云的镜像站ubuntu 16 04 网址 https mirrors aliyun com oldubuntu
  • LeetCode 37 把数组排成最小的数

    示例 1 输入 10 2 输出 102 示例 2 输入 3 30 34 5 9 输出 3033459 提示 0 lt nums length lt 100 解题思路 此题求拼接起来的最小数字 本质上是一个排序问题 设数组 nums 中任意两
  • 利用Hu不变矩进行特征提取

    include stdafx h include
  • C# 串口CRC CCITT-FALSE 校验

    串口CRC CCITT FALSE 校验 public static bool CRC16 CCITT FALSE byte byteData C crc 16 CCITT FALSE 带判断校验的 bool flag false usho
  • 大数据从入门到精通(超详细版)之HDFS安装部署 , 跟着部署 , 真的有手就行 !

    前言 嗨 各位小伙伴 恭喜大家学习到这里 不知道关于大数据前面的知识遗忘程度怎么样了 又或者是对大数据后面的知识是否感兴趣 本文是 大数据从入门到精通 超详细版 的一部分 小伙伴们如果对此感谢兴趣的话 推荐大家按照大数据学习路径开始学习哦
  • 什么是link标签?

    什么是link标签 link标签通常放置在一个网页的头部标签head标签内的用于链接外部css文件 链接收藏夹图标 favicon ico 标签最常见的用途是链接外部样式表 外部资源 link实例 链接外部css样式时候link标签的内容结
  • Android性能优化之内存优化

    前言 成为一名优秀的Android开发 需要一份完备的知识体系 在这里 让我们一起成长为自己所想的那样 内存优化可以说是性能优化中最重要的优化点之一 可以说 如果你没有掌握系统的内存优化方案 就不能说你对Android的性能优化有过多的研究
  • OpenHarmony鸿蒙 润和Pegasus套件样例--智能安防

    润和Pegasus套件样例 智能安防 该样例展示OpenHarmony智能安防项目 当温度传感器超过设定值后 或者烟雾传感器检测到烟雾时 会触发蜂鸣器工作 同时通知到HarmonyOS手机上的APP 下载源码 建议将本教程的设备源码下载后
  • 小白入门——“贪吃蛇”的C语言实现(详细)

    C语言实现 编译环境VS 附 easyx图形化 文章末尾 效果图如下 有一些函数kbhit getch 在这表示为 kbhit与 getch 不同编译器原因 注意在Dev等集成开发软件下可能会CE o o 一 引言 作为一个小白 相信大家的
  • 个人工作失误复盘

    今天 同门突然指出了我在去年10月做一项代码测试工作时犯的错误 当时 我的任务是测试某论文中新发布的图像配准算法在我们的航拍图像配准任务上的效果 以便决定是否在其上进行改进 我按照readme文件中的指引下载了预训练权重 并按照项目代码中给
  • 接口自动化测试环境搭建(unittest+requests+HTMLTestRunner)

    该自动化测试框架基于python单元测试框架unittest 使用HTMLTestRunner来生成测试报告 使用Requests xlrd 和 xlwt等库 一 安装python运行环境 安装包官方下载地址 https www pytho
  • GAN+pytorch实现MNIST生成

    背景知识 GAN 原理可以在这里查看 GAN入门简介 pytorch 一个深度学习的框架 关于环境配置有问题 可以在这里查看 从零开始机器学习 代码实现 本文实现最简单的例子 利用GAN生成MNIST的数字 代码如下 导入包 matplot