CGAN原理及tensorflow代码

2023-11-02

1.首先说明一下CGAN的意义

GAN的原始模型有很多可以改进的缺点,首当其中就是“模型不可控”。从上面对GAN的介绍能够看出,模型以一个随机噪声为输入。显然,我们很难对输出的结构进行控制。例如,使用纯粹的GAN,我们可以训练出一个生成器:输入随机噪声,产生一张写着0-9某一个数字的图片。然而,在现实应用中,我们往往想要生成“指定”的一张图片。

2.直观解决方案

在GAN上增加一个额外的输入。也就是说,以前我们的生成模型是p_g(x),现在,我们的生成模型是在一个条件c的控制下产生p_g(x|c)。而这个c就是我们用来控制模型的额外的输入。

c可以是表示我们意图的一串编码,例如我们想要做0-9的手写数字生成,则c可以是一个10维的one-hot向量。则在训练过程中,我们将这些label加入到训练数据中,从而得到一个按照我们需求产生图片的生成器。

这就是Conditional Generative Adversarial Nets最基本的想法。这里要注意的是,这个c不但附加在了生成器上,同时也附加在了判别器上,相当于给了判别器一个额外的信息:现在这个图片是以条件c生成的?还是以条件c控制下的真正的图片?

3.训练目标

原文中有这样一张图,在其他博客中也常见到

对于GAN来说,我们训练的目标是:

\mathop{\min}_{G}\mathop{\max}_{D}V(D,G)=\mathbb{E}_{\boldsymbol{x}\sim p_{\text{data}}}\left[\log D(\boldsymbol{x})\right]+\mathbb{E}_{\boldsymbol{z}\sim p_z(\boldsymbol{z})}\left[\log(1-D(G(\boldsymbol{z})))\right].

而对于Conditional的GAN来说,训练目标只需要变成:

\mathop{\min}_{G}\mathop{\max}_{D}V(D,G)=\mathbb{E}_{\boldsymbol{x}\sim p_{\text{data}}}\left[\log D(\boldsymbol{x}|\boldsymbol{y})\right]+\mathbb{E}_{\boldsymbol{z}\sim p_z(\boldsymbol{z})}\left[\log(1-D(G(\boldsymbol{z}|\boldsymbol{y})|\boldsymbol{y}))\right].

(原文中的公式有误,后面一项的判别器D中忘了加以y为条件的概率)

其实这个改动形象一些表示就是将原来只接受一个输入z的生成器变成接受两个输入(z和y),将原来只接受一个输入x的判别器变成接受两个输入(x和y)。
 

CGAN代码如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

#数据输入
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128

#返回随机值
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

#X代表输入图片,应该是28*28,但是这里没有使用CNN,y是相应的label
""" Discriminator Net model """
X = tf.placeholder(tf.float32, shape=[None, 784])
y = tf.placeholder(tf.float32, shape=[None, y_dim])
#权重,CGAN的输入是将图片输入与label concat起来,所以权重维度为784+10
D_W1 = tf.Variable(xavier_init([X_dim + y_dim, h_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
#第二层有h_dim个节点
D_W2 = tf.Variable(xavier_init([h_dim, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

theta_D = [D_W1, D_W2, D_b1, D_b2]

#D网络,这里是一个简单的神经网络,x是输入图片向量,y是相应的label
def discriminator(x, y):
    inputs = tf.concat(axis=1, values=[x, y])
    D_h1 = tf.nn.relu(tf.matmul(inputs, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)

    return D_prob, D_logit

#G网络参数,输入维度为Z_dim+y_dim,中间层有h_dim个节点,输出X_dim的数据
""" Generator Net model """
Z = tf.placeholder(tf.float32, shape=[None, Z_dim])
#权重
G_W1 = tf.Variable(xavier_init([Z_dim + y_dim, h_dim]))
G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))

G_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))

theta_G = [G_W1, G_W2, G_b1, G_b2]

#G网络
def generator(z, y):
    inputs = tf.concat(axis=1, values=[z, y])
    G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob

#噪声产生的函数
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])


def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

#生成网络,基本和GAN一致
G_sample = generator(Z, y)
D_real, D_logit_real = discriminator(X, y)
D_fake, D_logit_fake = discriminator(G_sample, y)
#优化式
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))
#训练
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)


sess = tf.Session()
sess.run(tf.global_variables_initializer())
#输出图片在out文件夹
if not os.path.exists('out/'):
    os.makedirs('out/')

i = 0

