PyTorch和TensorFlow生成对抗网络学习MNIST数据集

2023-10-28

介绍

生成对抗网络(简称GAN)是最近开发的最受欢迎的机器学习算法之一。

对于人工智能(AI)领域的新手,我们可以简单地将机器学习(ML)描述为AI的子领域,它使用数据来“教”机器/程序如何执行新任务。 一个简单的例子就是使用一个人的脸部图像作为算法的输入,以便程序学会在任何给定的图片中识别同一个人(它可能也需要负样本)。 为此,我们可以将机器学习描述为应用数学优化,其中一种算法可以表示多维空间中的数据,然后学习区分新的多维矢量样本是否属于目标分布。

生成对抗网络的魔法

事实证明,他们在建模和生成高维数据方面非常成功,这就是为什么它们如此受欢迎。 然而,它们并不是生成模型的唯一类型,其他类型包括变分自动编码器(VAE),pixelCNN / pixelRNN和real NVP。 每个模型都有其自身的权衡。

一些与GAN最相关的利弊是:

  • 他们目前生成最清晰的图像
  • 它们易于训练(因为不需要统计推断),并且仅需要反向传播即可获得梯度
  • 由于不稳定的训练动态,GAN难以优化
  • 他们无法进行统计推断:GAN属于直接隐式密度模型。他们在没有明确定义概率分布函数的情况下对p(x)进行建模。

生成模型是了解当今围绕我们的大量数据的最有前途的方法之一。 根据OpenAI,能够创建数据的算法可能在本质上更好地理解世界。

生成模型可被认为比其鉴别器包含更多的信息,因为它们也可用于判别任务,例如分类或回归(目标是诸如ℝ的连续值)。 通过对联合概率分布函数进行统计推断,可以计算出此类任务大部分时间所需的条件概率分布函数 p ( y ∣ x ) p(y \mid x) p(yx)

尽管生成模型可用于分类和回归,但是与某些情况下的生成方法相比,完全鉴别方法通常在鉴别任务上更为成功。

案例

在几个用例中,生成模型可以应用于:

  • 生成逼真的艺术品样本(视频/图像/音频)
  • 使用时序数据进行仿真和计划
  • 统计推断
  • 也可用于生成可扩展小型数据集的输入

GAN概述

生成对抗网络由两个模型组成:

  • 第一个模型称为生成器,它旨在生成与预期相似的新数据。生成器可以与人类的赝品相提并论,后者可以伪造艺术品。
  • 第二种模型称为鉴别器。 该模型的目的是识别输入数据是由伪造者生成的“真实”(属于原始数据集)还是“伪造”(fake)。 在这种情况下,鉴别器类似于艺术专家,后者试图将艺术品视为真实或赝品。

GAN数学模型

训练GAN

由于使用神经网络对生成器和鉴别器进行建模,因此可以使用基于梯度的优化算法来训练GAN。 在我们的编码示例中,我们将使用随机梯度下降法,因为事实证明该梯度下降法在多个领域中均已成功完成。

训练GAN的基本步骤可以描述如下:

  1. 采样噪声集和实数集,每个集的大小为m
  2. 在此数据上训练鉴别器
  3. 采样大小为m的另一个噪声子集
  4. 在此数据上训练生成器
  5. 从第1步重复

编程GAN

首先导入必要库

pip install torchvision tensorboardx jupyter matplotlib numpy

导入以下依赖项

import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets

为了记录进度,我们将导入一个我创建的附加文件,使我们可以在控制台/ Jupyter中可视化训练过程,同时将其存储在TensorBoard中,以供那些已经知道如何使用它的人使用。

from utils import Logger

您需要下载文件并将其放在GAN文件所在的文件夹中。您不必了解此文件中的代码,因为它仅用于可视化目的。

数据集

我们将在这里使用MNIST数据集,它由大约60.000个手写数字的黑白图像组成,每个图像的尺寸为28x28像素。 该数据集将根据一些有用的技巧进行预处理,这些技巧被证明对训练GAN很有用。

具体来说,介于[0,255]之间的输入值将在-1和1之间归一化。这意味着值0将被映射为-1,值255被映射为1,并且类似地,介于两者之间的所有值都将得到a。 值在[-1,1]范围内。

网络

