GAN的学习 - 基础知识了解

2023-11-04

20200818 -

0. 引言

最近看到了一些论文,是GAN在密码生成(PassGAN)、DGA检测(DeepDGA)这些论文,所以希望深入了解一下GAN的内容。之前的时候,只是知道GAN是什么东西,通过训练两个神经网络,然后相互促进来实现检测的目标,不过没有深入了解过。这里根据刚刚阅读的一篇文章[1]来记录下学习的内容。

大部分内容是文章[1]的原始内容,加上自己的理解。
(20200910 增加)
关于GAN的应用,主要是在生成模型中,一般而言,实际的数据的分布会非常复杂,无法直接模拟出相应的数学公式,所以可以采用GAN的形式来进行生成[4]。而对于GAN的研究方向,也主要来自两个方面:一是利用GAN解决领域问题,另一个是研究如何让GAN更加容易训练。[5]

1. GAN的结构

不管是什么网站,只要一介绍GAN就会告诉你:GAN由两个部分组成,一个部分是生成器,一个部分是判别器。那么具体,他们是怎么连接,生成器的输入是什么,判别器的输入又是什么呢。

1.1 判别模型与生成模型

在讲解GAN之前,需要对判别算法和生成算法有一个了解。这里首先来引用文章[1]中的一个结论:

  • 判别模型学习类之间的边界
  • 生成模型构造每个类的分布
    在以往的学习过程中,大多数都是在学习判别模型,因为平时做的主要工作就是分类过程,软件是否是恶意,域名是否是DGA算法生成等等,所以对于判别模型的理解更深入。但是,从前面的描述也可以知道,生成模型要做的工作是学习到底层每个类的分布,与判别模型恰恰相反,是从标签中学习到底层数据的分布。
    (这里有一个疑问,那么GAN输入的标签中要带有原始的Label,还是真假数据的标签呢)
    (20200910答,都可以,这种属于不同形式的GAN)

1.2 GAN如何工作

GAN(原始的)包含两个神经网络,一个叫做生成器,一个叫做判别器;生成器生成新的数据实例,而判别器用于评估生成的样本是真是假。如果是以MNIST数据集的图片为例,生成器通过生成新的图片样本,判别器要判定这个生成的样本是真实的还是伪造的。在文章[1]中列举步骤中有一点很关键。

  • 生成器输入一个随机的数值(在密码生成的文章中指出,一般使用均匀分布或者正态分布),然后返回一张图片
  • 生成的图片以及从真实数据集中提取出来的图片同时输入到判别器
  • 判别器对每个输入的图片返回一个概率,0代表是假的,1代表是真的。
    上面的步骤中,回答了两个问题:
    1)生成器的输入是什么,随机的数值,那么这个随机的数值是固定维度,还是单个呢?这个需要进一步后面看代码;(20200910答,这个输入是完全的随机分布,可以是均一分布,也可以是高斯分布)
    2)判别器的输入,包括图片和Label,从上面的描述上来看,这个Label应该是真假,那么这个时候,就需要解决另一个问题,他的类别怎么传递进去。
    (20200821更新,在文章[2]提到,训练判别器时,并且同时用标签告诉它这些样本分别来自哪里,还有,我还见过一个GAN的图,他好像将类别和真假同时训练的,所以我觉得可能这部分标签到底是什么应该还是跟他的需求相关,可能有的人就是需要真假)

生成器的目的就是生成新的数据实例,例如图片,然后能够骗过判别器。所以这里整体上有两个反馈循环:

  • 判别器利用标签来进行判断,例如图片是否是伪造的
  • 生成器利用判别器返回的数据,例如前面提到的概率来进行持续改进。

1.2.1 以MNIST数据为例

Minist的数据是手写的数字图片,使用该数据在GAN上进行实验,网络结构如下:

  • 判别器网路结构将是一个卷积神经网络,通过降采样,最终识别图片,最终将生成的图片标记为是否是伪造的,并输出一个概率。
  • 生成器是判别器反过来的过程,本来卷积神经网络是降采样,但是生成器输入一个随机噪声的向量,然后将其升采样变为一个图片。
    两个神经网络都尝试着获取最优的目标函数,具体的网络结构[1]如下。
    图片来自[1]

2. 具体的代码示例(生成MNIST数据集图片)

