GAN的入门与实践

2023-11-19

作者:Double_V

  编辑:龚 赛 



PART

01 GAN 简介


引言


生成对抗网络(Generative Adversarial Nets,GAN)是由open ai研究员Good fellow在2014年提出的一种生成式模型,自从提出后在深度学习领域收到了广泛的关注和研究。目前,深度学习领域的图像生成,风格迁移,图像变换,图像描述,无监督学习,甚至强化学习领域都能看到GAN 的身影。GAN主要针对的是一种生成类问题。目前深度学习领域可以分为两大类,其中一个是检测识别,比如图像分类,目标识别等等,此类模型主要是VGG, GoogLenet,residual net等等,目前几乎所有的网络都是基于识别的;另一种是图像生成,即解决如何从一些数据里生成出图像的问题,生成类模型主要有深度信念网(DBN),变分自编码器(VAE)。而某种程度上,在生成能力上,GAN远远超过DBN、VAE。经过改进后的GAN足以生成以假乱真的图像。本文将首先介绍一些GAN 的原理和公式推导,另外会详细给出GAN生成图像的Tensorflow的实现,基于python语言。


PART

02 GAN 原理


生成类

GAN主要解决的是生成类问题,即如何从一段任意的随机数中生成图像。假设给定一段100维的向量X{x1, x2,…, x100 }作为网络的输入,其中x是产生的随机数,一般按照高斯分布或者均匀分布产生,GAN通过对抗训练的方式,可以生成清晰的图像,这个过程是通过GAN不断模拟训练集中图像的像素分布来实现的。看完下文GAN的原理后或许你会对这个过程有一个清晰的认识。


图1 


首先,附上一张GAN的网络流程图,如图1所示。不同于以往的判别网络模型,GAN包括两个网络模型,一个生成模型G(generator)和一个判别模型D(discriminator),其中D就是识别检测类模型中经常使用的网络。GAN的大概流程是,G以随机噪声作为输入,生成出一张图像G(z),暂且不管生成质量多好,然后D以G(z)和真实图像x作为输入,对G(z)和x做一个二分类,检测谁是真实图像谁是生成的假图像。D的输出是一个概率值,比如G(z)作为输入时D输出0.15,那么代表D认为G(z)有15%的概率是真图像。然后G和D会根据D输出的情况不断改进自己,G提高G(z)和x的相似度,尽可能的欺骗D,而D则会通过学习尽可能的不被G欺骗。二者相当于是做一个极大极小的博弈过程,称为零和博弈。可以用一个简单的例子描述他们之间的过程,我们把G想象成制造假币的团伙,视D为警察,G不断产生假币,而D任务就是从真钱币中分辨出G的假币,刚开始时,G没有经验,制造的假币太假,D很容易就能分辨出来,所以G不断改进自己的技术,产生的假币越来越真实,D可能就没有那么容易判别出真假了,所以D也根据自己的情况不断改进自己,经过很多次这样的循环之后,G产生的假币足以以假乱真了,D很难分出真假。对应到图像生成上,此时G足以生成出一般的分类神经网络分辨不出真假的图像了,G从而获得了生成图像的能力。

与传统神经网络训练不一样的且有趣的地方,就是训练生成器的方法不同,生成器参数的更新来自于D的反传梯度。生成器一心想要“骗过”判别器。使用博弈理论分析技术,可以证明这里面存在一种纳什均衡。



这里就是他们的损失函数定义,实际上是一个交叉熵,判别器的目的是尽可能的令D(x)接近1,令D(G(z))接近0,所以D主要是最大化上面的损失函数,G恰恰相反,他主要是最小化上述损失函数。

训练过程:


(图2)



图2展示了GAN训练的伪代码,首先在迭代次数范围内,首先对z和x采样一个批次,获得他们的数据分布,然后通过随机梯度下降的方法先对D做k次更新,之后对G做一次更新,这样做的主要目的是保证D一直有足够的能力去分辨真假。实际在代码中我们可能会多更新几次G只更新一次D,不然D学习的太好,会导致训练前期发生梯度消失的问题。


