生成专题3

2023-11-09

  • 文章转自微信公众号:机器学习炼丹术
  • 作者:陈亦新(欢迎交流共同进步)
  • 联系方式:微信cyx645016617
  • 学习论文:Analyzing and Improving the Image Quality of StyleGAN


3.1 AdaIN

StyleGAN第一个版本提出了通过AdaIN模块来实现生成,这个模块非常好也非常妙。

图片中的latent Code W是一个一维向量。然后Adaptive Instance Norm其实是基于Instance Norm修改的。Instance Norm当中,包含了2个可学习参数,shift和scale。而AdaIN就是让这两个可学习参数是从W向量经过全连接层直接计算出来的。因为shift scale会影响生成的图片,所以这样可以让生成的图片收到latent code W的控制,从而实现生成的可控。

3.2 AdaIN的问题

研究人员发现,StyleGAN生成的图片中,大概率存在一些水滴样子的补丁。

研究人员说:We pinpoint the problem to the AdaIN operation that normalizes the mean and variance of each feature map separately, thereby potentially destroying any information found in the magnitudes of the features relative to each other.

导致水珠的原因是AdaIN操作,AdaIN对每一个feature map的通道进行归一化,这样可能破坏掉feature之间的信息。当然实验证明发现,去除AdaIN的归一化操作后,水珠就消失了。

我们来看StyleGAN2是如何改进AdaIN模块的:

  • 图a是原始的styleGAN1的结构图;
  • 图b把AdaIN拆分成了Norm mean/std和Mod mean/std两部分,Norm是做的归一化操作,而Mod则是从latent code计算shift和scale参数的步骤;
  • 图c,现在我们修改一下模型,我们去除对于mean的norm和mod的操作,只留下对方差的操作。
  • 图d则是在c的基础上,进一步提出了weight demodulation的操作。

3.3 weight demodulation

虽然我们修改了网络结构,去除了水滴问题,但是styleGAN的目的是对特征实现可控的精细的融合。

StyleGAN2说,style modulation可能会放大某些特征的影像,所以style mixing的话,我们必须明确的消除这种影像,否则后续层的特征无法有效的控制图像。如果他们想要牺牲scale-specific的控制能力,他们可以简单的移除normalization,就可以去除掉水滴伪影,这还可以使得FID有着微弱的提高。现在他们提出了一个更好的替代品,移除伪影的同时,保留完全的可控性。这个就是weight demodulation。

我们继续看这个图c:

里面包含三个style block,每一个block包含modulation(Mod),convolution and normalization。

modulation可以影响着卷积层的输入特征图。所以,其实Mod和卷积是可以继续宁融合的。比方说,input先被Mod放大了3倍,然后在进行卷积,这个等价于input直接被放大了3倍的卷积核进行卷积。Modulation和卷积都是在通道维度进行操作。所以有如下公式:

W i j k ′ = s i ⋅ w i j k W'_{ijk}=s_i \cdot w_{ijk} Wijk=siwijk

接下来的norm部分也做了修改:

w i j k ′ ′ = w i j k ′ ∑ i , k w ′ i j k 2 + ϵ w''_{ijk}=\frac{w'_{ijk}}{\sqrt{\sum_{i,k}{{w'}_{ijk}^2+\epsilon}}} wijk=i,kwijk2+ϵ wijk

这里替换了对特征图做归一化,而是去卷积的参数做了一个归一化,先前有研究提出,这样会有助于GAN的训练。

至此,我们发现,Mod和norm部分的操作,其实都可以融合到卷积核上。

3.4 代码学习

class GeneratorBlock(nn.Module):
    def __init__(self, latent_dim, input_channels, filters, upsample = True, upsample_rgb = True, rgba = False):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None

        self.to_style1 = nn.Linear(latent_dim, input_channels)
        self.to_noise1 = nn.Linear(1, filters)
        self.conv1 = Conv2DMod(input_channels, filters, 3)
        
        self.to_style2 = nn.Linear(latent_dim, filters)
        self.to_noise2 = nn.Linear(1, filters)
        self.conv2 = Conv2DMod(filters, filters, 3)

        self.activation = leaky_relu()
        self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)

    def forward(self, x, prev_rgb, istyle, inoise):
        if exists(self.upsample):
            x = self.upsample(x)

        inoise = inoise[:, :x.shape[2], :x.shape[3], :]
        noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
        noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))

        style1 = self.to_style1(istyle)
        x = self.conv1(x, style1)
        x = self.activation(x + noise1)

        style2 = self.to_style2(istyle)
        x = self.conv2(x, style2)
        x = self.activation(x + noise2)

        rgb = self.to_rgb(x, prev_rgb, istyle)
        return x, rgb

