论文阅读:Improved Denoising Diffusion Probabilistic Models

2023-11-04

本文是对ddpm简单的修改,但是能提高ddpm的性能

论文下载地址:https://proceedings.mlr.press/v139/nichol21a.html

我们发现反向过程中可学习的方差允许一个数量级的采样,样本质量的差异可以忽略不计,这对于模型的实际部署很厉害。

  • 关于变分下界的优化

使用简单的重参数化技巧学习优化变分下界。反向过程的方差使用简单的重参数化技巧和一个混合的目标vlb函数。

这样的变化导致少的采样步数,但是发生了很小的质量改变

  • 在实际训练中,关于 μ ( x t , t ) \mu(x_t,t) μ(xt,t)的参数化的方式

  • 使用网络预测 x 0 x_0 x0,然后使用公式
    μ ˉ ( x t , x 0 ) : = a ˉ t − 1 β t 1 − α ˉ t + α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t \bar{\mu } (x_t,x_0):=\frac{\sqrt{\bar{a}_{t-1}}\beta_t}{1-\bar{\alpha} _t} +\frac{\sqrt{\alpha _t}(1-\bar{\alpha }_{t-1} ) }{1-\bar{\alpha } _t}x_t μˉ(xt,x0):=1αˉtaˉt1 βt+1αˉtαt (1αˉt1)xt可以得到均值。

  • 也可以使用网络预测 ε \varepsilon ε,然后使用

    x t = α ˉ t x 0 + ( 1 − α ˉ t ) I x_t=\sqrt{\bar{\alpha}}_tx_0+(\sqrt{1-\bar{\alpha}}_t)I xt=αˉ tx0+(1αˉ t)I

    μ ˉ ( x t , x 0 ) : = a ˉ t − 1 β t 1 − α ˉ t + α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t \bar{\mu } (x_t,x_0):=\frac{\sqrt{\bar{a}_{t-1}}\beta_t}{1-\bar{\alpha} _t} +\frac{\sqrt{\alpha _t}(1-\bar{\alpha }_{t-1} ) }{1-\bar{\alpha } _t}x_t μˉ(xt,x0):=1αˉtaˉt1 βt+1αˉtαt (1αˉt1)xt去生成均值 μ ˉ ( x t , x 0 ) \bar{\mu } (x_t,x_0) μˉ(xt,x0)

目前这种方法是最好的,尤其是结合重新加权的损失函数

L simple  = E t , x 0 , ϵ [ ∥ ϵ − ϵ θ ( x t , t ) ∥ 2 ] L_{\text {simple }}=E_{t, x_{0}, \epsilon}\left[\left\|\epsilon-\epsilon_{\theta}\left(x_{t}, t\right)\right\|^{2}\right] Lsimple =Et,x0,ϵ[ϵϵθ(xt,t)2]

原本的lvb损失
L v l b : = L 0 + L 1 + … + L T − 1 + L T L_{\mathrm{vlb}} :=L_{0}+L_{1}+\ldots+L_{T-1}+L_{T} Lvlb:=L0+L1++LT1+LT
L 0 : = − log ⁡ p θ ( x 0 ∣ x 1 ) L_{0} :=-\log p_{\theta}\left(x_{0} \mid x_{1}\right) L0:=logpθ(x0x1)

L t − 1 : = D K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) L_{t-1}:=D_{K L}(q\left(x_{t-1} \mid x_{t}, x_{0}\right)|| p_{\theta}(x_{t-1}|x_{t})) Lt1:=DKL(q(xt1xt,x0)∣∣pθ(xt1xt))
L T : = D K L ( q ( x T ∣ x 0 ) ∥ p ( x T ) ) L_{T} :=D_{K L}\left(q\left(x_{T} \mid x_{0}\right) \| p\left(x_{T}\right)\right) LT:=DKL(q(xTx0)p(xT))

对数似然的改进

为研究不同流形上的作用,在image net64*64 上训练固定的模型结构使用固定的超参数。

  • 不同的实验设置对比
方法 ho(2020) our
损失函数及参数设置 L s i m p l e L_{simple} Lsimple & σ 2 = β t \sigma^2=\beta_t σ2=βt & T = 10000 T=10000 T=10000 L h y b r i d = L simple  + λ L v l b , λ = 0.001 L_{\mathrm{hybrid}}=L_{\text {simple }}+\lambda L_{\mathrm{vlb}},\lambda = 0.001 Lhybrid=Lsimple +λLvlb,λ=0.001
训练轮数 200k
T 1000 4000
数据集 image_net64*64
实验结果 3.99 3.77

