GAN数学原理及代码实现

2023-11-19

GAN

generator 和discriminator

生成式对抗网络(Generative Adversarial Networks, GAN),包括生成器 generator 和判别器 discriminator。

生成器(generator)是一个神经网络,根据不同的输入向量可以生成不同特征的图像或者语音等。

判别器 (discriminator)也是一个神经网络,其输入是 generator 的输出,输出为一个标量。discriminator 用于判别 generator 的输出和真实数据的相似情况。输出的标量值越大,对应的 generator 生成的图片更加真实。
在这里插入图片描述
最开始,generator 产生了一堆近似噪声的东西,训练 discriminator 并固定 generator 使得其能够判别 generator 产生的和数据集中图片的区别。

接下来固定 discriminator,来训练 generator 使得 generator 生成的图像可以骗过 discriminator。

如此循环迭代,直到 generator 的图像满意为止。

GAN 中的训练算法可以表述为如下:
在这里插入图片描述
上图中,蓝色的框表示 discriminator 的训练,红色框表示 generator 的训练。输入到 generator 中 的向量 z z z$可以从一个分布中随机采样得到。它和数据集中的数据 x x x 并没有直接的关系, x ~ \widetilde{x} x 表示生成的数据。
下面重点来探讨两个 V ~ \widetilde{V} V 的函数。需要调整参数使得 V ~ \widetilde{V} V 最大,因此使用的是梯度上升的方法而不是梯度下降,所以参数更新那里会有一个负号的差距。

  1. 对于判别器,当判别器的输入为数据集中的真实图片,那么 D ( x i ) D(x^i) D(xi) 越大,对应的 V ~ \widetilde{V} V 将增大,而对于生成器产生的数据 x ~ \widetilde{x} x 越小,对应的 V ~ \widetilde{V} V 增大。这样判别器优化的最终结果是对于数据集中的数据,会给出一个很高的分数(接近 1),而对于生成器生成的数据,则给出很低的分数。
  2. 对于生成器,优化目标依然是使得 V ~ \widetilde{V} V 最大,但是需要固定住判别器。这样的训练的最终结果是生成器骗过判别器,使得生成的数据经过判别器输出接近 1。

对于具体的实现,一般会将生成器和判别器放进一个网络中去,对于生成器的训练,冻结判别器;而对于判别器的训练,则冻结生成器。

一般来说,会将判别器的输出通过 sigmoid 归一化,并且一次生成器的训练不会迭代太多次,这样生成的数据不会有太大的变化。

仅使用generator来生成图片

generator 缺乏全局观。以图像数据为例,输出的每一个神经元对应一个像素值,但是更重要的是,输出的像素对应的值要构成一幅有意义的图像。通常的神经网络结构,同一层输出神经元之间是相互独立的,这样会导致输出图像的不和谐。
在这里插入图片描述

仅使用discriminator来生成图片

相比较 generator 直接生成,discriminator 生成很容易就有全局观念,但是也会有新的问题。主要如下:

  1. 利用 discriminator 生成需要求解
    在这里插入图片描述
    这需要很大的搜索量,对于图像而言,遍历所有的图像是几乎不可能完成的任务。
  2. 怎样获得一个训练良好的 discriminator。训练一个 discriminator,需要标注什么样的图像是好的,什么样的图像是不好的,而好与不好很难去界定。

条件GAN

条件 GAN 是对生成器和判别器分别添加一个对应的条件,以文本转图像为例,生成器 generator 希望生成和文本相关的图像,就需要将原始的文本作为条件,和一个随机向量一起输入到 generator 中,这告诉 generator,需要根据文本来生成对应的图像。

对于文本转图像的判别器 discriminator,需要将源文本输入到 discriminator 中作为判别的一个条件,用来告诉 discriminator,generator 不仅要生成合理的图像,生成的图像还要和文本对应。

如果不加条件,那么可能就会生成和文本描述无关的内容。
在这里插入图片描述