平衡点存在的证明


在求平衡点之前,我们先做一个数学假设,即G固定情况下D的最优形式,然后根据D的最优形式再去观察G最小化损失函数的问题。

假设在G固定的条件下,并将损失函数化为如下简单形式:



D的目标就是最大化L,我们可以通过对L求导,并令导数为0,计算出L取最大值时y的取值如下:



所以,换为原来的式子D的最优解形式为:



到这里我们得出了结论,当G固定时,D的最优形式是上面形式。

接下来我们求一下D最优时,G最小化损失函数到什么形式才能达到二者相互博弈的平衡点。

带入到损失函数里面后,损失函数可以写为如下形式:



这时观察到,上面式子仍然是一个交叉熵也称KL散度的形式,KL散度通常用来衡量分布之间的距离,它是非对称的。同样还有另一个衡量数据分布距离的散度--JS散度,他们之间有如下关系。




不过JS散度有一个很重要的性质就是总是大于等于0的,当且仅当 P1=P2上面的式子取得最小值0,

所以我们可以将C(G)写成JS散度的形式:



也即是当且仅当Pg=Pdata时,C(G)取得最小值-log(4),也即是D最优时,G能将损失函数最小化到-log(4),最小点处Pg=Pdata。即真实数据的分布和生成数据的分布相等。

分析到这里,直观上也很好理解了,Pg=Pdata意味着此时D恰好等于0.5,就是D有一半的概率认为D(G(z))是真的数据,有一半概率认为是假的数据,这不就和猜硬币正反面一样嘛。也说明了此时G生成的数据足以以假乱真。

到这里,GAN的原理和数学推导就介绍完了,理论上说明了GAN只要循规蹈矩的训练,G就可以完美的模拟数据分布并生成真实的图像,但是我们做数学推导的时候为了证明方便做了一些假设,实际上并不是这样,GAN存在训练困难、梯度消失、模式崩溃的问题,这些问题在这里不做重点介绍。


PART

03 GAN 实现


代码演示


首先,建立一个train.py文件,在文件里建立一个名为Train的类,在类的初始化函数里进行一些初始化:



Self.build_model()函数用来存放构建流图部分的代码,下面会介绍,其他初始化的都是一些简单的参数。

下面先介绍生成器和判别器的网络:



生成器传进去三个参数,分别是名字,输入数据,和一个bool型状态变量reuse,用来表示生成器是否复用,reuse=True代表网络复用,False代表不复用。

生成器一共包括1个全连接层和4个转置卷积层,每一层后面都跟一个batchnorm层,激活函数都选择relu。其中fc(),deconv2d()函数和bn()函数都是我们封装好的函数,代表全链接层,转制卷积层,和归一化层,其形式如下:



全连接层fc的输入参数value指输入向量,output_shape指经过全连接层后输出的向量维度,比如我们生成器这里噪声向量维度是128,我们输出的是4*4*8*64维。



其中Ksize指卷积核的大小,outshape指输出的张量的shape,sted是一个bool类型的参数,表示用不同的方式初始化参数

bn()函数我是直接放在了train的类里面,其形式如下:



我们都希望权重都能初始化到一个比较好的数,所以这里我没有直接用固定方差的高斯分布去初始化权重,而是根据每一层的输入输出通道数量的不同计算出一个合适的方差去做初始化。同理,我们还封装了卷积操作,其形式如下:



好了,目前已经介绍了生成器的结构和一些基本函数,下面来介绍一下判别网络,其代码如下所示:



与生成器不同的是,我们使用leakrelu作为激活函数,



这些函数的定义都是放在了layer.py文件里,




这里有两个GAN可供选择,DCGAN 和WGAN-GP,他们唯一不同的地方是损失函数的计算不同,网络结构都是一样的,二者都是GAN的改进版,WGAN-GP效果好更好一些,这里我们使用WGAN-GP。DCGAN训练的时候容易遇到训练不稳定的问题。

 

