本文相当于是对The Annotated Diffusion Model 的代码理解后加的注释,很详尽,具体有些公式图片不太好显示,在vx公众号“一蓑烟雨晴”回复“100”下载notebook版本的代码文件。
import math
from inspect import isfunction # inspect模块https://www.cnblogs.com/yaohong/p/8874154.html主要提供了四种用处:1.对是否是模块、框架、函数进行类型检查 2.获取源码 3.获取类或者函数的参数信息 4.解析堆栈
from functools import partial # 偏函数 https://www.runoob.com/w3cnote/python-partial.html
% matplotlib inline
import matplotlib. pyplot as plt
from tqdm. auto import tqdm # 进度条
from einops import rearrange # einops把张量的维度操作具象化,让开发者“想出即写出
import torch
from torch import nn, einsum # einsum很方便的实现复杂的张量操作 https://zhuanlan.zhihu.com/p/361209187
import torch. nn. functional as F
一些辅助性的小函数
# x是否为None,不是None则返回True,是None则返回False
def exists ( x) :
return x is not None
# 如果val非None则返回val,否则(如果d为函数则返回d(),否则返回d)
def default ( val, d) :
if exists( val) :
return val
return d( ) if isfunction( d) else d
# 残差连接
class Residual ( nn. Module) :
def __init__ ( self, fn) :
super ( ) . __init__( )
self. fn = fn
def forward ( self, x, * args, ** kwargs) :
return self. fn( x, * args, ** kwargs) + x
# 上采样
def Upsample ( dim) :
return nn. ConvTranspose2d( dim, dim, 4 , 2 , 1 )
# 下采样
def Downsample ( dim) :
return nn. Conv2d( dim, dim, 4 , 2 , 1 )
# 一种位置编码,前一半sin后一半cos
# eg:维数dim=5,time取1和2两个时间
# layer = SinusoidalPositionEmbeddings(5)
# embeddings = layer(torch.tensor([1,2]))
# return embeddings的形状是(2,5),第一行是t=1时的位置编码,第二行是t=2时的位置编码
# 额外连接(transformer原作位置编码实现):https://github.com/jalammar/jalammar.github.io/blob/master/notebookes/transformer/transformer_positional_encoding_graph.ipynb
class SinusoidalPositionEmbeddings ( nn. Module) :
def __init__ ( self, dim) :
super ( ) . __init__( )
self. dim = dim
def forward ( self, time) :
device = time. device
half_dim = self. dim // 2
embeddings = math. log( 10000 ) / ( half_dim - 1 )
embeddings = torch. exp( torch. arange( half_dim, device= device) * - embeddings)
embeddings = time[ : , None ] * embeddings[ None , : ]
embeddings = torch. cat( ( embeddings. sin( ) , embeddings. cos( ) ) , dim= - 1 )
return embeddings
# Block类,先卷积后GN归一化后siLU激活函数,若存在scale_shift则进行一定变换
class Block ( nn. Module) :
def __init__ ( self, dim, dim_out, groups = 8 ) :
super ( ) . __init__( )
self. proj = nn. Conv2d( dim, dim_out, 3 , padding = 1 )
self. norm = nn. GroupNorm( groups, dim_out) #GN归一化 https://zhuanlan.zhihu.com/p/177853578
self. act = nn. SiLU( )
def forward ( self, x, scale_shift = None ) :
x = self. proj( x)
x = self. norm( x)
if exists( scale_shift) :
scale, shift = scale_shift
x = x * ( scale + 1 ) + shift
x = self. act( x)
return x
#例:dim=8,dim_out=16,time_emb_dim=2, groups=8
#Block = ResnetBlock(8, 16, time_emb_dim=2, groups=8)
#a = torch.ones(1, 8, 64, 64)
#b = torch.ones(1, 2)
#result = Block(a, b)
class ResnetBlock ( nn. Module) :
"""https://arxiv.org/abs/1512.03385"""
def __init__ ( self, dim, dim_out, * , time_emb_dim= None , groups= 8 ) :
super ( ) . __init__( )
# 如果time_emb_dim存在则有mlp层
self. mlp = (
nn. Sequential( nn. SiLU( ) , nn. Linear( time_emb_dim, dim_out) )
if exists( time_emb_dim)
else None
)
self. block1 = Block( dim, dim_out, groups= groups)
self. block2 = Block( dim_out, dim_out, groups= groups)
self. res_conv = nn. Conv2d( dim, dim_out, 1 ) if dim != dim_out else nn. Identity( ) #nn.Identity()有 https://blog.csdn.net/artistkeepmonkey/article/details/115067356
def forward ( self, x, time_emb= None ) :
h = self. block1( x) # torch.Size([1, 16, 64, 64])
if exists( self. mlp) and exists( time_emb) :
# time_emb为torch.Size([1, 2])
time_emb = self. mlp( time_emb) # torch.Size([1, 16])
# rearrange(time_emb, "b c -> b c 1 1")为torch.Size([1, 16, 1, 1])
h = rearrange( time_emb, "b c -> b c 1 1" ) + h # torch.Size([1, 16, 64, 64])
h = self. block2( h) # torch.Size([1, 16, 64, 64])
return h + self. res_conv( x) # return最后补了残差连接 # torch.Size([1, 16, 64, 64])
# 可以参考class ResnetBlock进行理解
class ConvNextBlock ( nn. Module) :
"""https://arxiv.org/abs/2201.03545"""
def __init__ ( self, dim, dim_out, * , time_emb_dim= None , mult= 2 , norm= True ) :
super ( ) . __init__( )
# 如果time_emb_dim存在则有mlp层
self. mlp = (
nn. Sequential( nn. GELU( ) , nn. Linear( time_emb_dim, dim) )
if exists( time_emb_dim)
else None
)
self. ds_conv = nn. Conv2d( dim, dim, 7 , padding= 3 , groups= dim)
self. net = nn. Sequential(
nn. GroupNorm( 1 , dim) if norm else nn. Identity( ) ,
nn. Conv2d( dim, dim_out * mult, 3 , padding= 1 ) ,
nn. GELU( ) , # Gaussian Error Linear Unit
nn. GroupNorm( 1 , dim_out * mult) ,
nn. Conv2d( dim_out * mult, dim_out, 3 , padding= 1 ) ,
)
self. res_conv = nn. Conv2d( dim, dim_out, 1 ) if dim != dim_out else nn. Identity( )
def forward ( self, x, time_emb= None ) :
h = self. ds_conv( x)
if exists( self. mlp) and exists( time_emb) :
assert exists( time_emb) , "time embedding must be passed in"
condition = self. mlp( time_emb)
h = h + rearrange( condition, "b c -> b c 1 1" )
h = self. net( h)
return h + self. res_conv( x)
Attention流程图 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-AdG36fxR-1664237054138)(attachment:%E6%9C%AA%E5%91%BD%E5%90%8D%E6%96%87%E4%BB%B6%20%281%29.png)]
class Attention ( nn. Module) :
def __init__ ( self, dim, heads= 4 , dim_head= 32 ) :
super ( ) . __init__( )
self. scale = dim_head** - 0.5
self. heads = heads
hidden_dim = dim_head * heads
self. to_qkv = nn. Conv2d( dim, hidden_dim * 3 , 1 , bias= False )
self. to_out = nn. Conv2d( hidden_dim, dim, 1 )
def forward ( self, x) :
b, c, h, w = x. shape
qkv = self. to_qkv( x) . chunk( 3 , dim= 1 ) # qkv为一个元组,其中每一个元素的大小为torch.Size([b, hidden_dim, h, w])
q, k, v = map (
lambda t: rearrange( t, "b (h c) x y -> b h c (x y)" , h= self. heads) , qkv
) # qkv中每个元素从torch.Size([b, hidden_dim, h, w])变为torch.Size([b, heads, dim_head, h*w])
q = q * self. scale # q扩大dim_head**-0.5倍
sim = einsum( "b h d i, b h d j -> b h i j" , q, k) # sim有torch.Size([b, heads, h*w, h*w])
sim = sim - sim. amax( dim= - 1 , keepdim= True ) . detach( )
attn = sim. softmax( dim= - 1 ) # attn有torch.Size([b, heads, h*w, h*w])
out = einsum( "b h i j, b h d j -> b h i d" , attn, v) # [b, heads, h*w, h*w]和[b, heads, dim_head, h*w] 得 out为[b, heads, h*w, dim_head]
out = rearrange( out, "b h (x y) d -> b (h d) x y" , x= h, y= w) # 得out为[b, hidden_dim, h, w]
return self. to_out( out) # 得 [b, dim, h, w]
# 和class Attention几乎一致
class LinearAttention ( nn. Module) :
def __init__ ( self, dim, heads= 4 , dim_head= 32 ) :
super ( ) . __init__( )
self. scale = dim_head** - 0.5
self. heads = heads
hidden_dim = dim_head * heads
self. to_qkv = nn. Conv2d( dim, hidden_dim * 3 , 1 , bias= False )
self. to_out = nn. Sequential( nn. Conv2d( hidden_dim, dim, 1 ) ,
nn. GroupNorm( 1 , dim) )
def forward ( self, x) :
b, c, h, w = x. shape
qkv = self. to_qkv( x) . chunk( 3 , dim= 1 )
q, k, v = map (
lambda t: rearrange( t, "b (h c) x y -> b h c (x y)" , h= self. heads) , qkv
)
q = q. softmax( dim= - 2 )
k = k. softmax( dim= - 1 )
q = q * self. scale
context = torch. einsum( "b h d n, b h e n -> b h d e" , k, v)
out = torch. einsum( "b h d e, b h d n -> b h e n" , context, q)
out = rearrange( out, "b h c (x y) -> b (h c) x y" , h= self. heads, x= h, y= w)
return self. to_out( out)
# 先norm后fn
class PreNorm ( nn. Module) :
def __init__ ( self, dim, fn) :
super ( ) . __init__( )
self. fn = fn
self. norm = nn. GroupNorm( 1 , dim)
def forward ( self, x) :
x = self. norm( x)
return self. fn( x)
class Unet ( nn. Module) :
def __init__ (
self,
dim, # 下例中,dim=image_size=28
init_dim= None , # 默认为None,最终取dim // 3 * 2
out_dim= None , # 默认为None,最终取channels
dim_mults= ( 1 , 2 , 4 , 8 ) ,
channels= 3 , # 通道数默认为3
with_time_emb= True , # 是否使用embeddings
resnet_block_groups= 8 , # 如果使用ResnetBlock,groups=resnet_block_groups
use_convnext= True , # 是True使用ConvNextBlock,是Flase使用ResnetBlock
convnext_mult= 2 , # 如果使用ConvNextBlock,mult=convnext_mult
) :
super ( ) . __init__( )
self. channels = channels
init_dim = default( init_dim, dim // 3 * 2 )
self. init_conv = nn. Conv2d( channels, init_dim, 7 , padding= 3 )
dims = [ init_dim, * map ( lambda m: dim * m, dim_mults) ] # 从头到尾dim组成的列表
in_out = list ( zip ( dims[ : - 1 ] , dims[ 1 : ] ) ) # dim对组成的列表
# 使用ConvNextBlock或ResnetBlock
if use_convnext:
block_klass = partial( ConvNextBlock, mult= convnext_mult)
else :
block_klass = partial( ResnetBlock, groups= resnet_block_groups)
# time embeddings
if with_time_emb:
time_dim = dim * 4
self. time_mlp = nn. Sequential(
SinusoidalPositionEmbeddings( dim) ,
nn. Linear( dim, time_dim) ,
nn. GELU( ) ,
nn. Linear( time_dim, time_dim) ,
)
else :
time_dim = None
self. time_mlp = None
# layers
self. downs = nn. ModuleList( [ ] ) # 初始化下采样网络列表
self. ups = nn. ModuleList( [ ] ) # 初始化上采样网络列表
num_resolutions = len ( in_out) # dim对组成的列表的长度
for ind, ( dim_in, dim_out) in enumerate ( in_out) :
is_last = ind >= ( num_resolutions - 1 ) # 是否到了最后一对
self. downs. append(
nn. ModuleList(
[
block_klass( dim_in, dim_out, time_emb_dim= time_dim) ,
block_klass( dim_out, dim_out, time_emb_dim= time_dim) ,
Residual( PreNorm( dim_out, LinearAttention( dim_out) ) ) ,
Downsample( dim_out) if not is_last else nn. Identity( ) ,
]
)
)
mid_dim = dims[ - 1 ]
self. mid_block1 = block_klass( mid_dim, mid_dim, time_emb_dim= time_dim)
self. mid_attn = Residual( PreNorm( mid_dim, Attention( mid_dim) ) )
self. mid_block2 = block_klass( mid_dim, mid_dim, time_emb_dim= time_dim)
for ind, ( dim_in, dim_out) in enumerate ( reversed ( in_out[ 1 : ] ) ) :
is_last = ind >= ( num_resolutions - 1 )
self. ups. append(
nn. ModuleList(
[
block_klass( dim_out * 2 , dim_in, time_emb_dim= time_dim) ,
block_klass( dim_in, dim_in, time_emb_dim= time_dim) ,
Residual( PreNorm( dim_in, LinearAttention( dim_in) ) ) ,
Upsample( dim_in) if not is_last else nn. Identity( ) ,
]
)
)
out_dim = default( out_dim, channels)
self. final_conv = nn. Sequential(
block_klass( dim, dim) , nn. Conv2d( dim, out_dim, 1 )
)
def forward ( self, x, time) :
x = self. init_conv( x)
t = self. time_mlp( time) if exists( self. time_mlp) else None
h = [ ]
# downsample
for block1, block2, attn, downsample in self. downs:
x = block1( x, t)
x = block2( x, t)
x = attn( x)
h. append( x)
x = downsample( x)
# bottleneck
x = self. mid_block1( x, t)
x = self. mid_attn( x)
x = self. mid_block2( x, t)
# upsample
for block1, block2, attn, upsample in self. ups:
x = torch. cat( ( x, h. pop( ) ) , dim= 1 )
x = block1( x, t)
x = block2( x, t)
x = attn( x)
x = upsample( x)
return self. final_conv( x)
四种beta选择
def cosine_beta_schedule ( timesteps, s= 0.008 ) :
"""
cosine schedule as proposed in https://arxiv.org/abs/2102.09672
"""
steps = timesteps + 1
x = torch. linspace( 0 , timesteps, steps)
alphas_cumprod = torch. cos( ( ( x / timesteps) + s) / ( 1 + s) * torch. pi * 0.5 ) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[ 0 ]
betas = 1 - ( alphas_cumprod[ 1 : ] / alphas_cumprod[ : - 1 ] )
return torch. clip( betas, 0.0001 , 0.9999 )
def linear_beta_schedule ( timesteps) :
beta_start = 0.0001
beta_end = 0.02
return torch. linspace( beta_start, beta_end, timesteps)
def quadratic_beta_schedule ( timesteps) :
beta_start = 0.0001
beta_end = 0.02
return torch. linspace( beta_start** 0.5 , beta_end** 0.5 , timesteps) ** 2
def sigmoid_beta_schedule ( timesteps) :
beta_start = 0.0001
beta_end = 0.02
betas = torch. linspace( - 6 , 6 , timesteps)
return torch. sigmoid( betas) * ( beta_end - beta_start) + beta_start
import numpy as np
x = np. linspace( 1 , 1001 , 1000 )
timesteps = 1000
fig, ax = plt. subplots( ) # 创建图实例
ax. plot( x, ( cosine_beta_schedule( timesteps, s= 0.008 ) / 50 ) . numpy( ) , label= 'cosine' )
ax. plot( x, linear_beta_schedule( timesteps) . numpy( ) , label= 'linear' )
ax. plot( x, quadratic_beta_schedule( timesteps) . numpy( ) , label= 'quadratic' )
ax. plot( x, sigmoid_beta_schedule( timesteps) . numpy( ) , label= 'sigmoid' )
plt. legend( )
plt. show( )
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lISFybSq-1664237054141)(output_11_0.png)]
betas:
β
\beta
β
alphas:
α
=
1
−
β
\alpha = 1-\beta
α = 1 − β
alphas_cumprod:
α
t
‾
=
∏
s
=
1
t
α
s
\overline{\alpha_t} = \prod_{s=1}^{t}\alpha_s
α t = ∏ s = 1 t α s
alphas_cumprod_prev:
α
t
−
1
‾
\overline{\alpha_{t-1}}
α t − 1
sqrt_recip_alphas:
1
/
α
t
‾
1/\sqrt{\overline{\alpha_t}}
1/ α t
sqrt_alphas_cumprod:
α
t
‾
\sqrt{\overline{\alpha_t}}
α t
sqrt_one_minus_alphas_cumprod:
1
−
α
t
‾
\sqrt{1-\overline{\alpha_t}}
1 − α t
posterior_variance:
β
∗
(
1
−
α
t
−
1
‾
)
/
(
1
−
α
t
‾
)
\beta * (1-\overline{\alpha_{t-1}}) / (1-\overline{\alpha_{t}})
β ∗ ( 1 − α t − 1 ) / ( 1 − α t )
timesteps = 200
# define beta schedule
betas = linear_beta_schedule( timesteps= timesteps)
# define alphas
alphas = 1. - betas
alphas_cumprod = torch. cumprod( alphas, axis= 0 )
alphas_cumprod_prev = F. pad( alphas_cumprod[ : - 1 ] , ( 1 , 0 ) , value= 1.0 )
sqrt_recip_alphas = torch. sqrt( 1.0 / alphas)
# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch. sqrt( alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch. sqrt( 1. - alphas_cumprod)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * ( 1. - alphas_cumprod_prev) / ( 1. - alphas_cumprod)
# sqrt_alphas_cumprod = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
# x_start = torch.ones([1, 3, 8, 8])
# out = extract(a=sqrt_alphas_cumprod, t=torch.tensor([5]), x_shape=x_start.shape)
# print(out.shape)
def extract ( a, t, x_shape) :
batch_size = t. shape[ 0 ]
out = a. gather( - 1 , t. cpu( ) )
return out. reshape( batch_size, * ( ( 1 , ) * ( len ( x_shape) - 1 ) ) ) . to( t. device)
# 随便导入一个图片
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image. open ( requests. get( url, stream= True ) . raw)
image
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gw6bPY7E-1664237054143)(output_14_0.png)]
# 进行一些变化
from torchvision. transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize
image_size = 128
transform = Compose( [
Resize( image_size) , # 变为形状为128*128
CenterCrop( image_size) , # 中心裁剪
ToTensor( ) , # turn into Numpy array of shape HWC, divide by 255
Lambda( lambda t: ( t * 2 ) - 1 ) , # 变为[-1,1]范围
] )
x_start = transform( image) . unsqueeze( 0 )
x_start. shape
torch.Size([1, 3, 128, 128])
import numpy as np
reverse_transform = Compose( [
Lambda( lambda t: ( t + 1 ) / 2 ) ,
Lambda( lambda t: t. permute( 1 , 2 , 0 ) ) , # CHW to HWC
Lambda( lambda t: t * 255. ) ,
Lambda( lambda t: t. numpy( ) . astype( np. uint8) ) ,
ToPILImage( ) ,
] )
# 处理后的图片
reverse_transform( x_start. squeeze( ) )
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uxtXvcUu-1664237054146)(output_17_0.png)]
x
t
=
α
t
‾
x
0
+
1
−
α
t
‾
ϵ
x_t = \sqrt{\overline{\alpha_t}}x_0+\sqrt{1-\overline{\alpha_t}}\epsilon
x t = α t
x 0 + 1 − α t
ϵ
# forward diffusion (using the nice property)
def q_sample ( x_start, t, noise= None ) :
if noise is None :
noise = torch. randn_like( x_start)
sqrt_alphas_cumprod_t = extract( sqrt_alphas_cumprod, t, x_start. shape)
sqrt_one_minus_alphas_cumprod_t = extract( sqrt_one_minus_alphas_cumprod, t, x_start. shape)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
def get_noisy_image ( x_start, t) :
# add noise
x_noisy = q_sample( x_start, t= t)
# turn back into PIL image
noisy_image = reverse_transform( x_noisy. squeeze( ) )
return noisy_image
# take time step
t = torch. tensor( [ 40 ] )
get_noisy_image( x_start, t)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-R5teNXT1-1664237054151)(output_21_0.png)]
import matplotlib. pyplot as plt
# use seed for reproducability
torch. manual_seed( 0 ) # torch.manual_seed(0)
# pytorch官方的一个画图函数
# source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot ( imgs, with_orig= False , row_title= None , ** imshow_kwargs) :
if not isinstance ( imgs[ 0 ] , list ) :
# Make a 2d grid even if there's just 1 row
imgs = [ imgs]
num_rows = len ( imgs)
num_cols = len ( imgs[ 0 ] ) + with_orig
fig, axs = plt. subplots( figsize= ( 200 , 200 ) , nrows= num_rows, ncols= num_cols, squeeze= False )
for row_idx, row in enumerate ( imgs) :
row = [ image] + row if with_orig else row
for col_idx, img in enumerate ( row) :
ax = axs[ row_idx, col_idx]
ax. imshow( np. asarray( img) , ** imshow_kwargs)
ax. set ( xticklabels= [ ] , yticklabels= [ ] , xticks= [ ] , yticks= [ ] )
if with_orig:
axs[ 0 , 0 ] . set ( title= 'Original image' )
axs[ 0 , 0 ] . title. set_size( 8 )
if row_title is not None :
for row_idx in range ( num_rows) :
axs[ row_idx, 0 ] . set ( ylabel= row_title[ row_idx] )
plt. tight_layout( )
plt. show( )
# 观察结果多次前向传播后的图像
plot( [ get_noisy_image( x_start, torch. tensor( [ t] ) ) for t in [ 0 , 50 , 100 , 150 , 199 ] ] )
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Nd2MVf38-1664237054157)(output_23_0.png)]
# 三种损失函数有l1,l2和huber,默认为l1
def p_losses ( denoise_model, x_start, t, noise= None , loss_type= "l1" ) :
if noise is None :
noise = torch. randn_like( x_start)
x_noisy = q_sample( x_start= x_start, t= t, noise= noise)
predicted_noise = denoise_model( x_noisy, t)
if loss_type == 'l1' :
loss = F. l1_loss( noise, predicted_noise)
elif loss_type == 'l2' :
loss = F. mse_loss( noise, predicted_noise)
elif loss_type == "huber" :
loss = F. smooth_l1_loss( noise, predicted_noise)
else :
raise NotImplementedError( )
return loss
from datasets import load_dataset
# load dataset from the hub
dataset = load_dataset( "fashion_mnist" )
image_size = 28
channels = 1
batch_size = 128
from torchvision import transforms
from torch. utils. data import DataLoader
# define image transformations (e.g. using torchvision)
transform = Compose( [
transforms. RandomHorizontalFlip( ) ,
transforms. ToTensor( ) ,
transforms. Lambda( lambda t: ( t * 2 ) - 1 )
] )
# define function
def transforms ( examples) :
examples[ "pixel_values" ] = [ transform( image. convert( "L" ) ) for image in examples[ "image" ] ]
del examples[ "image" ]
return examples
# 得到变换之后的数据集
transformed_dataset = dataset. with_transform( transforms) . remove_columns( "label" )
# create dataloader
dataloader = DataLoader( transformed_dataset[ "train" ] , batch_size= batch_size, shuffle= True )
$ \begin{aligned} \tilde{\boldsymbol{\mu}}{t} &=\frac{1}{\sqrt{\alpha {t}}}\left(\mathbf{x}{t}-\frac{\beta {t}}{\sqrt{1-\bar{\alpha}{t}}} \mathbf{z} {t}\right) \end{aligned} $ 其中,
z
t
z_t
z t 由model(x,t)得
@torch. no_grad ( )
def p_sample ( model, x, t, t_index) :
# p_sample(model, img, torch.full((b,),i,device=device,dtype=torch.long),i)
betas_t = extract( betas, t, x. shape)
sqrt_one_minus_alphas_cumprod_t = extract( sqrt_one_minus_alphas_cumprod, t, x. shape)
sqrt_recip_alphas_t = extract( sqrt_recip_alphas, t, x. shape)
# Equation 11 in the paper
# Use our model (noise predictor) to predict the mean
model_mean = sqrt_recip_alphas_t * ( x - betas_t * model( x, t) / sqrt_one_minus_alphas_cumprod_t)
if t_index == 0 :
return model_mean
else : # 加一定的噪声
posterior_variance_t = extract( posterior_variance, t, x. shape)
noise = torch. randn_like( x)
# Algorithm 2 line 4:
return model_mean + torch. sqrt( posterior_variance_t) * noise
@torch. no_grad ( )
def p_sample_loop ( model, shape) :
# 从噪声中逐步采样
device = next ( model. parameters( ) ) . device
b = shape[ 0 ]
img = torch. randn( shape, device= device)
imgs = [ ]
for i in tqdm( reversed ( range ( 0 , timesteps) ) , desc= 'sampling loop time step' , total= timesteps) :
img = p_sample( model, img, torch. full( ( b, ) , i, device= device, dtype= torch. long ) , i)
imgs. append( img. cpu( ) . numpy( ) )
return imgs
@torch. no_grad ( )
def sample ( model, image_size, batch_size= 16 , channels= 3 ) :
return p_sample_loop( model, shape= ( batch_size, channels, image_size, image_size) )
from pathlib import Path
# 例如num = 10, divisor = 3,得[3,3,3,1]
def num_to_groups ( num, divisor) :
groups = num // divisor
remainder = num % divisor
arr = [ divisor] * groups
if remainder > 0 :
arr. append( remainder)
return arr
results_folder = Path( "./results" )
results_folder. mkdir( exist_ok = True ) # https://zhuanlan.zhihu.com/p/317254621
save_and_sample_every = 1000
0
from torch. optim import Adam
device = "cuda" if torch. cuda. is_available( ) else "cpu"
model = Unet(
dim= image_size,
channels = channels,
dim_mults= ( 1 , 2 , 4 )
)
model. to( device)
optimizer = Adam( model. parameters( ) , lr= 1e-3 )
from torchvision. utils import save_image
epochs = 5
for epoch in range ( epochs) :
for step, batch in enumerate ( dataloader) :
optimizer. zero_grad( ) # 优化器数值清零
batch_size = batch[ "pixel_values" ] . shape[ 0 ]
batch = batch[ "pixel_values" ] . to( device)
# Algorithm 1 line 3: sample t uniformally for every example in the batch
t = torch. randint( 0 , timesteps, ( batch_size, ) , device= device) . long ( ) # 随机取t
loss = p_losses( model, batch, t, loss_type= "huber" )
if step % 100 == 0 :
print ( "Loss:" , loss. item( ) )
loss. backward( )
optimizer. step( )
# save generated images
if step != 0 and step % save_and_sample_every == 0 :
milestone = step // save_and_sample_every
batches = num_to_groups( 4 , batch_size)
all_images_list = list ( map ( lambda n: sample( model, batch_size= n, channels= channels) , batches) )
all_images = torch. cat( all_images_list, dim= 0 )
all_images = ( all_images + 1 ) * 0.5
save_image( all_images, str ( results_folder / f'sample- { milestone} .png' ) , nrow = 6 )
# sample 64 images
samples = sample( model, image_size= image_size, batch_size= 64 , channels= channels)
# show a random one
random_index = 5
plt. imshow( samples[ - 1 ] [ random_index] . reshape( image_size, image_size, channels) , cmap= "gray" )
tensor([6, 9, 8, 3, 7])
# 展示从噪声生成图像的过程
import matplotlib. animation as animation
random_index = 53
fig = plt. figure( )
ims = [ ]
for i in range ( timesteps) :
im = plt. imshow( samples[ i] [ random_index] . reshape( image_size, image_size, channels) , cmap= "gray" , animated= True )
ims. append( [ im] )
animate = animation. ArtistAnimation( fig, ims, interval= 50 , blit= True , repeat_delay= 1000 )
animate. save( 'diffusion.gif' )
plt. show( )