(20200821 - 学习增加)
今天看到一些文章和实例,在构造模型时,判别器的模型中还加入了数据的类别,就不仅仅是图片是真的还是假的这种标签。不过,我个人感觉,实例的具体类别标签,应该也是能加入的,可能还是要看具体的应用。下面的例子就是仅仅判定是否是真的假的过程。
文章[2]通过完整的例子,从如何定义生成器到最后怎么评估GAN的性能的一篇文章,可能这篇文章的内容也不一定准确,但是先看看这个例子具体讲了什么。具体目录(仅仅列举有用的)如下:

  1. 如何定义和训练判别器模型
  2. 如何定义和使用生成器模型
  3. 如何训练生成器模型
  4. 如何评估GAN模型的性能
  5. 如何利用最后的生成器模型来生成图片

在[3]的开头,其提到GAN,生成对抗网络是一种训练生成模型的网咯结构。如果是这样说的话,那就要具体考虑GAN的应用到底是什么了,你看这里文章[3]就是使用最后的生成器来进行图片生成。
(20200910 - 关于GAN的应用见文章[5],本质上我一直在强调这个事情,是我还没有完全将生成模型这个概念的应用给理解)

2.1 如何定义和训练判别器模型

因为处理的是minist的图片数据集,只需要加载keras库的数据即可;然后在经过一些处理,比如将数据维度扩充为三维(添加图片隧道维度),然后正则化为[0,1]范围内。还有一些其他不同的点,这里列举一下。

  • 在卷积神经网络中,没有使用池化层
  • 没有遍历整个数据集来进行训练,而是自己生成了相应的随机实例,然后输入
    下面是比较重要的内容。
    在这个过程中,就已经对GAN的判别器进行了训练,同时训练的过程中,输入的数据标签是“真”和“假”,同时因为,现在还没有相应的生成器,随意这里只是完全随机生成的一个向量,然后处理为相应维度。

那么这里总结一下大致逻辑。
1)构造一个两层的卷积神经网络的二分类网络(没有使用池化层,不过我感觉这里应该没有问题把,使用了也一样应该)
2)输入的数据,一部分类别为真的,就是从实际MNIST数据随机取出的数据,假的是自己随机生成的向量,然后转化为相应的维度
3)利用上述数据进行训练(20200822 注意这里的训练仅仅是为了进行测试示例,在最后的时候还要配合生成器来工作)

那么,在这个步骤结束之后,就能够得到一个输入是图片,然后输出这个图片是否是真假的判别器神经网络了。具体代码参考文章[3]。

2.2 如何定义和使用生成器模型

在这部分内容中,主要就是利用卷积层中的升采样来还原之前的图片大小,同时我还看到过有些直接使用了那种升维度的方式就是直接使用了dense层,所以从这个角度来看,其实内部是什么东西属于细节的地方(使用不同的模型可能导致不同的效果而已),但是并不影响整体的工作。
同时这里需要注意的是,定义了生成器模型之后,并没有使用compile方法来编译这个模型,这部分也是因为后面要将这部分模型与判别器连接起来,从而在这个混合模型中训练。

2.3 如何训练生成器模型

这部分也是耗费我耗费时间最多的部分,看了不少文章的方法,看到他们的做法都是差不多的。
首先要明白,训练生成器的过程是利用判别器的信息来反馈到生成器中,然后进行相应的权值更新。这部分内容比较关键。为了在这个过程中保证判别器的权值不会更新,就需要设置判别器的模型部分trainable=False,而且要注意这个东西的设置地点。这个部分的内容在另外一篇文章《GAN的学习 - 训练过程(冻结判别器)》中进行具体说明,这里只要知道,为了训练生成器模型,要将两个模型连接在一起,同时设置判别器不可训练。下面是设置GAN的模型。

# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model):
	# make weights in the discriminator not trainable
	d_model.trainable = False
	# connect them
	model = Sequential()
	# add generator
	model.add(g_model)
	# add the discriminator
	model.add(d_model)
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss='binary_crossentropy', optimizer=opt)
	return model

下图是GAN的原始训练过程。
GAN的原始训练过程

2.4 如何评估GAN模型的性能

在原文[3]中指出,GAN的性能是很难评估的,在本次实验中,因为使用的是MNIST图片,其图片内容比较简单,人眼就可以看出来,所以能够用主观的方式实现性能评估。但是,在其他任务中,是很难实现有针对性的性能评估的。原文内容如下:

Generally, there are no objective ways to evaluate the performance of a GAN model.
We cannot calculate this objective error score for generated images. It might be possible in the case of MNIST images because the images are so well constrained, but in general, it is not possible (yet).

