Pytorch Advanced(一) Generative Adversarial Networks

2023-11-03

生成对抗神经网络GAN,发挥神经网络的想象力,可以说是十分厉害了

参考

1、AI作家
2、将模糊图变清晰(去雨,去雾,去抖动,去马赛克等),这需要AI具有“想象力”,能脑补情节;
3、进行数据增强,根据已有数据生成更多新数据供以feed,可以减缓模型过拟合现象。

那到底是怎么实现的呢?


GAN中有两大组成部分G和D

G是generator,生成器: 负责凭空捏造数据出来

D是discriminator,判别器: 负责判断数据是不是真数据

示例图如下:

给一个随机噪声z,通过G生成一张假图,然后用D去分辨是真图还是假图。假设G生成了一张图,在D那里的得分很高,那么G就很成功的骗过了D,如果D很轻松的分辨出了假图,那么G的效果不好,那么就需要调整参数了。


G和D是两个单独的网络,那么他们的参数都是训练好的吗?并不是,两个网络的参数是需要在博弈的过程中分别优化的。

下面就是一个训练的过程:

GAN在一轮反向传播中分为两步,先训练D在训练G。

训练D时,上一轮G产生的图片,和真实图片一起作为x进行输入,假图为0,真图标签为1,通过x生成一个score,通过score和标签y计算损失,就可以进行反向传播了。

训练G时,G和D是一个整体,取名为D_on_G。输入随机噪声,G产生一个假图,D去分辨,score = 1就是需要我们需要优化的目标,意思就是我们要让生成的图片变成真的。这里的D是不需要参与梯度计算的,我们通过反向传播来优化G,让他生成更加真实的图片。这就好比:如果你参加考试,你别指望能改变老师的评分标准


GAN无监督学习,(cGAN是有监督的),以后会学习的。怎么理解无监督学习呢?这里给的真图是没有经过人工标注的,只知道这是真的,D是不知道这是什么的,只需要分辨真假。G也不知道生成了什么,只需要学真图去骗D。


具体如何实施呢?

import os
import torch
import torchvision
import torch.nn as nn 
from torchvision import transforms
from torchvision.utils import save_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

注意这里有个归一化的过程,MNIST是单通道,但是如果mean=(0.5,0.5,0.5)会报错,因为是对3通道操作 。

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5,),   # 3 for RGB channels
                                     std=(0.5,))])

# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data/',train=True,transform=transform,download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,batch_size=batch_size, shuffle=True)

定义生成器和判别器:

生成器:可以看到输入的维度为64,是一组噪声图像,通过生成器将特征扩大到了MNIST图像大小784。

判别器:输入维度为图像大小,最后输出特征个数为1,采用sigmoid激活(不用softmax的)

# Discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())


# Generator 
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())
# Device setting
D = D.to(device)
G = G.to(device)

# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)


def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

 重点看训练部分,我们到底是如何来训练GAN的。

判别器部分:判别器的损失值分为两部分,(一)将mini_batch定义为正样本,告诉他我是正品,所以设置标签为1。优化判别器判断正品的能力;(二)生成一幅赝品,再给判别器判别,这时候赝品的标签为0,优化判断赝品的能力。所以总损失为这两部分之和,计算梯度,优化判别器参数。

G_on_D:输入一个噪声,让生成器生成一幅图像,然后让D去判别,计算和正品之间的距离,即损失。反向传播,优化G的参数。

# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #

        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        # Compute loss with fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = criterion(outputs, real_labels)
        
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

训练完了怎么用?

只要用我们的生成器就可以随意生成了。

