人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用

2023-11-03

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用,本文将具体介绍DCGAN模型的原理,并使用PyTorch搭建一个简单的DCGAN模型。我们将提供模型代码,并使用一些数据样例进行训练和测试。最后,我们将展示训练过程中的损失值和准确率。

文章目录:

  1. DCGAN模型简介
  2. DCGAN模型原理
  3. 使用PyTorch搭建DCGAN模型
  4. 数据样例
  5. 训练模型
  6. 测试模型
  7. 总结

1. DCGAN模型简介

DCGAN全称:Deep Convolutional Generative Adversarial Networks,它是一种生成对抗网络(GAN)的变体,它使用卷积神经网络(CNN)作为生成器和判别器。DCGAN在图像生成任务中表现出色,能够生成具有高分辨率和清晰度的图像。

2. DCGAN模型原理

DCGAN模型由两个部分组成:生成器(Generator)和判别器(Discriminator)。生成器负责生成图像,而判别器负责判断图像是否为真实图像。在训练过程中,生成器和判别器相互竞争,生成器试图生成越来越逼真的图像,而判别器试图更准确地识别生成的图像是否为真实图像。这个过程持续进行,直到生成器生成的图像足够逼真,以至于判别器无法区分生成的图像和真实图像。

DCGAN模型的数学原理表示:

生成器(Generator):

G ( z ) = x G(z) = x G(z)=x

其中, z z z是输入的随机噪声向量, x x x是生成的图像。

判别器(Discriminator):

D ( x ) = y D(x) = y D(x)=y

其中, x x x是输入的图像, y y y是判别器对图像的判断结果,表示图像是否为真实图像。

GAN的损失函数:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D,G) = \mathbb{E}{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1-D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中, p d a t a ( x ) p_{data}(x) pdata(x)表示真实数据的分, p z ( z ) p_z(z) pz(z)表示噪声向量的分布, D ( x ) D(x) D(x)表示判别器对图像 x x x的判断结果, G ( z ) G(z) G(z)表示生成器生成的图像, log ⁡ D ( x ) \log D(x) logD(x)表示判别器将真实图像判断为真实图像的概率, log ⁡ ( 1 − D ( G ( z ) ) ) \log(1-D(G(z))) log(1D(G(z)))表示判别器将生成图像判断为真实图像的概率。

在这里插入图片描述

3. 使用PyTorch搭建DCGAN模型

首先,我们需要导入所需的库:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dset
from torch.autograd import Variable

接下来,我们定义生成器和判别器的网络结构:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # 输入是一个100维的向量
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 输出为(512, 4, 4)
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # 输出为(256, 8, 8)
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # 输出为(128, 16, 16)
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # 输出为(3, 32, 32)
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 输入为(3, 32, 32)
            nn.Conv2d(3, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出为(128, 16, 16)
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出为(256, 8, 8)
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出为(512, 4, 4)
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1)

4. 数据样例

我们将使用CIFAR-10数据集进行训练。首先,我们需要对数据进行预处理:

if __name__ =="__main__":
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    trainset = dset.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

5. 训练模型

接下来,我们将训练DCGAN模型:

# 初始化生成器和判别器
netG = Generator()
netD = Discriminator()

# 设置损失函数和优化器
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 训练模型
num_epochs = 10

for epoch in range(num_epochs):
    for i, data in enumerate(trainloader, 0):
        # 更新判别器
        netD.zero_grad()
        real, _ = data
        batch_size = real.size(0)
        label = torch.full((batch_size,), 1)
        output = netD(real)
        errD_real = criterion(output, label)
        errD_real.backward()
        noise = torch.randn(batch_size, 100, 1, 1)
        fake = netG(noise)
        label.fill_(0)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        errD = errD_real + errD_fake
        optimizerD.step()

        # 更新生成器
        netG.zero_grad()
        label.fill_(1)
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        optimizerG.step()

        if i%5==0:
           # 打印损失值
           print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch, num_epochs, i, len(trainloader), errD.item(), errG.item()))

6. 测试模型

训练完成后,我们可以使用生成器生成一些图像进行测试:

import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

noise = torch.randn(64, 100, 1, 1)
fake = netG(noise)
imshow(torchvision.utils.make_grid(fake.detach()))

7. 总结

本文详细介绍了DCGAN模型的原理,并使用PyTorch搭建了一个简单的DCGAN模型。我们提供了模型代码,并使用CIFAR-10数据集进行训练和测试。最后,我们展示了训练过程中的损失值和生成的图像。希望本文能帮助您更好地理解DCGAN模型,并在实际项目中应用。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用 的相关文章

  • 从打包序列中获取每个序列的最后一项

    我试图通过 GRU 放置打包和填充的序列 并检索每个序列最后一项的输出 当然我的意思不是 1项目 但实际上是最后一个 未填充的项目 我们预先知道序列的长度 因此应该很容易为每个序列提取length 1 item 我尝试了以下方法 impor
  • PyTorch 中复数矩阵的行列式

    有没有办法在 PyTorch 中计算复矩阵的行列式 torch det未针对 ComplexFloat 实现 不幸的是 目前尚未实施 一种方法是实现您自己的版本或简单地使用np linalg det 这是一个简短的函数 它计算我使用 LU
  • 使 CUDA 内存不足

    我正在尝试训练网络 但我明白了 我将批量大小设置为 300 并收到此错误 但即使我将其减少到 100 我仍然收到此错误 更令人沮丧的是 在 1200 个图像上运行 10 epoch 大约需要 40 分钟 有什么建议吗 错了 我怎样才能加快这
  • 为什么 pytorch matmul 在 cpu 和 gpu 上执行时得到不同的结果?

    我试图找出 numpy pytorch gpu cpu float16 float32 数字之间的舍入差异 而我发现的内容让我感到困惑 基本版本是 a torch rand 3 4 dtype torch float32 b torch r
  • Pytorch“展开”等价于 Tensorflow [重复]

    这个问题在这里已经有答案了 假设我有大小为 50 50 的灰度图像 在本例中批量大小为 2 并且我使用 Pytorch Unfold 函数 如下所示 import numpy as np from torch import nn from
  • PyTorch 中的交叉熵

    交叉熵公式 但为什么下面给出loss 0 7437代替loss 0 since 1 log 1 0 import torch import torch nn as nn from torch autograd import Variable
  • 在 Pytorch 中估计高斯模型的混合

    我实际上想估计一个以高斯混合作为基本分布的归一化流 所以我有点被火炬困住了 但是 您可以通过估计 torch 中高斯模型的混合来在代码中重现我的错误 我的代码如下 import numpy as np import matplotlib p
  • 在Pytorch中计算欧几里得范数..理解和实现上的麻烦

    我见过另一个 StackOverflow 线程讨论计算欧几里德范数的各种实现 但我很难理解特定实现的原因 如何工作 该代码可以在 MMD 指标的实现中找到 https github com josipd torch two sample b
  • Pytorch 与 joblib 的 autograd 问题

    将 pytorch 的 autograd 与 joblib 混合似乎存在问题 我需要并行获取大量样本的梯度 Joblib 与 pytorch 的其他方面配合良好 但是 与 autograd 混合时会出现错误 我做了一个非常小的例子 显示串行
  • TensorFlow 相当于 PyTorch 的 Transforms.Normalize()

    我正在尝试推断最初在 PyTorch 中构建的 TFLite 模型 我一直在遵循PyTorch 实现 https github com leoxiaobin deep high resolution net pytorch blob 1ee
  • PyTorch 给出 cuda 运行时错误

    我对我的代码做了一些小小的修改 以便它不使用 DataParallel and DistributedDataParallel 代码如下 import argparse import os import shutil import time
  • 如何在不安装pytorch的情况下使用pytorch预训练模型?

    我只想在 pytorch 中使用预先训练的模型 而不安装整个包 我可以从 pytorch 复制模型模块吗 恐怕你不能这样做 为了运行模型 你不仅需要经过训练的权重 pth tar 文件 还需要网络的 结构 即层 它们如何相互连接等 该网络结
  • PyTorch DataLoader 对并行运行的批次使用相同的随机种子

    有一个bug https tanelp github io posts a bug that plagues thousands of open source ml projects 在 PyTorch Numpy 中 当并行加载批次时Da
  • 尝试将 cuda 与 pytorch 一起使用时出现运行时错误 999

    我为我的 Geforce 2080 ti 安装了 Cuda 10 1 和最新的 Nvidia 驱动程序 我尝试运行一个基本脚本来测试 pytorch 是否正常工作 但出现以下错误 RuntimeError cuda runtime erro
  • 如何解决错误:PyTorch 中预期输入批量大小与目标批量大小不匹配?

    我尝试通过 PyTorch 在 CIFAR10 数据集上创建逻辑模型 但是我收到错误 ValueError 预期输入batch size 900 与目标batch size 300 匹配 我认为正在发生的事情是 3 100 是 300 所以
  • 如何让火车装载机使用特定数量的图像?

    假设我正在使用以下调用 trainset torchvision datasets ImageFolder root imgs transform transform trainloader torch utils data DataLoa
  • 将 Pytorch 模型 .pth 转换为 onnx 模型

    我有一个预训练的模型 其格式为 pth 扩展名 我想将其转换为 Tensorflow protobuf 但我没有找到任何方法来做到这一点 我见过 onnx 可以将模型从 pytorch 转换为 onnx 然后从 onnx 转换为 Tenso
  • PyTorch 中的后向函数

    我对 pytorch 的后向功能有一些疑问 我认为我没有得到正确的输出 import numpy as np import torch from torch autograd import Variable a Variable torch
  • 如何使用 PyTorch 沿特定维度进行热编码?

    我有一个大小的张量 3 15 136 where 3 is batch size 15 sequence length and 136 is tokens 我想使用中的概率来单热我的张量tokens维度 136 为此 我想提取序列长度中每个
  • PyTorch:加速数据加载

    我正在使用 dendnet121 从 Kaggle 数据集进行猫 狗检测 我启用了cuda 看起来训练速度非常快 然而 数据加载 或者可能是处理 似乎非常慢 有一些方法可以加快速度吗 我尝试玩女巫批量大小 但没有提供太多帮助 我还将 num