针对MNIST数据集,他的评估方法包括:
1)周期性的检查判别器在区分真图片和假图片的性能
2)周期性的将生成器所生成的图片保存到硬盘,用于人眼主观检查
3)周期性的保存生成器模型

2.5 如何利用最后的生成器模型来生成图片

这部分相对来说比较简单,因为前文中提到,在GAN的训练过程中,已经周期性的保存了生成器的模型,那么就可以使用这部分保存的模型来进行图片的生成。
同时需要注意的是,这部分使用生成器进行图片生成时,需要一个随机的种子来驱动。

def generate_latent_points(latent_dim, n_samples):
	# generate points in the latent space
	x_input = randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	x_input = x_input.reshape(n_samples, latent_dim)
	return x_input

# load model
model = load_model('generator_model_100.h5')
# generate images
latent_points = generate_latent_points(100, 25)
# generate images
X = model.predict(latent_points)

3. 小节

通过文章[3],能够大致了解了GAN模型的构造及训练过程,其中比较关键的就是训练过程(关于具体的训练过程中的设置,见另外一篇文章);还有他这里提到了 一个比较关键性的问题,无法评估GAN的性能,这里能够评估是因为处理的数据集比较简单,同时是简单的数字图片所以可以人为的来进行评估,但是其他任务很难实现一个比较好的目标函数。

4. 后记

文章后面还包含了一些GAN与AE和VAE的区别,还有一些训练的技巧,同时还有一个GAN训练MNIST数据的代码。不过,我还是没有理解这个东西他在一些地方的实际应用是什么,最后时根据需求只要判别器或者生成器吗?这部分还是需要进一步来了解。

参考

[1]A Beginner’s Guide to Generative Adversarial Networks (GANs)
[2]教程 | 详解如何使用Keras实现Wassertein GAN
[3]How to Develop a GAN for Generating MNIST Handwritten Digits
[4]Generative Adversarial Nets in TensorFlow
[5]Wasserstein GAN implementation in TensorFlow and Pytorch

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