判别器的两种设计结构

条件 GAN 的 discriminator 的设计大致思路有以下两种:

  • 将 generator 生成的数据 x 和一个条件 c 分别通过一个网络,进行编码,然后通过一个网络生成一个总的分数;
  • 分别对 generator 生成的数据质量,和条件符合情况进行打分;
    在这里插入图片描述

条件 GAN 的训练

在文本转图像的场景下,c 一般为文本,或者文本的编码,x 为数据集中该文本对应描述的图像。
在这里插入图片描述
discriminator 的训练过程,相比较原始 GAN 的训练,还需要负样本,就是可以从数据集中拿出图像和描述不匹配的样本来协助训练。

GAN 背后的数学理论

KL 散度

KL 散度又称为相对熵,信息散度,信息增益。

KL 散度是是两个概率分布 P 和 Q 差别的非对称性的度量,即分布 P 和 分布 Q 的距离。

典型情况下,P 表示数据的真实分布,Q 表示数据的理论分布,模型分布,或 P 的近似分布。
在这里插入图片描述

JS 散度

JS 是 Jensen-Shannon 的缩写,JS 散度度量了两个概率分布的相似度,基于 KL 散度的变体,解决了 KL 散度非对称的问题。一般地,JS 散度是对称的,其取值是 0 到 1 之间。
在这里插入图片描述

生成器和判别器的数学解释

一幅图像可以看成是高维空间中的一个点,比如 64 × 64 64 \times 64 64×64 的一个灰度图像,那么可以认为该图像是 64 × 64 64 \times 64 64×64 维度空间内的一个点。

特定的图像组成特定的分布,比如所有的人脸图像可能在对应的 64 × 64 64 \times 64 64×64 维空间构成一个分布,只有该分布内的图像看起来像人脸。

生成器 generator 将一个确定分布的输入转换为一个未知分布的一个采样点,这个分布直观上来看,以图像为例,就是 generator 产生的一个图像。而 GAN 的训练目标就是使得 generator 输出的分布尽可能和数据集的分布接近。

一般来说,衡量 generator 产生的分布和对应的目标分布之间的距离的方法是 KL 散度或者 JS 散度,但是对于 GAN,并不知道目标分布的表达式,generator 由于是个神经网络,也不知道 generator 生成的分布的表达式,所以如何衡量两个分布之间的距离是一个问题。

在 GAN 中,discriminator 就成了衡量分布之间差距的工具。事实上,经典 GAN 的损失函数实际上就是在衡量两个分布的 JS 散度。

generator,可以表述为
在这里插入图片描述
本质上是衡量两个分布 JS 散度的过程,可通过公式推导证明。

生成器和判别器的对抗训练过程

GAN 的训练过程是解一个 min-max 的问题,最终的目标是最小化 generator 生成的数据和真实数据的分布差距。

其中 max 的过程是计算 JS 散度的过程,而 min 的过程是使得散度最小化的过程。

对于 generator 的更新,仅需要少量的迭代,而对 discriminator 的训练则应该尽可能多一些,这是因为固定 discriminator 训练 generator 是减少 JS 散度的的过程,少量的 generator 更新可以让 discriminator 更好的度量 JS 散度。

GAN的训练技巧

GAN 的训练过程是生成器 generator 和判别器 discriminator 对抗平衡的过程,难点就在平衡问题上。

如果判别器性能过于强大,那么由于 sigmoid 激活函数的原因,几乎很少有梯度信息传递到生成器。

判别器 discriminator 的训练其实是衡量两个分布 JS 散度的过程。JS 散度有一个缺陷,当两个分布的采样数据不重合的时候,其值恒为 log2,这样当生成器生成的数据和真实采用的数据无重叠时就很难更新。

一个朴素的解决办法就是不要把判别器训练的太好,这样使得 sigmoid 有足够的梯度返回给生成器。

LSGAN

LSGAN 是一种简单的改进 GAN 的方法,去掉 sigmoid,直接用 linear 来输出。

