DDPMs扩散模型Pytorch代码实现附详细注释

2023-11-14

本文相当于是对The Annotated Diffusion Model的代码理解后加的注释,很详尽,具体有些公式图片不太好显示,在vx公众号“一蓑烟雨晴”回复“100”下载notebook版本的代码文件。

import math
from inspect import isfunction # inspect模块https://www.cnblogs.com/yaohong/p/8874154.html主要提供了四种用处:1.对是否是模块、框架、函数进行类型检查 2.获取源码 3.获取类或者函数的参数信息 4.解析堆栈
from functools import partial # 偏函数 https://www.runoob.com/w3cnote/python-partial.html

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm # 进度条
from einops import rearrange # einops把张量的维度操作具象化,让开发者“想出即写出

import torch
from torch import nn, einsum # einsum很方便的实现复杂的张量操作 https://zhuanlan.zhihu.com/p/361209187
import torch.nn.functional as F

一些辅助性的小函数

# x是否为None,不是None则返回True,是None则返回False
def exists(x):
    return x is not None

# 如果val非None则返回val,否则(如果d为函数则返回d(),否则返回d)
def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

# 残差连接
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

# 上采样
def Upsample(dim):
    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)

# 下采样
def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, 1)
# 一种位置编码,前一半sin后一半cos
# eg:维数dim=5,time取1和2两个时间
# layer = SinusoidalPositionEmbeddings(5)
# embeddings = layer(torch.tensor([1,2]))
# return embeddings的形状是(2,5),第一行是t=1时的位置编码,第二行是t=2时的位置编码
# 额外连接(transformer原作位置编码实现):https://github.com/jalammar/jalammar.github.io/blob/master/notebookes/transformer/transformer_positional_encoding_graph.ipynb
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings
# Block类,先卷积后GN归一化后siLU激活函数,若存在scale_shift则进行一定变换
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out) #GN归一化 https://zhuanlan.zhihu.com/p/177853578
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

#例:dim=8,dim_out=16,time_emb_dim=2, groups=8
#Block = ResnetBlock(8, 16, time_emb_dim=2, groups=8)
#a = torch.ones(1, 8, 64, 64)
#b = torch.ones(1, 2)
#result = Block(a, b)
class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""
    
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        # 如果time_emb_dim存在则有mlp层
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() #nn.Identity()有 https://blog.csdn.net/artistkeepmonkey/article/details/115067356

    def forward(self, x, time_emb=None):
        h = self.block1(x) # torch.Size([1, 16, 64, 64])

        if exists(self.mlp) and exists(time_emb):
            # time_emb为torch.Size([1, 2])
            time_emb = self.mlp(time_emb) # torch.Size([1, 16])
            # rearrange(time_emb, "b c -> b c 1 1")为torch.Size([1, 16, 1, 1])
            h = rearrange(time_emb, "b c -> b c 1 1") + h # torch.Size([1, 16, 64, 64])
            

        h = self.block2(h) # torch.Size([1, 16, 64, 64])
        return h + self.res_conv(x) # return最后补了残差连接 # torch.Size([1, 16, 64, 64])
    
# 可以参考class ResnetBlock进行理解
class ConvNextBlock(nn.Module):
    """https://arxiv.org/abs/2201.03545"""

    def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
        super().__init__()
        # 如果time_emb_dim存在则有mlp层
        self.mlp = (
            nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
            if exists(time_emb_dim)
            else None
        )

        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)

        self.net = nn.Sequential(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1),
            nn.GELU(), # Gaussian Error Linear Unit
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
        )

        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        h = self.ds_conv(x)

        if exists(self.mlp) and exists(time_emb):
            assert exists(time_emb), "time embedding must be passed in"
            condition = self.mlp(time_emb)
            h = h + rearrange(condition, "b c -> b c 1 1")

        h = self.net(h)
        return h + self.res_conv(x)

