生成对抗网络GAN---生成mnist手写数字图像示例(附代码)

2023-05-16

Ian J. Goodfellow等人于2014年在论文Generative Adversarial Nets中提出了一个通过对抗过程估计生成模型的新框架。框架中同时训练两个模型:一个生成模型(generative model)G,用来捕获数据分布;一个判别模型(discriminative model)D,用来估计样本来自于训练数据的概率。G的训练过程是将D错误的概率最大化。可以证明在任意函数G和D的空间中,存在唯一的解决方案,使得G重现训练数据分布,而D=0.5。

生成对抗网络(GAN,Generative Adversarial Networks)的基本原理很简单:假设有两个网络,生成网络G和判别网络D。生成网络G接受一个随机的噪声z并生成图片,记为G(z);判别网络D的作用是判别一张图片x是否真实,对于输入x,D(x)是x为真实图片的概率。在训练过程中, 生成器努力让生成的图片更加真实从而使得判别器无法辨别图像的真假,而D的目标就是尽量把分辨出真实图片和生成网络G产出的图片,这个过程就类似于二人博弈,G和D构成了一个动态的“博弈过程”。随着时间的推移,生成器和判别器在不断地进行对抗,最终两个网络达到一个动态平衡:生成器生成的图像G(z)接近于真实图像分布,而判别器识别不出真假图像,即D(G(z))=0.5。最后,我们就可以得到一个生成网络G,用来生成图片。

对于GAN更加直观的理解:生成模型可以被看做是一个伪造团队,试图生产假币并且在不被发现的情况下使用, 而判别模型则类似于警察,尝试检查是否为假币。伪造团队的目的是生产出警察识别不出的假币,而警察则是想更加精确地识别出假币,因此在这个游戏中,两个团队因为各自目的而不断改进它们的方法直到伪造团队生产的假币警察分辨不出来。

 上面讲述生成对抗网络的基本原理, 为了能够更深此理解GAN,下面我们使用GAN来生成MNIST数据集。

import tensorflow as tf
import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data
from matplotlib import pyplot as plt

BATCH_SIZE = 64
UNITS_SIZE = 128
LEARNING_RATE = 0.001
EPOCH = 300
SMOOTH = 0.1

mnist = input_data.read_data_sets('/mnist_data/', one_hot=True)


# 生成模型
def generatorModel(noise_img, units_size, out_size, alpha=0.01):
    with tf.variable_scope('generator'):
        FC = tf.layers.dense(noise_img, units_size)
        reLu = tf.nn.leaky_relu(FC, alpha)
        drop = tf.layers.dropout(reLu, rate=0.2)
        logits = tf.layers.dense(drop, out_size)
        outputs = tf.tanh(logits)
        return logits, outputs

# 判别模型
def discriminatorModel(images, units_size, alpha=0.01, reuse=False):
    with tf.variable_scope('discriminator', reuse=reuse):
        FC = tf.layers.dense(images, units_size)
        reLu = tf.nn.leaky_relu(FC, alpha)
        logits = tf.layers.dense(reLu, 1)
        outputs = tf.sigmoid(logits)
        return logits, outputs

# 损失函数
"""
判别器的目的是:
1. 对于真实图片,D要为其打上标签1
2. 对于生成图片,D要为其打上标签0
生成器的目的是:对于生成的图片,G希望D打上标签1
"""
def loss_function(real_logits, fake_logits, smooth):
    # 生成器希望判别器判别出来的标签为1; tf.ones_like()创建一个将所有元素都设置为1的张量
    G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,
                                                                    labels=tf.ones_like(fake_logits)*(1-smooth)))
    # 判别器识别生成器产出的图片,希望识别出来的标签为0
    fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,
                                                                       labels=tf.zeros_like(fake_logits)))
    # 判别器判别真实图片,希望判别出来的标签为1
    real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits,
                                                                       labels=tf.ones_like(real_logits)*(1-smooth)))
    # 判别器总loss
    D_loss = tf.add(fake_loss, real_loss)
    return G_loss, fake_loss, real_loss, D_loss

# 优化器
def optimizer(G_loss, D_loss, learning_rate):
    train_var = tf.trainable_variables()
    G_var = [var for var in train_var if var.name.startswith('generator')]
    D_var = [var for var in train_var if var.name.startswith('discriminator')]
    # 因为GAN中一共训练了两个网络,所以分别对G和D进行优化
    G_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(G_loss, var_list=G_var)
    D_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(D_loss, var_list=D_var)
    return G_optimizer, D_optimizer