WGAN

WGAN 是 GAN 训练优化非常重要的技术, wasserstein 距离也可以用来描述分布之间的差距,但是克服了JS 散度的缺陷。
为了衡量 P G P_G PG P d a t a P_{data} Pdata之间的 wessertein 距离, V ( G , D ) V(G, D) V(G,D) 修改为:
V ( G , D ) = m a x D ∈ 1 − L i p s c h i t z { E x → P d a t a [ D ( x ) ] − E x → P G [ D ( x ) ] } V(G, D) = \mathop{max}\limits_{D \in 1-Lipschitz} \{E_{x \to P_{data}}[D(x)] - E_{x \to P_G}[D(x)] \} V(G,D)=D1Lipschitzmax{ExPdata[D(x)]ExPG[D(x)]}
x 从 P d a t a P_{data} Pdata 取出来时, D ( x ) D(x) D(x)越大越好;相反当 x 从 P G P_G PG 取出来时, D ( x ) D(x) D(x)越小越好。

判别器 D 必须是 1-Lipschitz 函数,即这个判别器函数要比较平滑。
所谓的 Lipschitz 函数是指满足如下要求的函数:
∣ ∣ f ( x 1 ) − f ( x 2 ) ∣ ∣ ≤ K ∣ ∣ x 1 − x 2 ∣ ∣ ||f(x_1) - f(x_2)|| \leq K ||x_1 - x_2|| ∣∣f(x1)f(x2)∣∣K∣∣x1x2∣∣
令 K = 1,即为 1 − L i p s c h i t z 1-Lipschitz 1Lipschitz,表达式为:
∣ ∣ f ( x 1 ) − f ( x 2 ) ∣ ∣ ≤ ∣ ∣ x 1 − x 2 ∣ ∣ ||f(x_1) - f(x_2)|| \leq ||x_1 - x_2|| ∣∣f(x1)f(x2)∣∣∣∣x1x2∣∣
这样很显然就限制了函数的变化率,让函数变得更加平滑。

在最原始的 WGAN 论文中,通过限制住判别器的网络的权重的大小来解决这样的问题。
在这里插入图片描述
WGAN 相较于原始 GAN 的改动主要如下:

  1. 去掉判别器 discriminator 输出的 sigmoid 激活函数;
  2. 对于判别器的损失函数,修改为 V ~ = 1 m ∑ i = 1 m D ( x i ) − 1 m ∑ i = 1 m D ( x ~ i ) \widetilde{V} = \frac{1}{m} \sum_{i=1}^m D(x^i) - \frac{1}{m} \sum_{i=1}^m D(\widetilde{x}^i) V =m1i=1mD(xi)m1i=1mD(x i),求梯度上升;
  3. 对于判别器,添加 weight clipping,即限制网络权重的范围,以及 gradient penalty 等方法;
  4. 对于生成器的损失函数,修改为 V ~ = − 1 m ∑ i = 1 m D ( G ( z i ) ) \widetilde{V} = - \frac{1}{m} \sum_{i=1}^m D(G(z^i)) V =m1i=1mD(G(zi)),求梯度下降。

GAN生成卡通头像

模型架构:
生成器是由简单的转置卷积 + BatchNorm + LeakyReLU激活函数组成;

判别器由普通的 conv+batchnorm+leaky_relu 结构的模块来构建。
在这里插入图片描述
仅使用预训练生成器生成图片效果:
在这里插入图片描述

损失函数与训练

将判别器看成是一个二分类的分类器,分类一张图片是不是真实的图像,使用 nn.BCELoss() 接口。

BCE 是二分类交叉熵(binary cross entropy)的缩写,简单来说就是计算二分类问题的交叉熵。