学习 Σ θ ( x t , t ) \Sigma_{\theta}\left(x_{t}, t\right) Σθ(xt,t)

  • 在无限步长的扩散的过程条件下,方差的作用远没有均值对实验结果的影响大。或者说,方差几乎不发挥作用。

  • 实验中,我们发现扩散过程前几步对整个扩散过程很重要。于是,通过使用更好的 Σ θ ( x t , t ) \Sigma_{\theta}\left(x_{t}, t\right) Σθ(xt,t)可以很大程度上提高对数释然。

  • 合理的 Σ θ ( x t , t ) \Sigma_{\theta}\left(x_{t}, t\right) Σθ(xt,t)的范围很小,对于神经网络去寻找一个合理的 Σ θ ( x t , t ) \Sigma_{\theta}\left(x_{t}, t\right) Σθ(xt,t)不是容易的,

  • 我们发现更好的参数化 Σ θ ( x t , t ) \Sigma_{\theta}\left(x_{t}, t\right) Σθ(xt,t)是差值 β ˉ t \bar{\beta}_t βˉt β t \beta_t βt在对数域。我们的模型输出 v v v,每个维度包含一个分量,然后将这个输出转换为如下方差:
    Σ θ ( x t , t ) = exp ⁡ ( v log ⁡ β t + ( 1 − v ) log ⁡ β ~ t ) \Sigma_{\theta}\left(x_{t}, t\right)=\exp \left(v \log \beta_{t}+(1-v) \log \tilde{\beta}_{t}\right) Σθ(xt,t)=exp(vlogβt+(1v)logβ~t)

  • t r i c k {\color{Red}trick } trick
    stop- gradient

改善噪声机制

线性噪声机制对于高分辨率图像很好,在分辨率小的图像上结果次优。ddpm中的前向加噪过程对采样过程没有太大的贡献。

  • 原本的加噪机制:

q ( x 1 , … , x T ∣ x 0 ) : = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(x_{1}, \dots, x_{T} | x_{0}) :=\prod_{t=1}^{T} q\left(x_{t} \mid x_{t-1}\right) q(x1,,xTx0):=t=1Tq(xtxt1)

q ( x t ∣ x t − 1 ) : = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_{t}| x_{t-1}):=\mathcal{N}\left(x_{t} ; \sqrt{1-\beta_{t}} x_{t-1}, \beta_{t} \mathbf{I}\right) q(xtxt1):=N(xt;1βt xt1,βtI)

对任何时刻的加噪样本:

q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) x t = α ˉ t x 0 + 1 − α ˉ t ϵ \begin{aligned} q\left(x_{t} \mid x_{0}\right) &=\mathcal{N}\left(x_{t} ; \sqrt{\bar{\alpha}_{t}} x_{0},\left(1-\bar{\alpha}_{t}\right) \mathbf{I}\right) \\ x_{t} &=\sqrt{\bar{\alpha}_{t}} x_{0}+\sqrt{1-\bar{\alpha}_{t}} \epsilon \end{aligned} q(xtx0)xt=N(xt;αˉt x0,(1αˉt)I)=αˉt x0+1αˉt ϵ

  • 加噪机制的改进:

    α ˉ t = f ( t ) f ( 0 ) , f ( t ) = cos ⁡ ( t / T + s 1 + s ⋅ π 2 ) 2 \bar{\alpha}_{t}=\frac{f(t)}{f(0)}, \quad f(t)=\cos \left(\frac{t / T+s}{1+s} \cdot \frac{\pi}{2}\right)^{2} αˉt=f(0)f(t),f(t)=cos(1+st/T+s2π)2
    其中:
    β t = 1 − α ˉ t α ˉ t − 1 \beta_{t}=1-\frac{\bar{\alpha}_{t}}{\bar{\alpha}_{t-1}} βt=1αˉt1αˉt

    • 为了防止开始时 β t , t = 0 \beta_t,t=0 βt,t=0太小,使得网络在预测噪声的时候很困难。我们让 β 0 \sqrt{\beta_0} β0 =1/127.5=0.008。
    • 在实际中使用 c o s 2 cos^2 cos2
    • 不同的 L v l b L_{vlb} Lvlb有不同边际。采样t均匀的会引起没必要的噪声在 L v l b L_{vlb} Lvlb中。我们使用重要性采样:
      L v l b = E t ∼ p t [ L t p t ] , where  p t ∝ E [ L t 2 ]  and  ∑ p t = 1 L_{\mathrm{vlb}}=E_{t \sim p_{t}}\left[\frac{L_{t}}{p_{t}}\right] \text {, where } p_{t} \propto \sqrt{E\left[L_{t}^{2}\right]} \text { and } \sum p_{t}=1 Lvlb=Etpt[ptLt], where ptE[Lt2]  and pt=1,由于
      E [ L t 2 ] E\left[L_{t}^{2}\right] E[Lt2]是未知的,可能在整个训练过程中发生变化。于是,我们的每个损失保持10个值的历史,并在训练期间动态更新。训练开始,我们均匀的采样十个样本, t ∈ [ 0 , T − 1 ] t \in[0, T-1] t[0,T1]