Attention流程图
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-AdG36fxR-1664237054138)(attachment:%E6%9C%AA%E5%91%BD%E5%90%8D%E6%96%87%E4%BB%B6%20%281%29.png)]

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)# qkv为一个元组,其中每一个元素的大小为torch.Size([b, hidden_dim, h, w])
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        ) # qkv中每个元素从torch.Size([b, hidden_dim, h, w])变为torch.Size([b, heads, dim_head, h*w])
        q = q * self.scale # q扩大dim_head**-0.5倍

        sim = einsum("b h d i, b h d j -> b h i j", q, k) # sim有torch.Size([b, heads, h*w, h*w])
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1) # attn有torch.Size([b, heads, h*w, h*w])

        out = einsum("b h i j, b h d j -> b h i d", attn, v) # [b, heads, h*w, h*w]和[b, heads, dim_head, h*w] 得 out为[b, heads, h*w, dim_head]
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) # 得out为[b, hidden_dim, h, w]
        return self.to_out(out) # 得 [b, dim, h, w]

# 和class Attention几乎一致
class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1) 
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)
# 先norm后fn
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)
class Unet(nn.Module):
    def __init__(
        self,
        dim, # 下例中,dim=image_size=28
        init_dim=None,# 默认为None,最终取dim // 3 * 2
        out_dim=None, # 默认为None,最终取channels
        dim_mults=(1,2,4,8),
        channels=3, # 通道数默认为3
        with_time_emb=True, # 是否使用embeddings
        resnet_block_groups=8, # 如果使用ResnetBlock,groups=resnet_block_groups
        use_convnext=True, # 是True使用ConvNextBlock,是Flase使用ResnetBlock
        convnext_mult=2, # 如果使用ConvNextBlock,mult=convnext_mult
    ):
        super().__init__()
        self.channels = channels
        
        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
        
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)] # 从头到尾dim组成的列表
        in_out = list(zip(dims[:-1], dims[1:])) # dim对组成的列表
        # 使用ConvNextBlock或ResnetBlock
        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)
            
        # time embeddings
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None
            
        # layers
        self.downs = nn.ModuleList([]) # 初始化下采样网络列表
        self.ups = nn.ModuleList([]) # 初始化上采样网络列表
        num_resolutions = len(in_out) # dim对组成的列表的长度
        
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1) # 是否到了最后一对
            
            self.downs.append(
                nn.ModuleList(
                [
                    block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                    block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                    Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                    Downsample(dim_out) if not is_last else nn.Identity(),
                ]
                )
            )
        
        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim,time_emb_dim=time_dim)
        
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )
            
        out_dim = default(out_dim, channels)
        self.final_conv = nn.Sequential(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )
        
    def forward(self, x, time):
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        # downsample
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)

        # bottleneck
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        # upsample
        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            x = upsample(x)

        return self.final_conv(x)

四种beta选择

def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
import numpy as np
x = np.linspace(1,1001,1000) 
timesteps = 1000

fig, ax = plt.subplots() # 创建图实例
ax.plot(x, (cosine_beta_schedule(timesteps, s=0.008)/50).numpy(), label='cosine')
ax.plot(x, linear_beta_schedule(timesteps).numpy(), label='linear')
ax.plot(x, quadratic_beta_schedule(timesteps).numpy(), label='quadratic')
ax.plot(x, sigmoid_beta_schedule(timesteps).numpy(), label='sigmoid')
plt.legend()
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lISFybSq-1664237054141)(output_11_0.png)]

betas: β \beta β

alphas: α = 1 − β \alpha = 1-\beta α=1β

alphas_cumprod: α t ‾ = ∏ s = 1 t α s \overline{\alpha_t} = \prod_{s=1}^{t}\alpha_s αt=s=1tαs

alphas_cumprod_prev: α t − 1 ‾ \overline{\alpha_{t-1}} αt1

sqrt_recip_alphas: 1 / α t ‾ 1/\sqrt{\overline{\alpha_t}} 1/αt

sqrt_alphas_cumprod: α t ‾ \sqrt{\overline{\alpha_t}} αt

sqrt_one_minus_alphas_cumprod: 1 − α t ‾ \sqrt{1-\overline{\alpha_t}} 1αt

posterior_variance: β ∗ ( 1 − α t − 1 ‾ ) / ( 1 − α t ‾ ) \beta * (1-\overline{\alpha_{t-1}}) / (1-\overline{\alpha_{t}}) β(1αt1)/(1αt)

timesteps = 200

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