其计算方式如下:
l n = − w n [ y n ∗ l o g ( x n ) + ( 1 − y n ) l o g ( 1 − x n ) ] l_n = -w_n [y_n * log(x_n) + (1 - y_n)log(1 - x_n)] ln=wn[ynlog(xn)+(1yn)log(1xn)]
其中 n 表示一条数据条目,一个 batch 由 N 个数据条目组成。 w n w_n wn表示第 n 条数据的损失函数权重,一般不考虑,这已经部分的实现了 WGAN 。

训练分为生成器 generator 和判别器 discriminator 的训练。
对于经典的 GAN 模型,训练过程首先固定生成器,训练判别器 k 次;

之后会固定判别器,训练生成器,一般而言生成器仅迭代一次。

image-to-image模型搭建

Image-to-Image Translation with Conditional Adversarial Network
https://arxiv.org/abs/1611.07004
将抽象的简单的房屋示意图转换为真实的房屋图像
在这里插入图片描述
生成器搭建:UNet,是一种 U 形的网络,通过卷积下采样和反卷积的上采样形成 U 形的结构,同时增加层之间的 skip 操作。
在这里插入图片描述
为了搭建模型的方便,首先对常用的 conv-batchnorm-relu 以及 transconv-batchnorm-relu 做了简单的一步封装,即下方的 ConvBnReLU 类和 TransConvBnReLU 类。

之后是生成器类(Pix2PixGenerator),依照 UNet 的结构,将生成器分为下采样层和上采样层,上下采样层之间的连接,即 skip 是通过 concatenate 操作来实现的。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

GAN数学原理及代码实现 的相关文章

