深度学习——深度生成模型(GAN,VAE)

2023-10-28

深度学习与PyTorch入门—对抗生成网络GAN理论讲解及项目实战_哔哩哔哩_bilibili

 

背景

生成模型

从某个分布中获取输入训练样本, 并学习表示该分布的模型

作用

  1. 产生真实数据, 艺术创作, 超分辨率图片

2.帮助我们发现数据集中的隐变量

3.异常检测

4.生成模型可以以多种方式被应用到强化学习中

5.进行隐变量表示的推理, 这些隐变量表示可以用做通用特征

变分自编码器VAE

AE与VAE

AE: 通过编码器输出单个值来刻画每个隐变量

VAE:将变量表示为可能的取值范围(概率分布)

编码器输出隐变量分布, 从这些分布中随机采样输入到解码器.

对于任何隐变量分布的采样, 希望解码器能够准确地重构输入数据

模型推导

  1. 理想: 根据训练数据X得到真实分布, 根据分布P(x)采样, 得到可能的X(包括数据外的), 即计算后验概率p(z|x)

2.设找到q(z|x)与p(z|x), 最小化两者之间的KL散度, 则问题转变成最大化下界, 等价于最小化模型的损失函数:

再参数化

隐变量不再从分布中采样, 而是对均值和方差进行放缩:

总结

  1. 是AE的概率版本, 是一种生成模型, 可以生成数据
  2. VAE相当于在AE重构损失的基础上, 加入一个KL散度的正则化项
  3. 优点:

     (1)通过最大化变分下界优化模型, 有比较漂亮的数学原理

     (2)通过重构可以实现无监督学习

     (3)利用参数化技巧实现端到端训练

     (4)通过推断出的隐变量可以获得比较好的可解释性特征

  1. 缺点:生成的图像和GAN相比有些模糊

生成对抗网络

采取博弈论方法, 通过2人博弈游戏从训练分布中生成数据.

具体通过两个神经网络相互竞争来构建一个生成模型

生成器

将随机噪声变换为模仿的”假”样本, 视图欺骗判别器

判别器

从真实样本和生成器生成的样本识别真实样本

损失函数

获得新样本

训练结束后, 去掉判别器, 可以用以生成新样本

优点

生成的结果很逼真, 是当前最好的结果

缺点

训练更难/ 更不稳定

当前研究热点

更好的损失函数, 更稳定的训练学习算法

大量的应用

VAE对比GAN

VAE:优化数据似然的变分下界.  数学原理比较漂亮. 隐变量表示有用, 能推理查询, 但生成结果质量不是最好

GAN:基于博弈论方法, 当前生成结果最佳, 但训练起来可能会复杂和很不稳定, 没有推理查询


一.GAN的基本要素

         1.真实数据集,初始化虚假数据集(噪音)

         2.生成器,鉴别器:

                 生成器

                          输入:原始数据的维数(一条数据)

                          输出:原始数据的维数(一条数据)

                                     除了最后一层都要经sigmoid()

                 鉴别器

                          输入:原始数据的维数(一个batch的数据)

                          输出:一维(以判别结果的真假)

                                     经sigmoid,在(0,1)范围(切确地说输入层,隐藏层,输出层都要经过)

          3.训练环(每一次epoch):(PS:这里冻结不冻结还是不理解,待补。。。)

                      每一次eopch里面,D和G都要训练,并且各自训练多次。(在代码中见到的情况是先D后G)

                  生成器训练周期:

                          真实数据训练:

                                          (1)判别器出来的结果向“1”靠近

                                          (2)反向传播

                          生成器产生数据训练:

                                           (1)冻结生成器并产生数据

                                           (2)判别器出来的结果向“0”靠近

                                           (3)反向传播

                  鉴别器训练周期:

                                 (1)不冻结生成器并产生数据(喂如D前转置了以下)

                                 (2)判别器出来的结果向“1”靠近

                                 (3)反向传播

二.GAN的损失函数(价值函数:从强化学习来)

                \large min_Gmax_D E _r_e_a_l[logD(x)]+E_f_a_k_e[log(1-D(z))]

解释:

            对于判别器来说:损失函数的值越大越好

                                         包括式子的两项                                        

                                          公式理解

                                                            希望:真判1,假判0,  经log都输出0(最大值:log在(0,1上))

            对于生成器来说:损失函数的值越小越好

                                         只包括式子的后一项

                                         公式理解   

                                                           希望: , 经上式---->负无穷(最小值)