# sqrt_alphas_cumprod = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
# x_start = torch.ones([1, 3, 8, 8])
# out = extract(a=sqrt_alphas_cumprod, t=torch.tensor([5]), x_shape=x_start.shape)
# print(out.shape)
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
# 随便导入一个图片
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gw6bPY7E-1664237054143)(output_14_0.png)]

# 进行一些变化
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize

image_size = 128
transform = Compose([
    Resize(image_size), # 变为形状为128*128
    CenterCrop(image_size), # 中心裁剪
    ToTensor(), # turn into Numpy array of shape HWC, divide by 255
    Lambda(lambda t: (t * 2) - 1), # 变为[-1,1]范围
    
])

x_start = transform(image).unsqueeze(0)
x_start.shape
torch.Size([1, 3, 128, 128])
import numpy as np
reverse_transform = Compose([
     Lambda(lambda t: (t + 1) / 2),
     Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
     Lambda(lambda t: t * 255.),
     Lambda(lambda t: t.numpy().astype(np.uint8)),
     ToPILImage(),
])
# 处理后的图片
reverse_transform(x_start.squeeze())

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uxtXvcUu-1664237054146)(output_17_0.png)]

x t = α t ‾ x 0 + 1 − α t ‾ ϵ x_t = \sqrt{\overline{\alpha_t}}x_0+\sqrt{1-\overline{\alpha_t}}\epsilon xt=αt x0+1αt ϵ

# forward diffusion (using the nice property)
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
def get_noisy_image(x_start, t):
  # add noise
  x_noisy = q_sample(x_start, t=t)

  # turn back into PIL image
  noisy_image = reverse_transform(x_noisy.squeeze())

  return noisy_image
# take time step
t = torch.tensor([40])

get_noisy_image(x_start, t)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-R5teNXT1-1664237054151)(output_21_0.png)]

import matplotlib.pyplot as plt

# use seed for reproducability
torch.manual_seed(0) # torch.manual_seed(0)

# pytorch官方的一个画图函数
# source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()
    plt.show()
# 观察结果多次前向传播后的图像
plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Nd2MVf38-1664237054157)(output_23_0.png)]

# 三种损失函数有l1,l2和huber,默认为l1
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss
from datasets import load_dataset

# load dataset from the hub
dataset = load_dataset("fashion_mnist")
image_size = 28
channels = 1
batch_size = 128
from torchvision import transforms
from torch.utils.data import DataLoader

# define image transformations (e.g. using torchvision)
transform = Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)
])

# define function
def transforms(examples):
   examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
   del examples["image"]

   return examples

# 得到变换之后的数据集
transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)

$
\begin{aligned}
\tilde{\boldsymbol{\mu}}{t} &=\frac{1}{\sqrt{\alpha{t}}}\left(\mathbf{x}{t}-\frac{\beta{t}}{\sqrt{1-\bar{\alpha}{t}}} \mathbf{z}{t}\right)
\end{aligned}
$
其中, z t z_t zt由model(x,t)得

@torch.no_grad()
def p_sample(model, x, t, t_index):
    # p_sample(model, img, torch.full((b,),i,device=device,dtype=torch.long),i)
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)

    if t_index == 0:
        return model_mean
    else: # 加一定的噪声
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise

    
@torch.no_grad()
def p_sample_loop(model, shape):
    # 从噪声中逐步采样
    device = next(model.parameters()).device
    
    b = shape[0]
    
    img = torch.randn(shape, device=device)
    imgs = []
    
    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step',total=timesteps):
        img = p_sample(model, img, torch.full((b,),i,device=device,dtype=torch.long),i)
        imgs.append(img.cpu().numpy())
    return imgs

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model,shape=(batch_size,channels, image_size, image_size))
from pathlib import Path
# 例如num = 10, divisor = 3,得[3,3,3,1]
def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr
results_folder = Path("./results")
results_folder.mkdir(exist_ok = True) # https://zhuanlan.zhihu.com/p/317254621
save_and_sample_every = 1000
0
from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"

model = Unet(
    dim=image_size,
    channels = channels,
    dim_mults=(1,2,4)
)
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-3)
from torchvision.utils import save_image