随机推荐

  • C++ 模板特化

    模板的特化 在使用模板时 可以实现一些与类型无关的代码 但对于一些特殊类型的可能会得到一些错误的结果 这时就一些需要特殊处理 对模板进行特化 在原模板类的基础上 针对特殊类型所进行特殊化的实现方式 模板特化又分为 函数模板特化 类模板特化
  • 关于Mysql线程的基本设置

    客户端发起连接到mysql server mysql server监听进程 监听到新的请求 然后mysql为其分配一个新的 thread 去处理此请求 从建立连接之开始 CPU要给它划分一定的thread stack 然后进行用户身份认证
  • 手把手教你部署AutoGPT,30分钟拥有自己的AI助手!

    如果不想往下看了 那就直接 点我 AutoGPT是由GPT 4驱动的开源应用程序 可以自主实现用户设定的任务目标 从AutoGPT开始 AI将可以自主地提出计划 然后执行计划 还具有互联网访问 长期和短期内存管理 用于文本生成的GPT 4实
  • std::packaged_task的简单使用

    std packaged task 包装一个可调用的对象 并且允许异步获取该可调用对象产生的结果 从包装可调用对象意义上来讲 std packaged task 与 std function 类似 只不过 std packaged task
  • 【Java】网络编程——多线程下载文件

    前言 多线程下载文件 比单线程要快 当然 线程不是越多越好 这和获取的源文件还有和网速有关 原理 在请求服务器的某个文件时 我们能得到这个文件的大小长度信息 我们就可以下载此长度的某一个片段 来达到多线程下载的目的 每条线程分别下载他们自己
  • docker使用(一)生成,启动,更新(容器暂停,删除,再生成)

    docker使用 一 编写一个 Dockerfile 构建镜像 构建失败 构建成功 运行镜像 运行成功 修改代码后再次构建 请不要直接进行构建 要将原有的旧容器删除或暂停 停止成功 删除成功 再次构建且构建成功 要创建一个镜像 你可以按照以
  • 最全前端性能优化总结

    最全前端性能优化总结 前端性能优化分两部分 一 加载性能优化 1 减少请求次数 为什么减少请求次数 减少请求次数方式 2 减少资源大小 减少资源大小方式 3 网络优化 其他 二 渲染性能优化 浏览器渲染过程 重排 重绘 渲染性能优化方式 三
  • GB28181状态信息报送解读及Android端国标设备接入技术实现

    今天主要聊聊GB T28181状态信息报送这块 先回顾下协议规范相关细节 然后再针对代码实现 做个简单的说明 状态消息报送基本要求 当源设备 包括网关 SIP设备 SIP客户端或联网系统 发现工作异常时 应立即向本 SIP监控域 的SIP服
  • Qume-KVM虚拟化

    Qume KVM虚拟化 文章目录 虚拟化概述 KVM概述 KVM虚拟化架构 Qume概述 部署Qume KVM KVM Web管理界面安装 Web管理界面 添加连接 新建存储池 新建镜像 新建网络 实例管理 虚拟化概述 什么是虚拟化 虚拟化
  • 用Python画出圣诞树,瞧瞧我这简易版的吧

    前言 嗨嗨 大家好 我是小圆 今天来实现一下 用python画出圣诞树 代码 模块 源码 点击领取即可 import turtle as t from turtle import import random as r import time
  • 32种针对硬件与固件的漏洞攻击

    2018年1月 全球计算机行业因为Meltdown以及Spectre这两个在处理器中存在的新型漏洞而受到威胁 这两个漏洞直接打破了分离内核以及用户内存的OS安全边界 这两个漏洞基于了现代CPU的预测执行功能 而缓解这两个漏洞带来的影响则需要
  • 最快方式 ESP-IDF 创建例子 教程

    需要条件 安装了 VSCODE 安装了插件 Espressif IDF工具 系统中安装了 ESP IDF 可使用离线包 或在线安装包 在插件中配置了 ESP IDF 可能需要在线更新一些东西 点击F1 输入 ESP 等待提示 出现提示后 选
  • 软件测试 接口测试 入门Jmeter 接口关联 提取器 断言 与fiddler配合使用 使Jmeter录制和创建脚本 操作数据库 持续集成测试

    文章目录 1 接口测试概述 1 1 什么是接口测试 1 2 接口分类 1 3 接口的设计风格分类 1 3 1 Soap架构 1 3 2 Rpc架构 1 3 3 RestFul架构 1 3 4 接口测试工具介绍 1 4 接口测试流程 2 Jm
  • 使用 Vue.js 结合bootstrap 实现的分页控件

    使用 vue js 结合 bootstrap 开发的分页控件 效果如下 实现代码 div class contai div
  • 毕业设计-基于卷积神经网络的花卉图片识别

    目录 前言 课题背景和意义 实现技术思路 一 LeNet 5 卷积神经网络模型 二 设计思路 三 实验及结果分析 四 总结 实现效果图样例 最后 前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实习为毕业后面临的就业升学做准备 一边要
  • vue项目使用externals优化打包体积

    查看打包体积 下载查看打包体积的插件 npm install webpack bundle analyzer save dev 在vue config js中配置 chainWebpack config gt 打包结果分析 if proce
  • prompt 综述

    动手点关注 干货不迷路 1 概述 1 1 基本概念 用一句话概括模板学习 即将原本的输入文本填入一个带有输入和输出槽位的模板 然后利用预训练语言模型预测整个句子 最终可以利用这个完整的句子导出最终需要的答案 模板学习最吸引人的关键在于其通过
  • Spring Boot 项目在本地可以成功访问但是在服务器上无法访问 Controller 方法解决办法

    这是一篇记录自己失了智的博客 晚上写了一个小 Demo 来测试在云服务器上同时运行两个 jar 包的情况 两个项目的端口分别为 8080 和 8880 以 8880 为端口的 Demo 在本地成功运行并且访问到了 Controller 中的
  • 2023华为OD机试真题【最大平分数组/动态规划】

    题目描述 给定一个数组nums 可以将元素分为若干个组 使得每组和相等 求出满足条件的所有分组中 最大的平分组个数 输入描述 第一行输入 m 接着输入m个数 表示此数组 数据范围 1 lt M lt 50 1 lt nums i lt 50
  • GAN数学原理及代码实现

    GAN generator 和discriminator 生成式对抗网络 Generative Adversarial Networks GAN 包括生成器 generator 和判别器 discriminator 生成器 generato