到这里我们已经介绍完了所有的初始化过程,接下来就是训练数据的提取和网络的训练部分了,训练数据我们使用cele名人数据集,一共20万张图像左右,数据集里的图像size并不是很一致,我们可以使用一小段代码把图像的人脸截取下来,并resize到64*64大小。

代码如下:



把数据集下载下来后解压到img_align_celeba文件夹里面,然后运行face_detec.py就可以了,截取下来的图像会放到64_crop文件夹里,本来有20万张图像的,截取过后就剩15万张了。

 

下面就是训练部分了,首先是读取数据,load_data()函数每次会读取一个batch_size的数据作为网络的输入,在训练过程中,我们选择训练一次D训练两次G,而不是训练多次D之后训练一次G,不然容易发生训练不稳定的问题,因为D总是学的太好,很容易就判别出真假,所以导致G不论怎么改进都没有用,有些太打击G的造假积极性了。



Plot()函数会每训练100步后绘出网络loss的变化图像,是另外封装的函数

同时我们选择每训练400步生成一张图像,看一下生成器的效果。

load_data()函数我们并没有使用队列或者转化为record文件读取,这样的方式肯定会快一些,读取图像我们使用scipy.misc 来读取,

具体是import scipy.misc as scm



可以看到,我们首先对所有的图像做一个排序,返回一个列表,列表里存放的是每个图像的位置索引,这样做就是每次将一个batch_size的数据读到了内存里,读取的数据做了一个归一化操作,我们选择归一化到[-0.5,+0.5]。

 

接下来就是展示结果的时候了,其中训练过程loss的变化如下所示:




由图可见,经过一次比较大的震荡之后,网络就收敛的比较好了。

接下来是展示生成结果了:

我测试的时候设置了bach_size是16:

训练1epoch的时候是这样子的:



训练一段时间后:



再往后训练效果看上去反而差了一些,而且明显没有学习到眼镜的特征(最后一行第二个)估计是数据集里眼镜比较少,GAN学习不到足够的特征,眼睛鼻子嘴巴学习的还是很好的。



训练失败的结果:




PART

0结束


总结


下面谈一谈我训练GAN的感受,GAN是在是太难训练了,即使是使用WGAN,WGAN-GP,还是遇到了训练困难的问题,以上这些结果都是我做了好几次实验得出来的结果,有些实验中间得到的生成结果其实是惨不忍睹的,就像是下面这样,我总结了一部分原因,一个原因是网络结构太简单,我本次使用的网络是几年前流行的DCGAN的网络结构,有很大的改进空间,现在基本上用的不多了,我也试了BEGAN,不得不说BEGAN是真好训练,只要写好代码就让他自己跑去吧,基本上不会出问题,而且效果还很好;另一个原因是优化器的选择和学习率等超参数的设置。设置好的超参数对GAN的训练是很有帮助的,至于优化器,尽量不要选择SGD,因为GAN的平衡点是一个鞍点,鞍点附近梯度几乎为0,使用梯度的优化方法很难收敛到最优点,另外就是SGD训练震荡,很容易引起训练不稳定。理论上是这样,实际的问题比这复杂的多。




1.全面直观认识深度神经网络

2.机器学习实战——LBP特征提取

3.RNN入门与实践

4.Logistic回归实战篇之预测病马死亡率(三)




扫描个人微信号,

拉你进机器学习大牛群。

福利满满,名额已不多…

80%的AI从业者已关注我们微信公众号

       

       




本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

