tensorflow 1.01中GAN(生成对抗网络)手写字体生成例子(MINST)的测试

2023-10-29

为了更好地掌握GAN的例子,从网上找了段代码进行跑了下,测试了效果。具体过程如下:


代码文件如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
from skimage.io import imsave
import os
import shutil




img_height = 28
img_width = 28
img_size = img_height * img_width


to_train = True
to_restore = False
output_path = "output"


# 总迭代次数500
max_epoch = 500


h1_size = 150
h2_size = 300
z_size = 100
batch_size = 256


# generate (model 1)
def build_generator(z_prior):
    w1 = tf.Variable(tf.truncated_normal([z_size, h1_size], stddev=0.1), name="g_w1", dtype=tf.float32)
    b1 = tf.Variable(tf.zeros([h1_size]), name="g_b1", dtype=tf.float32)
    h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1)
    w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], stddev=0.1), name="g_w2", dtype=tf.float32)
    b2 = tf.Variable(tf.zeros([h2_size]), name="g_b2", dtype=tf.float32)
    h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)
    w3 = tf.Variable(tf.truncated_normal([h2_size, img_size], stddev=0.1), name="g_w3", dtype=tf.float32)
    b3 = tf.Variable(tf.zeros([img_size]), name="g_b3", dtype=tf.float32)
    h3 = tf.matmul(h2, w3) + b3
    x_generate = tf.nn.tanh(h3)
    g_params = [w1, b1, w2, b2, w3, b3]
    return x_generate, g_params


# discriminator (model 2)
def build_discriminator(x_data, x_generated, keep_prob):
    # tf.concat
    x_in = tf.concat([x_data, x_generated],0)
    w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32)
    b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32)
    h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob)
    w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32)
    b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32)
    h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob)
    w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32)
    b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32)
    h3 = tf.matmul(h2, w3) + b3
    y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None))
    y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None))
    d_params = [w1, b1, w2, b2, w3, b3]
    return y_data, y_generated, d_params



def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5):
    batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5
    img_h, img_w = batch_res.shape[1], batch_res.shape[2]
    grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)
    grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)
    img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)
    for i, res in enumerate(batch_res):
        if i >= grid_size[0] * grid_size[1]:
            break
        img = (res) * 255
        img = img.astype(np.uint8)
        row = (i // grid_size[0]) * (img_h + grid_pad)
        col = (i % grid_size[1]) * (img_w + grid_pad)
        img_grid[row:row + img_h, col:col + img_w] = img
    imsave(fname, img_grid)




def train():
    # load data(mnist手写数据集)
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    
    x_data = tf.placeholder(tf.float32, [batch_size, img_size], name="x_data")
    z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")
    keep_prob = tf.placeholder(tf.float32, name="keep_prob")
    global_step = tf.Variable(0, name="global_step", trainable=False)


    # 创建生成模型
    x_generated, g_params = build_generator(z_prior)
    # 创建判别模型
    y_data, y_generated, d_params = build_discriminator(x_data, x_generated, keep_prob)


    # 损失函数的设置
    d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))
    g_loss = - tf.log(y_generated)


    optimizer = tf.train.AdamOptimizer(0.0001)


    # 两个模型的优化函数
    d_trainer = optimizer.minimize(d_loss, var_list=d_params)
    g_trainer = optimizer.minimize(g_loss, var_list=g_params)


    init = tf.initialize_all_variables()


    saver = tf.train.Saver()
    # 启动默认图
    sess = tf.Session()
    # 初始化
    sess.run(init)


    if to_restore:
        chkpt_fname = tf.train.latest_checkpoint(output_path)
        saver.restore(sess, chkpt_fname)
    else:
        if os.path.exists(output_path):
            shutil.rmtree(output_path)
        os.mkdir(output_path)




    z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)


    steps = 60000 / batch_size
    for i in range(sess.run(global_step), max_epoch):
        for j in np.arange(steps):
#         for j in range(steps):
            print("epoch:%s, iter:%s" % (i, j))
            # 每一步迭代,我们都会加载256个训练样本,然后执行一次train_step
            x_value, _ = mnist.train.next_batch(batch_size)
            x_value = 2 * x_value.astype(np.float32) - 1
            z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
            # 执行生成
            sess.run(d_trainer,
                     feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
            # 执行判别
            if j % 1 == 0:
                sess.run(g_trainer,
                         feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
        x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})
        show_result(x_gen_val, "output/sample{0}.jpg".format(i))
        z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
        x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})
        show_result(x_gen_val, "output/random_sample{0}.jpg".format(i))
        sess.run(tf.assign(global_step, i + 1))
        saver.save(sess, os.path.join(output_path, "model"), global_step=global_step)




def test():
    z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")
    x_generated, _ = build_generator(z_prior)
    chkpt_fname = tf.train.latest_checkpoint(output_path)


    init = tf.initialize_all_variables()
    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(init)
    saver.restore(sess, chkpt_fname)
    z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
    x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})
    show_result(x_gen_val, "output/test_result.jpg")




if __name__ == '__main__':
    if to_train:
        train()
    else:
        test()


按照500次迭代,每次迭代产生一张手写体图片,然后进行判别反馈,这样持续下去,可以看到不同迭代次数的效果。

(第1张)



(第2张)



第10张



第24张


第50张


第140张



第256张



第500张


本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tensorflow 1.01中GAN(生成对抗网络)手写字体生成例子(MINST)的测试 的相关文章