三.公式和代码的转换技巧

           1.损失函数的max,min,“+”,“-”等等:基本只体现在,代码的criterion()函数里。在代码里面的思路很简单:criterion()就是为了度量你想要的真实训练出来的差距;反向传播的过程就是改变网络的参数,使criterion()的值越来越小(即网络训练出来的会更加靠近你想要的)。

            PS:所以看文章的时候要小心,损失函数等max,还是min(老师和师兄都落坑了啊)。

           2.网络的结构框图:基本只体现在代码的训练周期里。

           3.看代码,运行代码真的是和刷题一样重要!!!论文摘要就像教科书里的重点,论文文本基本是玄学。只看课本文字不刷题,想想后果就知道,基本学不会本质的东西并且还浪费时间。

四.文章所辅助的代码

#!/usr/bin/env python

# Generative Adversarial Networks (GAN) example in PyTorch. Tested with PyTorch 0.4.1, Python 3.6.7 (Nov 2018)
# See related blog post at https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.sch4xgsa9
####################################################导入包################################################################
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

matplotlib_is_available = True
try:
  from matplotlib import pyplot as plt
except ImportError:
  print("Will skip plotting; matplotlib is not available.")
  matplotlib_is_available = False
####################################################导入包################################################################

####################################################获取数据################################################################
# 使用第二张GPU卡
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


# Data params
data_mean = 4
data_stddev = 1.25

# ### Uncomment only one of these to define what data is actually sent to the Discriminator
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
#(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)
#(name, preprocess, d_input_func) = ("Data and diffs", lambda data: decorate_with_diffs(data, 1.0), lambda x: x * 2)
(name, preprocess, d_input_func) = ("Only 4 moments", lambda data: get_moments(data), lambda x: 4)

print("Using data [%s]" % (name))

# ##### DATA: Target data and generator input data

def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

####################################################获取数据################################################################

# ##### MODELS: Generator model and discriminator model

####################################################生成器和判别器################################################################
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f

    def forward(self, x):
        x = self.map1(x)
        x = self.f(x)
        x = self.map2(x)
        x = self.f(x)
        x = self.map3(x)
        return x

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f

    def forward(self, x):
        x = self.f(self.map1(x))
        x = self.f(self.map2(x))
        return self.f(self.map3(x))
    ####################################################生成器和判别器################################################################

def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

def get_moments(d):
    # Return the first 4 moments of the data provided
    mean = torch.mean(d)  #生成的高斯分布求均值
    diffs = d - mean
    var = torch.mean(torch.pow(diffs, 2.0))
    std = torch.pow(var, 0.5)  #生成的高斯分布求标准差元素与
    zscores = diffs / std
    skews = torch.mean(torch.pow(zscores, 3.0))
    kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0  # excess kurtosis, should be 0 for Gaussian
    final = torch.cat((mean.reshape(1,), std.reshape(1,), skews.reshape(1,), kurtoses.reshape(1,)))  #一个向量,有四个元素,如代码
    return final

def decorate_with_diffs(data, exponent, remove_raw_data=False):
    mean = torch.mean(data.data, 1, keepdim=True)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    if remove_raw_data:
        return torch.cat([diffs], 1)
    else:
        return torch.cat([data, diffs], 1)

