diffusion model(三)—— classifier guided diffusion model

2023-11-01

系列阅读

背景

对于一般的DM(如DDPM, DDIM)的采样过程是直接从一个噪声分布,通过不断采样来生成图片。但这个方法生成的图片类别是随机的,如何生成特定类别的图片呢?这就是classifier guide需要解决的问题。

方法大意

为了实现带类别标签 y y y的DM的推导,进行了以下定义
q ^ ( x 0 ) : = q ( x 0 ) q ^ ( y ∣ x 0 ) : = Know labels per sample q ^ ( x t + 1 ∣ x t , y ) : = q ( x t + 1 ∣ x t ) q ^ ( x 1 : T ∣ x 0 , y ) : = ∏ t = 1 T q ^ ( x t ∣ x t − 1 , y ) (1) \begin{aligned} \hat{q}(x_0) &:= q(x_0) \\ \hat{q}(y|x_0) &:= \text{Know labels per sample} \\ \hat{q}(x_{t+1}|x_{t}, y) &:= q(x_{t+1}|x_t) \\ \hat{q}(x_{1:T}|x_0, y)&:= \prod \limits_{t=1}^T\hat{q}(x_t|x_{t-1}, y) \\ \end{aligned} \tag{1} q^(x0)q^(yx0)q^(xt+1xt,y)q^(x1:Tx0,y):=q(x0):=Know labels per sample:=q(xt+1xt):=t=1Tq^(xtxt1,y)(1)
虽然上式定义了以 y y y为条件的噪声过程 q ^ \hat{q} q^,但我们还可以证明当 q ^ \hat{q} q^不以 y y y为条件时的行为与 q q q完全相同,即
q ^ ( x t + 1 ∣ x t ) = ∫ y q ^ ( x t + 1 , y ∣ x t ) d y = ∫ y q ^ ( x t + 1 ∣ x t , y ) q ^ ( y ∣ x t ) d y = ∫ y q ( x t + 1 ∣ x t ) q ^ ( y ∣ x t ) d y = q ( x t + 1 ∣ x t ) ∫ y q ^ ( y ∣ x t ) d y = q ( x t + 1 ∣ x t ) = q ^ ( x t + 1 ∣ x t , y ) (2) \begin{aligned} \hat{q}(x_{t+1}|x_t) &= \int_y \hat{q}(x_{t+1}, y| x_t)dy \\ &= \int_y \hat{q}(x_{t+1}|x_t, y)\hat{q}(y|x_t)dy \\ &= \int_y q(x_{t+1}|x_t)\hat{q}(y|x_t)dy \\ &= q(x_{t+1}|x_t) \int_y \hat{q}(y|x_t)dy \\ &= q(x_{t+1}|x_t) \\ &= \hat{q}(x_{t+1}|x_t, y) \\ \end{aligned}\tag{2} q^(xt+1xt)=yq^(xt+1,yxt)dy=yq^(xt+1xt,y)q^(yxt)dy=yq(xt+1xt)q^(yxt)dy=q(xt+1xt)yq^(yxt)dy=q(xt+1xt)=q^(xt+1xt,y)(2)
同样的思路:
q ^ ( x 1 : T ∣ x 0 ) = ∫ y q ^ ( x 1 : T , y ∣ x 0 ) d y = ∫ y q ^ ( x 1 : T ∣ y , x 0 ) q ( y ∣ x 0 ) d y = ∫ y ∏ t = 1 T q ^ ( x t ∣ x t − 1 , y ) ⏟ q ( x t ∣ x t − 1 ) q ( y ∣ x 0 ) d y = ∏ t = 1 T q ( x t ∣ x t − 1 ) ⏟ q ( x 1 : T ∣ x 0 ) ∫ y q ( y ∣ x 0 ) d y ⏟ = 1 = q ( x 1 : T ∣ x 0 ) (3) \begin{aligned} \hat{q}(x_{1:T}|x_0) &= \int_y \hat{q}(x_{1:T}, y|x_0) d_y \\ &= \int_y \hat{q}(x_{1:T}|y, x_0)q(y| x_0) d_y \\ &= \int_y \prod \limits_{t=1}^T \underbrace{ \hat{q}(x_t|x_{t-1}, y)}_{q(x_t|x_t-1)} q(y| x_0) d_y \\ &= \underbrace{\prod \limits_{t=1}^Tq(x_t|x_{t-1})}_{q(x_{1:T}|x_0)} \underbrace{\int_y q(y| x_0)d_y}_{=1} \\ &= q(x_{1:T}|x_0) \end{aligned}\tag{3} q^(x1:Tx0)=yq^(x1:T,yx0)dy=yq^(x1:Ty,x0)q(yx0)dy=yt=1Tq(xtxt1) q^(xtxt1,y)q(yx0)dy=q(x1:Tx0) t=1Tq(xtxt1)=1 yq(yx0)dy=q(x1:Tx0)(3)
根据上式同样可以推导出
q ^ ( x t ) = ∫ x 0 : t − 1 q ^ ( x 0 , ⋯   , x t ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ^ ( x 0 ) ⏟ q ( x 0 ) q ^ ( x 1 , ⋯   , x t ∣ x 0 ) ⏟ q ( x 1 : T ∣ x 0 ) d x 0 : t − 1 = q ( x t ) (4) \begin{aligned} \hat{q}(x_t) &= \int_{x_{0:t - 1}} \hat{q}(x_0, \cdots, x_t)dx_{0:t-1} \\ &= \int_{x_{0:t - 1}} \underbrace{\hat{q}(x_0)}_{q(x_0)} \underbrace{\hat{q}(x_1, \cdots, x_t|x_0)}_{q(x_{1:T}|x_0)}dx_{0:t-1} \\ &= q(x_t) \end{aligned} \tag{4} q^(xt)=x0:t1q^(x0,,xt)dx0:t1=x0:t1q(x0) q^(x0)q(x1:Tx0) q^(x1,,xtx0)dx0:t1=q(xt)(4)
由上述推导可见带条件的DM的前向过程与DDPM完全相同。并且根据贝叶斯公式,不带逆向过程也满足
p ^ ( x t ∣ x t + 1 ) = p ( x t ∣ x t + 1 ) (5) \hat{p}(x_t|x_{t+1}) = p(x_t|x_{t+1}) \tag{5} p^(xtxt+1)=p(xtxt+1)(5)
与此同时我们可以证明分类分布 q ^ ( y ∣ x t ) \hat{q}(y|x_t) q^(yxt)只和当前时刻的输入 x t x_t xt有关,与 x t + 1 x_{t+1} xt+1无关
q ^ ( y ∣ x t , x t + 1 ) = q ^ ( x t + 1 ∣ x t , y ) ⏞ q ^ ( x t + 1 ∣ x t ) q ^ ( y ∣ x t ) q ^ ( x t + 1 ∣ x t ) = q ^ ( y ∣ x t ) (6) \begin{aligned} \hat{q}(y|x_t, x_{t+1}) & = \frac{ \overbrace{ \hat{q}(x_{t+1}|x_t, y)}^{\hat{q}(x_{t+1}|x_t)} \hat{q}(y|x_t) } {\hat{q}(x_{t+1}|x_t )} \\ & = \hat{q}(y|x_t) \end{aligned} \tag{6} q^(yxt,xt+1)=q^(xt+1xt)q^(xt+1xt,y) q^(xt+1xt)q^(yxt)=q^(yxt)(6)

基于条件的去噪过程

将带类别信息的去噪过程定义为 p ^ ( x t ∣ x t + 1 , y ) \hat{p}(x_t|x_{t+1}, y) p^(xtxt+1,y)

p ^ ( x t ∣ x t + 1 , y ) = p ^ ( x t , x t + 1 , y ) p ^ ( y ∣ x t + 1 ) p ^ ( x t + 1 ) = p ^ ( x t , y ∣ x t + 1 ) p ^ ( y ∣ x t + 1 ) = p ^ ( y ∣ x t , x t + 1 ) ⏞ p ^ ( y ∣ x t ) p ^ ( x t ∣ x t + 1 ) ⏞ p ( x t ∣ x t + 1 ) p ^ ( y ∣ x t + 1 ) = p ^ ( y ∣ x t ) p ( x t ∣ x t + 1 ) p ^ ( y ∣ x t + 1 ) (7) \begin{aligned} \hat{p} (x_t| x_{t+1}, y) & = \frac{\hat{p} (x_t, x_{t+1}, y) }{\hat{p} (y|x_{t+1}) \hat{p} (x_{t+1}) } \\ & = \frac{\hat{p} (x_t, y | x_{t+1}) }{\hat{p} (y|x_{t+1}) } \\ & = \frac{\overbrace{\hat{p} (y|x_t, x_{t+1})}^{\hat{p}(y|x_t)} \overbrace{\hat{p}(x_t | x_{t+1})}^{p(x_t|x_{t+1})} }{\hat{p} (y|x_{t+1}) } \\ & = \frac{\hat{p} (y|x_t) p(x_t | x_{t+1}) }{\hat{p} (y|x_{t+1}) } \end{aligned} \tag{7} p^(xtxt+1,y)=p^(yxt+1)p^(xt+1)p^(xt,xt+1,y)=p^(yxt+1)p^(xt,yxt+1)=p^(yxt+1)p^(yxt,xt+1) p^(yxt)p^(xtxt+1) p(xtxt+1)=p^(yxt+1)p^(yxt)p(xtxt+1)(7)
由于 x t + 1 x_{t+1} xt+1是已知的, p ^ ( y ∣ x t + 1 ) \hat{p} (y|x_{t+1}) p^(yxt+1)这个概率分布与 x t x_t xt无关,可以将 p ^ ( y ∣ x t + 1 ) \hat{p} (y|x_{t+1}) p^(yxt+1)视为常数 Z Z Z。此时上式可以表述为
p ^ ( x t ∣ x t + 1 , y ) = Z p ^ ( y ∣ x t ) p ( x t ∣ x t + 1 ) (8) \hat{p} (x_t| x_{t+1}, y) = Z \hat{p} (y|x_t) p(x_t | x_{t+1}) \tag{8} p^(xtxt+1,y)=Zp^(yxt)p(xtxt+1)(8)
上式的右边第二项 p ^ ( y ∣ x t ) \hat{p} (y|x_t) p^(yxt)很容易得到,我们可以根据 x t , y x_t, y xt,y的pair对训练一个分类模型 p ^ ϕ ( y ∣ x t ) \hat{p}_\phi(y|x_t) p^ϕ(yxt)

上式的右边第三项 p ( x t ∣ x t + 1 ) p(x_t | x_{t+1}) p(xtxt+1)在DDPM中也能够通过一个neural network进行估计 p ( x t ∣ x t + 1 ) ≈ p θ ( x t ∣ x t + 1 ) p(x_t | x_{t+1}) \approx p_\theta(x_t|x_{t+1}) p(xtxt+1)pθ(xtxt+1)

故采样分布
p ^ ( x t ∣ x t + 1 , y ) ≈ p ^ ϕ , θ ( x t ∣ x t + 1 , y ) = Z p ^ ϕ ( y ∣ x t ) p θ ( x t ∣ x t + 1 ) (9) \begin{aligned} \hat{p} (x_t| x_{t+1}, y) &\approx \hat{p}_{\phi, \theta} (x_t| x_{t+1}, y) \\ &= Z \hat{p}_{\phi} (y|x_t) p_{\theta}(x_t | x_{t+1}) \end{aligned} \tag{9} p^(xtxt+1,y)p^ϕ,θ(xtxt+1,y)=Zp^ϕ(yxt)pθ(xtxt+1)(9)
下面来看有了上面这个式子如何进行采样

直接对上面的式子进行采样是很难解决的。论文参考文献1将上式近似为perturbed Gaussian distribution。

根据前文DM的推导可知 p θ ( x t ∣ x t + 1 ) = N ( μ , Σ ) = 1 2 π Σ exp ⁡ ( − ( x − μ ) 2 2 Σ ) p_{\theta}(x_t | x_{t+1}) = \mathcal{N}(\mu, \Sigma)=\frac{1}{\sqrt{2\pi} \sqrt{\Sigma} } \exp \left ({- \frac{(x - \mu)^2}{2\Sigma}} \right) pθ(xtxt+1)=N(μ,Σ)=2π Σ 1exp((xμ)2) ,对其取对数
log ⁡ p θ ( x t ∣ x t + 1 ) = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + C (10) \log p_{\theta}(x_t|x_{t+1}) = - \frac{1}{2} (x_t - \mu)^T \Sigma^{-1} (x_t - \mu) + C \tag{10} logpθ(xtxt+1)=21(xtμ)TΣ1(xtμ)+C(10)
对于 log ⁡ p ^ ϕ ( y ∣ x t ) \log \hat{p}_{\phi} (y|x_t) logp^ϕ(yxt) 作者假设其curvature比 Σ − 1 \Sigma^{-1} Σ1低。这个假设是合理的,对于当diffusion steps足够大时, ∥ Σ ∥ → 0 \parallel \Sigma \parallel \rightarrow 0 Σ∥→0。在该情况下,对 log ⁡ p ^ ϕ ( y ∣ x t ) \log\hat{p}_{\phi} (y|x_t) logp^ϕ(yxt) x t = μ x_t = \mu xt=μ处进行泰勒展开
log ⁡ p ^ ϕ ( y ∣ x t ) ≈ log ⁡ p ^ ϕ ( y ∣ x t ) ∣ x t = μ + ( x t − μ ) ∇ x t log ⁡ p ϕ ( y ∣ x t ) ∣ x t = μ = ( x t − μ ) g + C 1 where:  g = ∇ x t log ⁡ p ϕ ( y ∣ x t ) ∣ x t = μ , C 1  is a contant. (11) \begin{aligned} \log \hat{p}_{\phi} (y|x_t) & \approx \log \hat{p}_{\phi} (y|x_t) | _{x_t = \mu} + (x_t - \mu) \nabla_{x_t} \log p_{\phi} (y|x_t)|_{x_t = \mu} \\ &= (x_t - \mu) g + C_1 \\ \text{where: } g &= \nabla_{x_t} \log p_{\phi} (y|x_t)|_{x_t = \mu}, C_1\text{ is a contant.} \end{aligned} \tag{11} logp^ϕ(yxt)where: glogp^ϕ(yxt)xt=μ+(xtμ)xtlogpϕ(yxt)xt=μ=(xtμ)g+C1=xtlogpϕ(yxt)xt=μ,C1 is a contant.(11)

log ⁡ ( p ^ ϕ ( y ∣ x t ) p θ ( x t ∣ x t + 1 ) ) = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) g + C 2 = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + C 3 = log ⁡ p ( z ) + C 4 , z ∼ N ( μ + Σ g , Σ ) (12) \begin{aligned} \log (\hat{p}_{\phi} (y|x_t) p_{\theta}(x_t | x_{t+1})) & = - \frac{1}{2} (x_t - \mu)^T \Sigma^{-1} (x_t - \mu) + (x_t - \mu) g + C_2 \\ & = - \frac{1}{2} (x_t - \mu - \Sigma g)^T \Sigma^{-1} (x_t - \mu- \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ & = - \frac{1}{2} (x_t - \mu - \Sigma g)^T \Sigma^{-1} (x_t - \mu- \Sigma g) + C_3 \\ & = \log p(z) + C_4, z \sim \mathcal{N}(\mu + \Sigma g, \Sigma) \end{aligned} \tag{12} log(p^ϕ(yxt)pθ(xtxt+1))=21(xtμ)TΣ1(xtμ)+(xtμ)g+C2=21(xtμΣg)TΣ1(xtμΣg)+21gTΣg+C2=21(xtμΣg)TΣ1(xtμΣg)+C3=logp(z)+C4,zN(μ+Σg,Σ)(12)

(附录给出了验证性证明)

通过上述推导,我们得到了带类别条件的采样过程也可以用高斯分布来近似,只是均值需要加上 Σ g \Sigma g Σg。具体的算法如下
在这里插入图片描述

代码实现

p_mean_var_ddpm是DDPM对高斯分布均值、方差的计算函数

p_mean_var_ddpm_with_classifier是引入类别控制后的对高斯分布均值、方差的计算函数

有了均值方差就可以进行采样了

def p_mean_var_ddpm(self, noise_model, x, t):
    """
    Math:
    \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} x_t -
        \frac{1 - \alpha_t }{\sqrt{\alpha_t}\sqrt{1 - \overline{\alpha}_t}}f_\theta(x_t, t) \tag{30}
    """
    betas_t = extract(self.betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        self.sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)
    model_mean_t = sqrt_recip_alphas_t * (
        x - betas_t * noise_model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = extract(self.posterior_variance, t, x.shape)
    return model_mean_t, posterior_variance_t

  
def p_mean_var_ddpm_with_classifier(classifier, noise_model, x, t, y=None, cfs=1):
    def cond_fn(x: torch.Tensor, t: torch.Tensor, y: torch.Tensor): 
        assert y is not None
        with torch.enable_grad():
            x_in = x.detach().requires_grad_(True)
            logits = classifier(x_in, t)
            log_probs = F.log_softmax(logits, dim=-1)
            selected = log_probs[range(len(logits)), y.view(-1)]
            return torch.autograd.grad(selected.sum(), x_in)[0].float()   # gradient descend
    grad = cond_fn(x_temp, t, y=y) * cfs 
    model_mean_t, posterior_variance_t = p_mean_var_ddpm(noise_model, x, t)
    new_mean = model_mean_t + posterior_variance_t * grad
    return new_mean, posterior_variance_t
DDIM 中基于条件的去噪过程

上述条件抽样推导仅对随机扩散采样过程有效,不能应用于DDIM2等确定性采样方法(因为DDIM中设定了方差为0,故无法推导出式19)。为此,作者在研究中采用score-based的思路,参考了Song等人[^ 3]的方法,并利用了扩散模型和score matching之间的联系3

首先根据贝叶斯公式
p ( x t ∣ y ) = p ( y ∣ x t ) p ( x t ) p ( y ) ⇒ log ⁡ p ( x t ∣ y ) = log ⁡ p ( y ∣ x t ) + log ⁡ p ( x t ) − log ⁡ p ( y ) ⇒ 对 x t 求导 ∇ x t log ⁡ p ( x t ∣ y ) = ∇ x t log ⁡ p ( y ∣ x t ) + ∇ x t log ⁡ p ( x t ) − ∇ x t log ⁡ p ( y ) ⏟ = 0 ⇒ ∇ x t log ⁡ p ( x t ∣ y ) = ∇ x t log ⁡ p ( y ∣ x t ) + ∇ x t log ⁡ p ( x t ) (13) \begin{aligned} p (x_t| y) & = \frac{p (y|x_t) p(x_t) }{p (y) } \\ \Rightarrow \log{p (x_t| y) } &= \log{p (y|x_t)} + \log{p(x_t)} - \log{p (y) } \\ \stackrel{对x_t求导} \Rightarrow \nabla_{x_t}\log{p (x_t|y)} &= \nabla_{x_t}\log{p (y|x_t)} + \nabla_{x_t}\log{p(x_t)} - \underbrace{\nabla_{x_t}\log{p(y) }}_{=0} \\ \Rightarrow \nabla_{x_t}\log{p(x_t| y)} &= \nabla_{x_t}\log{p(y|x_t)} + \nabla_{x_t}\log{p(x_t)} \\ \end{aligned} \tag{13} p(xty)logp(xty)xt求导xtlogp(xty)xtlogp(xty)=p(y)p(yxt)p(xt)=logp(yxt)+logp(xt)logp(y)=xtlogp(yxt)+xtlogp(xt)=0 xtlogp(y)=xtlogp(yxt)+xtlogp(xt)(13)
具体来说,如果我们有一个模型 ϵ θ ( x t ) \epsilon_\theta(x_t) ϵθ(xt)来预测添加到样本中的噪声,那么可以利用它来推导出一个score function:
∇ x t log ⁡ p θ ( x t ) = − 1 1 − α ‾ t ϵ θ ( x t ) (14) \nabla_{x_t} \log p_\theta (x_t) = - \frac{1}{\sqrt{1 - \overline{\alpha}_t}} \epsilon_\theta(x_t) \tag{14} xtlogpθ(xt)=1αt 1ϵθ(xt)(14)
代入式(20)得
∇ x t log ⁡ p ( x t ∣ y ) = ∇ x t log ⁡ p ( y ∣ x t ) − 1 1 − α ‾ t ϵ θ ( x t ) ⇒ 1 − α ‾ t ∇ x t log ⁡ p ( x t ∣ y ) = 1 − α ‾ t ∇ x t log ⁡ p ( y ∣ x t ) − ϵ θ ( x t ) (15) \begin{aligned} \nabla_{x_t}\log{p(x_t| y)} &= \nabla_{x_t}\log{p(y|x_t)} - \frac{1}{\sqrt{1 - \overline{\alpha}_t}} \epsilon_\theta(x_t) \\ \Rightarrow \sqrt{1 - \overline{\alpha}_t} \nabla_{x_t}\log{p(x_t| y)} &= \sqrt{1 - \overline{\alpha}_t} \nabla_{x_t}\log{p(y|x_t)} - \epsilon_\theta(x_t) \end{aligned} \tag{15} xtlogp(xty)1αt xtlogp(xty)=xtlogp(yxt)1αt 1ϵθ(xt)=1αt xtlogp(yxt)ϵθ(xt)(15)
定义在条件 y y y下的估计噪声 ϵ ^ ( x t ∣ y ) \hat{\epsilon}(x_t|y) ϵ^(xty)为:
ϵ ^ ( x t ∣ y ) : = ϵ θ ( x t ) − 1 − α ‾ t ∇ x t log ⁡ p ϕ ( y ∣ x t ) (16) \hat{\epsilon}(x_t|y) := \epsilon_\theta(x_t) - \sqrt{1 - \overline{\alpha}_t}\nabla_{x_t} \log{p_\phi(y|x_t)} \tag{16} ϵ^(xty):=ϵθ(xt)1αt xtlogpϕ(yxt)(16)
只需将DDIM中的$ \epsilon_\theta(x_t) 替换为 替换为 替换为\hat{\epsilon}(x_t|y)$就得到了基于条件的去噪过程。

在这里插入图片描述

代码上也很直观

def p_sample_ddim(self, model, x, t):
    """
    x_{t-1} &=  \sqrt{\overline{\alpha}_{t-1}} \frac{x_t - \sqrt{1 - \overline{\alpha}_{t}}\boldsymbol{\epsilon}_\theta(x_t, t)}
        {\sqrt{\overline{\alpha}_{t}}} +  \sqrt{1 - \overline{\alpha}_{t-1} } \boldsymbol{\epsilon}_\theta(x_t, t)
    """
    sqrt_alphas_cumprod_prev_t = extract(self.sqrt_alphas_cumprod_prev, t, x.shape) 
    sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_one_minus_alphas_cumprod_prev_t = extract(self.sqrt_one_minus_alphas_cumprod_prev, t, x.shape) 
    sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x.shape) 
    pred_noise = model(x, t)
    pred_x0 = sqrt_alphas_cumprod_prev_t * (x - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t
    x0_direction = sqrt_one_minus_alphas_cumprod_prev_t * pred_noise 
    return pred_x0 + x0_direction
  
  
def p_sample_with_classifier(self, model, x, t, t_index, y=None, **kwargs):
    if y is None:
        return self.p_sample_ddim(model, x, t, t_index=t_index)
    cfs = kwargs.get("cfs", 1) 
    sqrt_alphas_cumprod_prev_t = extract(self.sqrt_alphas_cumprod_prev, t, x.shape) 
    sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_one_minus_alphas_cumprod_prev_t = extract(self.sqrt_one_minus_alphas_cumprod_prev, t, x.shape) 
    sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x.shape) 
    pred_noise = model(x, t)
    score = self.cond_fn(x, t, y=y) * cfs
    pred_noise = pred_noise - sqrt_one_minus_alphas_cumprod_t * score  # update noise 
    pred_x0 = sqrt_alphas_cumprod_prev_t * (x - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t
    x0_direction = sqrt_one_minus_alphas_cumprod_prev_t * pred_noise 
    return pred_x0 + x0_direction

一些细节

classifier的训练

classifier的训练与扩散模型的训练可以是独立的。在训练classifier的时候可以噪声预测模型(Unet)的encode部分作为主干,在后面接了一个分类层。并且需要与相应的扩散模型相同的噪声分布对classifier进行训练。训练数据集如 [ ( x 1 t , t , y 1 ) , ( x 2 t , t , y 2 ) , . . . , ( x N t , t , y N ) ] [(x_1^t,t, y_1), (x_2^t,t, y_2), ..., (x_N^t,t, y_N)] [(x1t,t,y1),(x2t,t,y2),...,(xNt,t,yN)] t t t是对时间步的采样, x t x^t xt x x x在时间步 t t t的输出。训练完成后,采用上面的算法集成到采样过程中。

gradient score的作用

在上面的采样算法我们看到有一个gradient scale s s s来对梯度进行拉伸。

实验视角

一般来说当 s = 1 s=1 s=1时,大约能保证生成的图片50%是想要的类别4,随着 s s s的增大,这个比例也能够增加。如下图,当 s s s增加到10,此时生成的图片都是期望的类别。因此 s s s也称之为guidance scale。
在这里插入图片描述

其实理解这个scale还有另一个视角

s ∇ x t log ⁡ ( p ϕ ( y ∣ x t ) ) = ∇ x t log ⁡ ( p ϕ ( y ∣ x t ) s ) s\nabla_{x_t} \log (p_\phi(y|x_t)) = \nabla_{x_t} \log (p_\phi(y|x_t)^s) sxtlog(pϕ(yxt))=xtlog(pϕ(yxt)s),当 s > 1 s>1 s>1他相当于对分布 p ϕ ( y ∣ x t ) p_\phi(y|x_t) pϕ(yxt)进行了一个指数拉升,从而带来更大的梯度更新收益。

根据DM的采样过程,当没有classifier guided时,在时刻 t t t,的采样过程应当是
x t − 1 = μ θ ( x t , t ) + σ ( t ) ϵ , 其中 ϵ ∈ N ( ϵ ; 0 , I ) = 1 α t ( x t − 1 − α t 1 − α ‾ t ϵ θ ( x t , t ) ) ⏟ μ θ ( x t , t ) + σ ( t ) ϵ (17) \begin{aligned} x_{t-1} &= \mu_{\theta}(x_t, t) + \sigma(t) \epsilon,其中 \epsilon \in \mathcal{N}(\epsilon; 0, \textbf{I}) \\ & = \underbrace{\frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1 - \alpha_t }{\sqrt{1 - \overline{\alpha}_t}}\epsilon_\theta(x_t, t))}_{\mu_\theta(x_t, t)} + \sigma(t) \epsilon \end{aligned} \tag{17} xt1=μθ(xt,t)+σ(t)ϵ,其中ϵN(ϵ;0,I)=μθ(xt,t) αt 1(xt1αt 1αtϵθ(xt,t))+σ(t)ϵ(17)
当加了classifier guided相当于将 μ θ ( x t , t ) \mu_{\theta}(x_t, t) μθ(xt,t)向预测类别为 y y y的方向更新了一小步。 s s s是控制更新的幅值。
x t − 1 = μ θ ( x t , t ) + s ∇ x t log ⁡ p ϕ ( y ∣ x t ) ∣ x t = μ θ ( x t , t ) + σ ( t ) ϵ , 其中 ϵ ∈ N ( ϵ ; 0 , I ) \begin{align} x_{t-1} &=& \mu_{\theta}(x_t, t) + s\nabla_{x_t} \log p_{\phi} (y|x_t)|_{x_t = \mu_{\theta}(x_t, t)} + \sigma(t) \epsilon,其中 \epsilon \in \mathcal{N}(\epsilon; 0, \textbf{I}) \tag{18} \end{align} xt1=μθ(xt,t)+sxtlogpϕ(yxt)xt=μθ(xt,t)+σ(t)ϵ,其中ϵN(ϵ;0,I)(18)

