GAN生成MNIST数据-PyTorch

2023-11-10

摘抄别处,供自己学习用

直接上代码,代码如下:

# coding=utf-8
import torch.autograd
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image
import os
from torch.utils.data import DataLoader

# 创建文件夹
if not os.path.exists('./img'):
    os.mkdir('./img')


def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)  # Clamp函数可以将随机变化的数值限制在一个给定的区间[min, max]内:
    out = out.reshape(-1, 1, 28, 28)  # 将一行再次拼接多行 再次形成图片
    return out


batch_size = 128
num_epoch = 1000
z_dimension = 50
# 图像预处理
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1,), (0.5,))
])
# mnist dataset mnist数据集下载,没有下载的将download改成True
mnist = datasets.MNIST(
    root='./mnist/', train=True, transform=img_transform, download=True
)
# data loader 数据载入
dataloader = DataLoader( dataset=mnist, batch_size=batch_size, shuffle=True
)


# 定义判别器  #####Discriminator######使用多层网络来作为判别器
# 将图片28x28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,
# 最后接sigmoid激活函数得到一个0到1之间的概率进行二分类。
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 512),  # 输入特征数为784,输出为512
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),  # 进行非线性映射
            nn.Linear(512, 256),  # 进行一个线性映射
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 也是一个激活函数,二分类问题中,
        )

    def forward(self, x):
        x = self.dis(x)
        return x


####### 定义生成器 Generator #####
# 输入一个50维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维,
# 然后通过LeakyReLU激活函数,接着进行一个线性变换,再经过一个LeakyReLU激活函数,
# 然后经过线性变换将其变成784维,最后经过Tanh激活函数是希望生成的假的图片数据分布
# 能够在-1~1之间。
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(50, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.gen(x)
        return x

# 创建对象
D = discriminator()
G = generator()

#########判别器训练train#####################
# 分为两部分:1、真的图像判别为真;2、假的图像判别为假
# 此过程中,生成器参数不断更新
# 首先需要定义loss的度量方式  (二分类的交叉熵)
# 其次定义 优化函数,优化函数的学习率为0.0003
loss_function = nn.BCELoss()  # 是单目标二分类交叉熵函数
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
#####################进入训练##判别器的判断过程#####################
for epoch in range(num_epoch):  # 进行多个epoch的训练
    for i, (img, label) in enumerate(dataloader):
        num_img = img.size(0)

        # reshape()函数作用是将一个多行的Tensor,拼接成一行
        # 第一个参数是要拼接的tensor,第二个参数是-1
        # =========================训练判别器=====================
        img = img.reshape(num_img, -1)  # 将图片展开为28*28=784
        real_img = Variable(img)  # 将tensor变成Variable放入计算图中
        real_label = Variable(torch.ones(num_img))  # 定义真实的图片label为1
        fake_label = Variable(torch.zeros(num_img))  # 定义假的图片的label为0

        # 计算真实图片的损失
        real_out = D(real_img)  # 将真实图片放入判别器中    # 得到真实图片的判别值,输出的值越接近1越好
        real_out = real_out.squeeze()  # (128,1) ---> (128,)
        d_loss_real = loss_function(real_out, real_label)  # 得到真实图片的loss
        real_scores = real_out
        # 计算假的图片的损失
        z = Variable(torch.randn(num_img, z_dimension))  # 随机生成一些噪声
        fake_img = G(z)  # 随机噪声放入生成网络中,生成num_img张假的图片  128*784
        fake_out = D(fake_img)  # 判别器判断假的图片  # 得到假图片的判别值,对于判别器来说,假图片的损失越接近0越好
        fake_out = fake_out.squeeze() # (128,1) -> (128,)
        d_loss_fake = loss_function(fake_out, fake_label)  # 得到假的图片的loss
        fake_scores = fake_out
        # 损失函数和优化
        d_loss = d_loss_real + d_loss_fake  # 损失包括判真损失和判假损失
        d_optimizer.zero_grad()  # 在反向传播之前,先将梯度归0
        d_loss.backward()  # 将误差反向传播
        d_optimizer.step()  # 更新参数

        # ==================训练生成器============================
        ################################生成网络的训练###############################
        # 原理:目的是希望生成的假的图片被判别器判断为真的图片,
        # 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,
        # 反向传播更新的参数是生成网络里面的参数,
        # 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的
        # 这样就达到了对抗的目的
        # 计算假的图片的损失
        z = Variable(torch.randn(num_img, z_dimension))  # 得到随机噪声
        fake_img = G(z)  # 随机噪声输入到生成器中,得到num_img副假的图片

        output = D(fake_img)  # 经过判别器得到的结果
        output  = output.squeeze()
        g_loss = loss_function(output, real_label)  # 得到的假的图片与真实的图片的label的loss
        g_optimizer.zero_grad()  # 梯度归0
        g_loss.backward()  # 进行反向传播
        g_optimizer.step()  # .step()一般用在反向传播后面,用于更新生成网络的参数
        # 打印中间的损失
        if (i + 1) % 100 == 0:
            print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '
                  'D-real-scores: {:.6f},D-fake-scores: {:.6f}'.format(
                epoch, num_epoch, d_loss.data.item(), g_loss.data.item(),
                real_scores.data.mean(), fake_scores.data.mean()  # 打印的是真实图片的损失均值
            ))
        if (epoch+1) % 2 == 0:
            # real_images = to_img(real_img.data)
            # save_image(real_images, './img/real_images--{}.png'.format(epoch + 1))
            fake_images = to_img(fake_img.data)
            # save_image(fake_images, './img/fake_images-{}-{}.png'.format(i, epoch + 1))
