GAN原理及Pytorch框架实现GAN(比较容易理解)

2023-05-16

目录

1.初识GAN

什么是GAN?

GAN应用场景

2.GAN原理结构

(1)生成对抗网络子网络

(2)结构图

(1)生成器 

(2)判别器

(3)训练技巧 

3.GAN网络模型选择

(1)生成模型

(2)判别模型

4.GAN训练目标函数

(1)生成模型

(2)判别模型

5.训练算法

6.GAN代码实现

7.mainWindow窗口显示生成器生成的图片

拓展


1.初识GAN

  • 什么是GAN?

    • GAN(Generative Adversarial Networks):生成对抗网络;
    • GAN是当前人工智能领域最为重要的研究热点之一,并且应用非常的广泛;
  • 2014年,Universite de Montreal 大学Yoshua Bengio(2018年图灵奖获得者)的学生Ian Goodfellow提出 生成对抗网络(Generative adj-terminal networks,简称 GAN),从而开辟了深度学习最赤手可热的研究方向。
  • 从2014-2019年,GAN的研究稳步推进,研究捷报频传,最新的GAN算法在图片生成上的效果甚至达到了肉眼很难分辨的程度。由于GAN的发明,Ian Goodfello荣获GAN之父称号,并获得了2017年麻省理工大学科技评论颁奖的35 Innovators Goodfellow奖项。
  • 该方法利用了两个网络,一个称为生成网络,另一个称为鉴别网络,可用于以音频、视频和文本的形式产生不同寻常的创造性输出。他的这项研究,在人工智能文献中被广泛引用。
  • GAN应用场景

    • 图像编辑:给定一张图像,可以在该图像的基础之上生成各种各样的图像;
    • 恶意攻击检测:深度学习生成的模型是可以被黑客攻击,利用甚至控制的。为了对抗这样的逆向攻击(adversarial attacks),可以训练对抗神经网络去生成更多的虚假训练数据作为假想敌,让模型在演习中去识别出这些虚假数据,GAN生成的虚假数据让正在做分类的模型更加稳健;
    • 数据生成:例如医疗领域,缺少训练数据是应用深度学习的最大障碍。数据增强的传统做法是将原图像拉伸旋转剪切,但这毕竟还是原来的图像,通过使用GAN,能够生成更多类似的数据;
    • 注意力预测:人类在看一张图片时,往往只关注特定的部分,而通过GAN模型,可以预测出人类关心的区域在哪里。
    • 三位结构生成:pix2vox是一个基于GAN的开源工具,能够根据手绘的二维图片,生成对应的三维结构,不止有对应的形状,还会生成对应的颜色,有了这样的工具,就能降低3D建模的门槛,从而让3D打印更容易的落地。

2.GAN原理结构

提示:下面的原理解释可能对于读者来说比较枯燥无味,但是还是希望读者可以坚持看完原理,因为只有这样你才能真正的理解GAN的实现原理。

(1)生成对抗网络子网络

 GAN包含:生成网络(Generator Network)和判别网络(Discriminator Network),其中生成网络Gen负责学习样本的真实分布,判别网络Dis负责将对生成网络生成的样本和真实样本分别进行判别。

(2)结构图

(1)生成器 

生成模型以随机噪声(Random noise)或者类似的控制变量作为输入,生成器一般使用多层的神经网络实现,其输出为生成的样本,也就是一张假的图片(fake image);这样样本和真实给定的样本一起给判别器模型训练。

"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/4/24 14:21
"""
import torch
import numpy as np
#对于生成器,输入的为正态分布随机数
#输出为: [1,28,28]图片

class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(in_features=100,out_features=256),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=256,out_features=512),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=512,out_features=784),
            torch.nn.Tanh()#对于生成器使用tanh激活函数更好
        )
    def forward(self,input):
        x = self.fc(input)
        img = x.view(-1,28,28)
        return img

(2)判别器

判别器模型是一个二分类器,判别一个样本是真实的样本还是生成器生成的样本,一般也是使用神经网络实现。

"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/4/24 14:21
"""
import torch
import numpy as np

#判别器的输入为一张图片
#输出为二分类的概率值
#判别器对log(1 - D(G(z)))的判别作为生成器的损失值

class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(in_features=784,out_features=512),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(in_features=512,out_features=256),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(in_features=256,out_features=1),
            torch.nn.Sigmoid()
        )
    def forward(self,input):
        x = input.view(-1,784)
        x = self.fc(x)
        return x

(3)训练技巧 

  • 对于生成模型:训练目标是让生成的数据尽可能的与真实数据相似,最小化判别模型的判别准确率。
  • 对于判别模型:训练目标是最大化判别准确率,即区分样本是真实样本还是生成器生成的样本。