参考文献

附录

式12推导验证
− 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T − μ T − g T Σ T ) Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T − μ T − g T Σ T ) Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T Σ − 1 − μ T Σ − 1 − g T Σ T Σ − 1 ⏟ g T ) ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T Σ − 1 ( x t − μ − Σ g ) − μ T Σ − 1 ( x t − μ − Σ g ) − g T ( x t − μ − Σ g ) ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T Σ − 1 ( x t − μ ) − μ T Σ − 1 ( x t − μ ) ) ⏟ ( x t − μ ) T Σ − 1 ( x t − μ ) − 1 2 ( − g T ( x t − μ − Σ g ) + ( − x t T Σ − 1 Σ g ) ⏟ − x t T g + μ T Σ − 1 Σ g ⏟ μ T g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) g + C 2 \begin{align*} &- \frac{1}{2} (x_t - \mu - \Sigma g)^T \Sigma^{-1} (x_t - \mu- \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ = &- \frac{1}{2} (x_t^T - \mu^T - g^T \Sigma^T) \Sigma^{-1} (x_t - \mu - \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ = &- \frac{1}{2} (x_t^T - \mu^T - g^T \Sigma^T) \Sigma^{-1} (x_t - \mu - \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ \\ = & - \frac{1}{2} (x_t^T \Sigma^{-1} - \mu^T \Sigma^{-1} - \underbrace{g^T \Sigma^T \Sigma^{-1}}_{g^T} )(x_t - \mu - \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ = & - \frac{1}{2} (x_t^T \Sigma^{-1} (x_t - \mu - \Sigma g) - \mu^T \Sigma^{-1} (x_t - \mu - \Sigma g) - g^T (x_t - \mu - \Sigma g)) + \frac{1}{2}g^T\Sigma g + C_2 \\ = & - \frac{1}{2} \underbrace{(x_t^T \Sigma^{-1} (x_t - \mu ) - \mu^T \Sigma^{-1} (x_t - \mu))}_{(x_t - \mu)^T \Sigma^{-1} (x_t - \mu)} - \frac{1}{2} ( - g^T (x_t - \mu - \Sigma g) + \underbrace{(- x_t^T \Sigma^{-1}\Sigma g)}_{-x_t^Tg} + \underbrace{\mu^T \Sigma^{-1}\Sigma g}_{\mu^Tg}) + \frac{1}{2}g^T\Sigma g + C_2 \\ = & - \frac{1}{2} (x_t - \mu)^T \Sigma^{-1} (x_t - \mu) + (x_t - \mu) g + C_2 \\ \end{align*} ======21(xtμΣg)TΣ1(xtμΣg)+21gTΣg+C221(xtTμTgTΣT)Σ1(xtμΣg)+21gTΣg+C221(xtTμTgTΣT)Σ1(xtμΣg)+21gTΣg+C221(xtTΣ1μTΣ1gT gTΣTΣ1)(xtμΣg)+21gTΣg+C221(xtTΣ1(xtμΣg)μTΣ1(xtμΣg)gT(xtμΣg))+21gTΣg+C221(xtμ)TΣ1(xtμ) (xtTΣ1(xtμ)μTΣ1(xtμ))21(gT(xtμΣg)+xtTg (xtTΣ1Σg)+μTg μTΣ1Σg)+21gTΣg+C221(xtμ)TΣ1(xtμ)+(xtμ)g+C2


  1. Deep unsupervised learning using nonequilibrium thermodynamics ↩︎

  2. [Denoising Diffusion Implicit Models (DDIM) Sampling](https://arxiv.org/abs/2010.02502) ↩︎

  3. Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. arXiv:arXiv:1907.05600, 2020. ↩︎

  4. Diffusion Models Beat GANs on Image Synthesis ↩︎

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

diffusion model(三)—— classifier guided diffusion model 的相关文章

  • XGBoost参数调优完全指南(附Python代码)

    XGBoost参数调优完全指南 附Python代码 原文地址 Complete Guide to Parameter Tuning in XGBoost with codes in Python 译注 文内提供的代码和运行结果有一定差异 可
  • 什么是HTTP协议和HTTPS协议,以及两者的区别

    HTTP协议 超文本传输协议 Hyper Text Transfer Protocol 是一个简单的请求 响应协议 它通常运行在TCP之上 它指定了客户端可能发送给服务器什么样的消息以及得到什么样的响应 请求和响应消息的头以形式给出 而消息
  • STM32 Flash详解

    本文将根据ST官方Flashprogramming manual 文档编号 PM0059 讲解STM32F207内部Flash编程 01 概述 这里的flash是指STM32F207内部集成的Flash Flash存储器有以下特点 最大1M
  • opencv入门Vec3f

    Vec3f表示的是3通道float类型的 Vect 就相当于3通道float类型的图像 这是其中一个具体化 解释可以从源代码中看出来 下面给出一个具体的例子 Vec3f point Vec3f 10 10 3 2 Float 3 compo
  • Fiddler+夜神模拟器对安卓app进行抓包,安卓9,安装Magisk和LSPosed

    效果图 安装教程 1 下载夜神模拟器 国际版 2 下载Fiddler 1 配置fiddler允许监听到https Tools gt Options gt HTTPS 2 配置fiddler允许远程连接 Tools gt Options gt
  • 超级好用的思维导图软件XMind

    超级好用的思维导图软件XMind 今天和大家分享一款我一直使用的思维导图制作软件XMind 关于 思维导图制作的软件和网站是非常非常多的 可以说上网一搜的话一大把 我推荐这款 XMind是我自己搜集整理各种信息以及自己的实际使用后感觉特别喜
  • PyCharm配置anaconda环境

    PyCharm配置anaconda环境 PyCharm是一款很好用很流行的python编辑器 Anaconda是专注于数据分析的Python发行版本 包含了conda Python等190多个科学包及其依赖项 Anaconda通过管理工具包

随机推荐

  • coherence

    coherence Coherence是 Oracle为了建立一种高可靠和高扩展 集群计算的一个关键部件 集群指的是多于一个应用服务器参与到运算里 Coherence的主要用途是共享一个应用的对象 主要是java对象 比如Web应用的一个会
  • MAC电脑出现 .bin/webpack-dev-server permission denied 错误解决方法

    以前同事 新买的mac电脑 拉取项目后 npm i 安装了所有依赖 但是执行npm run dev 报错 MAC电脑出现 bin webpack dev server permission denied 提示权限问题 这样解决 sudo n
  • DeepSpeed Chat: 一键式RLHF训练,让你的类ChatGPT千亿大模型提速省钱15倍

    目录 1 概述 2 简洁高效且经济的 ChatGPT 训练与推理体验 使用 DeepSpeed Chat 的 RLHF 示例轻松训练你的第一个 类ChatGPT 模型 想尝试不同的模型大小和配置 轻松实现 利用 DeepSpeed Chat
  • 静态代码检查工具 - SourceInsight_Scan 使用指南

    静态代码检查工具 SourceInsight Scan 使用指南 静态代码检查是软件开发过程中非常重要的一环 它可以帮助开发人员发现潜在的代码问题 提高代码质量和可维护性 本文将介绍一款名为SourceInsight Scan的静态代码检查
  • 【100天精通python】Day22:字符串常用操作大全

    目录 专栏导读 一 字符串常用操作 1 拼接字符串 2 计算字符串长度 3 截取字符串 4 分割合并字符串 5 检索字符串 6 字母的大小写转换 7 去除字符串的空格和特殊字符 8 格式化字符串 二 字符串编码转换 2 1 使用encode
  • 管理与维护samba服务器,配置与管理samba服务器

    安装Samba服务器 环境 CentOS 8 boot 64 位 window 10 64 位 samba软件包 检查是否安装了samba软件包 rpm qa grep samba 没有安装软件包 则使用yum命令安装 yum clean
  • 怎样在VMware Workstation中安装Linux系统

    安装步骤 一 创建虚拟机 1 新建虚拟机 典型 下一步 2 稍后安装操作系统 下一步 3 Linux 下一步 4 客户机命名 必须是英文名 自定义浏览存放的位置 下一步 5 设置磁盘大小 拆分多个文件 下一步
  • 公链分层要怎么设计?

    比特币把人类在密码学和计算机工程中的应用融合到了一起 全球矿工 开发者们对比特币有着十分浓厚的兴趣 他们都在比特币提供的结构思考中完成创新和应用 比特币网络的改良的一个最好例子就是以太坊 而EOS等其余公链的出现则在更多的途径上为区块链技术
  • 基于PiggyMetrics微服务搭建的分布式系统

    1环境 2工具安装 3docker运行 1 docker安装完成后桌面出现Docker Quickstart Terminal和Kitematic Alpha两个快捷方式 2 修改VirtualBox中docker虚拟机的内存和处理器核数
  • Linux运维常用工具软件

    1 远程桌面连接 TigerVNC Xshell 2 FTP服务和客户端 FileZilla 将客户端的文件上传到服务器上 客户端可以使用免费的FileZilla Client 支持多线程上传文件 3 硬件检测 CPU Z CPU Z是一款
  • 2022-12-11 leetcode与蓝桥刷题情况

    一 leetcode题目 今天的leetcode是写的周赛题目 昨天状态不好 摸鱼一天 1 数组中最长的方波 题目描述 给你一个整数数组 nums 如果 nums 的子序列满足下述条件 则认为该子序列是一个 方波 子序列的长度至少为 2 并
  • fullcalendar v6的使用记录

    翻了翻百度 教程很多都过时了 方法都废弃了 这次是以V6版本CDN方式使用说明 文档地址 配置 var calendarConfig 加载 loading function load 可以控制页面的加载状态 app instance dat
  • ctfshow-萌新-web3( 利用intval函数的特性配合联合注入获取网站敏感信息)

    ctf show 萌新模块 web3关 此关卡考察的是 intval 函数的特性 以及SQL注入漏洞的利用 首先需要利用 intval 转换字符串的特性绕过校验 而后利用联合注入获取数据库中的敏感信息 从而获取flag 源码中过滤了or 加
  • Linux内核编译和安装

    下载对应的linux内核源码 The Linux Kernel Archives 2 传输到要安装内核的服务器系统中的对应根目录下 usr src scp p 端口号 源文件 username IP地址 例如 scp p 8090 linu
  • 验证实例的有效性与类型的判断

    我们常常会用到指针变量 指针只有赋值以相应的实例才有意义 怎么判断指针引用实例的有效性是我们经常面对的一个问题 我现在 只知道两种方法 列举如下 1 ASSERT VALID pMyObject ASSERT VALID 要判断的类必须是C
  • PTA-找完数(C语言)

    所谓完数就是该数恰好等于除自身外的因子之和 例如 6 1 2 3 其中1 2 3为6的因子 本题要求编写程序 找出任意两正整数m和n之间的所有完数 输入格式 输入在一行中给出2个正整数m和n 1
  • java field_Java Field.get()取得对象的Field属性值

    首页 gt 基础教程 gt 反射 gt Reflection API Java Field get 取得对象的Field属性值 定义 public Object get Object obj 1 如果字段不是静态字段的话 要传入反射类的对象
  • win10修改系统字体(替换OneNote中Calibri字体)

    微软的OneNote还是很好用的 但是字体问题一直是一个吐槽点 我自己就去微软官网吐槽了好几次 然而并没有什么用 我说设置默认字体为consolas完全无法生效 再次输入笔记时 中文自动改为微软雅黑 英文就是Calibri 他们回复我说确实
  • 递归实现逆序输出(C)

    一 概念 程序调用自身的编程技巧称为递归 recursion 递归做为一种算法在程序设计语言中广泛应用 一个过程或函数在其定义或说明中有直接或间接调用自身的一种方法 它通常把一个大型复杂的问题层层转化为一个与原问题相似的规模较小的问题来求解
  • diffusion model(三)—— classifier guided diffusion model

    系列阅读 diffusion model 一 DDPM技术小结 denoising diffusion probabilistic diffusion model 二 DDIM技术小结 diffusion model 三 classifie