AttGAN从paper到code理解

2023-11-19

AttGAN:Facial Attribute Editing by Only Changing What You Want(2017 CVPR)

文章简介
本文研究面部属性编辑任务,其目的是通过操作单个或多个感兴趣的属性(如头发颜色、表情、胡须和年龄)来编辑面部图像。

Dataset: CeleA
Contribution:

  1. 移除了严格的attribute-independent约束,仅需要通过attribute classification来保证正确地修改属性
  2. 整合了attribute classification constraint、reconstruction learning、adversarial learning,使得结果生成效果非常好
  3. 可以直接控制属性强度,从而可以自然地完成风格变换
    在这里插入图片描述

理解论文算法

A. 人脸属性编辑
以学习为基础的方法提出通过部署一个对抗属性损失和一个深度特征损失,来训练出深度特征识别属性转移模型。该模型可以增加或移除一个属性到一个人脸图像(或者将属性从图像中移除)。
属性编辑的能力是通过修改潜在表达去获得所期待的属性信息并解码它而获得的。也就是说这个属性编辑的能力来源于将解码模型增添属性的feature map,然后通过解码过程及鉴别器和calssification的损失,在训练中BP来优化的。这一部分我们在后面的代码中在详细了解。

B. 生成对抗网络(Genreative Adversarial Networks)
GAN的灵魂在于生成对抗,它的原理就是生成器G和鉴别器D的对抗。G包括encode和decode部分,G将输入图像压缩成高维特征后通过解码再形成假图,这个假图作为D的输入,D输出约接近1说明假图越像真图。所以就是在这个不断的生成对抗的过程中,G可以把假图变得越来越真。
在这里插入图片描述
从上图loss中,我们知道X是原图,D(x)表示鉴别器对原图的鉴定结果,Z表示X通过encode压缩成高维特征,G(z)指z由decode后生成的假图,D(G(z))表示用鉴别器去鉴别这个假图有多真。所以当 minmax条件成立,说明假图逼近真图。

ATTGAN
A. Testing Formulation
给一张带有n个二进制属性a=[a1,…,a2]的人脸图像 X a X^a Xa,编码器Genc将 X a X^a Xa转化为潜在表达,记为:
在这里插入图片描述
X a X^a Xa编辑为属性b的过程是通过解码z(以属性b为条件)来获得的。
在这里插入图片描述
在这里插入图片描述
test的过程如上,给定输入图像以及它的属性a,通过Genc变成z,再加入b属性(b属性可以通过a获得,原文中有13个特征,逐位取反既可获得b,可以生成13个b所以test出来可以有13张)
这里大家一定疑惑code中如何添加b的,下面大家看到code就明白了。

B. Training Formulation
属性编辑的问题可以定义为编码器和解码器的学习过程。这个过程是非监督的,因为我们并没有 X b X^b Xb的ground truth。一方面,在原图 X a X^a Xa上编辑,期望产生带有b属性的真实图像。为了达到这个目标,属性分类器被用来限制产生的 X b X^b Xb能够正确获得所期望的属性。另一方面,一个合格的属性编辑应该只改变想改变的属性,同时保持其他不变的细节。为了达到这个目的,reconstruction learning被引入
在这里插入图片描述
A. Attribute Classification Constraint.
正如上面提及的,生成图应该正确获得新属性b。因此,部署classifier C来限制它获得所期待的属性。
在这里插入图片描述
在这里插入图片描述
表示对第i个属性的预测,其实可以简单的看成二分类,loss就是交叉熵损失。

B. Reconstruction Loss.
为了完美的保留不改变的部分,作者提出Reconstruction Loss.
在这里插入图片描述

C. Adversarial Loss.
对抗损失同样还是为了生成图更加真实,它分为G和D两部分损失
在这里插入图片描述

总结,train部分训练了两块(G,D)。
G的损失函数如下:
在这里插入图片描述
D的损失函数如下:
在这里插入图片描述
网络结构:
在这里插入图片描述