随机推荐

  • 分享一个很容易实现的某大学的结构光源码【DIY自己的三维扫描仪】

    这个分享是一个大学做的结构光的代码 用一个usb相机 再加一个投影仪 完全按照说明配置opencv QT 还有一些库 只要配置好了 很容易跑通 代码和UI也很好 也可以优化成自己想要的那种 里面资料很全 非常适合不同高度的人来学习 看完觉得
  • ES6之map()方法

    map 方法 map 映射 即原数组映射成一个新的数组 map方法接受一个新参数 这个参数就是将原数组变成新数组的映射关系 function myfun 1 arr var array arr map item gt array push
  • ubuntu18.04LTS安装2080ti显卡驱动

    NVIDIA驱动安装 首先需要去nvidia官网下载对应的驱动 run文件 我下载的最新驱动450 57 一 禁用开源驱动 sudo gedit etc modprobe d blacklist conf 在文件最后添加两行 保存并关闭文件
  • 国内90%以上的 iOS 开发者,对 APNs 的认识都是错的

    2016 04 26 06 39 编辑 cocopeng 分类 iOS开发 来源 iOS程序犭袁的简书 本文为投稿文章 作者 iOS程序犭袁 博客 前言 APNs 协议在近两年的 WWDC 上改过两次 2015年12月17日更是推出了革命性
  • 报错unexpected ‘list‘ (T_LIST), expecting identifier (T_STRING)

    报错unexpected list T LIST expecting identifier T STRING 应该是控制器方法名称与系统内置函数名重复 修改方法名后问题暂时解决
  • 关于Unity3d中OnGUI的用法-显示对话框(刚开始学持续更新)

    系统调用 OnGUI 来渲染和处理 GUI 事件 这意味着每帧可能会多次调用 OnGUI 实现 每个事件调用一次 有关 GUI 事件的更多信息 请参阅 Event 参考 如果 MonoBehaviour 的 enabled 属性设置为 fa
  • Python 中的__main__和__name__

    用 C 族语言 C C Java C 等 编写的程序 需要main 功能来指示执行的起点 另一方面 在 Python 中 没有main 函数的概念 因为它是一种基于解释器的语言 同样可以在交互 Shell中使用 扩展名为 py的 Pytho
  • 【华为OD机试 2023】快递投放问题(C++ Java JavaScript Python)

    华为od机试题库 华为OD机试2022 2023 C Java JS Py https blog csdn net banxia frontend category 12225173 html 华为OD机试2023最新题库 更新中 C Ja
  • css实现下拉菜单

    这次css实现下拉菜单是仿照小米官网的一个小效果 如下 css实现下拉菜单 主要用到的知识点有用到伪元素来实现小箭头的点缀 还有transition属性实现下拉菜单过渡出现 不是直接崩出来的那种 提高用户体验 可以看到这个二维码出来的时候还
  • java什么是面向过程_Java 基础(一) -- 面向对象

    面向过程和面向对象的区别 什么是面向过程 pop 面向过程 Process oriented programming 是一种以事件为中心的编程思想 就是分析出解决问题所需要的步骤 然后用函数把这些步骤全部实现 然后按照顺序依次调用 什么是面
  • [NOIP1998 普及组]幂次方

    NOIP1998 普及组 幂次方 题目描述 任何一个正整数都可以用 2 2 2 的幂次方表示 例如 137 27 23 2 0 同时约定方次用括号来表示 即 a b a b
  • 智能小车运行及测速原理

    光电码盘测速原理 如何求解小车速度参数 大小与方向 测量速度方向的方法 根据A B两相脉冲的超前滞后关系确定电机旋转方向 假定A相超前于B相时 为电机正方向 则当A相滞后于B相 当前电机为反向旋转 普通测量速度大小的方法 单位时间内采集的脉
  • MySQL8.0修改用户密码验证

    问题 MySQL升级到8 0 客户端或者连接器没有升级到8 0 连接时出现吧报错 Authentication plugin caching sha2 password is not supported 查看当前用户信息 mysql gt
  • vue-axios框架详解

    axios框架详解 网络请求模块的选择 axios 选择什么网络模块 vue中发送网络请求有非常多的方式 那么在开发中如何悬着呢 选择一 传统的Ajax是基于XMLHttpRequest XHR 为什么不用Ajax呢 一 配置和调用方式非常
  • 使用vs2013开发过程中,调试时很慢的解决办法

    网上看到的解决办法 项目 配置属性 C C 代码生成 启用最小重新生成 Yes GM 项目 配置属性 C C 常规 调试信息格式 程序数据库 Zi 项目 配置属性 连接器 常规 启用增量链接 是 自己发现的问题 我的问题是在开发过程中 直接
  • SSL协议

    参考链接 https blog csdn net qq 38265137 article details 90112705 1 HTTP SSL HTTPS 非对称加密计算量很大 效率不如对称加密 我们打开网页最注重的是啥 是速度这点SSL
  • C++中运算符 &和&&、

    简介 是逻辑与运算符 是逻辑或运算符 都是逻辑运算符 两边只能是bool类型 与 既可以进行逻辑运算 又可以进行位运算 两边既可以是bool类型 又可以是数值类型 区别 if A B 如果 A 为 false 整个表达式就为 false 不
  • type_traits 类型萃取

    一 c traits traits是c 模板编程中使用的一种技术 主要功能 把功能相同而参数不同的函数抽象出来 通过traits将不同的参数的相同属性提取出来 在函数中利用这些用traits提取的属性 使得函数对不同的参数表现一致 trai
  • 解决Smack 提示“ Connection is not authenticated”

    在获取VCard 电子卡 信息的时候 我百度了一下 大部分的写法 如下 获取用户的vcard信息 param connection param user return throws XMPPException public static V
  • tensorflow 1.01中GAN(生成对抗网络)手写字体生成例子(MINST)的测试

    为了更好地掌握GAN的例子 从网上找了段代码进行跑了下 测试了效果 具体过程如下 代码文件如下 import tensorflow as tf from tensorflow examples tutorials mnist import