可以发现,这个过程是矛盾的,因此:

  • 在训练的过程中采用交替优化的方式,每一次迭代时分为两个阶段:
    • 第一个阶段:首先固定判别模型,优化生成模型,使得生成的数据备判别模型判定为真样本的概率尽可能的高。
    • 第二个阶段:固定生成模型,优化盘被模型,提高判别模型的分类准确率。

提示:在训练过程中,生成器努力地让生成的图像更加的真实,而判别器则努力地识别生成器图片的真假,这是是一个相互博弈的过程,互相提升自己,也就是不断的进行对抗的过程。随着训练的进行,生成模型产生的样本和真实样本几乎没有什么差别,判别模型也无法准确的判别一个样本的真假,此时的分类错误率为0.5(那什均衡)

3.GAN网络模型选择

生成对抗网络是一个抽象的框架,并没有指定生成模型和判别模型具体为哪一种模型,可以是神经网络模型,也可以是卷积神经网络 模型或者其他的机器学习模型。

(1)生成模型

        在本文中,生气模型选择是神经网络模型。根据类型等输入变量来生成图像之类的样本数据,生成模型接收的输入是类别之类的隐变量和随机噪声,输出与训练样本相似的样本数据(比如图片之类的)。

(2)判别模型

        判别模型一般用分类问题的神经网络,用于区分样本的真假(给定的真实数据和生成器生成的数据),是一个二分类问题。

4.GAN训练目标函数

提示:在确定生成模型和判别模型之前,首先了解一下logistic回归模型:

logistic回归即对数概率回归,是一种二分类问题的分类算法,使用sigmoid函数估计出样本属于正样本的概率(关于细节推导,建议看《机器学习原理,算法与应用》)。

logistic回归似然函数:

  • 回归对数函数和生成对抗区别:
    • logistic回归在训练达到最优点处时,负样本的预测输出接近于0;
    • 生成对抗网络中判别模型对抗样本的输出概率值在最优点处接近于0.5,。 

(1)生成模型

(2)判别模型

5.训练算法

  

6.GAN代码实现

提示:代码放在了Github上,读者自行下载:https://github.com/KeepTryingTo/Pytorch-GAN

 

7.mainWindow窗口显示生成器生成的图片

提示:这里编写了一个显示生成器显示图片的程序(mainWindow.py),加载之前训练之后保存的生成器模型,之后可使用该模型进行随机生成数字图片,如下:

(1)运行mainWindow.py 初始界面如下

 

 (2)点击生成图片按钮,每一次的点击生成的数字都不是一样的。

 

 

拓展

pytorch中的detach作用

参考文章:

《TensorFlow深度学习》

《机器学习原理,算法与应用》

https://www.jiqizhixin.com/articles/2019-04-15-6

https://b23.tv/6P7M8mh

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