code

    def trainG(self, img_a, att_a, att_a_, att_b, att_b_):
        for p in self.D.parameters():
            p.requires_grad = False
        
        zs_a = self.G(img_a, mode='enc')
        img_fake = self.G(zs_a, att_b_, mode='dec')
        img_recon = self.G(zs_a, att_a_, mode='dec')
        d_fake, dc_fake = self.D(img_fake)
        
        if self.mode == 'wgan':
            gf_loss = -d_fake.mean()
        if self.mode == 'lsgan':  # mean_squared_error
            gf_loss = F.mse_loss(d_fake, torch.ones_like(d_fake))
        if self.mode == 'dcgan':  # sigmoid_cross_entropy
            gf_loss = F.binary_cross_entropy_with_logits(d_fake, torch.ones_like(d_fake))
        gc_loss = F.binary_cross_entropy_with_logits(dc_fake, att_b)
        gr_loss = F.l1_loss(img_recon, img_a)
        g_loss = gf_loss + self.lambda_2 * gc_loss + self.lambda_1 * gr_loss
        
        self.optim_G.zero_grad()
        g_loss.backward()
        self.optim_G.step()
        
        errG = {
            'g_loss': g_loss.item(), 'gf_loss': gf_loss.item(),
            'gc_loss': gc_loss.item(), 'gr_loss': gr_loss.item()
        }
        return errG

trainG就是对G这一部分进行训练,从p.requires_grad = False可以找到,虽然G的loss用到了D的部分,但是并不对D进行BP,两部分是分开训练的。
self.G是Genc,将图像压缩成高维特征,这一块代码比较简单,我们从上面的网络结构就可以知道Genc是个啥了。 img_fake就是生成的带有b属性的假图,img_recon是生成的带有a属性的假图。

class Discriminators(nn.Module):
    # No instancenorm in fcs in source code, which is different from paper.
    def __init__(self, dim=64, norm_fn='instancenorm', acti_fn='lrelu',
                 fc_dim=1024, fc_norm_fn='none', fc_acti_fn='lrelu', n_layers=5, img_size=128):
        super(Discriminators, self).__init__()
        self.f_size = img_size // 2**n_layers
        
        layers = []
        n_in = 3
        for i in range(n_layers):
            n_out = min(dim * 2**i, MAX_DIM)
            layers += [Conv2dBlock(
                n_in, n_out, (4, 4), stride=2, padding=1, norm_fn=norm_fn, acti_fn=acti_fn
            )]
            n_in = n_out
        self.conv = nn.Sequential(*layers)
        self.fc_adv = nn.Sequential(
            LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn),
            LinearBlock(fc_dim, 1, 'none', 'none')
        )
        self.fc_cls = nn.Sequential(
            LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn),
            LinearBlock(fc_dim, 13, 'none', 'none')
        )
    
    def forward(self, x):
        h = self.conv(x)
        h = h.view(h.size(0), -1)
        return self.fc_adv(h), self.fc_cls(h)

d_fake, dc_fake = self.D(img_fake),这里注意。我们看到这个D里有两个输出,一个是将假图压缩成1个pixel,用来判断真假,真就是1,假是0。另一个是把假图压缩成13个pixels用来编辑属性的。

g_loss = gf_loss + self.lambda_2 * gc_loss + self.lambda_1 * gr_loss
gf_loss 就是对抗生成损失,gc_loss是 Attribute Classification Constraint的损失,gr_loss是reconstruction loss,正好对应了论文中的结论。
其他部分就不详述了,想交流的可以留言。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

AttGAN从paper到code理解 的相关文章