import matplotlib.pyplot as plt
z = torch.randn(1,latent_size).to(device)
output = G(z)
plt.imshow(output.cpu().data.numpy().reshape(28,28),cmap='gray') 
plt.show()

 下面就是随机生成的图像了!

  

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch Advanced(一) Generative Adversarial Networks 的相关文章

  • 打印 scrapy 请求的“响应”

    我正在尝试学习 scrapy 在遵循教程的同时 我正在尝试进行细微的调整 我想简单地从请求中获取响应内容 然后我会将响应传递到教程代码中 但我无法发出请求并获取响应内容 建议就好 from scrapy http import Respon
  • 为什么我不能导入 geopandas?

    我唯一的代码行是 import geopandas 它给了我错误 OSError Could not find libspatialindex c library file 以前有人遇到过这个吗 我的脚本运行得很好 直到出现此错误 请注意
  • 在 Python 中使用 XPath 和 LXML

    我有一个 python 脚本 用于解析 XML 并将某些感兴趣的元素导出到 csv 文件中 我现在尝试更改脚本以允许根据条件过滤 XML 文件 等效的 XPath 查询将是 DC Events Confirmation contains T
  • 如何用 python 和 sympy 解决多元不等式?

    我对使用 python 和 Sympy 还很陌生 并且遇到了使用 sympy 解决多元不等式的问题 假设我的文件中有很多函数 如下所示 cst sqrt x 2 cst exp sqrt cst x 1 4 log log sqrt cst
  • 获取单个方程的脚本

    在文本文件中输入 a 2 8 b 3 9 c 4 8 d 5 9 e a b f c d g 0 6 h 1 7 i e g j f h output i j 期望的输出 输出 2 8 3 9 0 6 4 8 5 9 1 7 如果输入文件名
  • 如何自动替换多个文件的文本内容中的字符?

    我有一个文件夹 myfolder包含许多乳胶表 我需要替换其中每个字符 即替换任何minus sign by an en dash 只是为了确定 我们正在替换连字符INSIDE该文件夹中的所有 tex 文件 我不关心 tex 文件名 手动执
  • NLTK、搭配问题:需要解包的值太多(预期为 2)

    我尝试使用 NLTK 检索搭配 但出现错误 我使用内置的古腾堡语料库 I wrote alice nltk corpus gutenberg fileids 7 al nltk corpus gutenberg words alice al
  • 使用正则表达式解析 Snort 警报文件

    我正在尝试使用 Python 中的正则表达式从 snort 警报文件中解析出源 目标 IP 和端口 和时间戳 示例如下 03 09 14 10 43 323717 1 2008015 9 ET MALWARE User Agent Win9
  • VSCode pytest 测试发现失败

    Pytest 测试发现失败 用户界面指出 Test discovery error please check the configuration settings for the tests 输出窗口显示 Test Discovery fa
  • Python int 太大,无法放入 SQLite

    我收到错误 OverflowError Python int 太大 无法转换为 SQLite INTEGER 来自以下代码块 该文件约25GB 因此必须分部分读取 length 6128765 Works on partitions of
  • 在pycharm中调试python代码

    这个问题类似于this https stackoverflow com questions 10240018 how to use pycharm to debug python script一 我正在尝试调试pyethapp https
  • 使用 lambda 函数更改属性值

    我可以使用 lambda 函数循环遍历类对象列表并更改属性值 对于所有对象或满足特定条件的对象 吗 class Student object def init self name age self name name self age ag
  • 是否可以写一个负的python类型注释

    这可能听起来不合理 但现在我需要否定类型注释 我的意思是这样的 an int Not Iterable a string Iterable 这是因为我为一个函数编写了一个重载 而 mypy 不理解我 我的功能看起来像这样 overload
  • Pandas 在特定列将数据帧拆分为两个数据帧

    I have pandas我组成的 DataFrameconcat 一行由 96 个值组成 我想将 DataFrame 从值 72 中分离出来 这样 一行的前 72 个值存储在 Dataframe1 中 接下来的 24 个值存储在 Data
  • Google App Engine 中的自定义身份验证

    有谁知道或知道我可以在哪里学习如何使用 Python 和 Google App Engine 创建自定义身份验证流程 我不想使用 Google 帐户进行身份验证 并且希望能够创建自己的用户 如果不是专门针对 Google App Engin
  • 如何对字符串列表进行排序?

    在 Python 中创建按字母顺序排序的列表的最佳方法是什么 基本回答 mylist b C A mylist sort 这会修改您的原始列表 即就地排序 要获取列表的排序副本而不更改原始列表 请使用sorted http docs pyt
  • 使用 Keras 和 fit_generator 绘制 TensorBoard 分布和直方图

    我正在使用 Keras 使用 fit generator 函数训练 CNN 这似乎是一个已知问题 https github com fchollet keras issues 3358TensorBoard 在此设置中不显示直方图和分布 有
  • 如何使用 Django (Python) 登录表单?

    我在 Django 中构建了一个登录表单 现在我遇到了路由问题 当我选择登录按钮时 表单不会发送正确的遮阳篷 我认为前端的表单无法从 查看 py 文件 所以它不会发送任何 awnser 并且登录过程无法工作 该表单是一个简单的静态 html
  • 如何在SqlAlchemy中执行“左外连接”

    我需要执行这个查询 select field11 field12 from Table 1 t1 left outer join Table 2 t2 ON t2 tbl1 id t1 tbl1 id where t2 tbl2 id is
  • 将此 MATLAB 代码转换为 Python 时我做错了什么?

    我正在努力将生成波形的 MATLAB 代码转换为 Python 就上下文而言 这是原子力显微镜带激发响应的模拟 与代码错误无关 在 MATLAB 中从 r vec 生成的图形与我在 Python 中生成的图形不同 我是否正确地将 MATLA