GAN的入门与实践 的相关文章

  • 【语义分割】--SegNet理解

    原文地址 SegNet 复现详解 http mi eng cam ac uk projects segnet tutorial html 实现代码 github TensorFlow 简介 SegNet是Cambridge提出旨在解决自动驾
  • mybatis自动生成@Table、@Column、@Id注解

    在pom xml中添加如下插件以及插件相关的依赖
  • kvm之多网卡队列开启设置

    背景 目前基于dpdk数据平面开发套件的应用越来越多 而dpdk对于上层应用运行时 服务的进程数多于1时 要求网卡支持多队列 否则项nginx这种多进程应用程序只能再默认配置下运行 只能启动一个worker 要求 在kvm虚机中将网卡设置支
  • 模拟器提示关闭 hyper-V,但 hyper-V实际上并没有开启

    这个问题是windows系统问题导致无法使用BlueStacks 按下win R键打开执行窗口 输入regedit命令 打开注册表找到位置 HKEY LOCAL MACHINE SYSTEM CurrentControlSet Contro
  • 语义分割该如何走下去?

    作者 立夏之光 链接 https www zhihu com question 390783647 answer 1223902660 来源 知乎 著作权归作者所有 商业转载请联系作者获得授权 非商业转载请注明出处 做过一年多语义分割 来回
  • Java实现MySQL图片存取

    Reference Java实现对Mysql的图片存取操作 java 字节流读取图片 字符流读取 二进制读取 mysql BLOB字段类型用法介绍 Notes Java对图片的读写就跟其它文件一样的 但要用字节流而不用字符流 MySQL中各
  • 《Ansible语法篇:剧本对象关键字之when》

    一 前言 在ansible playbook中 也可以像其他编程语言一样进行条件判断 循环等流程控制 除此之外 还可以控制task的执行结果 在ansible中 可以通过when语句来执行条件判断 只有符合条件 才会执行对应的task wh
  • 【计算机视觉】华为天才少年谢凌曦:关于视觉识别领域发展的个人观点!

    文章目录 一 前言 二 CV的三大基本困难和对应研究方向 三 以下简要分析各个研究方向 3 1 方向1a 神经网络架构设计 3 2 方向1b 视觉预训练 3 3 方向2 模型微调和终身学习 3 4 方向3 无限细粒度视觉识别任务 四 在上述
  • EasyCode代码模板-适用于mybatis-plus 的项目中

    下面的模板适用于mybatis plus 的项目中 pojo类 面的模板适用于mybatis plus 的项目中 导入宏定义 define vm 保存文件 宏定义 save pojo java 包路径 宏定义 setPackageSuffi
  • Android10.0 os定制化系列讲解导读

    一 前言 本专栏主要是作者本人在10 0 frameworks定制化实战功能系列的解读 在从事几年的frameworks定制化功能的经验的积累 开发过平板 广告机 会议机 车机等一系列系统上层定制的功能性开发 写博客的目的 一方面是整理自己
  • Centos 8 替换镜像源

    1 替换 1 1 备份 mkdir etc yum repos d bak mv etc yum repos d etc yum repos d bak 1 2 下载 curl o etc yum repos d CentOS Base r
  • 【100%通过率 】【华为OD机试真题c++ /python】寻找符合要求的最长子串【 2022 Q4 A卷

    华为OD机试 真题 点这里 华为OD机试 真题考点分类 点这里 知识点双指针 时间限制 1s 空间限制 256MB 限定语言 不限 题目描述 给定一个字符串 s 找出这样一个子串 1 该子串中的任意一个字符最多出现2次 2 该子串不包含指定
  • SQL入门书籍内容汇总

    转头一晃 SQL入门书籍看完了 并且在画思维导图和整理笔记的过程中又一次加深了印象 不过也仅仅停留在课本层面上 不进行实际操作终将不知道你有没有学会如何运用 当然肯定不会自己创建一些数据库了 这个事不用质疑的 不过却可以读懂里面的一些用法
  • 【RK3399】I3399烧写Android系统详解

    00 目录 文章目录 00 目录 01 驱动安装 02 Android镜像文件烧写 03 问题讨论 04 附录 01 驱动安装 1 1 没有安装驱动的时候 显示感叹号 1 2 解压DriverAssitant v5 1 1 zip 1 3
  • 展锐8541芯片CPU推理MNN模型加速的若干问题

    一 在只有CPU的嵌入式设备上部署AI模型时 可以采取以下方法来提高模型的运行速度 1 量化模型 将浮点数模型转换为定点数模型 可以减少模型的计算和存储需求 从而提高模型的运行速度 2 剪枝模型 通过删除模型中不必要的连接和神经元 可以减少
  • 如何安装虚拟机?安装虚拟机的详细步骤

    1 下载虚拟机软件 首先 在官方网站上下载需要的虚拟机软件 如VMware VirtualBox等 注意软件版本的兼容性 最好选择最新版本 2 安装虚拟机软件 下载完成后 双击安装文件 按照提示完成安装 期间需要设置虚拟机软件的安装路径等信
  • C语言实现推箱子小游戏

    一 设计目的 用简单的C语言知识制作的推箱子游戏 通过上下左右键将所有箱子移动到目标位置 2 让我们更好地了解和巩固C语言知识 并实际运用 同时运用一些不太常见的知识点 二 功能描述 1 模块功能 本程序可分为初始界面 进行游戏 判定通关三
  • C++ list——push_back()与insert()

    push back 是把插入元素直接放入链表结尾 不多表述 insert 是把元素插入指定位置 摘自MSDN IDE VS2012 Parameters Parameter D
  • 使用UncaughtExceptionHandler进行未知异常得捕获

    UncaughtExceptionHandler UncaughtExceptionHandler使用场景 Thread类源代码 UncaughtExceptionHandler使用代码 UncaughtExceptionHandler使用
  • 场景法

    场景法 通过运用场景来对系统的功能点或业务流程的描述 从而提高测试效果 场景法一般包含基本流和备用流 从一个流程开始 通过描述经过的路径来确定的过程 经过遍历所有的基本流和备用流来完成整个场景 基本流 通过一个正确的事件流实现正确流程 备选