接下来,我们将从鉴别器开始定义神经网络。 该网络将以扁平化的图像作为输入,并返回其属于真实数据集或合成数据集的概率。 每个图像的输入大小将为28x28 = 784。 关于该网络的结构,它将具有三个隐藏层,每个隐藏层后面是Leaky-ReLU非线性和一个Dropout层,以防止过度拟合。将Sigmoid / Logistic函数应用于实值输出,以获取开放范围(0,1)中的值:

我们还需要一些其他功能,这些功能允许我们将扁平化的图像转换为二维表示,而另一种则相反。

另一方面,生成网络将潜变量向量作为输入,并返回784值向量,该向量对应于扁平化的28x28图像。 请记住,该网络的目的是学习如何创建手写数字的无法区别的图像,这就是为什么其输出本身就是新图像的原因。

该网络将具有三个隐藏层,每个隐藏层之后是Leaky-ReLU非线性。 输出层将具有TanH激活函数,该函数将结果值映射到(-1,1)范围内,该范围与我们预处理的MNIST图像所界定的范围相同。

我们还需要一些其他功能,以允许我们创建随机噪声。随机噪声将从此链接中提出的均值0和方差1的正态分布中采样。

def noise(size):
    '''
    Generates a 1-d vector of gaussian sampled random values
    '''
    n = Variable(torch.randn(size, 100))
    return n

优化

结果

最初生成的图像是纯噪声:

但是后来他们改进了,

在获得不错的合成图像之前,

也可以可视化学习过程。 正如您在下图中所看到的,开始时鉴别器错误非常高,因为它不知道如何正确地将图像分类为真实还是伪造。 当鉴别器变得更好并且其误差在步骤5k减小到约0.5时,生成器误差增加,证明了鉴别器的性能优于生成器,并且可以正确地对假样本进行分类。 随着时间的流逝和训练的继续,生成器误差会降低,这意味着生成的图像越来越好。 随着生成器的改进,鉴别器的误差也会增加,因为合成图像每次都变得越来越逼真。

生成器随时间的错误

鉴别器随时间的错误

我已经介绍了生成对抗网络。 我们首先了解它们是哪种算法,以及为什么它们如此重要。 接下来,我们探索了符合GAN的部分以及它们如何协同工作。 最终,我们通过编程并使用GAN的完全有效的实现进行编程,从而将理论与实践联系起来,该实现学会了创建MNIST数据集的综合示例。

本文源码

详情参阅 - 亚图跨际

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch和TensorFlow生成对抗网络学习MNIST数据集 的相关文章