总结

  • 损失函数的改进
  • 噪声机制使用cos
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

论文阅读:Improved Denoising Diffusion Probabilistic Models 的相关文章

  • nps auth_key未授权访问漏洞

    一 漏洞简介 nps是一款轻量级 高性能 功能强大的内网穿透代理服务器 目前支持tcp udp流量转发 可支持任何tcp udp上层协议 访问内网网站 本地支付接口调试 ssh访问 远程桌面 内网dns解析等等 此外还支持内网http代理
  • 微信常见错误码及解决方案

    40001 获取access token时AppSecret错误 或者access token无效 这个错误代码表示您的访问令牌 access token 已经过期或者无效 需要重新获取 40003 openid错误 openid是微信公众
  • devServer-host解析

    devServer的其他配置 host解析 host设置主机地址 默认值是localhost 如果希望其他地方也可以访问 可以设置为 0 0 0 0 localhost 和 0 0 0 0 的区别 localhost 本质上是一个域名 通常
  • C语言实现队列

    文章目录 一 什么是队列 二 队列的实现 2 1 队列的结构 2 2 队列的几个功能 2 2 1 初始化队列 2 2 2 队列判空 2 2 3 队尾入队列 2 2 4 队头出队列 2 2 5 获取队列头部元素 2 2 6 获取队列队尾元素