随机推荐

  • chatgpt赋能python:如何用Python实现抢购?

    如何用Python实现抢购 Python是一种灵活多样的编程语言 可以用它来完成各种任务 其中之一就是抢购 在电商大促销的节日 抢购商品通常需要竞争非常激烈 但是使用Python编写抢购脚本可以让您获得更高的成功率 以下是一些建议 通过Py
  • ①GD32Keil编译环境搭建及编译Demo

    进入 兆易官网 下载对应芯片的演示套件 下载解压后文件内容如下 安装keil5 我的keil5 下的pack包 打开demo包下的一个例程 提示如下
  • vue-router 2.0 常用基础知识点之router.push()

    router push location 除了使用
  • 红帽rhce认证考试科目有哪些?

    红帽RHCE认证考试主要考察的科目包括 RH124 主要考察Linux基础 文件和目录管理 用户和组管理 文件和目录权限管理 进程管理 系统服务 网络配置 日志分析等内容 RH134 主要考察自动化安装 文件编辑 任务计划 系统进程优先级管
  • 在Linux中安装Cmake过程中,遇到有关于openSSL的问题

    在Linux中安装Cmake过程中 遇到有关于openSSL的问题 以下是个人在安装cmake的过程中遇到的一些问题 如有什么错误之处欢迎各位大佬留言 共同进步 提示 Could Not Find OpenSSL try to set th
  • 【CV中的Attention机制】模块梳理合集

    文章目录 0 总述 1 SENet CVPR18 2 SKNet CVPR19 3 CBAM ECCV18 BAM BMVC18 scSE MICCAI18 4 Non Local Network CVPR19 5 GCNet ICCVW1
  • 手把手教你爬取并下载英雄联盟所有英雄皮肤高清大图

    利用requests和urlretrieve爬取并下载英雄联盟所有英雄皮肤高清大图 不知道屏幕前的你是不是也是一名loler 最近学习爬虫 印象中以前看过爬取英雄联盟的帖子 所以也就自己试了试 结果很是满意 先上效果图 亲女儿拉克丝 下面开
  • IO多路复用--[select

    因为在简历上写了netty的项目 因此还是将网络底层的那点东西搞清楚 首先希望明确的是 BIO NIO IO多路复用这是不同的东西 我会在本文中详细讲出来 本文参考资料 JAVA IO模型 IO多路复用 select poll epoll介
  • SpringBoot2.2.X整合ElasricSearch7.8

    这里默认大家已经掌握es基础语法 es版本为7 8 pom
  • pikachu靶场&RCE&文件包含&上传下载&越权(四)

    文章目录 RCE 概述 RCE PING RCE EVEL File Inclusion 文件包含漏洞 概述 文件下载漏洞 概念 文件上传漏洞 概述 前端页面检查 client check MIME TYPE漏洞 getimagesize
  • openwrt恢复出厂设置有两种方法

    1 输入以下指令 firstboot mtd r erase rootfs data 2 输入以下指令 mount root firstboot reboot f
  • spring Bean 生命周期BeanNameAware, BeanFactoryAware, ApplicationContextAware, InitializingBean接口详解

    继续接着上一篇完成后续接口的解析 还是借用上一篇引用大佬的文章 https www jianshu com p 1dec08d290c1 第二篇spring Bean 生命周期及BeanPostProcessor和Instantiation
  • Oracle dba_ts_quotas

    修改用户表空间配额 用户bpx1默认表空间是bpx1 select default tablespace from dba users where username in BPX1 SQL gt alter user bpx1 quota
  • 人工智能巨头碰撞——埃隆·马斯克推出xAI挑战OpenAI的统治地位

    目录 前言 XAI 的推出 什么是XAI 它将聚焦于什么 一 反AI斗士 马斯克进军AI 你怎么看 二 回顾上半年的 百模大战 中国的AI产业怎么样了 三 AI大模型这把火 还能怎么烧 其它资料下载 前言 北京时间7月13日凌晨 马斯克在T
  • R语言中用于计算Rsquare的包rsq

    文章目录 理论介绍 线性模型情形 广义线性模型情形 函数介绍 rsq 的介绍 实例 rsq partial 的介绍 相关文献 pcor 函数介绍 vresidual 函数介绍 实例 理论介绍 线性模型情形 R squared 值范围 0 1
  • 软件工程期末试题及答案(史上最全)

    软件工程期末试题及答案 文章目录 软件工程期末试题及答案 一 填空题 二 选择题 三 判断题 四 简答题 五 分析题 六 画图题 一 填空题 在信息处理和计算机领域内 一般认为软件是 文库 程序 文档 和 数据 数据流图的基本组成部分有 数
  • 双机热备VRRP协议介绍及其工作原理

    VRRP协议 为了更好的解决由于网关故障引起的网络中断问题 网络开发者提出了VRRP协议 VRRP协议是一种容错协议 他保证当链路中路由器出现故障的时候 由备份路由器自动替代路由器进行工作 从而保证网络通信的持续性和可靠性 虚拟路由器 由一
  • eNSP:ospf相关实验

    一 实验要求 二 实验步骤 1 建设如下图拓扑并划分网段 2 配置R1 R4的接口和回环地址 R1 r1 int g0 0 0 r1 GigabitEthernet0 0 0 ip add 192 168 1 1 27 r1 int l 0
  • ThinkPHP 关闭调试模式

    ThinkPHP有专门为开发过程而设置的调试模式 开启调试模式后 会牺牲一定的执行效率 但方便了不少 同时除错功能也非常值得 开启调试模式 config app php 文件 return 应用名称 app name gt 应用地址 app
  • GAN的入门与实践

    作者 Double V 编辑 龚 赛 PART 01 GAN 简介 引言 生成对抗网络 Generative Adversarial Nets GAN 是由open ai研究员Good fellow在2014年提出的一种生成式模型 自从提出