可以发现,这个噪音也会经过Linear层的简单变换,然后里面加入了残差。为什么要输出rgb图像呢?这个会放在下次,或者下下次的内容。styleGAN1是需要用progressive growing的策略的,而StyleGAN2使用新的架构,解决了这种繁琐的训练方式。下次讲styleGAN2的lazy path length regularization,下下次讲这个No progressive growing。

回到代码部分,发现我们讲到的AdaIN的改进,应该在Conv2DMod模块当中:

class Conv2DMod(nn.Module):
    def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, eps = 1e-8, **kwargs):
        super().__init__()
        self.filters = out_chan
        self.demod = demod
        self.kernel = kernel
        self.stride = stride
        self.dilation = dilation
        self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
        self.eps = eps
        nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

    def _get_same_padding(self, size, kernel, dilation, stride):
        return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2

    def forward(self, x, y):
        b, c, h, w = x.shape

        w1 = y[:, None, :, None, None]
        w2 = self.weight[None, :, :, :, :]
        weights = w2 * (w1 + 1)

        if self.demod:
            d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
            weights = weights * d

        x = x.reshape(1, -1, h, w)

        _, _, *ws = weights.shape
        weights = weights.reshape(b * self.filters, *ws)

        padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
        x = F.conv2d(x, weights, padding=padding, groups=b)

        x = x.reshape(-1, self.filters, h, w)
        return x

代码剖析:

y是style code经过全连接层得到的scale参数,假设batch size是16,输入特征图的通道数为256。所以w1.shape=[16,1,256,1,1];

w2是卷积层的weight,w2.shape=[1,out_chan, in_chan, kernel, kernel]

这里为什么要为w1加1呢?说实话,我觉得加不加都无所谓,因为之前的全连接层也有bias,所以无所谓的。

torch.rsqrt就是取平方根后取倒数。weight先求平方,然后对234维度求和,那么就保留了batch维度和输出通道维度。这个运算过程和论文中的weight demodulation是一致的

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

生成专题3 的相关文章

