classUNetModel(nn.Module):"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""def__init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1,2,4,8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False,# custom transformer support
transformer_depth=1,# custom transformer support
context_dim=None,# custom transformer support
n_embed=None,# custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,):super().__init__()if use_spatial_transformer:assert context_dim isnotNone,'Fool!! You forgot to include the dimension of your cross-attention conditioning...'if context_dim isnotNone:assert use_spatial_transformer,'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'from omegaconf.listconfig import ListConfig
iftype(context_dim)== ListConfig:
context_dim =list(context_dim)if num_heads_upsample ==-1:
num_heads_upsample = num_heads
if num_heads ==-1:assert num_head_channels !=-1,'Either num_heads or num_head_channels has to be set'if num_head_channels ==-1:assert num_heads !=-1,'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed isnotNone# 用于计算当前采样时间t的embedding
time_embed_dim = model_channels *4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),)if self.num_classes isnotNone:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)# 定义输入模块的第一个卷积# TimestepEmbedSequential也可以看作一个包装器,根据层的种类进行时间或者文本的融合。
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels,3, padding=1))])
self._feature_size = model_channels
input_block_chans =[model_channels]
ch = model_channels
ds =1# 对channel_mult进行循环,channel_mult一共有四个值,代表unet四个部分通道的扩张比例# [1, 2, 4, 4]for level, mult inenumerate(channel_mult):# 每个部分循环两次# 添加一个ResBlock和一个AttentionBlockfor _ inrange(num_res_blocks):# 先添加一个ResBlock# 用于对输入的噪声进行通道数的调整,并且融合t的特征
layers =[
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,)]# ch便是上述ResBlock的输出通道数
ch = mult * model_channels
if ds in attention_resolutions:# num_heads=8if num_head_channels ==-1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
# 使用了SpatialTransformer自注意力,加强全局特征,融合文本的特征
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,)ifnot use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
))
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)# 如果不是四个部分中的最后一个部分,那么都要进行下采样。if level !=len(channel_mult)-1:
out_ch = ch
# 在此处进行下采样# 一般直接使用Downsample模块
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,)if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)))# 为下一阶段定义参数。
ch = out_ch
input_block_chans.append(ch)
ds *=2
self._feature_size += ch
if num_head_channels ==-1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
# 定义中间层# ResBlock + SpatialTransformer + ResBlock
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,)ifnot use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,),)
self._feature_size += ch
# 定义Unet上采样过程
self.output_blocks = nn.ModuleList([])# 循环把channel_mult反了过来for level, mult inlist(enumerate(channel_mult))[::-1]:# 上采样时每个部分循环三次for i inrange(num_res_blocks +1):
ich = input_block_chans.pop()# 首先添加ResBlock层
layers =[
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,)]
ch = model_channels * mult
# 然后进行SpatialTransformer自注意力if ds in attention_resolutions:if num_head_channels ==-1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,)ifnot use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
))# 如果不是channel_mult循环的第一个# 且# 是num_res_blocks循环的最后一次,则进行上采样if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,)if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch))
ds //=2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
# 最后在输出部分进行一次卷积
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels,3, padding=1)),)if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed,1),#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits)defconvert_to_fp16(self):"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)defconvert_to_fp32(self):"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)defforward(self, x, timesteps=None, context=None, y=None,**kwargs):"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""assert(y isnotNone)==(
self.num_classes isnotNone),"must specify y if and only if the model is class-conditional"
hs =[]# 用于计算当前采样时间t的embedding
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)if self.num_classes isnotNone:assert y.shape ==(x.shape[0],)
emb = emb + self.label_emb(y)# 对输入模块进行循环,进行下采样并且融合时间特征与文本特征。
h = x.type(self.dtype)for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)# 中间模块的特征提取
h = self.middle_block(h, emb, context)# 上采样模块的特征提取for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)# 输出模块if self.predict_codebook_ids:return self.id_predictor(h)else:return self.out(h)
@torch.no_grad()defdecode_first_stage(self, z, predict_cids=False, force_not_quantize=False):if predict_cids:if z.dim()==4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z,'b h w c -> b c h w').contiguous()
z =1./ self.scale_factor * z
# 一般无需分割输入,所以直接将x_noisy传入self.model中,在下面else进行ifhasattr(self,"split_input_params"):......else:ifisinstance(self.first_stage_model, VQModelInterface):return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)else:return self.first_stage_model.decode(z)
文本到图像预测过程代码
整体预测代码如下:
import random
import einops
import numpy as np
import torch
import cv2
import os
from ldm_hacked import DDIMSampler
from ldm_hacked import create_model, load_state_dict, DDIMSampler
from pytorch_lightning import seed_everything
# ----------------------- ## 使用的参数# ----------------------- ## config的地址
config_path ="model_data/sd_v15.yaml"# 模型的地址
model_path ="model_data/v1-5-pruned-emaonly.safetensors"# 生成的图像大小为input_shape
input_shape =[512,512]# 一次生成几张图像
num_samples =2# 采样的步数
ddim_steps =20# 采样的种子,为-1的话则随机。
seed =12345# eta
eta =0# 提示词
prompt ="a cat"# 正面提示词
a_prompt ="best quality, extremely detailed"# 负面提示词
n_prompt ="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"# 正负扩大倍数
scale =9# save_path
save_path ="imgs/outputs_imgs"# ----------------------- ## 创建模型# ----------------------- #
model = create_model(config_path).cpu()
model.load_state_dict(load_state_dict(model_path, location='cuda'), strict=False)
model = model.cuda()
ddim_sampler = DDIMSampler(model)with torch.no_grad():if seed ==-1:
seed = random.randint(0,65535)
seed_everything(seed)# ----------------------- ## 获得编码后的prompt# ----------------------- #
cond ={"c_crossattn":[model.get_learned_conditioning([prompt +', '+ a_prompt]* num_samples)]}
un_cond ={"c_crossattn":[model.get_learned_conditioning([n_prompt]* num_samples)]}
H, W = input_shape
shape =(4, H //8, W //8)# ----------------------- ## 进行采样# ----------------------- #
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)# ----------------------- ## 进行解码# ----------------------- #
x_samples = model.decode_first_stage(samples)
x_samples =(einops.rearrange(x_samples,'b c h w -> b h w c')*127.5+127.5).cpu().numpy().clip(0,255).astype(np.uint8)# ----------------------- ## 保存图片# ----------------------- #ifnot os.path.exists(save_path):
os.makedirs(save_path)for index, image inenumerate(x_samples):
cv2.imwrite(os.path.join(save_path,str(index)+".jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))