# 训练
def train(mnist):
    image_size = mnist.train.images[0].shape[0]
    real_images = tf.placeholder(tf.float32, [None, image_size])
    fake_images = tf.placeholder(tf.float32, [None, image_size])

    #调用生成模型生成图像G_output
    G_logits, G_output = generatorModel(fake_images, UNITS_SIZE, image_size)
    # D对真实图像的判别
    real_logits, real_output = discriminatorModel(real_images, UNITS_SIZE)
    # D对G生成图像的判别
    fake_logits, fake_output = discriminatorModel(G_output, UNITS_SIZE, reuse=True)
    # 计算损失函数
    G_loss, real_loss, fake_loss, D_loss = loss_function(real_logits, fake_logits, SMOOTH)
    # 优化
    G_optimizer, D_optimizer = optimizer(G_loss, D_loss, LEARNING_RATE)

    saver = tf.train.Saver()
    step = 0
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        for epoch in range(EPOCH):
            for batch_i in range(mnist.train.num_examples // BATCH_SIZE):
                batch_image, _ = mnist.train.next_batch(BATCH_SIZE)
                # 对图像像素进行scale,tanh的输出结果为(-1,1)
                batch_image = batch_image * 2 -1
                # 生成模型的输入噪声
                noise_image = np.random.uniform(-1, 1, size=(BATCH_SIZE, image_size))
                #
                session.run(G_optimizer, feed_dict={fake_images:noise_image})
                session.run(D_optimizer, feed_dict={real_images: batch_image, fake_images: noise_image})
                step = step + 1
            # 判别器D的损失
            loss_D = session.run(D_loss, feed_dict={real_images: batch_image, fake_images:noise_image})
            # D对真实图片
            loss_real =session.run(real_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
            # D对生成图片
            loss_fake = session.run(fake_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
            # 生成模型G的损失
            loss_G = session.run(G_loss, feed_dict={fake_images: noise_image})
            print('epoch:', epoch, 'loss_D:', loss_D, ' loss_real', loss_real, ' loss_fake', loss_fake, ' loss_G', loss_G)
            model_path = os.getcwd() + os.sep + "mnist.model"
            saver.save(session, model_path, global_step=step)

def main(argv=None):
    train(mnist)

if __name__ == '__main__':
    tf.app.run()

上述是训练模型,下面是测试模型,依据训练模型训练的参数。generatorImage函数生成手写字体图片, 在这里显示了25张图片。 生成图像如下图1所示,还能够大略猜出生成的图片中的数字。

import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
import pickle
import mnist_GAN

UNITS_SIZE = mnist_GAN.UNITS_SIZE

def generatorImage(image_size):
    sample_images = tf.placeholder(tf.float32, [None, image_size])
    G_logits, G_output = mnist_GAN.generatorModel(sample_images, UNITS_SIZE, image_size)
    saver = tf.train.Saver()
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        saver.restore(session, tf.train.latest_checkpoint('.'))
        sample_noise = np.random.uniform(-1, 1, size=(25, image_size))
        samples = session.run(G_output, feed_dict={sample_images:sample_noise})
    with open('samples.pkl', 'wb') as f:
        pickle.dump(samples, f)

def show():
    with open('samples.pkl', 'rb') as f:
        samples = pickle.load(f)
    fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=5, sharey=True, sharex=True)
    for ax, image in zip(axes.flatten(), samples):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.imshow(image.reshape((28, 28)), cmap='Greys_r')
    plt.show()

def main(argv=None):
    image_size = mnist_GAN.mnist.train.images[0].shape[0]
    generatorImage(image_size)
    show()

if __name__ == '__main__':
    tf.app.run()
图1. 生成图片展示

上述基于MNIST数据集构造了一个简单的GAN模型,对于生成模型和判别模型,仅仅使用了简单的神经网络,对于图像的处理,卷积神经网络更胜一筹,如果将生成模型和判别模型改为深度卷积网络,那么生成更加清晰的图片。 而且目前也有各种GAN变体,后续慢慢整理。

参考博客:基于GAN的mnist训练集图片生成神经网络实现_gan训练集_lpty的博客-CSDN博客

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

生成对抗网络GAN---生成mnist手写数字图像示例(附代码) 的相关文章

  • npm安装报错 rollbackFailedOptional verb npm-session 解决办法

    该问题一般情况是因为代理问题 xff0c npm代理和git代理都要设置 首先确认网络是否需要设置代理 如果是公司网络需要设置代理 xff0c 则设置npm代理和git代理 1 设置npm代理 1npm config set proxy a
  • Docker 安装C#编译环境

    Docker 是一个开源的应用容器引擎 xff0c 让开发者可以打包他们的应用以及依赖包到一个可移植的镜像中 xff0c 然后发布到任何流行的 Linux或Windows操作系统的机器上 xff0c 也可以实现虚拟化 本文主要介绍Docke
  • 关于句柄中带命名空间对实际程序运行中的影响

    ROS头文件 include lt ros ros h gt 自定义msg产生的头文件 include lt topic demo gps h gt int main int argc char argv 用于解析ROS参数 xff0c 第
  • 部署Redis集群

    部署Redis集群 span class token comment 创建网卡 span span class token function docker span network create redis subnet span clas
  • ROS编译出现generate_messages() must be called after add_message_files()错误

    新人小白 xff0c 刚刚开始学ROS 编译出现这个错误 xff0c 搞了好久也没找到这个问题 xff0c 后来偶然发现了问题所在 CMake Error at opt ros kinetic share genmsg cmake genm
  • SLAM因子图构建笔记

    因子图简介 最近在读了Joan Sola所写的Course on SLAM中有关因子图部分的介绍后 xff0c 发现其中有关于因子图构建的思路觉得很有意思 xff0c 因此在这里记录一下 DBN网络 首先简单地介绍一下如何将一个SLAM问题
  • xmanager关闭linux命令,Linux下xmanager passive功能无法使用的解决技巧

    xmanager Passive可以在仅登陆ssh字符界面的情况下传输图形 xff0c 为很多开发者所喜爱 有一用户因需要调整了防火墙 xff0c xmanger passive功能便无法正常使用了 xff0c 我们该如何处理这个问题呢 到
  • liteos内核驱动和linux,移植RTOS必备基础知识

    1 基础知识 移植内核对技术的要求比较高 比较细 1 1 单片机相关的知识 栈的作用 加载地址 链接地址 重定位 几个简单的硬件知识 串口 定时器 中断的概念 1 2 Linux操作相关的知识 Linux常用命令 简单的脚本 xff1a 脚
  • matlab subs什么意思,什么是matlab subs函数?

    matlab中subs 是符号计算函数 xff0c 详细用法可以在Matlab的Command Windows输入 xff1a help subs subs 函数表示将符号表达式中的某些符号变量替换为指定的新的变量 xff0c 常用调用方式
  • 虚拟机linux装无线网卡驱动,linux无线网卡驱动安装

    环境 在笔记本里的虚拟机10 0版本 xff0c centos 6 5 无线网卡fast fw300um 第一步要查看芯片 lsusb 当你得到芯片之后接下来查看内核 xff0c 如果内核已经有芯片模块就不用再装了 xff0c 如果不支持的
  • 使用Altium Designer 20绘制双层板以及四层板

    直接入正题 1 按照正常的绘制双层板的方式新建工程文件 xff0c 加入原理图和PCB文件 xff08 如果会绘制双层板请直接看第二步 xff09 xff08 1 xff09 新建工程文件 xff08 2 xff09 选择工程类型 xff0
  • 1.1 Ubuntu18.04 ROS tcp/ip Server通信实现

    Ubuntu18 04 ROS tcp ip Server通信实现 此小节介绍tcp ip Server收发数据 xff0c 并将截取到底信息通过话题方式发布出去 下一节介绍Ubuntu18 04 ROS tcp ip client通信实现
  • 1.2 Ubuntu18.04 ROS tcp/ip Client通信实现

    Ubuntu18 04 ROS tcp ip Client通信实现 此小节介绍tcp ip Client收发数据 xff0c 测试平台为为Ubuntu18 04 与Windows系统上的网络调试助手进行通信测试 xff0c 调试助手采用的有
  • 使用Gazebo对PX4飞控进行SITL仿真

    在仿真之前 xff0c 首先需要搞清楚每个模块所代表的含义 xff0c 在这个操作中扮演什么角色 Gazebo xff1a 可以理解成对我们实际飞行物理环境的一个仿真 QGC xff1a 地面站 xff0c 不用多说 Firmware xf
  • GitLab统计代码量

    gitlab官方文档 xff1a https docs gitlab com ee api index html 1 生成密钥 登录gitlab xff0c 编辑个人资料 xff0c 设置访问令牌 2 获取当前用户所有可见的项目 接口地址
  • 【树莓派】(2)网络连接、IP设置、屏幕大小设置、VNC安装与配置

    目录 1 网络连接 1 1有线网连接 SSH协议 1 2 无线网连接 VNC 方法1 xff1a 不能联网 方法2 xff1a 能联网 2 VNC安装与配置 3 IP WiFi配置 4 屏幕大小 屏幕黑屏时间设置 1 网络连接 分为有屏幕和
  • Linux服务配置 配置VNC远程桌面

    一 VNC简介 VNC Virtual Network Console 是虚拟网络控制台的缩写 它 是一款优秀的远程控制工具软件 xff0c 由著名的 AT amp T 的欧洲研究实验室开发的 VNC 是在基于 UNIX 和 Linux 操
  • 异常检测 and 自编码器(2)

    文章目录 前言一 自编码器用于异常检测的网址推荐1 自编码器AutoEncoder解决异常检测问题2 基于自编码器的时间序列异常检测算法3 深度学习实现自编码器Autoencoder神经网络异常检测心电图ECG时间序列 总结 前言 上篇文章
  • python树莓派3控制蜂鸣器_树莓派3 modelB型 连接HC-SR501人体红外感应模块和蜂鸣器模块...

    连接前准备 树莓派3 modelB型一个 HC SR501传感器一只 低电平蜂鸣器模块 有源 即接上电就会响 xff0c 低电平触发 母对母杜邦线三根 实物图如下 xff1a 博主连接的不是特别美观 两个传感器的连接图分别如下 HC SR5
  • git submodule 使用教程

    1 submoude 介绍 xff08 1 xff09 项目很大参与开发人员多的时候 xff0c 需要将各个模块文件进行抽离单独管理 xff08 2 xff09 使用git submodule来对项目文件做成模块抽离 xff0c 抽离出来的

随机推荐