PyTorch训练深度卷积生成对抗网络DCGAN

2023-10-27

DCGAN介绍

将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN)

DCGAN的生成器结构:
在这里插入图片描述
图片来源:https://arxiv.org/abs/1511.06434

代码

model.py

import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # Input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ), # 32 x 32
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16
            self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8
            self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1
            nn.Sigmoid(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4
            self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8
            self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16
            self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32
            nn.ConvTranspose2d(
                features_g*2, channels_img, kernel_size=4, stride=2, padding=1,
            ),
            nn.Tanh(),

        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.gen(x)


def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    initialize_weights(disc)
    assert disc(x).shape == (N, 1, 1, 1)

    gen = Generator(z_dim, in_channels, 8)
    initialize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W)
    print("success")
    
if __name__ == "__main__":
    test()

训练使用的数据集:CelebA dataset (Images Only) 总共1.3GB的图片,使用方法,将其解压到当前目录

图片如下图所示:
在这里插入图片描述

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights

# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3 # 1 if MNIST dataset; 3 if celeb dataset
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64

transforms = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

# If you train on MNIST, remember to set channels_img to 1
# dataset = datasets.MNIST(
#     root="dataset/", train=True, transform=transforms, download=True
# )

# comment mnist above and uncomment below if train on CelebA

# If you train on celeb dataset, remember to set channels_img to 3
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
        fake = gen(noise)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

结果

训练5个epoch,部分结果如下:

Epoch [3/5] Batch 1500/1583                   Loss D: 0.4996, loss G: 1.1738
Epoch [4/5] Batch 0/1583                   Loss D: 0.4268, loss G: 1.6633
Epoch [4/5] Batch 100/1583                   Loss D: 0.4841, loss G: 1.7475
Epoch [4/5] Batch 200/1583                   Loss D: 0.5094, loss G: 1.2376
Epoch [4/5] Batch 300/1583                   Loss D: 0.4376, loss G: 2.1271
Epoch [4/5] Batch 400/1583                   Loss D: 0.4173, loss G: 1.4380
Epoch [4/5] Batch 500/1583                   Loss D: 0.5213, loss G: 2.1665
Epoch [4/5] Batch 600/1583                   Loss D: 0.5036, loss G: 2.1079
Epoch [4/5] Batch 700/1583                   Loss D: 0.5158, loss G: 1.0579
Epoch [4/5] Batch 800/1583                   Loss D: 0.5426, loss G: 1.9427
Epoch [4/5] Batch 900/1583                   Loss D: 0.4721, loss G: 1.2659
Epoch [4/5] Batch 1000/1583                   Loss D: 0.5662, loss G: 2.4537
Epoch [4/5] Batch 1100/1583                   Loss D: 0.5604, loss G: 0.8978
Epoch [4/5] Batch 1200/1583                   Loss D: 0.4085, loss G: 2.0747
Epoch [4/5] Batch 1300/1583                   Loss D: 1.1894, loss G: 0.1825
Epoch [4/5] Batch 1400/1583                   Loss D: 0.4518, loss G: 2.1509
Epoch [4/5] Batch 1500/1583                   Loss D: 0.3814, loss G: 1.9391

使用

tensorboard --logdir=logs

打开tensorboard

在这里插入图片描述

参考

[1] DCGAN implementation from scratch
[2] https://arxiv.org/abs/1511.06434

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch训练深度卷积生成对抗网络DCGAN 的相关文章