GAN的学习 - 基础知识了解 的相关文章

  • numba 函数何时编译?

    我正在研究这个例子 http numba pydata org numba doc 0 15 1 examples html multi threading http numba pydata org numba doc 0 15 1 ex
  • Python中Decimal类型的澄清

    每个人都知道 或者至少 每个程序员都应该知道 http docs oracle com cd E19957 01 806 3568 ncg goldberg html 即使用float类型可能会导致精度错误 然而 在某些情况下 精确的解决方
  • 使用 python 进行串行数据记录

    Intro 我需要编写一个小程序来实时读取串行数据并将其写入文本文件 我在读取数据方面取得了一些进展 但尚未成功地将这些信息存储在新文件中 这是我的代码 from future import print function import se
  • Python逻辑运算符优先级[重复]

    这个问题在这里已经有答案了 哪个运算符优先4 gt 5 or 3 lt 4 and 9 gt 8 这会被评估为真还是假 我知道该声明3 gt 4 or 2 lt 3 and 9 gt 10 显然应该评估为 false 但我不太确定 pyth
  • Django 模型在模板中不可迭代

    我试图迭代模型以获取列表中的第一个图像 但它给了我错误 即模型不可迭代 以下是我的模型和模板的代码 我只需要获取与单个产品相关的列表中的第一个图像 模型 py class Product models Model title models
  • if 语句未命中中的 continue 断点

    在下面的代码中 两者a and b是生成器函数的输出 并且可以评估为None或者有一个值 def testBehaviour self a None b 5 while True if not a or not b continue pri
  • 如何在 pytest 中将单元测试和集成测试分开

    根据维基百科 https en wikipedia org wiki Unit testing Description和各种articles https techbeacon com devops 6 best practices inte
  • 以同步方式使用 FastAPI,如何获取 POST 请求的原始正文?

    在中使用 FastAPIsync not async模式 我希望能够接收 POST 请求的原始 未更改的正文 我能找到的所有例子都显示async代码 当我以正常同步方式尝试时 request body 显示为协程对象 当我通过发布一些内容来
  • 从 python 发起 SSH 隧道时出现问题

    目标是在卫星服务器和集中式注册数据库之间建立 n 个 ssh 隧道 我已经在我的服务器之间设置了公钥身份验证 因此它们只需直接登录而无需密码提示 怎么办 我试过帕拉米科 它看起来不错 但仅仅建立一个基本的隧道就变得相当复杂 尽管代码示例将受
  • 如何解决使用 Spark 从 S3 重新分区大量数据时从内存中逐出缓存的表分区元数据的问题?

    在尝试从 S3 重新分区数据帧时 我收到一个一般错误 Caused by org apache spark SparkException Job aborted due to stage failure Task 33 in stage 1
  • 奇怪的 MySQL Python mod_wsgi 无法连接到 'localhost' (49) 上的 MySQL 服务器问题

    StackOverflow上也有类似的问题 但我还没有发现完全相同的情况 这是在使用 MySQL 的 OS X Leopard 机器上 一些起始信息 MySQL Server version 5 1 30 Apache 2 2 13 Uni
  • 在 pytube3 中获取 youtube 视频的标题?

    我正在尝试构建一个应用程序来使用 python 下载 YouTube 视频pytube3 但我无法检索视频的标题 这是我的代码 from pytube import YouTube yt YouTube link print yt titl
  • Python 将日志滚动到变量

    我有一个使用多线程并在服务器后台运行的应用程序 为了无需登录服务器即可监控应用程序 我决定包括Bottle http bottlepy org为了响应一些HTTP端点并报告状态 执行远程关闭等 我还想添加一种查阅日志文件的方法 我可以使用以
  • 无法在 osx-arm64 上安装 Python 3.7

    我正在尝试使用 Conda 创建一个带有 Python 3 7 的新环境 例如 conda create n qnn python 3 7 我收到以下错误 Collecting package metadata current repoda
  • 使用yield 进行字典理解

    作为一个人为的例子 myset set a b c d mydict item yield join item s for item in myset and list mydict gives as cs bs ds a None b N
  • Tkinter - 浮动窗口 - 调整大小

    灵感来自this https stackoverflow com a 22424245 13629335问题 我想为我的根窗口编写自己的调整大小函数 但我刚刚注意到我的代码显示了一些性能问题 如果你快速调整它的大小 你会发现窗口没有像我希望
  • Ubuntu 上的 Python 2.7

    我是 Python 新手 正在 Linux 机器 Ubuntu 10 10 上工作 它正在运行 python 2 6 但我想运行 2 7 因为它有我想使用的功能 有人敦促我不要安装 2 7 并将其设置为我的默认 python 我的问题是 如
  • 在Python中按属性获取对象列表中的索引

    我有具有属性 id 的对象列表 我想找到具有特定 id 的对象的索引 我写了这样的东西 index 1 for i in range len my list if my list i id specific id index i break
  • 如何读取Python字节码?

    我很难理解 Python 的字节码及其dis module import dis def func x 1 dis dis func 上述代码在解释器中输入时会产生以下输出 0 LOAD CONST 1 1 3 STORE FAST 0 x
  • 您可以使用关键字参数而不提供默认值吗?

    我习惯于在 Python 中使用这样的函数 方法定义 def my function arg1 None arg2 default do stuff here 如果我不供应arg1 or arg2 那么默认值None or default

