pytorch学习之Condition GAN与代码的部分解析

2023-11-14

1.首先,GAN网络是有生成器和判别器,比如可以生成新的图像,而CGAN则是添加了条件,生成有限制的图像,比如生成带微笑的人脸。CGAN的架构如下:

                    

2.主要部分的代码

  • 定义判别器:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.label_emb = nn.Embedding(10,10) 
        #Embedding类返回的是一个形状为[每句词个数, 词维度]的矩阵。


        self.model = nn.Sequential(
            nn.Linear(794,1024),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Dropout(0.4),
            nn.Linear(1024,512),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Dropout(0.4),
            nn.Linear(256,1),
            nn.Sigmoid()
        )

    def forward(self,x,labels):
        x = x.view(x.size(0),784)
        c = self.label_emb(labels)
        x = torch.cat([x,c],1)
        out = self.model(x)
        return out.squeeze()
        #可以删除数组形状中的单维度条目,即把shape中为1的维度去掉,但是对非单维的维度不起作用。
  • 定义生成器:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.label_emb = nn.Embedding(10,10)

        self.model = nn.Sequential(
            nn.Linear(110,256),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(256,512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    def forward(self,z,labels):
        z = z.view(z.size(0),100)
        c = self.label_emb(labels)
        x = torch.cat([z,c],1)
        out = self.model(x)
        return out.view(x.size(0),28,28)
  • 训练判别器:
for epoch in range(num_epochs):
    for i,(images,labels) in enumerate(data_loader):
        step = epoch*len(data_loader)+i+1
        images = images.to(device)
        labels = labels.to(device)
        # 定义图像是真或假的标签
        real_labels = torch.ones(batch_size).to(device)  #真标签全是1
        fake_labels = torch.randint(0,10,(batch_size,)).to(device) 
        #返回均匀分布的[0,10]之间的整数随机值

        #训练判别器                                    
     

        # 定义判断器对真图片的损失函数
        real_validity = D(images,labels)
        d_loss_real = criterion(real_validity,real_labels)  #损失比较,与1
        real_score = real_validity   #判别器生成的值
        # 定义判别器对假图片(即由潜在空间点生成的图片)的损失函数
        z = torch.randn(batch_size,100).to(device)
        #创建batch_size行100列的随机数的tensor,随机值的分布式均值为0,方差为1
        fake_labels = torch.randint(0, 10, (batch_size,)).to(device)
        #创建batch_size行列不指定的随机整数的tensor,随机值的区间是[low, high)[0,10]
        fake_images = G(z,fake_labels)
        fake_validity = D(fake_images,fake_labels)

        d_loss_fake = criterion(fake_validity, torch.zeros(batch_size).to(device)) 
        #损失比较,与0
        fake_score = fake_images   #生成器生成的值
        d_loss= d_loss_fake + d_loss_real

        # 对生成器、判别器的梯度清零
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
  • 训练生成器:
  
        # 训练生成器
     

        # 定义生成器对假图片的损失函数,这里我们要求
        # 判别器生成的图片越来越像真图片,故损失函数中
        # 的标签改为真图片的标签,即希望生成的假图片,
        # 越来越靠近真图片

        z = torch.randn(batch_size, 100).to(device)
        fake_images = G(z, fake_labels)
        validity = D(fake_images, fake_labels)
        g_loss = criterion(validity, torch.ones(batch_size).to(device))  #标签为1

        # 对生成器、判别器的梯度清零
        # 进行反向传播及运行生成器的优化器
        reset_grad()
        g_loss.backward()
        g_optimizer.step()

3.具体的整个代码:https://github.com/maojinjiayou/CGAN

4.一些注意的点:

  • import os  简单来说,是对文件进行操作,比如说列出当前文件夹里面有啥文件、删除文件.
  • 安装tensorboardX To install this package with conda run one of the following:
conda install -c conda-forge tensorboardx 
conda install -c conda-forge/label/gcc7 tensorboardx 
conda install -c conda-forge/label/cf201901 tensorboardx
  • squeeze()函数
    可以删除数组形状中的单维度条目,即把shape中为1的维度去掉,但是对非单维的维度不起作用。
    
  • nn.Embedding(10,10)

    其中有一个nn.Embedding(vocab_size,embed_dim)类,它是Module类的子类,这里它接受最重要的两个初始化参数:词汇量大小,每个词汇向量表示的向量维度。Embedding类返回的是一个形状为[每句词个数, 词维度]的矩阵。

import torch
import torch.nn as nn
embedding=nn.Embedding(10,3)
input=torch.LongTensor([[1,2,4,5],[4,3,2,9]])
embedding(input)
tensor([[[ 0.8052, -0.1044, -0.6971],
         [ 1.3792, -0.1265, -1.1444],
         [ 1.4152, -0.1551, -1.2433],
         [ 0.7060, -1.0585,  0.5130]],

        [[ 1.4152, -0.1551, -1.2433],
         [-0.9881, -0.1601,  0.6339],
         [ 1.3792, -0.1265, -1.1444],
         [-1.1703,  1.8496,  0.8113]]], grad_fn=<EmbeddingBackward>)

第一个参数是字的总数,第二个参数是字的向量表示的维度。我们的输入input是两个句子,每个句子都是由四个字组成的,使用每个字的索引来表示,于是使用nn.Embedding对输入进行编码,每个字都会编码成长度为3的向量。

参考文章:https://www.cnblogs.com/xiximayou/p/13343608.html

  • nn.Sequential()

模块将按照构造函数中传递的顺序添加到模块中。另外,也可以传入一个有序模块。

  • nn.LeakyReLU(0.2,inplace=True)  是否需要覆盖之前的值。
  • torch.randint(0,10,(batch_size,)) 
    创建batch_size行列不指定的随机整数的tensor,随机值的区间是[low, high)[0,10]
  • torch.randn([3,4]) 

       创建3行4列的随机数的tensor,随机值的分布式均值为0,方差为1

  • make_grid(images, nrow=10, normalize=True)
    make_grid用于把几个图像按照网格排列的方式绘制出来,每行的图片数量为10,
    normalize如果为True,则把图像的像素值通过range指定的最大值和最小值归一化到0-1。
  • fig, ax = plt.subplots(figsize=(10,10))
   fig代表绘图窗口(Figure);ax代表这个绘图窗口上的坐标系(axis),一般会继续对ax进行操作。
   figsize表示figure 的大小为宽、长(单位为inch)
  • ax.imshow(grid.permute(1, 2, 0).detach().cpu().numpy(), cmap='binary')
  grid.permute(1, 2, 0)将tensor的维度换位,原来的顺序是(0,1,2)
  当使用detach()分离tensor但是没有更改这个tensor时,并不会影响backward()
  cmap='binary' 显示设置,两端发散的色图 colormaps
  • img = 0.5 * img + 0.5  
    还原图像,反归一化​​​​​​
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch学习之Condition GAN与代码的部分解析 的相关文章

  • Road Construction POJ - 3352(tarjan双连通缩点模板)

    题目描述 给一个无向连通图 至少添加几条边使得去掉图中任意一条边不改变图的连通性 即使得它变为边双连通图 include
  • CH3___Debugging C++ Programs

    3 1 Syntax and semantic errors Modern compilers have been getting better at detecting certain types of common semantic e
  • Linux下yum命令及软件的安装

    yum命令 1 yum install softwarename 安装 2 yum remove softwarename 卸载 安装dhcp及卸载 mkdir iso 建立目录 mv home kiosk Desktop iso iso
  • tcp 是一个安全的网络协议

    1 tcp 是一个安全的网络协议 确定双方的收发能力之后 才会真正传输数据 2 tcp 建立起一个连接 比较消耗成本 所以比较平稳 安全 3 3次握手 发起连接 双方确认 确认双方的收发能力 客户端告诉服务器i我要创建连接i 一次 服务器告
  • 出栈的合法性检测

    对于一个给定的入栈顺序 可能的出栈顺序会有很多 但是肯定都要遵循栈 后进先出 的特点 那么怎么进行合法性检测呢 算法思想如下 定义变量InIndex标记入栈序列的当前位置 定义OutIndex标记出栈序列的当前位置 对InIndex和Out
  • 利用纯净语音和噪声合成不同信噪比的训练数据

    如题 这应该算是我前往语音这座大山的第一步 在此做出记录 一 工作背景 由于需要进行单通道降噪的实验 但是现在只有纯净语音和噪声数据 而在阅读文章的过程中 大家并没有将这个细小的内容写道论文中 的确也不应该 做出来之后确实感觉蛮简单的 所以
  • python爬虫ip被封怎么办?

    用python写的爬虫 设置了headers 包括host和useragent 设置了cookies 访问的结果是 访问过于频繁 请输入验证码 但是用浏览器访问怎么刷新都没有问题 这个时候大致可以判定你被反爬虫锁定 那怎样解决 你可能不太了

随机推荐

  • Python无法识别csv文件

    我的报错 utf 8 codec can t decode byte 0xc9 in position 84 invalid continuation byte 大概意思是utf 8无法识别文件里的一些信息 后面将encoding里面改成下
  • 【Java编程】图书管理系统

    图书管理系统 我们用一个列表存放书籍信息 private static List
  • paxos算法_共识算法(8) —— PBFT 算法详解

    本文翻译自 伊利诺伊大学厄巴纳 香槟分校助理教授 Ling Ren 开设的讨论课 CS598 Consensus Algorithm 参考论文 PBFT 原论文 1999 pmg csail mit edu 前言 上一节中我们介绍了经典的P
  • IDEA项目初次上传到git(超简单)

    IDEA上传到git 1 右键项目 打开 终端 2 在打开的终端输入 git init 3 右键项目 选择 git gt 添加 add 4 右键项目 选择 git gt 提交 commit 输入 init 点击 提交并推送 commit a
  • Ehoney开源欺骗防御系统

    一 特点 支持丰富的蜜罐类型 通用蜜罐 SSH 蜜罐 Http蜜罐 Redis蜜罐 Telnet蜜罐 Mysql蜜罐 RDP 蜜罐 IOT蜜罐 RTSP 蜜罐 工控蜜罐 ModBus 蜜罐 基于云原生技术 基于k3s打造saas平台欺骗防御
  • 创建一们计算机语言_建立自己的计算机语言

    创建一们计算机语言 只需编码 如果您想构建自己的计算机语言 但又不知道该如何开始 或者您认为自己没有时间和技能来做到这一点 那么请看鲍勃 尼斯特罗姆 Bob Nystrom 的 技巧翻译 一书 即从刮 从一开始到成熟的面向对象的东西就是这样
  • JDBC工具类——JdbcUtils

    JdbcUtils 一 JDBC的工具类 二 JdbcUtils工具类的组成 1 类加载时加载驱动 2 连接池 db properties 3 ThreadLocal控制事务 4 dbcp连接池提高资源利用率 三 JDBC工具类的实例演变
  • Ubuntu opencv的搭建

    打开终端 apt install cmake 依次输入以下的命令 sudo apt get install cmake git libgtk2 0 dev pkg config libavcodec dev libavformat dev
  • Linux字符集的查看及修改

    一 查看字符集 字符集在系统中体现形式是一个环境变量 以CentOS6 5为例 其查看当前终端使用字符集的方式可以有以下几种方式 第一种 root Testa www tmp echo LANG zh CN UTF 8 第二种 root T
  • Nvidia显卡硬件编解码能力表 官方链接

    记录用 便于快速查找 从表中得知 1070支持 H265 10bit 硬件编码 似乎不错 官方链接 https developer nvidia com video encode and decode gpu support matrix
  • C++虚函数表地址偏移

    include
  • 架构图以及vue的简介

    架构图 前后端分离总架构图 前端架构设计图 MVVM 架构模式 MVVM 的简介 MVVM 由 Model View ViewModel 三部分构成 Model 层代表数据模型 也可以在Model中定义数据修改和操作的业务逻辑 View 代
  • 公务员和事业单位的差别有多大?

    公务员和事业单位是两种不同的就业形式和组织类型 它们在以下几个方面存在一些差别 1 归属关系 公务员属于政府部门的编制人员 直接依附于政府机构 而事业单位是独立法人实体 独立承担法人责任 不隶属于政府机构 2 支付方式 公务员工资由政府财政
  • 算法训练 P0505

    标题 include
  • 基于GBDT+LR模型的深度学习推荐算法

    GBDT LR算法最早是由Facebook在2014年提出的一个推荐算法 该算法分两部分构成 第一部分是GBDT 另一部分是LR 下面先介绍GBDT算法 然后介绍如何将GBDT和LR算法融合 1 1 GBDT算法 GBDT的全称是 Grad
  • flutter GridView和Wrap

    GridView有2种gridDelegate 记录小嵌套冲突的问题 SingleChildScrollView ListView GrilView嵌套问题解决 子布局添加属性 physics NeverScrollableScrollPh
  • Windows 10 Office文件图标异常处理(Word

    1 我们经常会遇到office重新安装完成后 或者换了版本后 前期做好的excel ppt word文件可以正常打开 但图标显示为白色或者异常 备注 如果不能正常打开 则是office程序没有关联到 只需要选中需打开文件 右键 更改 里面找
  • MODBUS TCP协议实例数据帧详细分析

    MODBUS TCP协议实例数据帧详细分析 1 简介 2 ModbusTCP数据帧 2 1 报文头MBAP 2 2 帧结构PDU 3 ADU详细结构 3 1 0x01 读线圈 3 2 0x02 读离散量输入 3 3 0x03 读保持寄存器
  • 达梦数据库,大小写敏感这个参数怎么设置

    达梦数据库 大小写敏感这个参数怎么设置 1 1 现象描述 达梦在安装完软件后 需要初始化数据库实例 其他大部分数据库 也是同样的操作 但是 达梦在初始化数据库实例前 有几个需要特别注意的参数 这几个参数一定要特别关注 因为如果设置错了 是不
  • pytorch学习之Condition GAN与代码的部分解析

    1 首先 GAN网络是有生成器和判别器 比如可以生成新的图像 而CGAN则是添加了条件 生成有限制的图像 比如生成带微笑的人脸 CGAN的架构如下 2 主要部分的代码 定义判别器 class Discriminator nn Module