AIGC基础:从VAE到DDPM原理、代码详解

2023-10-26

1853ee0fe1c530b8d0faee1559dfdca6.gif

©作者 | 王建周

单位 | 来也科技AI团队负责人

研究方向 | 分布式系统、CV、NLP

90e7d398e248f1c5de184d4b2cd38c80.png


前言

AIGC 目前是一个非常火热的方向,DALLE-2,ImageGen,Stable Diffusion 的图像在以假乱真的前提下,又有着脑洞大开的艺术性,以下是用开源的 Stable Diffusion 生成的一些图片。

f0a51dc416c97b83647c7b64c4ba2d1a.jpeg

这些模型后边都使用了 Diffusion Model 的技术,但是缺乏相关背景知识去单纯学习 Diffusion Model 门槛会比较高,不过沿着 AE、VAE、CVAE、DDPM 这一系列的生成模型的路线、循序学习会更好的理解和掌握,本文将从原理、数学推导、代码详细讲述这些模型。

71f588c8f579038d982c208ad29cd522.png

AE (AutoEncoder)

AE 模型作用是提取数据的核心特征(Latent Attributes),如果通过提取的低维特征可以完美复原原始数据,那么说明这个特征是可以作为原始数据非常优秀的表征。

AE 模型的结构如下图:

65dc58cfe52c82320401b6b5abfc7ad1.png

训练数据通过 Encoder 得到 Latent,Latent 再通过 Decoder 得到重建数据,通过重建数据和训练的数据差异来构造训练 Loss,代码如下(本文所有的场景都是 mnist,编码器和解码器都用了最基本的卷积网络):

class DownConvLayer(tf.keras.layers.Layer):
    def __init__(self, dim):
        super(DownConvLayer, self).__init__()
        self.conv = tf.keras.layers.Conv2D(dim, 3, activation=tf.keras.layers.ReLU(), use_bias=False, padding='same')
        self.pool = tf.keras.layers.MaxPool2D(2)

    def call(self, x, training=False, **kwargs):
        x = self.conv(x)
        x = self.pool(x)
        return x


class UpConvLayer(tf.keras.layers.Layer):
    def __init__(self, dim):
        super(UpConvLayer, self).__init__()
        self.conv = tf.keras.layers.Conv2D(dim, 3, activation=tf.keras.layers.ReLU(), use_bias=False, padding='same')
        # 通过UpSampling2D上采样
        self.pool = tf.keras.layers.UpSampling2D(2)

    def call(self, x, training=False, **kwargs):
        x = self.conv(x)
        x = self.pool(x)
        return x

# 示例代码都是通过非常简单的卷积操作实现编码器和解码器
class Encoder(tf.keras.layers.Layer):
    def __init__(self, dim, layer_num=3):
        super(Encoder, self).__init__()
        self.convs = [DownConvLayer(dim) for _ in range(layer_num)]

    def call(self, x, training=False, **kwargs):
        for conv in self.convs:
            x = conv(x, training)
        return x


class Decoder(tf.keras.layers.Layer):
    def __init__(self, dim, layer_num=3):
        super(Decoder, self).__init__()
        self.convs = [UpConvLayer(dim) for _ in range(layer_num)]
        self.final_conv = tf.keras.layers.Conv2D(1, 3, strides=1)

    def call(self, x, training=False, **kwargs):
        for conv in self.convs:
            x = conv(x, training)
        # 将图像转成和输入图像shape一致
        reconstruct = self.final_conv(x)
        return reconstruct