随机推荐

  • C++学习心得

    C 学习心得 一周的C 学习结束了 从C 的简介 各种专业术语的介绍到最后的标准模板库 对于这个c的加强版的语言有了一定的认识理解 但是由于6天时间学完了全部 而且由于疫情在家里上了两天网课 对于一些运用层次还不是很熟悉 学的重点放在了面向
  • 浅谈CSS中/deep/ >>> ::v-deep属性 进行样式穿透

    背景 在开发vue项目中 引入第三方ui组件库 只需要在当前页面修改第三方组件库的样式以做到不污染全局样式 通过在样式标签上使用 scoped 达到样式只制作用到本页面 但是此时再修改组件样式不起作用 scoped的作用及实现原理 作用 当
  • hook方法

    dl iterate phdr
  • 【Mysql】初探表连接的原理

    在我们的工作学习中肯定都用到过表连接的操作 不同连接写法在执行效率上会有不小的区别 要想写出高效的表连接语句 还是需要我们知晓表连接的原理 什么是连接 连接的本质就是要连接在一起的表中符合条件的结果集组合在一起 然后返回给用户的过程 准备
  • SSM实战开发:构建强大的Java Web应用

    SSM实战开发 构建强大的Java Web应用 本文介绍如何使用SSM框架 Spring SpringMVC MyBatis 进行实战开发 构建一个强大的Java Web应用 通过该实例 你将学会SSM框架的整合 配置和使用 以及常见的We
  • 数据标注工具大汇总

    图片 拉框 labelimg 已经安装 bbox label tool LabelBoundingBox Yolo mark FastAnnotationTool od annotation RectLabel cvat VoTT VIA
  • upload.addEventListener is not a function报错

    原因 Mock js重写了XMLHttpRequest 导致了原生XMLHttpRequest被mockjs覆盖找不到相应的方法 场景 vite plugin mock vue3 element UI upload组件 解决办法 更改配置项
  • centos7 pip3 安装python模块包报错解决

    centos7 pip3 安装python模块包报错 bash usr local bin pip3 usr local bin python3 6 坏的解释器 没有那个文件或目录 root localhost Python pip3 in
  • 关于蚁剑的安装和使用

    下载地址 加载器 https github com AntSwordProject AntSword Loader 核心源码 https github com AntSwordProject antSword 加载器中的是exe文件 因为源
  • PyQt5+VTK环境搭建

    PyQt5 VTK环境搭建 VTK 简介及安装 VTK 介绍 VTK 在 Python 环境下安装 方法一 安装 anaconda 使用 conda install 安装 适用于 python3 适用于 python 2 方法二 镜像安装
  • 第一章遇见的问题(题目是原创,答案转载收集互联网)

    1 PCTSTR和LPCTSTR 在ANSI编译方式下 PCTSTR和LPCTSTR等价于LPCSTR 在Unicode下等价于LPCWSTR 2 LPVOID WINAPI LocalLock in HLOCAL hMem 功能 锁定一个
  • IAR个人常用配置

    IAR个人常用配置 文章目录 IAR个人常用配置 1 设置 2 设置tab和indent为4空格 3 设置编码为UTF 8 4 自动缩进设置 5 修改背景颜色和字体 6 修改全局搜索快捷键 1 设置 Tools gt Options 2 设
  • 网红漏洞“致远OA系统上的GetShell漏洞”详解

    概述 腾讯御界高级威胁检测系统近期监测到 致远OA系统上的 GetShell漏洞 在网上被频繁利用攻击政企客户 对于存在漏洞的OA系统 攻击者无需任何权限 即可向服务器上传webshell 腾讯驻场工程师通过御界高级威胁检测系统告警通知及时
  • Flutter Error: The method ‘toInt‘ isn‘t defined for the class ‘Decimal‘

    1 运行项目报错 2 错误原因分析 从错误日志可以看出 是common utils插件中的decimal 2 0 0依赖库报错了 猜测可能是decimal升级版本了导致不兼容造成的 打开https pub flutter io cn 搜索d
  • Windows安装frida

    一 正常步骤 cmd中 pip3 install frida i https pypi mirrors ustc edu cn simple 上面失败用这个 pip install frida i http mirrors aliyun c
  • linux 查看及修改字符集

    一 查看当前linux系统的字符集方法 1 1 locale 1 2 echo LANG 1 3 env grep LANG 二 查看当前系统支持的字符集 root localhost locale a 三 修改系统字符集 3 1 临时生效
  • vue中使用bus总线在非父子组件之间传值

    使用bus总 线可以在 兄弟 父子 祖先和后代 组件之间传值 原理 在Vue原型中 创建一个bus属性 让每一个组件 实例 都具有这个属性 这里自行引入 vue
  • Idea 发布最适合程序员的字体!

    作为 编译期界的大佬 JetBrains公司一直致力于提供更好的编码环境 前两天 JetBrain推出了一个新的字体 JetBrain Mono 号称是最适合程序员的编码的字体 我赶紧尝了尝鲜 体验了一天之后发现确实好看 因此推荐给大家 首
  • ABB MPRC086444-005数字输入模块

    ABB MPRC086444 005 是一款数字输入模块 通常用于工业自动化和控制系统中 用于接收和处理数字信号 以下是这种类型的数字输入模块通常可能具备的一般功能和特点 数字输入接口 MPRC086444 005 模块通常配备多个数字输入
  • AttGAN从paper到code理解

    AttGAN Facial Attribute Editing by Only Changing What You Want 2017 CVPR 文章简介 本文研究面部属性编辑任务 其目的是通过操作单个或多个感兴趣的属性 如头发颜色 表情