def train():
    # Model parameters
    g_input_size = 1      # Random noise dimension coming into generator, per output vector
    g_hidden_size = 5     # Generator complexity
    g_output_size = 1     # Size of generated output vector
    d_input_size = 500    # Minibatch size - cardinality of distributions
    d_hidden_size = 10    # Discriminator complexity
    d_output_size = 1     # Single dimension for 'real' vs. 'fake' classification
    minibatch_size = d_input_size

    d_learning_rate = 1e-3
    g_learning_rate = 1e-3
    sgd_momentum = 0.9

    num_epochs = 5000
    print_interval = 100
    d_steps = 20
    g_steps = 20

    dfe, dre, ge = 0, 0, 0
    d_real_data, d_fake_data, g_fake_data = None, None, None

    discriminator_activation_function = torch.sigmoid
    generator_activation_function = torch.tanh

    d_sampler = get_distribution_sampler(data_mean, data_stddev)
    gi_sampler = get_generator_input_sampler()
    G = Generator(input_size=g_input_size,
                  hidden_size=g_hidden_size,
                  output_size=g_output_size,
                  f=generator_activation_function)
    D = Discriminator(input_size=d_input_func(d_input_size),
                      hidden_size=d_hidden_size,
                      output_size=d_output_size,
                      f=discriminator_activation_function)
    criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
    d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum)
    g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum)

    ####################################################训练################################################################
    for epoch in range(num_epochs):
        for d_index in range(d_steps):  #在每一个训练期里面,D训练20次
            # 1. Train D on real+fake
            D.zero_grad()  #将梯度都置为0,注意是梯度不是网络参数

            #  1A: Train D on real
            d_real_data = Variable(d_sampler(d_input_size))  #产生高斯分布的数据,是随机散乱排列的
            d_real_decision = D(preprocess(d_real_data))  #这里直接调用forward()??
            d_real_error = criterion(d_real_decision, Variable(torch.ones([1,1])))  # ones = true
            d_real_error.backward() # compute/store gradients, but don't change params

            #  1B: Train D on fake
            d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
            d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
            d_fake_decision = D(preprocess(d_fake_data.t()))
            d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1,1])))  # zeros = fake
            d_fake_error.backward()
            d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

            dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0]  #每一次的损失

        for g_index in range(g_steps):  #在每一个训练器里面G训练20次
            # 2. Train G on D's response (but DO NOT train D on these labels)
            G.zero_grad()

            gen_input = Variable(gi_sampler(minibatch_size, g_input_size))  #随机产生(0,1)间的数,注意不是高斯分布
            g_fake_data = G(gen_input)  #不明感觉输出维度不对
            dg_fake_decision = D(preprocess(g_fake_data.t()))
            g_error = criterion(dg_fake_decision, Variable(torch.ones([1,1])))  # Train G to pretend it's genuine

            g_error.backward()
            g_optimizer.step()  # Only optimizes G's parameters
            ge = extract(g_error)[0]

        if epoch % print_interval == 0:
            print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s),  Fake Dist (%s) " %
                  (epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data))))
    ####################################################训练################################################################

    ####################################################画图################################################################
    if matplotlib_is_available:
        print("Plotting the generated distribution...")
        values = extract(g_fake_data)  #G产生出来的数据
        print(" Values: %s" % (str(values)))
        plt.hist(values, bins=50)  #画直方图,参数一:输入数据,参数二:条的个数
        plt.xlabel('Value')
        plt.ylabel('Count')
        plt.title('Histogram of Generated Distribution')
        plt.grid(True)  #生成网格
        plt.show()
####################################################画图################################################################

train()

结果:

代码输出图像近似接近于高斯分布。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习——深度生成模型(GAN,VAE) 的相关文章