epochs = 5

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad() # 优化器数值清零

      batch_size = batch["pixel_values"].shape[0]
      batch = batch["pixel_values"].to(device)

      # Algorithm 1 line 3: sample t uniformally for every example in the batch
      t = torch.randint(0, timesteps, (batch_size,), device=device).long() # 随机取t

      loss = p_losses(model, batch, t, loss_type="huber")

      if step % 100 == 0:
        print("Loss:", loss.item())

      loss.backward()
      optimizer.step()
      # save generated images
      if step != 0 and step % save_and_sample_every == 0:
        milestone = step // save_and_sample_every
        batches = num_to_groups(4, batch_size)
        all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
        all_images = torch.cat(all_images_list, dim=0)
        all_images = (all_images + 1) * 0.5
        save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)
# sample 64 images
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)

# show a random one
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")
tensor([6, 9, 8, 3, 7])
# 展示从噪声生成图像的过程
import matplotlib.animation as animation

random_index = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

DDPMs扩散模型Pytorch代码实现附详细注释 的相关文章

  • 从 SHAP 值中获取特征重要性

    我想要获得重要功能的数据框 通过下面的代码 我得到了 shap values 但我不确定这些值的含义是什么 在我的 df 中有 142 个特征和 67 个实验 但得到了一个带有 ca 的数组 2500 个值 explainer shap T
  • 检测到通过 ChromeDriver 启动的 Chrome 浏览器

    我正在尝试在 python 中使用 selenium chromedriver 来访问 www mouser co uk 网站 然而 从第一次拍摄开始 它就被检测为机器人 有人对此有解释吗 此后我使用的代码 options Options
  • 使用 Python 创建 MIDI

    本质上 我正在尝试从头开始创建 MIDI 并将它们放到网上 我对不同的语言持开放态度 但更喜欢使用Python 两种语言之一 如果这有什么区别的话 并且想知道我应该使用哪个库 提前致谢 看起来这就是您正在寻找的 适用于 Python 的简单
  • 如何使用 colorchecker 在 opencv 中进行颜色校准?

    我有数码相机获取的色彩检查器图像 我如何使用它来使用 opencv 校准图像 按照以下颜色检查器图像操作 您是想问如何进行颜色校准或如何使用 OpenCV 进行校准 为了进行颜色校准 您可以使用校准板的最后一行 灰色调 以下是您应该逐步进行
  • ctypes 错误:libdc1394 错误:无法初始化 libdc1394

    我正在尝试将程序编译为共享库 我可以使用 ctypes 在 Python 代码中使用该库 使用以下命令该库可以正常编译 g shared Wl soname mylib O3 o mylib so fPIC files pkg config
  • 在Python中如何获取字典的部分视图?

    是否有可能获得部分视图dict在Python中类似于pandasdf tail df head 说你有很长一段时间dict 而您只想检查某些元素 开头 结尾等 dict 就像是 dict head 3 To see the first 3
  • 无故运行测试时 PyCharm 抛出“AttributeError: 'module' object has no attribute”

    因此 我有一个 Django REST Framework 项目 有一天它无法在 PyCharm 中运行测试 从命令行我可以使用它们来运行它们paver or the manage py直接地 曾经有一段时间 当我们没有在文件顶部导入类的超
  • Pandas dataframe:每批行的操作

    我有一个熊猫数据框df我想计算每批行的一些统计信息 例如 假设我有一个batch size 200000 对于每批batch sizerows 我想要一列的唯一值的数量ID我的数据框 我怎样才能做这样的事情呢 这是我想要的一个例子 prin
  • 使用 NLTK 在 Python 中获取大量名词(或形容词);或 Python Mad Libs

    Like 这个问题 https stackoverflow com questions 7439555 noun adjective etc word lists or dictionaries common words 我有兴趣按词性获取
  • 小部件之间的自定义信号

    尝试将信号从一个 gtk EventBox 子级发送到另一个 在 init HeadMode 第 75 行 上出现错误 类型错误 未知信号名称 消息发送 why usr bin env python coding utf8 import p
  • 更改 x 轴比例

    我使用 Matlab 创建了这个图 使用 matplotlib x 轴绘制大数字 例如 100000 200000 300000 我想要 1 2 3 和 10 5 之类的值来指示它实际上是 100000 200000 300000 有没有一
  • 将 numpy 代码点数组与字符串相互转换

    我有一个很长的 unicode 字符串 alphabet range 0x0FFF mystr join chr random choice alphabet for in range 100 mystr re sub W mystr 我想
  • 在相同任务上,Keras 比 TensorFlow 慢

    我正在使用 Python 运行斩首 DCNN 本例中为 Inception V3 来获取图像特征 我使用的是 Anaconda Py3 6 和 Windows7 使用 TensorFlow 时 我将会话保存在变量中 感谢 jdehesa 并
  • Alembic:如何迁移模型中的自定义类型?

    My User模型是 class User UserMixin db Model tablename users noinspection PyShadowingBuiltins uuid Column uuid GUID default
  • 如何在 Django 中使用基于类的视图创建注册视图?

    当我开始使用 Django 时 我几乎使用 FBV 基于函数的视图 来处理所有事情 包括注册新用户 但当我更深入地研究项目时 我意识到基于类的视图通常更适合大型项目 因为它们更干净且可维护 但这并不是说 FBV 不是 无论如何 我将整个项目
  • 迭代列表的奇怪速度差异

    我创建了两个重复两个不同值的长列表 在第一个列表中 值交替出现 在第二个列表中 一个值出现在另一个值之前 a1 object object 10 6 a2 a1 2 a1 1 2 然后我迭代它们 不对它们执行任何操作 for in a1 p
  • Pandas 堆积条形图中元素的排序

    我正在尝试绘制有关某个地区 5 个地区的家庭在特定行业赚取的收入比例的信息 我使用 groupby 按地区对数据框中的信息进行排序 df df orig groupby District Portion of income value co
  • 从 python 检测 macOS 中的暗模式

    我正在编写一个 PyQt 应用程序 我必须添加一个补丁 以便在启用暗模式的 Macos 上可以读取字体 app QApplication Fix for the font colours on macos when running dark
  • bs4 `next_sibling` VS `find_next_sibling`

    我在使用时遇到困难next sibling 并且类似地与next element 如果用作属性 我不会得到任何返回 但如果用作find next sibling or find next 然后就可以了 来自doc https www cru
  • 操作错误:(sqlite3.OperationalError) SQL 变量太多,同时将 SQL 与数据帧一起使用

    我有一个熊猫数据框 如下所示 activity User Id 0 VIEWED MOVIE 158d292ec18a49 1 VIEWED MOVIE 158d292ec18a49 2 VIEWED MOVIE 158d292ec18a4