# 保存模型
# torch.save(G.state_dict(), './generator.pth')
# torch.save(D.state_dict(), './discriminator.pth')



本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

GAN生成MNIST数据-PyTorch 的相关文章

随机推荐

  • Python(1)--Python安装

    本篇作为学习Python笔记 来记录学习过程 安装环境 windows10 官方下载地址 https www python org 有很多的版本 我这里选择了3 7 2 executable表示可执行版 需要安装后使用 embeddable
  • Python基础 NumPy数组相关概念及操作

    NumPy是Python的一种开源的数值计算扩展库 提供 数组支持以及相应的高效处理函数 它包含很多功能 如创建n维数组 矩阵 对数组进行函数运算 数值积分 线性代数计算 傅里叶变换和随机数产生等 Why NumPy 标准的Python用L
  • CentOS8基础篇2:文件系统

    一 文件系统概述 1 文件系统的基本概念 操作系统中负责管理和存储文件信息的软件机构称为文件管理系统 简称文件系统 它规定了文件的存储方式及文件索引方式等信息 文件系统主要由三部分组成 分别是与文件管理相关的软件 被管理的文件和实施文件管理
  • 神经网络中的神经元和激活函数详解

    在上一节 我们通过两个浅显易懂的例子表明 人工智能的根本目标就是在不同的数据集中找到他们的边界 依靠这条边界线 当有新的数据点到来时 只要判断这个点与边界线的相互位置就可以判断新数据点的归属 上一节我们举得例子中 数据集可以使用一条直线区分
  • jdk13快来了,jdk8的这几点应该看看!

    说明 jdk8虽然出现很久了 但是可能我们还是有很多人并不太熟悉 本文主要就是介绍说明一些jdk8相关的内容 主要会讲解 lambda表达式 方法引用 默认方法 Stream 用Optional取代null 新的日志和时间 Completa
  • 自定义view

    自定义View 有这一篇就够了 简书 jianshu com
  • STM32cubeProgrammer连接设置说明

    芯片型号 STM32F427 连接 connect Frequency设置为200 点击connection REG模块 随后device选STM32F427 peripheral选择GPIOD
  • android应用安装成功之后删除apk文件

    摘要 题目 正在运用开辟中碰到须要如许的需供 正在用户下载我们的运用装置以后删除装置包 办理 android会正在每一个中界操纵APK的举措以后收回体系级其余播送 过滤器称号 问题 在应用开发中遇到需要这样的需求 在用户下载我们的应用安装之
  • C语言学前班

    C 语言学前班 10分钟入门 10天练习 哪有那么难 根本用不着科班通过上课学几个月 程序 数据结构 算法 数据结构 容器来存储要进行各种操作的数据 算法 对各种数据进行各种操作 加减乘除 增删改查 判 判断 排 排序 复 复位 输出结果来
  • Some Android licenses not accepted. To resolve this, run: flutter doctor --android-licenses 解决方法

    mopondys iMac zyc flutter doctor Doctor summary to see all details run flutter doctor v Flutter Channel dev v1 16 2 on M
  • NVisionXR for ARCore内测版开放申请

    NVisionXR for ARCore引擎能够帮助开发者快速开发原生ARCore应用 只要你懂基本的Android开发 直接使用Android Studio 即可实现动画模型渲染 粒子特效 音视频播放 灯光渲染等功能 NVisionXR引
  • java线程池的使用

    线程池概述 线程池 Thread Pool 是一种基于池化思想管理线程的工具 使用线程池可以带来诸多好处 降低资源消耗 通过池化技术复用已创建的线程 减少线程创建和销毁的损耗 提高响应速度 任务到达时 特定情况下无需再创建线程 便于管理 j
  • hangfire+bootstrap ace 模板实现后台任务管理平台

    前言 前端时间刚开始接触Hangfire就翻译了一篇官方的教程 翻译 山寨 Hangfire Highlighter Tutorial 后来在工作中需要实现一个异步和定时执行的任务管理平台 就结合bootstrap ace模板和hangfi
  • echarts中多y轴图像(柱,折)

    先看看效果吧 var myChart echarts init document getElementById demo echarts zyyh 放入的id var colors e6bcff a3ffcd fefefe option c
  • C++之explicit的作用介绍

    1 C 中的关键字explicit主要是用来修饰类的构造函数 被修饰的构造函数的类 不能发生相应的隐式类型转换 只能以显示的方式进行类型转换 类构造函数默认情况下声明为隐式的即implicit 隐式转换即是可以由单个实参来调用的构造函数定义
  • 147. 精读《@types react 值得注意的 TS 技巧》

    1 引言 从 types react 源码中挖掘一些 Typescript 使用技巧吧 2 精读 泛型 extends 泛型可以指代可能的参数类型 但指代任意类型范围太模糊 当我们需要对参数类型加以限制 或者确定只处理某种类型参数时 就可以
  • 2022年江西省中职组“网络空间安全”赛项模块B-Web渗透测试

    2022年中职组山西省 网络空间安全 赛项 B 8 Web渗透测试任务书 B 8 Web渗透测试解析 不懂可以私信博主 一 竞赛时间 420分钟 共计7小时 吃饭一小时 二 竞赛阶段 竞赛阶段 任务阶段 竞赛任务 竞赛时间 分值 第 阶段
  • 【MySQL】数据库基本操作:创建删除数据库(Create/Drop),表增删改查

    数据库基本操作 1 启动服务 DOS命令 net start mysql 回车 2 登录MySQL数据库 mysql uroot proot 回车 3 查看MySQL中数据库 show databases 4 创建数据库 create da
  • 2023备战金三银四,Python自动化软件测试面试宝典合集(八)

    马上就又到了程序员们躁动不安 蠢蠢欲动的季节 这不 金三银四已然到了家门口 元宵节一过后台就有不少人问我 现在外边大厂面试都问啥 想去大厂又怕面试挂 面试应该怎么准备 测试开发前景如何 面试 一个程序员成长之路永恒绕不过的话题 每每到这个时
  • GAN生成MNIST数据-PyTorch

    摘抄别处 供自己学习用 直接上代码 代码如下 coding utf 8 import torch autograd import torch nn as nn from torch autograd import Variable from