随机推荐

  • QCM2290平台XBL阶段I2C使能并点亮LED

    描述 qcm2290平台上 在低压充电阶段 XBL 需要驱动LED灯芯片 提供接口供充电模块调用 显示充电指示灯功能 说明 LED控制芯片是I2C接口 我只需要提供接口即可 我这边实现了在开机时led灯闪烁 在充电相关PmicLib目录下添
  • 在webpack的less中使用绝对路径import

    假设项目目录结构如下 webpack中 resolve modulesDirectories path join dirname node modules path join dirname src 在a less中写上 import st
  • 干掉鲁大师监控,Windows免费监控软件

    大家好 今天我找到了一款在电脑上可以实时在任务栏显示实时网速的免费开源的小插件 非常的好用 而且呢它竟然还能实时的显示显卡和CPU的温度和占用一个百分比 让你对你的电脑性能了如指掌 一点也不逊色于收费的鲁大师桌面监控程序 拿到我的电脑上面去
  • 软件测试用例所有疑问,只需这篇就够了

    1 测试用例是什么 答 测试用例的设计就是如何覆盖所有软件表现出来的状态 即在满足输入 输出的一组条件下 软件运行是一系列有次序的 受控制的状态变化过程 2 设计用例是否有必要 答 如果不记下来 很可能到执行的时候测试点就遗漏了 另外也不便
  • 手撕源码之代码手写mvc

    1 首先附上代码地址 https gitee com cqut lin hand tear source code 实现思路 Spring主要也是通过DispatcherServlet实现了Servlet这个接口 又叫前端控制器 来自前端的
  • 51单片机的波特率

    最近使用51单片机的时候 设置串口的波特率 需要多种 固先记下来 晶振更改的时候可以通过excel中的改动来调整 excel在126中email的网盘中 51单片机的波特率 et 下面列表是基于定时器2的方式2 自动重装的方式 晶振 11
  • 安卓Activity跳转的几种方式

    本文转载于http blog sina com cn s blog 5140274d0100q4j7 html 本人仅作为学习交流之用 请大家尊重原创 第一种方式 用action来跳转 使用Action跳转 如果有一个程序的 Android
  • Java从小白到大牛第1篇 Java基础-关东升-专题视频课程

    Java从小白到大牛第1篇 Java基础 3042人已学习 课程介绍 本视频是智捷课堂推出的一套 Java语言学习立体教程 的视频第一部分 读者以及观看群是初级小白 通过本视频的学习能够成为Java大牛 本主要内容包括 Java语法基础 J
  • 波特率_通信基本概念扫盲(波特率与带宽的关系)

    在工作和学习中 通常会遇到一些比较基础的技术性问题 比如波特率为B的信号 它的频谱宽度是多少 说这个问题基础 但答案并不简单 今天分享的一些基本概念 就是希望能解答上述的问题 1 信号的快慢 表示信号快慢通常会用速率相关的参数 比如 码元速
  • 计算机主机内部结构连接,电脑主机内部结构图详解

    电脑主机内部结构分为多种硬件组合而成 硬件可以理解为看得到摸得着的东西 计算机硬件通常包括主板 CPU 内存 硬盘 光驱 电源 以及其他输入输出控制器和接口 如 USB 控制器 显卡 网卡 声卡等等 位于主机箱内的通常称为内设 而位于主机箱
  • android设备SD卡文件扫描与同步(暂备份)

    package com owo contentresolvermedia import java io File import java util ArrayList import android app Activity import a
  • 同一页面、不同页面监听localStorage变化

    当同源页面的某个页面修改了localStorage 其余的同源页面只要注册了storage事件 就会触发 所以 localStorage 的例子运行需要如下条件 同一浏览器打开了两个同源页面 其中一个网页修改了 localStorage 另
  • 简单易懂的隐马尔可夫模型(HMM)讲解

    学习目标 了解什么是马尔科夫链 知道什么是HMM模型 知道前向后向算法评估观察序列概率 知道维特比算法解码隐藏状态序列 了解鲍姆 韦尔奇算法 知道HMM模型API的使用 一 马尔科夫链 在机器学习算法中 马尔可夫链 Markov chain
  • Top-1错误率、Top-5错误率等常见的模型算法评估指标解析

    Top 1 错误率 指预测输出的概率最高的类别与人工标注的类别相符的准确率 就是你预测的label取最后概率向量里面最大的那一个作为预测结果 如过你的预测结果中概率最大的那个分类正确 则预测正确 否则预测错误 比如预测100张图像的类别 每
  • Spring Cloud Alibaba和Spring Cloud的区别

    目录 Spring Cloud Netflix 和 Spring Cloud 是什么关系 为什么有了Spring Cloud又出来个Spring Cloud Alibaba呢 Spring Cloud Alibaba都有哪些功能呢 Clou
  • JAVA——注解和反射

    注解的理解 引用b乎大佬的比喻 注解就像一张标签 给人贴标签是一种行为 会使一个人身上 的特性只有一部分被放大出来 但是换个角度 标签就是对事物行为的某些角度的评价与解释 从代码的角度上看 注解就是对于代码中需要拥有某些特别意义的功能的部分
  • 计算个人所得税

    输入一个职工的月薪salary 输出应交的个人所得税tax 保留2位小数 tax rate salary 850 当 salary lt 850 时 rate 0 0 当 850 lt salary lt 1350 时 rate 0 05
  • centos 6 yum源不可用安装报YumRepo Error: All mirror URLs are not using ftp, http[s] or file

    项目场景 centos6 5 使用yum安装资源时 报如下错误 1 YumRepo Error All mirror URLs are not using ftp http s or file 解决方案 修改 etc yum repos d
  • Spring Data 与MongoDB 集成四:操作篇(查询)

    本文转载至 http blog csdn net congcong68 article details 47183209 一 简介 spring Data MongoDB提供了org springframework data mongodb
  • 论文阅读:Improved Denoising Diffusion Probabilistic Models

    本文是对ddpm简单的修改 但是能提高ddpm的性能 论文下载地址 https proceedings mlr press v139 nichol21a html 我们发现反向过程中可学习的方差允许一个数量级的采样 样本质量的差异可以忽略不