随机推荐

  • Python+Flask(2)--通过flask paginate解决列表分页问题

    先看最终实现效果 实现主要步骤及重要代码如下 1 列表需要用到的数据源及内容自己随便建立 我这边用新闻资讯数据测试 CREATE TABLE article aid int 11 NOT NULL AUTO INCREMENT cat id
  • 【第一章】专栏介绍

    版本 修改时间 初稿 2023 03 26 补充 考研和就业的选择 2023 04 04 自我介绍 你好 我曾经是一名普通一本学生 专业是电子信息工程专业 从大二就开始独自一人自学后端开发 大三后面三年大部分时间都在图书馆或者实验室学习 在
  • 如何上传大文件(4GB)到虚拟机

    使用xhell上传大文件会报文件过大的异常 解决方案 可以使用 Everything 工具 实现快速便捷传送大文件到虚拟机 1 百度搜索Everything 进入官网下载 这个程序体量非常小 可以放心下载 2 下载完成后打开 点击工具一栏
  • linux下MySql服务器的安装(yum安装OK)

    root test219 mysql mysql V mysql Ver 14 14 Distrib 5 5 11 for Linux x86 64 using readline 5 1 mysql5 5在linux服务器上的安装 mysq
  • 2021.11.12总结

    把入门3循环结构的题大致写完了
  • 树莓派4B之Windows XP系统安装游戏(一)

    上一篇博文 树莓派4B安装windows xp windows 95 windows xp windows 95 for raspberry pi 4B 下一篇博文 树莓派4B之Windows XP系统安装游戏 二 目录 一 模拟器 游戏下
  • AI新手必看:如何区分参数和超参数

    相信所有人刚开始应用机器学习时 都会被两个术语混淆 计算机学科里有太多的术语 而且许多术语的使用并不一致 哪怕是相同的术语 不同学科的人理解一定有所不同 比如说 模型参数 model parameter 和 模型超参数 model Hype
  • 华为nova6se怎么升级鸿蒙,华为EMUI11支持哪些手机

    华为EMUI11适配机型有什么 首批支持EMUI11 更新的机型有 P40 系列 Mate30 系列 MatePad Pro系列等 10 款机型 先了解更多EMUI11适配机型相关内容的小伙伴下面和小编一起来看看吧 华为EMUI11适配机型
  • VC++ 图像颜色调节

    1 BMP图片在GDI方式下贴图 32位位图 半透明像素会显示黑色或白底 像素处理代码 void CrossImage CImage img if img IsNull return 确认该图像包含Alpha通道 if img GetBPP
  • JAVA的图形用户界面布局GUI入门(上)

    java的GUI企业里面用的比较少 现在主流的UI都使用HTML5 开发 Java提供了三个主要包 做GUI开发 java awt 包 主要提供字体 布局管理器 javax swing 包 商业开发常用 主要提供各种组件 窗口 按钮 文本框
  • 神经网络学习之一——M-P模型

    神经网络学习之一 M P模型 M P模型是什么 M P模型是于1943年美国神经生理学家沃伦 麦卡洛克 Warren McCuloch 和数学家沃尔特 皮茨 Walter Pitts 提出 是首个通过模仿神经元而形成的模型 结构图如下所示
  • 主机地址变更后,dubbo请求时依旧会寻址旧IP的问题

    机房迁移 导致测试服务器IP变更 比原于IP为192 168 1 105变更为10 1 9 120 服务源码未做任何变更 启动服务时依旧是旧地址请求 此问题由dubbo本地注册中心的缓存所致 清理掉即可 位置一般在于 用户目录 dubbo目
  • Redis(一)常见命令使用

    常见文件名 Redis cli使用命令 1 启动Redis 2 连接Redis 3 停止Redis 4 发送命令 1 redis cli带参数运行 如 2 redis cli不带参数运行 如 5 测试连通性 key操作命令 获取所有键 查询
  • PostgreSQL系列3:PostgreSQL导入导出SQL

    启动数据库 pg ctl D data db pgsql data l data db pgsql logs pgsql log start 关闭数据库 pg ctl D data db pgsql data stop 使用pgsql客户端
  • R语言实战学习--回归

    文章目录 普通最小二乘回归 OLS 简单线性回归 多项式回归 多元线性回归 回归诊断 标准方法 QQ图正态性检验 残差图 误差的独立性 成分残差图 偏残差图 线性 同方差性 线性模型假设综合验证 异常观测值 高杠杆值 强影响点 变量添加图
  • 爬虫基础————ip地址和url详解

    学习慕课网bobby老师的课程从零起步 系统入门Python爬虫工程师时做的笔记 有兴趣的同学可以去慕课网观看视频 1 ip地址 整个网络传输可以比作快递 数据就是快递包裹 会经过一系列中转站 分包捡包等操作 最后才送到客户手中 Ip地址就
  • Python程序:输出杨辉三角的几种办法

    文章目录 一 问题描述 二 问题分析 三 第一种方法 1 具体代码 2 运行结果 3 程序的改进 四 第二种方法 1 具体代码 2 运行结果 五 总结分析 一 问题描述 给定一个非负整数 n 生成 杨辉三角 的前 n行 在 杨辉三角 中 每
  • 【文献调研】多任务学习-Part1

    基于数据增强和多任务学习的突发公共卫生时间谣言识别研究 摘要 Motivation 通过引入多任务学习模型和数据增强方法 解决突发公共卫生事件情景下谣言识别任务数据不平衡且带标签数据量少的问题 Methods 首先提取突发公共卫生事件谣言文
  • 《Learning Spark》第八章:调优及调试spark应用

    2020 07 05 引言 我记得当时我就是因为使用hadoop太过费劲了 才上手的spark 然后因为自己的机器性能不行 又一点一点调优 当时调优的过程 主要是从底层的结构上来进行调优 主要就是那些worker数量以及内存大小等等 但是对
  • PyTorch和TensorFlow生成对抗网络学习MNIST数据集

    介绍 生成对抗网络 简称GAN 是最近开发的最受欢迎的机器学习算法之一 对于人工智能 AI 领域的新手 我们可以简单地将机器学习 ML 描述为AI的子领域 它使用数据来 教 机器 程序如何执行新任务 一个简单的例子就是使用一个人的脸部图像作