随机推荐

  • socket原理以及socket的简单实现

    目录 一 socket学前基础 TCP的三次握手和四次挥手 二 为什么要使用socket 三 什么是socket 四 socket的简单代码实现 服务端 客户端 一 socket学前基础 TCP的三次握手和四次挥手 1 服务端和客户端如果想
  • python怎么绘制渐变图_有没有一种使用Python生成渐变位图的简单方法?

    实现这一点的一种方法是使用matplotlib 正如您在标记中建议的那样 为了做到这一点 我会的使用numpy创建一个NxN数组来表示image gradient 在 创建一个figure 其大小以英寸为单位与圆的半径 image circ
  • 浏览器console几种报错类型

    2019独角兽企业重金招聘Python工程师标准 gt gt gt 1 SyntaxError 语法错误 2 TypeError 类型错误 通常是 is not a function 即 不是一个函数 3 ReferenceError 引用
  • 开启电脑虚拟化功能

    一 查看笔记本是否支持虚拟化 打开任务管理器 同时摁住ctrl alt del这个三个健 选择任务管理器 查看是否开启虚拟机 如果未开启 一 进入BIOS 参考以下按键 开机时按住对应的键进入BIOS 组装机以主板分 华硕按F8 Intel
  • 最强自动化测试框架Playwright-(2)元素定位

    元素定位 定位器是playwright自动等待和重试功能的核心部分 简而言之 定位器表示一种随时在页面上查找元素的方法 Locators Playwright Python 如下这些是推荐的 page get by role 按显式和隐式辅
  • 关于window.open()方法 返回的的打开的新窗口的对象

    关于window open 方法 返回的的打开的新窗口的对象
  • Dump libasound 音频数据

    QNX有如下两种方法dump pcm数据 可以录声卡之前的数据 1 QNX自带的pcm logger工具 工具位置 qnx qnx sdp target qnx7 aarch64le usr bin pcm logger 打开pcm dum
  • 计算机网络-运输层

    To 个人主页 关注不迷路 运输层 重要概念 运输层为相互通信的应用进程提供逻辑通信 端口和套接字的意义 无连接的 UDP 的特点 面向连接的 TCP 的特点 在不可靠的网络上实现可靠传输的工作原理 停止等待协议和 ARQ 协议 TCP 滑
  • 蓝桥杯:外卖店优先级(map排序算法) Java

    分析 发现只是输入两种数据 则可以考虑用map 经过分析发现 可以用店家编号来表示map的第一个参数Integer 第二个参数因为有可能有多个相同的时刻 所以用arraylist
  • 一, SpringCloud Alibaba-nacos注册中心

    1 nacos官网 https nacos io zh cn https nacos io zh cn docs what is nacos html https github com alibaba spring cloud alibab
  • c++拷贝与引用讲解

    目录 拷贝与引用 2 const限定符 3 const与指针 拷贝与引用 1 拷贝 即复制 在初始化变量时 初始值会被拷贝到新建的对象中 对象会开辟一块新的内存空间用来存储该变量 int a 10 int b a std cout lt l
  • 广播到底啥啊,arp广播原理

    1网络广播 网络广播是指一个节点同时向相同域中的其它所有节点传输数据包的过程 例如 有4台主机 分别为1号主机 2号主机 3号主机 4号主机 假如1号主机 要给4号主机发数据 如果是用广播传输方法的话 那么4台主机都会收到数据包 4台主机
  • ChatGpt 从入门到精通

    相关资源下载地址 基于ChatGPT的国际中文语法教学辅助应用的探讨 pdf 生成式人工智能技术对教育领域的影响 关于ChatGPT的专访 pdf 电子 从ChatGPT热议看大模型潜力 pdf 从图灵测试到ChatGPT 人机对话的里程碑
  • Python爬虫-使用Selenium模拟百度登录

    前言 前面我已经安装好了Selenium并模拟成功了一下打开百度页面并进行查询 让我这个python初学者信心倍增 今天再来试一试百度登录 正文 把打开百度的代码放到构造方法中 ps 那个文件目录是用于后面滑块验证图片保存的 def ini
  • linux最简单预览摄像头方法

    我只想要打开摄像头 想当然就是用ffplay centos如何安装ffplay 找了一通都是编译安装 编译安装也就算了 竟然没有生成ffplay 搜了一通解决ffmpeg编译安装没有生成ffplay的教程 累了 我到底在干什么 linux真
  • FatFs目录访问接口中文版

    我是阿荣 关注我 在技术路上一起精进 目录访问 f opendir 打开目录 函数原型 FRESULT f opendir DIR dp OUT Pointer to the directory object structure const
  • Vue-Axios的封装---登录注册---axios(二)

    Vue cli Axios的封装 简单的的登录与注册 第一种 逻辑数据未分离 注册 登录 用户页面获取用户数据信息以及注销 第二种 逻辑数据分离 token 封装Axios 为什么封装axios 实现 调用封装完毕的Axios 并在添加所需
  • vuforia sdk及案例 (第二章)

    有过上一章了现在去看下载部分 我事先下载好了 用的Android Q版本 开发软件版本是3 5 3的 然后我看了一去升级到了3 6 1 我是最新版本来做的 这是我下载的 导入工程案例 VuforiaSamples 8 6 10 出现问题一
  • IntelliJ IDEA开发工具的安装,scala插件安装

    IntelliJ IDEA开发工具安装 scala插件安装 1 IntelliJ IDEA开发工具下载 下载官方网址 https www jetbrains com idea download other html 我下的2021 3 2
  • Pytorch Advanced(一) Generative Adversarial Networks

    生成对抗神经网络GAN 发挥神经网络的想象力 可以说是十分厉害了 参考 1 AI作家 2 将模糊图变清晰 去雨 去雾 去抖动 去马赛克等 这需要AI具有 想象力 能脑补情节 3 进行数据增强 根据已有数据生成更多新数据供以feed 可以减缓