GAN原理及Pytorch框架实现GAN(比较容易理解) 的相关文章

  • 云计算思维导图

    根据近期的云计算学习心得 xff0c 将云计算部分内容制作成思维导图 xff0c 方便于广大云计算学习者作为辅导讲义 xff01 思维导图内容主要包含 xff1a 1 云计算概述 2 云体系结构 3 网络资源 4 存储资源 5 硬件介绍 6
  • 路由器重温——串行链路链路层协议积累

    对于广域网接口来说 xff0c 主要的不同或者说主要的复杂性在于理解不同接口的物理特性以及链路层协议 xff0c 再上层基本都是 IP 协议 xff0c 基本上都是相同的 WAN口中的serial接口主要使用点对点的链路层协议有 xff0c
  • 路由器重温——PPPoE配置管理-2

    四 配置设备作为PPPoE服务器 路由器的PPPoE服务器功能可以配置在物理以太网接口或 PON 接口上 xff0c 也可配置在由 ADSL 接口生成的虚拟以太网接口上 1 配置虚拟模板接口 虚拟模板接口VT和以太网接口或PON接口绑定后
  • Python入门自学进阶——1--装饰器

    理解装饰器 xff0c 先要理解函数和高阶函数 首先要明白 xff0c 函数名就是一个变量 xff0c 如下图 xff0c 定义一个变量名和定义一个函数 xff0c 函数名与变量名是等价的 既然函数名就是一个变量名 xff0c 那么在定义函
  • Python入门自学进阶-Web框架——21、DjangoAdmin项目应用

    客户关系管理 以admin项目为基础 xff0c 扩展自己的项目 一 创建项目 二 配置数据库 xff0c 使用mysql数据库 xff1a 需要安全mysqlclient模块 xff1a pip install mysqlclient D
  • Python入门自学进阶-Web框架——33、瀑布流布局与组合查询

    一 瀑布流 xff0c 是指页面布局中 xff0c 在显示很多图片时 xff0c 图片及文字大小不相同 xff0c 导致页面排版不美观 如上图 xff0c 右边的布局 xff0c 因为第一行第一张图片过长 xff0c 第二行的第一张被挤到第
  • Python入门自学进阶-Web框架——34、富文本编辑器KindEditor、爬虫初步

    KindEditor 是一个轻量级的富文本编辑器 xff0c 应用于浏览器客户端 一 首先是下载 xff1a http kindeditor net down php xff0c 如下图 下载后是 解压缩后 xff1a 红框选中的都可以删除
  • Python入门自学进阶-Web框架——35、网络爬虫使用

    自动从网上抓取信息 xff0c 就是获取相应的网页 xff0c 对网页内容进行抽取整理 xff0c 获取有用的信息 xff0c 保存下来 要实现网上爬取信息 xff0c 关键是模拟浏览器动作 xff0c 实现自动向网址发送请求 xff0c
  • 6、spring的五种类型通知

    spring共提供了五种类型的通知 xff1a 通知类型接口描述Around 环绕通知org aopalliance intercept MethodInterceptor拦截对目标方法调用Before 前置通知org springfram
  • 路由器接口配置与管理——1

    路由器的接口相对于交换机来说最大的特点就是接口类型和配置更为复杂 xff0c 一般吧路由器上的接口分为三大类 xff1a 一类用于局域网的LAN接口 xff0c 一类用于广域网接入 互联的WAN接口 xff0c 最后一类可以应用于LAN组网
  • 路由配置与管理——静态路由配置与管理

    静态路由是一种最简单的路由 xff0c 需手工配置 xff0c 用一条指令指定静态路由的目的IP地址 子网掩码 下一跳IP地址 xff0c 或者出接口 优先级等主要参数值就可以了 还可根据实际需要配置静态路由与BFD或者NQA的联动 一 路
  • TCP实现局域网通信

    TCP实现局域网通信 TCP客户端通信步骤 xff1a 1 xff1a 创建套接字 sockfd 61 socket AF INET SOCK STREAM 0 2 xff1a 填写服务器结构体信息 span class token key
  • 路由策略和策略路由配置与管理-1

    路由策略和策略路由配置与管理 路由策略 与 策略路由 之间的区别就在于它们的主体 xff08 或者说 作用对象 xff09 不同 xff0c 前者的主体是 路由 xff0c 是对符合条件的路由 xff08 主要 xff09 通过修改路由属性
  • IP组播基础及工作原理——1

    IP组播在一些多用户定向发送的网络应用中使用非常普遍 xff0c 如远程多媒体会议 远程教学 视频点播 定向电子商务 xff0c 以及ISP的IPTV xff08 网络电视 xff09 等 学好IP组播基础知识及配置与管理方法 xff0c
  • IP组播配置与管理实战——1

    IGMP 配置与管理 IGMP xff08 InternetGroup Management Protocol xff0c 因特网组管理协议 xff09 是TCP IP 协议族 中负责IPv4组播成员管理 的协议 xff0c 需要在组播组成
  • Linux:利用返回值传出参数,地址传递,值传递,使用回调函数赋值几个例程。

    利用返回值传出参数 xff0c 地址传递 xff0c 值传递 使用回调函数赋值几个例程 代码 xff1a include lt stdlib h gt include lt stdio h gt include lt unistd h gt
  • 一步一步实现多尺度多角度的形状匹配算法(C++版本)

    前言 用过halcon形状匹配的都知道 xff0c 这个算子贼好用 xff0c 随便截一个ROI做模板就可以在搜索图像中匹配到相似的区域 xff0c 并且能输出搜索图像的位置 xff0c 匹配尺度 xff0c 匹配角度 现在我们就要利用op
  • 查看麒麟操作系统版本

    root 64 tbase01 nkvers Kylin Linux Version Release Kylin Linux Advanced Server release V10 Tercel Kernel 4 19 90 23 8 v2
  • Parallax-tolerant Image Stitching - 解决大视差图片拼接的方法

    Paper name Parallax tolerant Image Stitching Paper Reading Note URL http web cecs pdx edu fliu papers cvpr2014 stitching
  • Consistent Video Depth Estimation - 时序一致的视频深度估计算法

    Paper name Consistent Video Depth Estimation Paper Reading Note URL xff1a https arxiv org pdf 2004 15021 pdf 代码 URL xff1

随机推荐