随机推荐

  • 蚁群算法(Ant Colony Optimization,ACO)

    1 算法基本思想 在自然界中 蚂蚁群体在寻找食物的过程中 无论是蚂蚁与蚂蚁之间的协作还是蚂蚁与环境之间的交互均依赖于一种被称为信息素 Pheromone 的物质实现蚁群的间接通信 从而通过合作发现从蚁穴到食物源的最短路径 蚂蚁在寻找食物的过
  • 2019.9最新JRebel激活方式

    原文链接 最近JRebel离线方式到期 idea报无法激活JRebel了 找了很多以前的方式都无法生效 ip或域名都已经失效了 好在找到了大神有效的激活方式 以下是激活步骤 1 下载反向代理软件 下载地址 https github com
  • Linux安装python3

    1 获取安装包 第一种方式 通过官网下载 登录 https www python org downloads source 下载所需安装包并上传至服务器 第二种方式 通过命令行的下载工具 以python3 6 1为例 wget https
  • ViewModel 使用及原理解析

    本文是基于 androidx lifecycle lifecycle extensions 2 0 0 的源码进行分析 ViewModel旨在以生命周期意识的方式存储和管理用户界面相关的数据 它可以用来管理Activity和Fragment
  • WDA学习笔记(二)通过页面跳转理解WDA开发流程

    在进行开发之前先简单介绍一下WDA的控制器 WDA控制器包括组件控制器 定制控制器 视图控制器和窗口控制器 组件控制器 每个 Web Dynpro 组件只有一个组件控制器 该控制器是全局控制 器 对所有其它控制器可见 组件控制器可以控制整个
  • 服务器重装

    搜索自己的品牌看如何进入bios 这里是在最后进入前按del bios的命令 在boot里面主要有 Boot Settings Configuration 启动选项设定 Boot Device Priority 启动顺序设置 Hard Di
  • Vue项目Vite配置代理解决跨域问题

    Vite 一个Vue作者开发的Web开发工具 它具有快速的冷启动 及时的模块热更新 真正的按需加载 Vite基于浏览器原生 ES imports 的开发服务器 利用浏览器去解析 imports 在服务器端按需编译返回 完全跳过了打包这个概念
  • 一文读懂函数指针

    前言 本篇是关于函数指针的保姆级教程 一 函数指针的定义和修饰 函数指针广泛应用于嵌入式软件开发中 其常用的两个用途 调用函数和做函数的参数 void fptr 把函数的地址赋值给函数指针 一般采用如下形式 fptr Function 如果
  • 在CentOS 7中使用SAMBA部署文件共享服务

    SMB Server Message Block 服务信息块 又称CIFS Common Internet File System 通用Internet文件系统 是一种应用层网络传输协议 微软公司和英特尔公司于1987年共同制定了SMB 旨
  • 电源防反接小结

    前言 为了方便查看博客 特意申请了一个公众号 附上二维码 有兴趣的朋友可以关注 和我一起讨论学习 一起享受技术 一起成长 1 概述 电源的输入部分 为了防止误操作 将电源的正负极接反 对电路造成损坏 一般会对其进行防护 如采用保险丝 二极管
  • Cause: java.sql.SQLSyntaxErrorException: FUNCTION test5.count does not exist. Check the ‘Function Na

    解决方案 戴脑子 删空格
  • 虚拟机无法连接到图形服务器,详解VMware 当中出现:无法将 Ethernet0 连接到虚拟网络"VMnet8"的问题...

    此文 是通过查阅各位大神的经验总结得出的小小的结论 只是为了记录自己在学习过程中 遇到的问题而写 假若能帮到大家 十分荣幸 当vmvare出现 无法将 ethernet0 连接到虚拟网络 vmnet8 的问题 出现本问题的情况 是在存在主机
  • Unity ECS学习笔记(一)

    ECS架构概述 ECS术语 实体Entity 像容器一样 组件数据Component Data 要存储在实体中的数据 不包括处理 组件系统ComponentSystem 处理 组Group 组件系统运行所需的ComponentData列表
  • 交叉熵代价函数(cross-entropy cost function)

    1 从方差代价函数说起 代价函数经常用方差代价函数 即采用均方误差MSE 比如对于一个神经元 单输入单输出 sigmoid函数 定义其代价函数为 其中y是我们期望的输出 a为神经元的实际输出 a z where z wx b 在训练神经网络
  • IntelliJ IDEA 配置go语言环境(图文教程)

    首先确保你电脑安装了go并配置了环境变量 我这是Win10 golang版本1 17 3 cmd输入go version 回车 出现版本信息 说明本地环境没问题 配置 Idea 我的 idea版本2019 3 5发行的 选择File set
  • Jaspersoft Studio安装

    Jaspersoft Studio 一个专为JasperReport 报表引擎而开发的报表设计器 基于Eclipse实现 它能够创建图表 图片 子报表 交叉表等复杂的布局 并可以通过JDBC TableModels JavaBeans XM
  • unity鼠标动态点击事件

    unity鼠标动态点击事件 通过鼠标点击按钮来改变text中的内容 这个功能主要就是通过按钮来监听onclick这个组件 第一步 新建一个按钮 按钮自带一个text 所以就不用新建一个text了 第二步 新建一个脚本 using Unity
  • 【shell】linux输出重定向

    目录 即看即用 详细 知识铺垫 说明 shell 输出重定向2 1 即看即用 标准输出 ls thereisno 1 gt out txt 标准输出重定向 也可以不加1写成 ls thereisno gt out txt 标准错误 ls t
  • pyecharts生成并保存图片

    文章目录 一 安装snapshot selenium 二 采用chromediver进行图片操作 1 下载chromediver 2 将chromedriver exe复制到Chrome浏览器安装目录和Python根目录 3 配置PATH
  • GAN的学习 - 基础知识了解

    20200818 0 引言 最近看到了一些论文 是GAN在密码生成 PassGAN DGA检测 DeepDGA 这些论文 所以希望深入了解一下GAN的内容 之前的时候 只是知道GAN是什么东西 通过训练两个神经网络 然后相互促进来实现检测的