随机推荐

  • OpenCSV web下载csv文件demo

    OpenCSV web下载csv文件demo pom xml
  • 嵌入式Linux&Android开发-LCD屏幕调试

    目录 一 简介 二 开发流程 三 硬件说明 四 电子特性 五 关注启动时序 六 关注引脚 七 屏参适配 7 1 DTS 驱动配置 7 2 屏参配置 案例一 7 3屏参配置 案例二 7 4 屏参配置 案例三 7 5 屏参配置 案例四 7 6
  • 单元测试、集成测试、系统测试、验收测试

    本文是按照开发阶段划分测试技术 单元测试 单元测试是对软件组成单元进行测试 目的是检验软件基本组成单元的正确性 测试对象是软件设计的最小单位 模块 又称为模块测试 单元测试的实质是代码测代码 测试阶段 编码后或者编码前 TDD 编码前属于测
  • 树莓派笔记4:树莓派游戏机

    这次记录比较轻松的内容 将树莓派做成 游戏主机 当然这个主机只是具备模拟器功能而已 可以模拟街机 FC等平台上的游戏 最早要在树莓派上玩模拟器游戏需要手动安装和配置不同的模拟器 而现在国外很多爱好者专门制作了定制化的系统 直接把系统烧到树莓
  • latex插图\begin{minipage}强制左移\hspace命令

    事情是这样的 我在latex中插图 上面一张图是排列整整齐齐的图片 下面一张图就是我绘制的概率密度图 在使用latex插图的时候 因为概率密度图的纵坐标是有title的 所以会显得不整齐 如下图所示 在includegraphics前面添加
  • Inkscape 捕捉图标翻译

  • Docker Portainer 安装与报错处理

    安装docker 管理器 Portainer 最近在看spring cloud alibaba的时候 觉得docker是肯定要用的 然后找了个管理的docker的东东 比较方便的查询docker的情况 直接看操作吧 root localho
  • 分布式锁之redis实现

    docker安装redis 拉取镜像 docker pull redis 6 2 6 查看镜像 启动容器并挂载目录 需要挂在的data和redis conf自行创建即可 docker run restart always d v usr l
  • python字符串的常用方法(3-2)

    目录 一 字符串find 和index 获取某个值的位置方法 二 字符串strip lstrip rstrip左右去空格方法 三 字符串的replace 替换方法 四 字符串bool集合 一 字符串find 和index 获取某个值的位置方
  • vue项目通过directives指令实现vue实现盒子的移动;vue拖拽盒子;vue移动;

    vue项目 点击拖拽盒子 移动盒子 代码可直接复制 注意需要在移动的盒子上添加 v 指令 注意采用固定定位
  • 轻量级调试器神器 - mimikatz

    昨天有朋友发了个法国佬写的神器叫 mimikatz 让我们看下 神器下载地址 mimikatz trunk zip 还有一篇用这个神器直接从 lsass exe 里获取windows处于active状态账号明文密码的文章 http pent
  • 网络与信息安全应急处置预案

    分享一下我老师大神的人工智能教程 零基础 通俗易懂 http blog csdn net jiangjunshow 也欢迎大家转载本篇文章 分享知识 造福人民 实现我们中华民族伟大复兴 为加强北海市电子政务系统的安全 管理 形成科学有效 反
  • jpa自增id(@GeneratedValue和@GenericGenerator)

    一 JPA通用策略生成器 通过annotation来映射hibernate实体的 基于annotation的hibernate主键标识为 Id 其生成规则由 GeneratedValue设定的 这里的 id和 GeneratedValue都
  • Qt应用程序嵌入浏览器的常用方法

    1 使用QAxObject嵌入微软ActiveX软件 使用QAxObject需要包含Qt模块 QT axcontainer 注 1 此方式只针对微软的组件才有效 不可以用来加载第三方的应用程序 2 获取该组件的相关的API接口文档可以采用以
  • 多线程案例【二】

    目录 定时器 标准库中的定时器 实现定时器 线程池 Java标准库的线程池 实现线程池 定时器 定时器像是一个闹钟 在一定时间之后 被唤醒并执行某个之前设定好的任务 之前学习的 join 指定超时时间 sleep 休眠指定时间 都是基于系统
  • Vue3基础

    1 setup函数 setup 函数是组件逻辑的地方 它在组件实例被创建时 初始化 props 之后调用 2 ref ref 主要是用来包装原始类型的数据 为什么要包装对象 我们知道在 JavaScript 中 原始值类型如 string
  • 决策树实例(工资预测)【机器学习算法一决策树与随机森林3】

    数据集adult data下载地址 http archive ics uci edu ml machine learning databases adult 下载后将其重命名为adult csv 打开后可看到如下样子 数据集描述如下 属性如
  • Illegal base64 character 20

    1 问题 RSA 解密报错 Illegal base64 character 20 2 分析 如果是 url 地址栏传参 只需要UrlDecode 一次 如果开发平台默认 UrlDecode 程序就不用再次 UrlDecode 否则 bas
  • 互联网摸鱼日报(2023-09-05)

    互联网摸鱼日报 2023 09 05 36氪新闻 蔚小理上半年比拼 谁拿住了不下牌桌的筹码 一杯酱香拿铁 年轻人就能爱上茅台 关于瑞幸酱香拿铁的一些疑问 为什么不直接滴酒 是科技与狠活吗 小红书关停自营电商业务 本硕加入抢单 千万外卖员 卷
  • 生成专题3

    文章转自微信公众号 机器学习炼丹术 作者 陈亦新 欢迎交流共同进步 联系方式 微信cyx645016617 学习论文 Analyzing and Improving the Image Quality of StyleGAN 文章目录 3