随机推荐

  • php消息队列的应用

    欢迎加入 新群号码 99640845 最近打算开发一个新功能 计划应用消息队列 以前对消息队列都是简单的理论了解 真正应用之后把自己的感觉和一些理解整理下来 说正事分割线 具体的业务场景如下 用户下单 生成订单 支付 返回支付信息 就是正常
  • el-tab切换时echarts图表宽度变为100px

    由于el tabs切换的时候 不显示的tab内容默认通过display none 所以导致echarts图表为100px 解决办法 在图表上使用v if来解决
  • c++ 九九乘法表,倒计时,成语接龙等游戏源代码

    include
  • 数据结构之顺序表详解

    目录 前言 1 顺序表 2 顺序表及其功能实现 2 1 准备工作 2 2 顺序表结构的创建 2 3 顺序表的初始化 2 4 顺序表向后插入数据 2 5 打印函数的实现 2 5 顺序表从后删除数据 2 6 顺序表向前插入数据 2 7 顺序表从
  • chrome45以后的版本安装lodop后,仍提示未安装解决

    请先查看你chrome浏览器的版本 如果是45版本以前的版本 安装后仍提示 未安装 或 请升级 请参照本链接解决 http blog sina com cn s blog 721e77e50102vfjl html 以下是chrome版本4
  • 一键部署设计稿至线上 —— D2C国产神器

    微软近期推出了Power Apps 新功能 Express Design 只要上传一个草图或者是 Figma 文件 Express Design 都会在几秒钟之内用 AI 技术将其转化为一个应用程序 不写代码就能生成一个应用 一直以来是我们
  • 服务器上调试程序 pdb命令调试

    以前写python一直用pycharm 调试啥的比较方便 最近要在远程服务器上调试一些程序 只有一个控制台就可以用pdb进行调试了 常用的只有几个命令 break 或 b 设置断点 continue 或 c 继续执行程序 list 或 l
  • 输入一个字符串,判断其是否是回文。(回文:即正读和反读都一样,如abccba, abccba)

    输入一个字符串 判断其是否是回文 回文 即正读和反读都一样 如abccba abccba 这里讨了个巧用了strcmp函数 注 strcmp用法 字符串比较函数 一般形式为strcmp 字符串1 字符串2 比较规则 对两个字符串自左至右逐个
  • linux 系统留后门方法和清除日志

    1 setuid cp bin sh tmp sh chmod u s tmp sh 加上 suid 位到shell上 虽然很简单 但容易被发现 2 echo hack 0 0 bin csh gt gt etc passwd 即给系统增加
  • 剑指 Offer 62. 圆圈中最后剩下的数字(java+python)

    0 1 n 1这n个数字排成一个圆圈 从数字0开始 每次从这个圆圈里删除第m个数字 删除后从下一个数字开始计数 求出这个圆圈里剩下的最后一个数字 例如 0 1 2 3 4这5个数字组成一个圆圈 从数字0开始每次删除第3个数字 则删除的前4个
  • Qt QToolButton和QListWidget的使用

    1 本篇简介 本篇主要演示QListWidget的使用 还涉及工具箱 QToolBox 和工具按钮 QToolButton 的使用 还会通过Action创建工具按钮的下拉菜单和QListWidget的组件的快捷菜单 展示如下图 2 QLis
  • Redis之string类型的三大编码解读

    目录 string类型的三大编码 int 编码 embstr 编码 raw 编码 明明没有超过阈值 为什么变成raw 查看数据类型相关命令 redis看看类型 type key 看看编码 object encoding debug结构 de
  • 2021年wsl2中配置Ubuntu18.04+CUDA+Pytorch深度学习环境完全版

    2021年4月 wsl2 中配置深度学习环境完全版 windows10 RTX3090 wsl2 ubuntu18 04 cuda cudatoolkit11 0 cudnn11 0 gnome anaconda3 pycharm 写在前面
  • 使用tinyproxy简易搭建代理服务器

    需要 腾讯云服务器或阿里云服务器 虚拟机 步骤 第一步 在自己的云服务器上安装 tinyproxy 如果是 Ubuntu 就使用 apt install y tinyproxy 如果是 Centos 则使用 yum install y ti
  • 图像去模糊:MIMO-UNet 模型详解

    本内容主要介绍实现图像去模糊的 MIMO UNet 模型 论文 Rethinking Coarse to Fine Approach in Single Image Deblurring 代码 官方 https github com cho
  • opengles3.0 学习,顶点着色器(六)

    opengles3 0 学习 顶点着色器 六 顶点着色器输入包括 属性 用顶点数组提供的逐顶点数据 统一变量和统一变量缓冲区 顶点着色器使用的不变数据 采样器 代表顶点着色器使用的纹理的特殊统一变量类型 着色器程序 顶点着色器的源码 顶点着
  • 基于人脸的常见表情识别(3)——模型搭建、训练与测试

    基于人脸的常见表情识别 3 模型搭建 训练与测试 模型搭建与训练 1 数据接口准备 2 模型定义 3 模型训练 模型测试 本 Task 是 基于人脸的常见表情识别 训练营的第 3 课 如果你未学习前面的课程 请从 Task1 开始学习 本
  • 基于std::queue C++11 线程安全队列。

    网上看到的封装不错 记录一下 非原创 pragma once include
  • JAVA实现大文件多线程下载,提速30倍!(提供exe版)

    JAVA实现大文件多线程下载 提速30倍 前言 兄弟们看到这个标题可能会觉得是个标题党 为了解决疑虑 我们先来看下最终的测试结果 测试云盘下载的文件 46M 自己本地最大下载速度 2M 1 单线程下载 总耗时 603s 2 多线程下载 50
  • DDPMs扩散模型Pytorch代码实现附详细注释

    本文相当于是对The Annotated Diffusion Model的代码理解后加的注释 很详尽 具体有些公式图片不太好显示 在vx公众号 一蓑烟雨晴 回复 100 下载notebook版本的代码文件 import math from