class AutoEncoderModel(tf.keras.Model):
    def __init__(self):
        super(AutoEncoderModel, self).__init__()
        self.encoder = Encoder(64, layer_num=3)
        self.decoder = Decoder(64, layer_num=3)

    def call(self, inputs, training=None, mask=None):
        image = inputs[0]
        # 得到图像的特征表示
        latent = self.encoder(image, training)
        # 通过特征重建图像
        reconstruct_img = self.decoder(latent, training)
        return reconstruct_img

    @tf.function
    def train_step(self, data):
        img = data["image"]
        with tf.GradientTape() as tape:
            reconstruct_img = self((img,), True)
        trainable_vars = self.trainable_variables
        # 利用l2 loss 来判断重建图片和原始图像的一致性
        l2_loss = (reconstruct_img - img) ** 2
        l2_loss = tf.reduce_mean(tf.reduce_sum(
            l2_loss, axis=(1, 2, 3)
        ))
        gradients = tape.gradient(l2_loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        return {"l2_loss": l2_loss}

通过 AE 模型可以看到,只要有有效的数据的 Latent Attribute 表示,那么就可以通过 Decoder 来生成新数据,但是在 AE 模型中,Latent 是通过已有数据生成的,所以没法生成已有数据外的新数据。

所以我们设想,是不是可以假设 Latent 符合一定分布规律,只要通过有限参数能够描述这个分布,那么就可以通过这个分布得到不在训练数据中的新 Latent,利用这个新 Latent 就能生成全新数据,基于这个思路,有了 VAE(Variational AutoEncoder 变分自编码器)。

34cddec10b3adcd4a9f0b20e1b63b7a5.png

VAE

VAE 中假设 Latent Attributes (公式中用 z)符合正态分布,也就是通过训练数据得到的 z 满足以下条件:

5d3d6732944e5c9f97fa0110bbb9314b.png

因为 z 是向量,所 都是向量,分别为正态分布的均值和方差。有了学习得到正态分布的参数 ,那么就可以从这个正态分布中采样新的 z,新的 z 通过解码器得到新的数据。

所以在训练过程中需要同时优化两点:

1. 重建的数据和训练数据差异足够小,也就是生成 x 的对数似然越高,一般依然用 L2 或者 L1 loss;

2.  定义的正态分布需要和标准正态分布的一致,这里用了 KL 散度来约束两个分布一致;

Loss 公式定义如下,其中 和 为生成分布, 为编码分布, 为从正态分布中采样的先验分布:

699a9df8da8e6e720a35a45e678d1320.png

Loss 的证明如下:

cdada97851d0e6f53683bc15fc7af33c.png

bd6532ca1e58fc1a53761287ce627055.png

因为我们的目标是最大化对数似然生成分布 ,也就是最小化负的公式 15,也就是公式 1 的 Loss。

所以 VAE 的结构如下:

35351d5453697a8de83ca3f49ac267f2.png

注意的是在上图中有一个采样 z 的操作,这个操作不可导导致无法对进行优化,所以为了反向传播优化,用到重参数的技巧,也就是将 z 表示成 的数学组合方式且该组合方式可导,组合公式如下:

ee97b31ed13f478eef2a89c890453526.png

可以证明重参数后的模型 f 输出期望是不变的(z 是连续分布)。

fe4466f1b1170aa333b4b1b4cf3365af.png

在计算 定义的正态分布和 定义的正态分布的 KL 散度时,用了数学推导进行简化。

5060943d547f7792a780c5b1b77eef50.png

对公式 28 的 log 部分继续简化:

274dd10296e8cc0b6271d05ffc099ee1.png

令:

3f6c2a129e0f14106260f7bfc954e9e3.png

将公式 32 和 33 带入公式 28 得到:

1b31ee5e44d4377c5cadf8c63ef73e0c.png

因为:

6dfb835479ed6a71a00841705d8a86ea.png

ad568a04a9a468cde16842b35e48b768.png

将公式 37、38、45 带入公式 34 得到最终的 KL 散度 Loss 公式:

b9c5cfff133845cb13766f34cb45df92.png

因为 非负,所以我们通过神经网络来学习 。

有了前边的铺垫,所以 VAE 的实现上也比较简单,代码如下:

class VAEModel(tf.keras.Model):
    def __init__(self, inference=False):
        super(VAEModel, self).__init__()
        self.inference = inference
        self.encoder = Encoder(64, layer_num=3)
        self.decoder = Decoder(64, layer_num=3)
        # mnist 的size是28,这里为了简单对齐大小,缩放成了32
        self.img_size = 32
        # z的维度
        self.latent_dim = 64
        # 通过全连接来学习隐特征z正态分布的均值
        self.z_mean_mlp = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(self.latent_dim * 2, activation="relu"),
                tf.keras.layers.Dense(self.latent_dim, use_bias=False),
            ]
        )
        # 通过全连接来学习隐特征z正态分布的方差的对数log(o^2)
        self.z_log_var_mlp = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(self.latent_dim * 2, activation="relu"),
                tf.keras.layers.Dense(self.latent_dim, use_bias=False),
            ]
        )
        # 通过全连接将z 缩放成上采样输入适配的shape
        self.decoder_input_size = [int(self.img_size / (2 ** 3)), 64]
        self.decoder_dense = tf.keras.layers.Dense(
            self.decoder_input_size[0] * self.decoder_input_size[0] * self.decoder_input_size[1],
            activation="relu")

    def sample_latent(self, bs, image):
        # 推理阶段的z直接可以从标准正态分布中采样,因为训练的decoder已经可以从标准高斯分布生成新的图片了
        if self.inference:
            z = tf.keras.backend.random_normal(shape=(bs, self.latent_dim))
            z_mean, z_log_var = None, None
        else:
            x = image
            x = self.encoder(x)
            x = tf.keras.layers.Flatten()(x)
            z_mean = self.z_mean_mlp(x)
            z_log_var = self.z_log_var_mlp(x)
            epsilon = tf.keras.backend.random_normal(shape=(bs, self.latent_dim))
            '''
            实现重参数采样公式17
            u + exp(0.5*log(o^2))*e
            =u +exp(0.5*2*log(o))*e
            =u + exp(log(o))*e
            =u + o*e
            '''
            z = z_mean + tf.exp(0.5 * z_log_var) * epsilon
        return z, z_mean, z_log_var

    def call(self, inputs, training=None, mask=None):
        # 推理生成图片时,image为None
        bs, image = inputs[0], inputs[1]
        z, z_mean, z_log_var = self.sample_latent(bs, image)
        latent = self.decoder_dense(z)
        latent = tf.reshape(latent,
                            [-1, self.decoder_input_size[0], self.decoder_input_size[0], self.decoder_input_size[1]])
        # 通过z重建图像
        reconstruct_img = self.decoder(latent, training)
        return reconstruct_img, z_mean, z_log_var

    def compute_loss(self, reconstruct_img, z_mean, z_log_var, img):
        # 利用l2 loss 来判断重建图片和原始图像的一致性
        l2_loss = (reconstruct_img - img) ** 2
        l2_loss = tf.reduce_mean(tf.reduce_sum(
            l2_loss, axis=(1, 2, 3)
        ))
        # 实现公式48
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = kl_loss + l2_loss
        return {"l2_loss": l2_loss, "total_loss": total_loss, "kl_loss": kl_loss}

    @tf.function
    def forward(self, data, training):
        img = data["img_data"]
        bs = tf.shape(img)[0]
        reconstruct_img, z_mean, z_log_var = self((bs, img), training)
        return self.compute_loss(reconstruct_img, z_mean, z_log_var, img)

    def train_step(self, data):
        with tf.GradientTape() as tape:
            result = self.forward(data, True)
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(result["total_loss"], trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        return result

生成的图片效果如下:

f7b38953edd1f29c1ec4396cd968e343.png

在我们大多数生成场景,都需要带有控制条件,比如我们在生产手写数字的时候,我们需要明确的告诉模型,生成数字 0 的图片,基于这个需求,有了 Conditional Variational AutoEncoder(CVAE)。

b802703bd35952d75dec4b11ccf527ae.png


CVAE

CVAE 的改进思路比较简单,就是训练阶段的 z 同时由 x 和控制条件 y 决定,同时生成的 x 也是由 y 和 z 同时决定,Loss 如下:

d4d0900edc8c172db3be4159e285a0ed.png

而  q(z|y) 我们仍然期望符合标准正态分布,对 VAE 代码改动非常少,简单的实现方法就是对条件 y 有一个 embedding 表示,这个 embedding 表示参与到 encoder 和 decoder 的训练,代码如下:

class CVAEModel(VAEModel):
    def __init__(self, inference=False):
        super(CVAEModel, self).__init__(inference=inference)
        # 定义label的Embedding
        self.label_dim = 128
        self.label_embedding = tf.Variable(
            initial_value=tf.keras.initializers.HeNormal()(shape=[10, self.label_dim]),
            trainable=True,
        )
        self.encoder_y_dense = tf.keras.layers.Dense(self.img_size * self.img_size, activation="relu")
        self.decoder_y_dense = tf.keras.layers.Dense(
            self.decoder_input_size[0] * self.decoder_input_size[0] * self.decoder_input_size[1], activation="relu")

    def call(self, inputs, training=None, mask=None):
        # 推理生成图片时,image为None
        bs, image, label = inputs[0], inputs[1], inputs[2]
        label_emb = tf.nn.embedding_lookup(self.label_embedding, label)
        label_emb = tf.reshape(label_emb, [-1, self.label_dim])
        if not self.inference:
            # 训练阶段将条件label的embedding拼接到图片上作为encoder的输入
            encoder_y = self.encoder_y_dense(label_emb)
            encoder_y = tf.reshape(encoder_y, [-1, self.img_size, self.img_size, 1])
            image = tf.concat([encoder_y, image], axis=-1)
        z, z_mean, z_log_var = self.sample_latent(bs, image)
        latent = self.decoder_dense(z)
        # 将条件label的embedding拼接到z上作为decoder的输入
        decoder_y = self.decoder_y_dense(label_emb)
        latent = tf.concat([latent, decoder_y], axis=-1)
        latent = tf.reshape(latent,
                            [-1, self.decoder_input_size[0], self.decoder_input_size[0],
                             self.decoder_input_size[1] * 2])
        # 通过特征重建图像
        reconstruct_img = self.decoder(latent, training)
        return reconstruct_img, z_mean, z_log_var

    @tf.function
    def forward(self, data, training):
        img = data["img_data"]
        label = data["label"]
        bs = tf.shape(img)[0]
        reconstruct_img, z_mean, z_log_var = self((bs, img, label), training)
        return self.compute_loss(reconstruct_img, z_mean, z_log_var, img)

    def train_step(self, data):
        with tf.GradientTape() as tape:
            result = self.forward(data, True)
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(result["total_loss"], trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        return result

生成 0~9 的图片效果如下:

3b063a286adbbbddcdd8fb7262bc0b59.png

从 VAE 的原理可以看到,我们做了假设 ,但是在大多数场景,这个假设过于严苛,很难保证数据特征符合基本的正态分布(严格意义上也做不到,严格分布的话说明特征就是高斯噪声了),因为这个缺陷,所以基本的 VAE 生成的图像细节不够,边缘偏模糊。

为了解决这些问题,又出现 DDPM(Denoising Diffusion Probabilistic Model),因为 DDPM 相比 GAN,更容易训练(GAN 需要交替训练,而且容易出现模式崩塌,可以参考我们以前的文章),此外 DDPM 的多样性相比 GAN 更好(GAN 因为生成的图像要“欺骗”过鉴别器,所以生成的图像和训练集合的真实图像类似),所以最近 DDPM 成为最受欢迎的生成模型。

3f85185342c2b8550d4f447ed5e2896b.png

DDPM

DDPM 启发点来自非平衡热力学,系统和环境之间有着物质和能量交换,比如在一个盛水的容器中滴入一滴墨水,最终墨水会均匀的扩散到水中,但是如果扩散的每一步足够小,那么这一步就可逆。

所以主要流程上分两个阶段,前向加噪和反向去噪,原始数据为 ,每一步添加足够小的高斯噪声,经过足够的 step T 后,最终数据 会变成标准的高斯噪声(下图的 q),因为前向加噪上是可行的,所以我们假设反向去噪也是可行的,可以逐步的从噪声中一点点的恢复数据的有用信息(下图的 p)直到为 ,下边将详细介绍两部分。

d95a1cd0147be75e6fe903a22922f30b.png

1. 前向加噪

假设前向加噪过程每一步添加噪声的过程符合以下高斯分布,且整个过程满足马尔科夫链,即以下公式:

3ea959fdd6cd7ea2d99d80b7fba4ccbe.png

根据上文提到的重参数技巧,公式 50 可以写成(为了方便,写成标量形式):

c15b2a94753e225d22a7b134be8666b0.png

其中 ,所以公式 52 可以理解为向原始的数据设中加非常小的高斯噪音,并且随着t变大加的噪音逐渐变大,为了方便公式推导,令:

7033f377a13b5f220077482de9d3fee3.png

因为:

ee2e4d8431c142728d0b1fe81b25dd61.png

根据正态分布的求和计算公式以及重参数技巧:

58d04d7258e78fc7348dbc6e7b1ef4a7.png

令 ,将公式 63 带入 57 并推导到一般形式,得到如下前向公式:

665513ca1e2546097eb0ba6d587c89cc.png

公式 64 就是正向过程的最终公式,可以看到正向过程是不存在任何网络参数的,而且对于给定的 t,无需迭代,通过表达式可以直接计算得到 。

2. 反向去噪

反向去噪期望从标准的高斯分布噪声 逐步的消除噪音,每次只恢复目标数据的一点点,最终生成目标数据 ,假设的反向去噪也是符合高斯分布和马尔科夫链,可以用以下数学公式描述:

748266913cd17d09102d6e4fa8a06084.png

因为 中 是依赖 的,所以单纯的 是无法计算的,所以我们需要转而计算 (上图的粉色路径),前者有带学习的参数,我们假设:

1e4521151d59d3aab4adb628d2aae273.png

接下来的目标是需要写出 的表达式,主要是利用条件概率和贝叶斯公式(为了简化都用标量的形式)。

dd8c8bf6b1ec3b4426b5e2d92c9bcd38.png

带入各自的表达式:

0906fe92e06ff1a01b81009e27701518.png

得到:

231f394ad7dfb78957f4b486f0039633.png

对比正态分布公式:

0acb7283603d67e7c2078cb1fc432474.png

可以得到我们需要的 的表达式:

1c152143c254add7eca3890744658ec9.png

接下来我们需要推导下优化的目标,根据前边公式 10 的推导有以下:

5545e6975bf24c7c2f531f72cd64699c.png

因为:

58a84ca435dc70adbfa4e0924ff2214c.png

公式 101 代入公式 96 得到:

c478f853aaafd717a8e8fff90fb6f537.png

对 112 的 继续推导:

10a95192432a83eb2c4b78795b3a5c2b.png

其中 中 为直接计算出来等于常数,所以 为常数;而 为 的 t=1 的特殊表达式,故可以合并到 ,所以从公式 95 可以看出,我们最大化的对数似然 ,等价最小化公式 118,而根据公式 47,两个正态分布的 KL 散度等于:

25ee480003dbe75601e5236a8b26a71c.png

如果上述的 KL 离散度最小,我们希望 逼近 ,根据前边公式93的推导,我们知道:

ae0efa16696116f41ce26c77324904ee.png

根据这个公式,对于已知 的情况下,如果能预测出 ,就可以解决我们的问题,启发我们设计以下目标:

6e0b13a7b2eb55dd02386a14f9086606.png

所以 KL 散度(公式 122)变成以下公式:

986781e47b55dbef19063da3319f0c7f.png

前边的公式 130 的常数 在训练过程可以认为被合并到学习率,所以可以被略掉,所以我们最终的优化目标 Loss 为以下:

a40879ef543403eedde9ff620cebb0be.png

所以训练过程如下:

90a5471c05b5456bd6d6ec6b4184e394.png

从公式 66 和 126,以及重参数技巧可以得知:

adf9791f9798355a5787b2c625b84906.png

所以等待训练完成得到 后,循环执行公式 132 就得到了最终的目标数据 ,过程如下:

1a5a57e65154373ce9ee56216b3f7204.png

经过前边较多的公式推导,最终得到 DDPM 的训练和生成过程确非常简单,从前边能看到希望网络 输入输出 shape 一致,所以常见的 DDPM 都是用 unet 来实现(下图,核心是四点:下采样、上采样、上下采样的特征拼接),在代码上我们做了部分优化。

1. 为了简化代码,我们去掉常见实现方式的 self-attention;

2. 一般时间步 t 也会采用 transformer 中基本的 sincos 的 position 编码,为了简化编码,我们的时间编码直接采用可以学习网络并只加入 Unet 的编码阶段,解码阶段不加入;

3. 相比前边的 VAE 代码,这里的代码相对复杂,卷积模块采用 Resnet 的残差处理方式(经过实验,前边 VAE 基本的编码器和解码器过于简单,没法收敛);

4. 参照官方,用 group norm 代替 batch norm。

d7edfc9d88ab3c114753695fc7171ca8.png

class ConvResidualLayer(tf.keras.layers.Layer):
    def __init__(self, filter_num):
        super(ConvResidualLayer, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filter_num, kernel_size=1, padding='same')
        # import tensorflow_addons as tfa
        self.gn1 = tfa.layers.GroupNormalization(8)
        self.conv2 = tf.keras.layers.Conv2D(filter_num, kernel_size=3, padding='same')
        self.gn2 = tfa.layers.GroupNormalization(8)
        self.act2 = tf.keras.activations.swish

    def call(self, inputs, training=False, *args, **kwargs):
        residual = self.conv1(inputs)
        x = self.gn1(residual)
        x = tf.nn.swish(x)
        x = self.conv2(x)
        x = self.gn2(x)
        x = tf.nn.swish(x)
        out = x + residual
        return out / 1.44

class SimpleDDPMModel(tf.keras.Model):
    def __init__(self, max_time_step=100):
        super(SimpleDDPMModel, self).__init__()
        # 定义ddpm 前向过程的一些参数
        self.max_time_step = max_time_step
        # 采用numpy 的float64,避免连乘的精度失准
        betas = np.linspace(1e-4, 0.02, max_time_step, dtype=np.float64)
        alphas = 1.0 - betas
        alphas_bar = np.cumprod(alphas, axis=0)
        betas_bar = 1.0 - alphas_bar
        self.betas, self.alphas, self.alphas_bar, self.betas_bar = tuple(
            map(
                lambda x: tf.constant(x, tf.float32),
                [betas, alphas, alphas_bar, betas_bar]
            )
        )
        filter_nums = [64, 128, 256]
        self.encoders = [tf.keras.Sequential([
            ConvResidualLayer(num),
            tf.keras.layers.MaxPool2D(2)
        ]) for num in filter_nums]
        self.mid_conv = ConvResidualLayer(filter_nums[-1])
        self.decoders = [tf.keras.Sequential([
            tf.keras.layers.Conv2DTranspose(num, 3, strides=2, padding="same"),
            ConvResidualLayer(num),
            ConvResidualLayer(num),
        ]) for num in reversed(filter_nums)]
        self.final_conv = tf.keras.Sequential(
            [
                ConvResidualLayer(64),
                tf.keras.layers.Conv2D(1, 3, padding="same")
            ]
        )
        self.img_size = 32
        self.time_embeddings = [
            tf.keras.Sequential(
                [
                    tf.keras.layers.Dense(num, activation=tf.keras.layers.LeakyReLU()),
                    tf.keras.layers.Dense(num)
                ]
            )
            for num in filter_nums]

    # 实现公式 64 从原始数据生成噪音图像
    def q_noisy_sample(self, x_0, t, noisy):
        alpha_bar, beta_bar = self.extract([self.alphas_bar, self.betas_bar], t)
        sqrt_alpha_bar, sqrt_beta_bar = tf.sqrt(alpha_bar), tf.sqrt(beta_bar)
        return sqrt_alpha_bar * x_0 + sqrt_beta_bar * noisy

    def extract(self, sources, t):
        bs = tf.shape(t)[0]
        targets = [tf.gather(source, t) for i, source in enumerate(sources)]
        return tuple(map(lambda x: tf.reshape(x, [bs, 1, 1, 1]), targets))

    # 实现公式 131,从噪声数据恢复上一步的数据
    def p_real_sample(self, x_t, t, pred_noisy):
        alpha, beta, beta_bar = self.extract([self.alphas, self.betas, self.betas_bar], t)
        noisy = tf.random.normal(shape=tf.shape(x_t))
        # 这里的噪声系数和beta取值一样,也可以满足越靠近0,噪声越小
        noisy_weight = tf.sqrt(beta)
        # 当t==0 时,不加入随机噪声
        bs = tf.shape(x_t)[0]
        noisy_mask = tf.reshape(
            1 - tf.cast(tf.equal(t, 0), tf.float32), [bs, 1, 1, 1]
        )
        noisy_weight *= noisy_mask
        x_t_1 = (x_t - beta * pred_noisy / tf.sqrt(beta_bar)) / tf.sqrt(alpha) + noisy * noisy_weight
        return x_t_1

    # unet 的下采样
    def encoder(self, noisy_img, t, data, training):
        xs = []
        for idx, conv in enumerate(self.encoders):
            noisy_img = conv(noisy_img)
            t = tf.cast(t, tf.float32)
            time_embedding = self.time_embeddings[idx](t)
            time_embedding = tf.reshape(time_embedding, [-1, 1, 1, tf.shape(time_embedding)[-1]])
            # time embedding 直接相加
            noisy_img += time_embedding
            xs.append(noisy_img)
        return xs

    # unet的上采样
    def decoder(self, noisy_img, xs, training):
        xs.reverse()
        for idx, conv in enumerate(self.decoders):
            noisy_img = conv(tf.concat([xs[idx], noisy_img], axis=-1))
        return noisy_img

    @tf.function
    def pred_noisy(self, data, training):
        img = data["img_data"]
        bs = tf.shape(img)[0]
        noisy = tf.random.normal(shape=tf.shape(img))
        t = data.get("t", None)
        # 在训练阶段t为空,随机生成成t
        if t is None:
            t = tf.random.uniform(shape=[bs, 1], minval=0, maxval=self.max_time_step, dtype=tf.int32)
            noisy_img = self.q_noisy_sample(img, t, noisy)
        else:
            noisy_img = img
        xs = self.encoder(noisy_img, t, data, training)
        x = self.mid_conv(xs[-1])
        x = self.decoder(x, xs, training)
        pred_noisy = self.final_conv(x)
        return {
            "pred_noisy": pred_noisy, "noisy": noisy,
            "loss": tf.reduce_mean(tf.reduce_sum((pred_noisy - noisy) ** 2, axis=(1, 2, 3)), axis=-1)
        }

    # 生成图片
    def call(self, inputs, training=None, mask=None):
        bs = inputs[0]
        x_t = tf.random.normal(shape=[bs, self.img_size, self.img_size, 1])
        for i in reversed(range(0, self.max_time_step)):
            t = tf.reshape(tf.repeat(i, bs), [bs, 1])
            p = self.pred_noisy({"img_data": x_t, "t": t}, False)
            x_t = self.p_real_sample(x_t, t, p["pred_noisy"])
        return x_t

    def train_step(self, data):
        with tf.GradientTape() as tape:
            result = self.pred_noisy(data, True)
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(result["loss"], trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        return {"loss": result["loss"]}

    def test_step(self, data):
        result = self.pred_noisy(data, False)
        return {"loss": result["loss"]}

生成的图片如下:

82dd9ab7cd239c96bb1a3a9fe296b5b1.png

类似 CVAE,使用 DDPM 的时候,我们依然希望可以通过条件控制生成,如前边提到的 DALLE-2,Stable Diffusion 都是通过条件(文本 prompt)来控制生成的图像,为了实现这个目的,就需要采用 Conditional Diffusion Model。

b084f2a88c30ac94c8788e82deb7fc1f.png

Conditional Diffusion Model

目前最主要使用的 Conditional Diffusion Model 主要有两种实现方式,Classifier-guidance 和 Classifier-free,从名字也可以看出,前者需要一个分类器模型,后者无需分类器模型,下边讲简单推导两种的实现方案,并给出  Classifier-free Diffusion Model 的实现代码。

1. Classifier-guidance

参考前边的推导公式在无条件的模型下,我们需要优化;而在控制条件 y 下,我们需要优化的是,可以用贝叶斯进行以下的公式推导:

5072c37614a3bc9d628aea3d28b12a50.png

从以下公式推导可以看出,我们需要一个分类模型,这个分类模型可以对前向过程融入噪音的数据很好的分类,在扩散模型求梯度的阶段,融入这个分类模型对当前噪音数据的梯度即可。

2. Classifier-free

通过 classifier-guidance 的公式证明,我们很容易得到以下的公式推导:

5b965b0d4f0acfef9f87a8fef8737db8.png

取值 0~1 之间,从公式 140 可以看出,只要我们在模型输入上,采样性的融入 y 就可以达到目标,所以在前边的 DDPM 代码上改动比较简单,我们对 0~9 这 10 个数字学习一个 embedding 表示,然后采样性的加入 unet 的 encoder 的阶段,代码如下:

class SimpleCDDPMModel(SimpleDDPMModel):
    def __init__(self, max_time_step=100, label_num=10):
        super(SimpleCDDPMModel, self).__init__(max_time_step=max_time_step)
        # condition 的embedding和time step的一致
        self.condition_embedding = [
            tf.keras.Sequential(
                [
                    tf.keras.layers.Embedding(label_num, num),
                    tf.keras.layers.Dense(num)
                ]
            )
            for num in self.filter_nums]

    # unet 的下采样
    def encoder(self, noisy_img, t, data, training):
        xs = []
        mask = tf.random.uniform(shape=(), minval=0.0, maxval=1.0, dtype=tf.float32)
        for idx, conv in enumerate(self.encoders):
            noisy_img = conv(noisy_img)
            t = tf.cast(t, tf.float32)
            time_embedding = self.time_embeddings[idx](t)
            time_embedding = tf.reshape(time_embedding, [-1, 1, 1, tf.shape(time_embedding)[-1]])
            # time embedding 直接相加
            noisy_img += time_embedding
            # 获取 condition 的embedding
            condition_embedding = self.condition_embedding[idx](data["label"])
            condition_embedding = tf.reshape(condition_embedding, [-1, 1, 1, tf.shape(condition_embedding)[-1]])
            # 训练阶段一定的概率下加入condition,推理阶段全部加入
            if training:
                if mask < 0.15:
                    condition_embedding = tf.zeros_like(condition_embedding)
            noisy_img += condition_embedding
            xs.append(noisy_img)
        return xs

    # 生成图片
    def call(self, inputs, training=None, mask=None):
        bs = inputs[0]
        label = tf.reshape(tf.repeat(inputs[1], bs), [-1, 1])
        x_t = tf.random.normal(shape=[bs, self.img_size, self.img_size, 1])
        for i in reversed(range(0, self.max_time_step)):
            t = tf.reshape(tf.repeat(i, bs), [bs, 1])
            p = self.pred_noisy({"img_data": x_t, "t": t, "label": label}, False)
            x_t = self.p_real_sample(x_t, t, p["pred_noisy"])
        return x_t

最终生成的图片如下:

e4d65d1f2fcd0896e1824ab3c7659817.png

outside_default.png

参考文献

outside_default.png

[1] https://www.jarvis73.com/2022/08/08/Diffusion-Model-1/

[2] https://blog.csdn.net/qihangran5467/article/details/118337892

[3] https://jaketae.github.io/study/vae/

[4] https://pyro.ai/examples/cvae.html

[5] https://lilianweng.github.io/posts/2021-07-11-diffusion-models/

[6] https://spaces.ac.cn/archives/9164

[7] https://zhuanlan.zhihu.com/p/575984592

[8] https://kxz18.github.io/2022/06/19/Diffusion/

[9] https://zhuanlan.zhihu.com/p/502668154

[10] https://xyfjason.top/2022/09/29/%E4%BB%8EVAE%E5%88%B0DDPM/

[11] https://arxiv.org/pdf/2208.11970.pdf

更多阅读

584c23e640e0f0c831e4bac3873bf1a3.png

c1b4eadf0ed058be0f7902370b7d434e.png

74d488bc615f055600ab74dea4d6abac.png

3d17e861282e2be67e1b76dbc1c58a79.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

AIGC基础:从VAE到DDPM原理、代码详解 的相关文章

  • 为什么 urllib2 出现 urllib2.HTTPError 而 urllib 没有错误?

    我有以下简单的代码 import urllib2 import sys sys path append BeautifulSoup BeautifulSoup 3 1 0 1 from BeautifulSoup import page h
  • Python lambda 函数没有在 for 循环中正确调用[重复]

    这个问题在这里已经有答案了 我正在尝试使用 Python 中的 Tkinter 制作一个计算器 我使用 for 循环来绘制按钮 并且尝试使用 lambda 函数 以便仅在按下按钮时调用按钮的操作 而不是在程序启动时立即调用 然而 当我尝试这
  • 缺少 python 配置

    我正在安装一个程序 需要安装 python config 唯一的问题是我目前没有 python config 而且我似乎不知道如何获取它 经过搜索后 我应该可以通过以下方式安装它 yum install python devel 然而 这样
  • 将新形状传递给“np.reshape”

    Within numpy ndarray reshape https docs scipy org doc numpy reference generated numpy ndarray reshape html the shape参数是一
  • 如何在cvxpy中编写多个约束?

    我想在 cvxpy 下的优化问题中添加许多约束 在 matlab 中 我可以通过添加一行 subject to 然后使用 for 循环来生成约束 我怎样才能在 cvxpy 中做同样的工作 因为 cvxpy 中没有 服从 概念 有什么建议吗
  • Django/gevent socket.IO 与 redis pubsub。我把东西放在哪里?

    我有一个独立的 python 脚本 它只是从 Twitter 的流 API 捕获数据 然后在收到每条消息时 使用 redis pubsub 将其发布到频道 tweets 这是该脚本 def main username username pa
  • 识别 Windows 版本

    我正在编写一个打印出详细 Windows 版本信息的函数 输出可能是这样的元组 32bit XP Professional SP3 English 它将支持 Windows XP 及更高版本 我一直坚持获取 Windows 版本 例如 专业
  • 从 SQL Server 中调用 Python 文件

    我的文件名中有 Python 脚本 C Python HL py 在此 Python 脚本中 有预测模型以及对 SQL 数据库中某些表的更新 我想将此文件称为 SQL 作业 我怎样才能做到这一点 这个问题不一样 如何在 SQL Server
  • 使用python同时播放两个正弦音

    我正在使用 python 来播放正弦音 音调基于计算机的内部时间 以分钟为单位 但我想根据秒同时播放一个音调 以获得和谐或双重的声音 这就是我到目前为止所拥有的 有人能指出我正确的方向吗 from struct import pack fr
  • 通过 Python 在 PostgreSQL 中的 unicode 字符串中是否允许空字节?

    unicode 字符串中是否允许空字节 我不问 utf8 我的意思是 unicode 字符串的高级对象表示 背景 我们通过 Python 在 PostgreSQL 中存储包含空字节的 unicode 字符串 如果我们再次读取字符串 字符串会
  • python中的语音识别持续时间设置问题

    我有一个 Wav 格式的音频文件 我想转录 我的代码是 import speech recognition as sr harvard sr AudioFile speech file wav with harvard as source
  • 无法将 python 数据框中的列类型从 object 转换为 str

    我已经下载了一个csv文件 然后将其读取到python dataframe 现在所有4列都有对象类型 我想将它们转换为str类型 现在dtypes的结果如下 Name object Position Title object Departm
  • Flask-httpauth: get_password 装饰器如何为 basic-auth 工作?

    我想知道有没有人用过这个烧瓶延伸 https github com miguelgrinberg flask httpauth简化 http basic auth 基本上我不明白这个example https github com migu
  • Python、cPickle、酸洗 lambda 函数

    我必须像这样腌制一组对象 import cPickle as pickle from numpy import sin cos array tmp lambda x sin x cos x test array tmp tmp tmp tm
  • 如何从列表中删除“\xe2”

    我是 python 新手 正在使用它在我的项目中使用 nltk 对从网页获得的原始数据进行单词标记后 我得到了一个包含 xe2 xe3 x98 等的列表 但是我不需要这些并想删除它们 我只是尝试过 if x in a and if a st
  • pip:证书失败,但curl 有效

    我们在客户端安装了根证书 https 连接适用于curl 但如果我们尝试使用pip 它失败 Could not fetch URL https installserver 40443 pypi simple pep8 There was a
  • 如何点击 Google Trends 中的“加载更多”按钮并通过 Selenium 和 Python 打印所有标题

    这次我想单击一个按钮来加载更多实时搜索 这是网站的链接 该按钮位于页面末尾 代码如下 div class feed load more button Load more div 由于涉及到一些 AngularJS 我不知道该怎么做 有什么提
  • 对 Python 的 id() 感到困惑[重复]

    这个问题在这里已经有答案了 我可以理解以下定义 每个对象都有一个身份 类型和值 对象的身份 一旦创建就永远不会改变 你可能会认为它是 对象在内存中的地址 这is操作员比较身份 两个物体 这id 函数返回一个代表其值的整数 身份 我假设上面的
  • 使用 PuLP 进行线性优化,变量附加条件

    我必须用 Pull 解决 Python 中的整数线性优化问题 我解决了基本问题 现在我必须添加额外的约束 有人可以帮助我用逻辑指示器添加条件吗 逻辑限制是 如果 A gt 20 则 B gt 5 这是我的代码 from pulp impor
  • 每行中最后一次出现 True 的索引

    我有一个二维数组 a False False False False False True True True True True True True True True True True True True True True True

随机推荐

  • 斐波那契查找详细注解版

    对于斐波那契数列 1 1 2 3 5 8 13 21 34 55 89 也可以从0开始 前后两个数字的比值随着数列的增加 越来越接近黄金比值0 618 比如这里的89 把它想象成整个有序表的元素个数 而89是由前面的两个斐波那契数34和55
  • Python中RotatingFileHandler、TimedRotatingFileHandler函数用法

    欢迎来到我的博客 作者 秋无之地 简介 CSDN爬虫 后端 大数据领域创作者 目前从事python爬虫 后端和大数据等相关工作 主要擅长领域有 爬虫 后端 大数据开发 数据分析等 欢迎小伙伴们点赞 收藏 留言 背景 在python开发过程中
  • Linux如何卸载软件

    Linux系统可以通过终端 Terminal 或图形界面 GUI 来卸载软件 终端方式可以使用apt get Ubuntu 或yum CentOS 命令来实现 而图形界面方式可以使用系统自带的软件管理器来实现 比如Ubuntu的Ubuntu
  • libev学习系列之二:libev下载

    libev学习系列之二 libev下载 版本说明 版本 作者 日期 备注 0 1 ZY 2019 5 31 初稿 目录 文章目录 libev学习系列之二 libev下载 版本说明 目录 官网 GitHub 我的某度网盘 官网 可以去官网下载
  • 【python练习题 03】高矮个子排队

    题目 现在有一队小朋友 他们高矮不同 我们以正整数数组表示这一队小朋友的身高 如数组 5 3 1 2 3 我们现在希望小朋友排队 以 高 矮 高 矮 顺序排列 每一个 高 位置的小朋友要比相邻的位置高或者相等 每一个 矮 位置的小朋友要比相
  • Date:January 29th Title: 集训Day2-小信小友打怪兽 题解

    时间 1s 空间 256M 题目描述 小信与小友一起组队打怪兽 有一个长度为n的怪兽序列 一些怪兽会对小信造成伤害 另一些不会 小友是大佬 所有怪兽都伤害不了他 小信与小友轮流打怪兽 小信先手 小友后手 他们需要按照顺序打怪兽 由于技能有冷
  • 小米盒子3s刷机为国际版系统android TV 8.0

    小米盒子3s刷机为国际版系统android TV 8 0 所需工具和软件 一个U盘 adb工具 使用adb工具 通过ip连接小米盒子 官方下载地址 点此进入 dump 16AB img MiBOX3S queenchristina r145
  • tera-PROMISE数据集

    tera Promise数据集 原网址 http openscience us repo 已经打不开 备份网址 https github com opensciences opensciences github io 来源 论文An Imp
  • 蓝桥杯历届试题——取球游戏(博弈论)

    取球游戏 今盒子里有n个小球 A B两人轮流从盒中取球 每个人都可以看到另一个人取了多少个 也可以看到盒中还剩下多少个 并且两人都很聪明 不会做出错误的判断 我们约定 每个人从盒子中取出的球的数目必须是 1 3 7或者8个 轮到某一方取球时
  • 文件对应的Content-Type类型

    https www cnblogs com liu heng p 7520564 html CONTENT TYPE load text html 123 application vnd lotus 1 2 3 3ds image x 3d
  • 海思Hi3559A平台移植 opencv4.0.0

    原文 https blog csdn net xclshwd article details 85257117 海思Hi3559A平台移植 opencv4 0 0 2018年12月26日 09 51 53 xclshwd 阅读数 370 版
  • Jetpack学习之WorkManager

    绝大部分应用程序都有在后台执行任务的需求 根据需求的不同 Android为后台任务提供了多种解决方案 如JobScheduler Loader Service等 WorkManager为应用程序中那些不需要及时完成的任务提供了一个统一的解决
  • 基于MATLAB的图像压缩感知

    一 课题背景 数据压缩技术是提高无线数据传输速度的有效措施之一 传统的数据压缩技术是基于奈奎斯特采样定律进行采样 并根据数据本身的特性降低其冗余度 从而达到压缩的目的 近年来出现的压缩感知理论 Compressed Sensing CS 则
  • 将参数字符串中的字符反向排列,不是逆序打印。

    小题分享 定义一个字符串 abcdef 封装一个函数使他反向排列 不是逆序打印 我们很容易就能想到一种方法 采用循环的方式互换首尾的元素 void reverse char str int left 0 int right strlen s
  • 廊坊师范学院IT提高班,你真正了解多少?

    最近在csdn博文中经常看到博友们问 什么是提高班 更有人对提高班怀有疑惑 or 不理解 廊坊师范学院信息技术提高班到底是怎样的一个地方 你对这个地方又有怎样的认识 你对这个地方是否怀有一份好奇心呢 让这篇文章解开你心中的某些疑惑吧 我一个
  • Node.js使用session或JWT机制登录验证教程

    Session实现代码 Session 对象存储特定用户会话所需的属性及配置信息 这样 当用户在应用程序的 Web 页之间跳转时 存储在 Session 对象中的变量将不会丢失 而是在整个用户会话中一直存在下去 当用户请求来自应用程序的 W
  • 小程序中实现点击切换不同组件的效果

    前言 小程序中实现点击切换不同页面的组件效果 实现效果 实现步骤 第一 分别建立三个页面的文件夹以及他们的相关文件 第二 index模块中 index wxml
  • 安装APK的两种方式

    我的新书 Android App开发入门与实战 已于2020年8月由人民邮电出版社出版 欢迎购买 点击进入详情 网络安装 一般通过网线连接到设备 通过网线进行apk的传输和安装 步骤如下 1 adb connect 目标设备ip和端口 2
  • C++中long是什么类型

    long long本质上还是整型 只不过是一种超长的整型 int型 32位整型 取值范围为 2 31 2 31 1 long 在32位系统是32位整型 取值范围为 2 31 2 31 1 在64位系统是64位整型 取值范围为 2 63 2
  • AIGC基础:从VAE到DDPM原理、代码详解

    作者 王建周 单位 来也科技AI团队负责人 研究方向 分布式系统 CV NLP 前言 AIGC 目前是一个非常火热的方向 DALLE 2 ImageGen Stable Diffusion 的图像在以假乱真的前提下 又有着脑洞大开的艺术性
Powered by Hwhale