随机推荐

  • mysql数据库升级-MySQL 5.7.25主备架构小版本In-Place升级思路

    一 描述 漏扫发现MySQL有低风险漏洞 自己写方案 自己做测试 自己升级 版本 MySQL 5 7 25 升级到MySQL 5 7 28最新版本 架构 主从架构 二 升级流程 1 下载最新版数据库软件MySQL 5 7 28 2 上传到指
  • 萌新的Arduino大作业

    全自动收 晾衣服机 备注 本人因学校社团假期作业要求 用Arduino IDE编写并模拟实现了一个全自动 收 凉衣服的机器 由于硬件条件不足只能模拟 本人也是萌新一枚 希望观看的 大佬们不喜勿喷 有发现做错的话欢迎在评论区讨论 如果对你有帮
  • 西门子编程基础学习分享(3)-数据类型详述

    1200PLC的数据类型详述 前文所提到的数据类型用于描述数据的长度以及属性 即为指定数据元素的大小以及如何解释数据 每个指令至少支持一种数据类型 因而指令上使用的操作数的数据类型必须与指令所支持的数据类型一致 所以在设计程序 建立变量时需
  • Uva 540 Team Queue

    有t个团体的人正在排一个长队 每次新来一个人时 如果这个成员所在的团体已经有人在排队了 那么他就加到最后一个队友身后 如果整个大队列中没有他的团体 那么他就要排在整个大队列的最后 输入每个团队的人数 每个人的编号 要求支持下面的操作 前两种
  • 【订单服务】库存解锁和关单

    消息队列流程图 监听库存解锁 下单成功 库存锁定成功 接下来的业务调用失败 导致订单回滚 之前锁定的库存就要自动解锁 配置队列和交换机 Configuration public class MyRabbitConfig 使用json序列化机
  • 失业在家靠做PPT日赚800-1000元,有一门副业真的很重要!

    下班做PPT 半年挣8万是什么感觉 你好 我是佳佳 一个用PPT兼职挣钱的宝妈 我现在每天抽2个小时 坐在电脑前 把各种素材像拼图一样拼接一下 像这样 然后把成稿投稿到设计平台 就能挣到钱 你是不是觉得 我是个职业设计师 挺厉害的 不是的
  • NLP(十五)让模型来告诉你文本中的时间

    背景介绍 在文章NLP入门 十一 从文本中提取时间 中 笔者演示了如何利用分词 词性标注的方法从文本中获取时间 当时的想法比较简单快捷 只是利用了词性标注这个功能而已 因此 在某些地方 时间的识别效果并不太好 比如以下的两个例子 原文1 苏
  • python递归实现字符串逆反

    def main string input Enter a string string1 reverse string print string s reverse format is string1 def reverse string
  • YOLOV7学习记录之训练过程

    在前面学习YOLOV7的过程中 我们已经学习了其网络结构 然而实际上YOLOV7项目的难点并不在于其网络模型而是在于其损失函数的设计 即如何才能训练出来合适的bbox 神经网络模型都有训练和测试 推理 过程 在YOLOV7的训练过程中 包含
  • Java学习笔记:Java中的加号“+”

    在今晚学习Java时惊奇地发现Java中有 System out println 赋值后c的值为 c 这样的与c语言不同的语法 本着打破砂锅问到底 xue dao si 的精神 稍微整理了一下 下面是整理出来的Java中加号 的用法 算术运
  • mysql字段使用非int做主键,查询时候使用整型和字符串做查询条件的区别

    where条件key是整型的时候也可以找到记录 但是效率慢 不会使用索引 使用字符串的时候会使用主键索引会很快
  • ionic入门教程第十五课-ionic性能优化之图片延时加载

    周五的时候有个朋友让我写一个关于图片延时加载的教程 直到今天才有空编辑 这阶段真的是很忙 公众号都变成僵尸号了 实在是对不起大家 有人喜欢我的教程 可能我总习惯了用比较简单容易理解的方式去描述这些东西 别的就不多说了 大家遇到什么问题 可以
  • 100天精通Python(基础篇)——第23天:while循环 :99乘法表

    i 0 while i lt 10 print 我喜欢你 i 1 print endl i 0 sum 0 while i lt 101 i 1 sum i print f sum sum import random num random
  • django1.10 静态文件配置

    settings配置 网站引用静态文件时都会加上该地址 如 http www xxx com static css mini css STATIC URL static 静态文件根目录 执行命令 python manage py colle
  • PostgreSQL 服务启动不了问题

    配置了postgresql数据的配置文件 pg hba conf后 重记一下服务 结果启动不了 提 示错误 root instance 609xznso run systemctl start postgresql 11 Job for p
  • C++11 function、bind、可变参数模板

    在设计回调函数的时候 无可避免地会接触到可回调对象 在C 11中 提供了std function和 std bind两个方法来对可回调对象进行统一和封装 C 语言中有几种可调用对象 函数 函数指针 lambda表达式 bind创建的对象以及
  • Hibernate的加载方式——GET与LOAD的对比

    在Hibernate框架中 最常用到的加载方式就非Get和Load莫属了 然而Get和Load在加载方式上边还有很多的不同 下面让我们来分析一下他们的不同之处 区别 从返回的结果上来看 get load在检索到数据的时候 会返回对象 代理对
  • firefox火狐书签windows和ubuntu无法同步问题

    装了ubuntu后发现firefox的书签没法同步 最终发现问题的原因 firefox有个全球服务和本地服务 ubuntu下的firefox默认是全球服务的 而windows下的firefox默认是本地服务的 这样相当于两个系统下默认的存储
  • 【生信】初探基因定位和全基因组关联分析

    初探QTL和GWAS 文章目录 初探QTL和GWAS 实验目的 实验内容 实验题目 第一题 玉米MAGIC群体的QTL分析 第二题 TASSEL自带数据集的关联分析 实验过程 玉米MAGIC群体的QTL分析 包含的数据 绘制LOD曲线 株高
  • PyTorch训练深度卷积生成对抗网络DCGAN

    文章目录 DCGAN介绍 代码 结果 参考 DCGAN介绍 将CNN和GAN结合起来 把监督学习和无监督学习结合起来 具体解释可以参见 深度卷积对抗生成网络 DCGAN DCGAN的生成器结构 图片来源 https arxiv org ab