for it in range(1000000):
    if it % 1000 == 0:
        #n_sample 是G网络测试用的Batchsize,为16,所以输出的png图有16张
        n_sample = 16

        Z_sample = sample_Z(n_sample, Z_dim)#输入的噪声,尺寸为batchsize*noise维度
        y_sample = np.zeros(shape=[n_sample, y_dim])#输入的label,尺寸为batchsize*label维度
        y_sample[:, 7] = 1 #输出7

        samples = sess.run(G_sample, feed_dict={Z: Z_sample, y:y_sample})#G网络的输入

        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')#输出生成的图片
        i += 1
        plt.close(fig)
    #mb_size是网络训练时用的Batchsize,为100
    X_mb, y_mb = mnist.train.next_batch(mb_size)
    #Z_dim是noise的维度,为100
    Z_sample = sample_Z(mb_size, Z_dim)
    #交替最小化训练
    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: Z_sample, y:y_mb})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: Z_sample, y:y_mb})
    #输出训练时的参数
    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

生成效果如下:

为了方便理解,本文只用了最简单的神经网络,有时间会使用CNN重写该网络。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

CGAN原理及tensorflow代码 的相关文章

  • 【Paper】2017_事件触发机制下的多智能体领导跟随一致性

    黄红伟 黄天民 事件触发机制下的多智能体领导跟随一致性 J 计算机工程与应用 2017 53 6 29 33 文章目录 2 预备知识及问题描述2 1 代数图论2 2 领导跟随一致性 3 主要结果3 1 集中式事件触发机制下的一致性对应程序
  • 【Paper】2019_DoS/数据注入攻击下基于一致性的信息物理系统安全性研究_曹雄

    曹雄 DoS 数据注入攻击下基于一致性的信息物理系统安全性研究 D 天津大学 2019 DOI 10 27356 d cnki gtjdu 2019 003044 文章目录 第2章 拒绝服务攻击下多智能体系统安全性研究2 1 问题描述2 1
  • 【Paper】2018_Consensus of leader-following multiagent systems: A distributed event-triggered impulsiv

    Tan X Cao J Li X Consensus of leader following multiagent systems A distributed event triggered impulsive control strate
  • Paper review: Dynamic Routing Between Capsules

    Paper review Dynamic Routing Between Capsules 基本信息主要内容摘要基本思想神经科学设想routing by agreement卷积胶囊 算法和网络算法细节网络结构网络主体用重构来做正则化方法 实
  • [转载][paper]Threat of Adversarial Attacks on Deep Learning in Computer Vision: A Survey

    文章目录 摘要 深度学习是当前人工智能崛起的核心 在计算机视觉领域 xff0c 它已经成为从自动驾驶汽车到监控和安全等各种应用的主力 虽然深度神经网络在解决复杂问题方面取得了惊人的成功 通常超出了人类的能力 xff0c 但最近的研究表明 x
  • 写在Paper Reading之前

    写在Paper Reading 之前 2016年第一篇文章 xff0c 就以paper reading开头 xff0c 这段时间最少写五篇 xff0c 达到申请专栏的条件 通过申请专栏 xff0c 也可以达到监督作用 xff0c 催促自己多
  • stm32驱动微雪墨水屏1.54inch e-Paper V2

    我一起驱动墨水屏 一 墨水屏相关基础 xff08 摘自微雪官方 xff09 二 干起来PART2 配置I OPART2 底层硬件接口必要的调用函数PART3 功能函数PART4 应用函数 三 应用注意 代码下载 xff1a https do
  • 【Paper】Learning to Resize Images for Computer Vision Tasks

    From 别魔改网络了 xff0c Google研究员 xff1a 模型精度不高 xff0c 是因为你的Resize方法不够好 xff01 知乎 zhihu com paper 2103 09950v2 pdf arxiv org code
  • [paper] Hourglass

    Stacked Hourglass Networks for Human Pose Estimation Abstract Hourglass Net是一个进行人体位姿估计的卷积神经网络 也可以用在人脸关键点检测等领域 它结合了身体上的空间
  • 【图像处理】【图像去模糊】 总结

    本人最近由于做相关去卷积工作 查阅了上百篇文献 发现在这个领域 可能也是 水太深 了 并没有一篇较好的综述 现在做以下总结 只对高斯与散焦模糊的非盲去卷积领域 对于运动模糊并未做总结 但实际上除了点扩散函数的估计有区别 实际上这三类去模糊甚
  • 论文分享-Heterogeneity-Aware Cluster Scheduling Policies for Deep Learning Workloads

    前言 这篇文章是由斯坦福大学和微软研究院共同合作的 于2020年11月发表于系统类顶级会议OSDI 主要研究了不同异构硬件资源对深度学习负载的影响和集群调度策略的设计 1 摘要 专门的加速器 如gpu TPUs fpga和定制asic 已经
  • 神经辐射场 (NeRF) 概念

    神经辐射场 NeRF 概念 理论介绍 NeRF模型以其基本形式将三维场景表示为由神经网络近似的辐射场 辐射场描述了场景中每个点和每个观看方向的颜色和体积密度 这写为 F x
  • 论EI、SCI和ISTP检索论文的收录号和期刊号查询方法

    http www scitsg com Article 134240802101541 aspx 需要申请博士后进站和国家自然科学基金的朋友都知道申请博士后进站和国家自然科学基金需要填写很多申请表格 其中就需要填写所发表的EI SCI和IS
  • CGAN原理及tensorflow代码

    1 首先说明一下CGAN的意义 GAN的原始模型有很多可以改进的缺点 首当其中就是 模型不可控 从上面对GAN的介绍能够看出 模型以一个随机噪声为输入 显然 我们很难对输出的结构进行控制 例如 使用纯粹的GAN 我们可以训练出一个生成器 输
  • Chapter 2 Trajectory Indexing and Retrival

    This 26 pages paper is a bit short as a survey but a little too long for me the first day to write a papaer analysis But
  • CVPR 2017论文

    近期在看CVPR2017的文章 顺便就把CVPR2017整理一下 分享给大家 更多的 Computer Vision的文章可以访问Computer Vision Foundation open access CVPapers Machine
  • 论文写作的基本套路

    最近在写论文 写好之后给大神师兄看了看 提出了一些意见 按照师兄的意见整理出来 以供以后写作参考 博主是写的英文期刊论文 一 Abstract 一篇论文的精华都在abstract中 一片论文是否能够抓住审稿人的眼球 让审稿人有兴趣读下去 a
  • 如何写好一篇高质量的IEEE/ACM Transaction级别的计算机科学论文?

    http www zhihu com question 22790506 answer 81787300 f3fb8ead20 ea27429f8cbe31fd9183a68ccb41caa7 from timeline isappinst
  • 随笔:vscode-latex中文配置

    vscode用的久了 感觉确实比texstudio好用 question 1 vscode latex中文配置 vscode安装LaTeX Workshop Extension 默认latexmk就已经可以满足写英文paper的要求了 因为
  • 2019 SIGGRAPH paper

    20190704 Image Science 1 Hyperparameter Optimization in Black box Image Processing using Differentiable Proxies 基于可微代理的黑