随机推荐

  • 【单片机毕业设计】【mcuclub-jj-045】基于单片机的红外遥控器的设计

    最近设计了一个项目基于单片机的红外遥控器系统 与大家分享一下 一 基本介绍 项目名 红外遥控器 项目编号 mcuclub jj 045 单片机类型 STC89C52 STM32F103C8T6 具体功能 1 从机利用4 4键盘通过红外发射管
  • level7 项目实战:基于Linux的Flappy bird游戏开发

    目录 1 Flappy bird项目介绍 项目介绍 功能总结 项目框图 Ncurses库安装 Ncurses库函数介绍 2 信号机制详解 相关函数介绍 3 项目实现 1 Flappy bird项目介绍 项目介绍 目标 借助Ncurses库
  • Java--ArrayList遍历的三种方法

    Java遍历主要有以下几种 分别是利用for循环 或者for each 把链表变为数组进行遍历 利用迭代 IntIterator 遍历 下面我们分别进行学习 For循环 import java util ArrayList import j
  • ONVIF测试工具 ONVIF Device Test Tool的使用

    ONVIF测试工具 ONVIF Device Test Tool的使用 双击 打开软件 选择当前网络 点击 Discover Devices 进行搜索 可以看到搜索到一个设备
  • 使用OpenWRT配置SFTP远程文件传输,安全高效的文件传输方法

    文章目录 前言 1 openssh sftp server 安装 2 安装cpolar工具 3 配置SFTP远程访问 4 固定远程连接地址 前言 本次教程我们将在OpenWRT上安装SFTP服务 并结合cpolar内网穿透 创建安全隧道映射
  • ip代理

    为什么会出现IP被封 网站为了防止被爬取 会有反爬机制 对于同一个IP地址的大量同类型的访问 会封锁IP 过一段时间后 才能继续访问 如何应对IP被封的问题 有几种套路 修改请求头 模拟浏览器 而不是代码去直接访问 去访问 采用代理IP并轮
  • PC中自带计算器使用说明

    Backspace 删除当前显示数字的最后一位 CE 清除显示数字 C 清除当前的计算 MC 清除内存中的所有数字 MR 重调用存内存中的数字 该数字保留在内存中 MS 将显示数字保存在内存中 M 将显示的数字与内存中已有的任何数字相加 但
  • 记录一次线上OOM问题排查处理过程

    背景 项目为docker部署的springboot单体项目 非前后端分离 前端文件是集成在项目的类路径的resources路径下的 项目使用ruoyi vue版本做为开发原始代码 系统目前没什么用 主要是客户分公司在基础数据模块录入数据比较
  • 1001 害死人不偿命的(3n+1)猜想 PAT乙级真题 C++

    1001 害死人不偿命的 3n 1 猜想 卡拉兹 Callatz 猜想 对任何一个正整数 n 如果它是偶数 那么把它砍掉一半 如果它是奇数 那么把 3n 1 砍掉一半 这样一直反复砍下去 最后一定在某一步得到 n 1 卡拉兹在 1950 年
  • 简单实现动态代理(Proxy)

    前言 最近学习了Jdk的动态代理 然后自己也简单的手写了一个 思路 根据代理的接口 生成对应的Java代码文件 将生成的Java文件编译成class文件 利用URLClassLoader加载class到Jvm中 利用反射在new出这个对象
  • 小程序文字上下滚动轮播效果实现CSS

    wxml
  • CentOS7-查询可以远程登录的帐号信息

    查询可以远程登录的帐号信息 查询 etc shadow 文件 etc shadow 文件 用于存储 Linux 系统中用户的密码信息 又称为 影子文件 文件内容格式解析 用户名 加密密码 最后一次修改时间 最小修改时间间隔 密码有效期 密码
  • 谈谈初学者该怎么学电脑

    十五年前 一说电脑 就感觉是很高科技的东西 那时候一般只有计算机专业和相关行业的人才能够接触 随着信息和科技的发展 电脑已经渗入到各个行业和家庭 电脑不仅广泛用于各种工作 还普及到了家庭娱乐中 因此 掌握电脑不再仅仅是工作需要 而是一项基本
  • 超级无敌详细使用ubuntu搭建hadoop完全分布式集群

    一 软件准备 安装VMware 下载ubuntu镜像 阿里源ubuntu下载地址 选择自己适合的版本 以下我使用的是18 04 server版就是没有桌面的 安装桌面版如果自己电脑配置不行的话启动集群容易卡死 说明一下哈就是桌面版和服务器版
  • JSP输出HelloWorld和Servlet输出HelloWorld

    一 新建Web工程 1 更新插件以获取Dynamic Web Project Eclispe Help Install New Software 下拉选择后等一会 就会出现需要更新的东西如下图 下拉选择Web XML Java EE and
  • 手把手教你学Python之波士顿房价预测(scikit-learn的应用)

    目录 1 波士顿房价预测介绍 2 线性回归算法 3 调用scikit learn库实现房价预测 1 波士顿房价预测介绍 问题描述 波士顿房价数据集统计的是20世纪70年代中期波士顿郊区房价的中位数 统计了城镇人均犯罪率 不动产税等共计13个
  • sysbench 随机数随机算法详解

    https www percona com blog 2020 03 26 sysbench and the random distribution effect 随机算法 https www jianshu com p 30933e0be
  • MFC几个常用函数:OnCreate和OnInitialUpDate,GetActiveFrame和MDIGetActive,Invalidate、SetModifiedFlage、UpdateAll

    把用常用的都整理一下 不然好乱 一 OnCreate和OnInitialUpDate 参考 http www cnblogs com mingfei200169 articles 666567 html ONCREATE只是产生VIEW的基
  • 关于Unity ScriptableObject 的数据保存问题

    最近在开发的时候遇到的问题 在用ScriptableObject进行保存数据的时候 并不是所有的数据都能正常保存 这让人很是难受 所以我决定系统性地整理一下这个问题 注 建议大家将自己的Unity文件保存方式设置为Text而不是二进制 这样
  • 深度学习——深度生成模型(GAN,VAE)

    深度学习与PyTorch入门 对抗生成网络GAN理论讲解及项目实战 哔哩哔哩 bilibili 背景 生成模型 从某个分布中获取输入训练样本 并学习表示该分布的模型 作用 产生真实数据 艺术创作 超分辨率图片 2 帮助我们发现数据集中的隐变