随机推荐

  • 获取本地硬盘信息

    using System using System Runtime InteropServices using System Text namespace driverId Serializable public struct HardDi
  • JS-语法进阶

    JS 语法进阶 三元运算符 类数组对象
  • 蓝桥杯 51单片机 AT24C02

    工作电压为1 8v 6v 第7引脚 WP 接地时允许正常读写 24C02设备地址包括固定部分和可编程部分 编程部分由A2 A1 A0三个硬件引脚来控制 设备地址最后一位用于设置数据传输方向 读 写 在IIC总线协议中 设备地址是起始信号后第
  • git分支管理策略

    1 总览 git 的分支整体预览图如下 从上图可以看到主要包含下面几个分支 master git默认主分支 这里不作操作 stable 稳定分支 替代master 主要用来版本发布 develop 日常开发分支 该分支正常保存了开发的最新代
  • 黑客自学路线

    谈起黑客 可能各位都会想到 盗号 其实不尽然 黑客是一群喜爱研究技术的群体 在黑客圈中 一般分为三大圈 娱乐圈 技术圈 职业圈 娱乐圈 主要是初中生和高中生较多 玩网恋 人气 空间 建站收徒玩赚钱 技术高的也是有的 只是很少见 技术圈 这个
  • Shader开发之三大着色器

    Shader开发之三大着色器 固定功能管线着色器Fixed Function Shaders 固定功能管线着色器的关键代码一般都在Pass的材质设置Material 和纹理设置SetTexture 部分 Shader Custom Vert
  • Anaconda3-5.1.0下载和安装

    下载安装anaconda的小插曲 1 在官网上找到windows的32位的下载 毕竟是八年前的老本了 另一个本装的64位 结果网站上出现问题 没有成功下载 2 万能的网络 终于找到可以下载的清华镜像地址 Index of anaconda
  • 如何阅读源代码

    我们在写程式时 有不少时间都是在看别人的代码 例如看小组的代码 看小组整合的守则 若一开始没规划怎么看 就会 噜看噜苦 台语 不管是参考也好 从开源抓下来研究也好 为了了解箇中含意 在有限的时间下 不免会对庞大的源代码解读感到压力 网路上有
  • Win11 安装Docker Desktop报错:Update the WSL kernel by running “wsl --update“ or follow instructions

    这个问题解决了一整个下午 看了无数的解决方案 最后找到了最有效的解决方案 总结如下 安装Docker Desktop之后 打开出现这样的问题 根据提示在powershell通过 wsl update 命令 出现 error 那么可以试试下面
  • 计算机视觉技术与应用综述

    引用自 无人系统之 眼 计算机视觉技术与应用浅析 张 丹 单海军 王 哲 吴陈炜 一 前言 近年来 人工智能和深度学习获得突破 成为了大众关注的焦点 如LeCun Y Bengio Y Hinton G等 1 提出的深度卷积网络在图像识别领
  • 一篇文章搞定Python多进程(这才是正确的Python多进程的打开方式)

    1 Python多进程模块 Python中的多进程是通过multiprocessing包来实现的 和多线程的threading Thread差不多 它可以利用multiprocessing Process对象来创建一个进程对象 这个进程对象
  • python3 [爬虫入门实战] 爬虫之selenium 模拟QQ登陆抓取好友说说内容(暂留)

    很遗憾 部分数据有些问题 不过还是可以进行爬取出来的 先贴上源代码 encoding utf8 from selenium import webdriver import re from bs4 import BeautifulSoup f
  • 二分字符串,没有连续的 1,使用递归思路,以及算法改进探讨

    今天聊一个递归解决二分字符串的问题 问题 给定正整数 N 计算所有长度为 N 但没有连续 1 的二分字符 比如 N 2 时 输出为 00 01 10 当 N 3 时 输出为 000 001 010 100 101 这个问题我在网上简单搜了一
  • linux 修改文件用户组和所有者

    目录 1 linux下修改文件用户组 2 linux下修改文件所有者 3 linux下同时修改文件所有者和用户组 1 linux下修改文件用户组 chgrp change group的简写 修改文件所属的用户组 chgrp 用户组名 文件名
  • (转) .net web项目的安装制作

    原 http blog csdn net houlinghouling archive 2005 06 17 396338 aspx 一 创建基本安装部署项目 1 在解决方案资源管理器 右击解决方案 添加 新建项目 安装部署项目 Web安装
  • 在loader中创建GDT,进入保护模式

    回顾 上一节实现了从BIOS中加载MBR MBR从磁盘2扇区读取loader加载到内存0x900处 但loader目前尚未实现任何功能 Q A Q1 loader在OS中主要做什么 答 创建一些系统数据结构 如GDT 页表等 打开进入保护模
  • 在csdn中复制的代码 去掉前面的行号

    在csdn中复制的代码会有行号 如下 1 2 3 4 5 6 解决方法 利用notepad 的替换功能 如下图一个个查找替换便可
  • stable diffusion实践操作-Controlnet

    本文专门开一节写提示词相关的内容 在看之前 可以同步关注 stable diffusion实践操作 文章目录 前言 1 ControlNet是什么 2 常用的模型 3 基本操作 openpose full 1 提示词 2 参数 控制效果参数
  • Thinkpad E580 硬件错误0187、2200、2201解决经历

    我的电脑是Thinkpad E580 最近电脑坏了 以下是具体情况 一天中午 我打开电脑 在屏幕显示完联想的logn之后 它出现了 我从未见过的我的电脑出现这样的情况 它也给我带来了生活上的不便以及精神和金钱上的损失 当然这是后话了 它长这
  • 人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用

    大家好 我是微学AI 今天给大家介绍一下人工智能 pytorch 搭建模型11 pytorch搭建DCGAN模型 一种生成对抗网络GAN的变体实际应用 本文将具体介绍DCGAN模型的原理 并使用PyTorch搭建一个简单的DCGAN模型 我