随机推荐

  • onnx的VS2022和QT部署中遇到的问题GetInputName()函数报错问题

    这个问题是因为onnxruntime在1 7版本改变了函数名的原因 GetInputName 改名成了GetInputNameAllocated 在修改函数名后需要做一些小调整如下 修改前 修改后 在qt中也遇到了这个问题 修改前 修改后
  • [深入理解NAND Flash (颗粒篇) ] QLC NAND 已来未热,是时候该拥抱了?

    前言 伴随着闪存芯片的发展趋势 现如今便宜 大容量的SSD基本上都需要上QLC闪存芯片了 一时间QLC有山雨欲来之势 大容量QLC SSD的普及似乎已经触手可及 虽然现在主流是 TLC NAND 第三代 但下一代 Q L C N A N D
  • iOS开发—RunLoop详解

    随着oc语言不断迭代 苹果的API也是逐步完善 RunLoop在实际开发中应用的越来越少 但是在面试中 假如面试官问你RunLoop的相关知识了解 那就相当于面试官在问你从事iOS开发工作的真实年限问题 那么下面我们就详细了解一下RunLo
  • Linux(阿里云)禅道部署

    开源版本下载地址 底部 禅道18 1 禅道开源项目管理软件 本人选择安 Linux一键安装 csdn下载链接 https www zentao net dl zentao 18 1 ZenTaoPMS 18 1 zbox 64 tar gz
  • Windows10 安装Geant4-支持Release/Debug版本

    1 预先下载的软件 数据包 1 安装CMake 2 安装Visual Studio 可在官网安装社区版 免费 3 下载官网 https geant4 web cern ch support download 中的Source File zi
  • 为了在 Windows 11 上启用 IE ,我撸了个修复工具

    网管小贾的博客 www sysadm cc Windows 11 正式版已于前不久官宣发布了 好不好用呢 我想八成的人都是冲着尝鲜去的 所以说好用的不少 说不好用的也是大有人在 对我们来说 不管是真的好用还是真的不好用 那完全是见仁见智的个
  • 生活中哪些地方运用计算机网络,计算机网络技术在生活中应用.doc

    计算机网络技术在生活中应用 计算机网络技术在生活中应用 摘 要 近年来 计算机网络技术得以飞速发展 也在很大程度上改变了人们的生活方式 它可以说是人类发展历程中的新突破 进入二十一世纪之后 社会逐渐向着网络化的方向发展 计算机网络技术逐渐成
  • 2021水流向何处

    只要房价不涨 不用担心钱被稀释 钱不值钱 说白了就是货更加值钱了 货变贵了 这个货可以是白菜萝卜 可以是汽车 也可以是房子 汽车等工业品明显是更加不值钱 变便宜了 白菜萝卜等需要大量纯粹劳动力的货 是变贵了 但是人民工资水平的上涨能够更上它
  • springboot+mybatis+redis+thymeleaf Web项目搭建 开箱即用

    手动搭建了一个springboot mybatis redis thymeleaf的Web后台项目 因此写篇博客记录下搭建的完整过程 文章最后有完整代码地址 首先简单介绍下用到的技术框架及用途 1 springboot框架 项目主体结构 2
  • 简单递归(最大公约数,阶乘)

    include
  • Centos6.8安装glib-2.32.1

    Centos6 8安装glib 2 32 1遇到的问题及解决方法 1 glib 2 32 1下载网址 http ftp gnome org pub gnome sources glib 2 32 glib 2 32 1 tar xz 2 执
  • OpenWrt系统安全改进<三> --- Web UI密码错误控制

    OpenWrt系统安全改进 lt 二 gt 中所做的尝试 是为了增强用户登录的鉴权机制 密码输错三次就禁用用户一段时间 PAM可以实现对用户登录的控制 但是进一步操作中发现WebUI的登录并没有支持PAM 前功尽弃 了解了一下OpenWrt
  • jmeter 安装部署

    1 软件安装 1 1 Windows安装 1 1 1 软件下载 进入官网 http jmeter apache org 直接下载zip包 下载后直接解压 eg我的解压路径如下 D Program Files apache jmeter 5
  • GitHub拉取报错remote: Support for password authentication was removed on August 13, 2021

    问题描述 今天从GitHub上拉取我自己的私有仓库 结果报错说自21年8月13日后不在支持用户名密码方式验证 如图所示 解决方案 通过查看别人博客原博主以及官网阅读 得知可以通过创建个人访问令牌 personal access token
  • 【附源码】Python小游戏 ——开心消消乐

    目录 前言 开发工具 环境搭建 效果展示 选择关卡首页 游戏界面 过关 代码展示 模块导入 主函数 声音类 树类 元素类 数组类 前言 今天主要是给大家拿牌一个小游戏 开心消消乐 看看有没有小伙伴能够通过呀 开发工具 Python版本 3
  • 网络无法访问互联网是什么原因

    很多用户在使用手机或电脑连接网络时 明明可以正常连接 但却无法访问互联网 网络无法访问互联网是什么意思 无法连接到互联网是指当前只可访问本地网络的资源 没办法正常上外网 访问网页 上 QQ 微信等 网络无法访问互联网是什么原因 网络无法访问
  • Python的Logging模块

    1 日志的相关概念 日志是指记录系统或应用程序运行状态 事件和错误信息的文件或数据 在计算机系统中 日志通常用于故障排除 性能分析 安全审计等方面 日志可以记录各种信息 如系统启动和关闭时间 应用程序的运行状态 用户登录和操作记录 网络通信
  • 6.英文字母排序 (20分)

    题目内容 编写一个程序 当输入不超过 个字符组成的英文文字时 计算机将这个句子中的字母按英文字典字母顺序重新排列 排列后的单词的长度要与原始句子中的长度相同 并且要求只对 到 的字母重新排列 其它字符保持原来的状态 输入描述 一个字符串 包
  • python安装程序已停止工作_python.exe已经停止工作

    昨天 我成功地将sip pyqt4和vtk 包括python的绑定 安装在64位windows7虚拟机上 在 但是 当我执行 import vtk 操作时 会弹出一个对话框 import vtk python exe已经停止工作 在 事件查
  • CGAN原理及tensorflow代码

    1 首先说明一下CGAN的意义 GAN的原始模型有很多可以改进的缺点 首当其中就是 模型不可控 从上面对GAN的介绍能够看出 模型以一个随机噪声为输入 显然 我们很难对输出的结构进行控制 例如 使用纯粹的GAN 我们可以训练出一个生成器 输