This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the module’s state. …
这个方法是用来注册一个变量,但不是模型要训练的参数。换句话说,就是用来注册
β
t
\beta_t
βt这样的常量,它们不会被反向传播影响。那么怎么使用呢?直接像用成员变量一样,调用self.oneover_sqrta、self.mab_over_sqrtmab即可。
x_i =(
self.oneover_sqrta[i]*(x_i - eps * self.mab_over_sqrtmab[i])+ self.sqrt_beta_t[i]* z
)
defforward(self, x, c):"""
this method is used in training, so samples t and noise randomly
"""
_ts = torch.randint(1, self.n_T,(x.shape[0],)).to(self.device)# t ~ Uniform(0, n_T)
noise = torch.randn_like(x)# eps ~ N(0, 1)loss
x_t =(
self.sqrtab[_ts,None,None,None]* x
+ self.sqrtmab[_ts,None,None,None]* noise
)# This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps# We should predict the "error term" from this x_t. Loss is what we return.# dropout context with some probability
context_mask = torch.bernoulli(torch.zeros_like(c)+self.drop_prob).to(self.device)# return MSE between added noise, and our predicted noisereturn self.loss_mse(noise, self.nn_model(x_t, c, _ts / self.n_T, context_mask))
_ts 的意思是采样的时刻t。这里torch.randint的3个参数分别是low, high, shape。也就是说,我们要从[1, T]中采样batch_size个t,然后计算
x
t
x_t
xt。换句话说,输入模型的有batch_size张图片,对于每张图片,我们随机取一个t,让模型预测该时刻下的噪声,再作反向传播。 noise是高斯分布中采样的噪声,维度与图片完全一致。 如何理解 self.sqrtab[_ts, None, None, None]?阅读Pytorch中[:,None]的用法解析可知,[None]可以用于拓展维度,比如以下代码:
res = torch.randn((3,4))print(res.shape)// torch.Size([3,4])
res = res[:,:,None]print(res.shape)// torch.Size([3,4,1])
这里的变量x_t是论文里t时刻的
x
t
x_t
xt,维度拓展成了[batch_size, 1, 1, 1]。
变量x_t的运算过程与原文公式一致:
x
t
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
z
ˉ
t
x_{t}=\sqrt{\bar{\alpha}_{t}} x_{0}+\sqrt{1-\bar{\alpha}_{t}} \bar{z}_{t}
xt=αˉtx0+1−αˉtzˉt
self.nn_model(x_t, c, _ts / self.n_T, context_mask)这行代码表示,输入U-net的参数有混噪音的图像
x
t
x_t
xt,标签语义
c
c
c,_ts / self.n_T代表当前时刻t的进度百分比(相较于T),context_mask代表该样例的标签是否要掩盖。
defforward(self, x, c, t, context_mask):# x is (noisy) image, c is context label, t is timestep, # context_mask says which samples to block the context on
x = self.init_conv(x)
down1 = self.down1(x)
down2 = self.down2(down1)
hiddenvec = self.to_vec(down2)# convert context to one hot embedding
c = nn.functional.one_hot(c, num_classes=self.n_classes).type(torch.float)# mask out context if context_mask == 1
context_mask = context_mask[:,None]
context_mask = context_mask.repeat(1,self.n_classes)
context_mask =(-1*(1-context_mask))# need to flip 0 <-> 1
c = c * context_mask
# embed context, time step
cemb1 = self.contextembed1(c).view(-1, self.n_feat *2,1,1)
temb1 = self.timeembed1(t).view(-1, self.n_feat *2,1,1)
cemb2 = self.contextembed2(c).view(-1, self.n_feat,1,1)
temb2 = self.timeembed2(t).view(-1, self.n_feat,1,1)# could concatenate the context embedding here instead of adaGN# hiddenvec = torch.cat((hiddenvec, temb1, cemb1), 1)
up1 = self.up0(hiddenvec)# up2 = self.up1(up1, down2) # if want to avoid add and multiply embeddings
up2 = self.up1(cemb1*up1+ temb1, down2)# add and multiply embeddings
up3 = self.up2(cemb2*up2+ temb2, down1)
out = self.out(torch.cat((up3, x),1))return out
信息向量
U-net为每个时间进度t/T、标签c设置嵌入向量。
首先,定义了嵌入全连接层,用于将输入维度的向量通过全连接层转化到输出维度的向量。
classEmbedFC(nn.Module):def__init__(self, input_dim, emb_dim):super(EmbedFC, self).__init__()'''
generic one layer FC NN for embedding things
'''
self.input_dim = input_dim
layers =[
nn.Linear(input_dim, emb_dim),
nn.GELU(),
nn.Linear(emb_dim, emb_dim),]
self.model = nn.Sequential(*layers)defforward(self, x):
x = x.view(-1, self.input_dim)return self.model(x)
# mask out context if context_mask == 1
context_mask = context_mask[:,None]
context_mask = context_mask.repeat(1,self.n_classes)
context_mask =(-1*(1-context_mask))# need to flip 0 <-> 1
c = c * context_mask
# don't drop context at test time
context_mask = torch.zeros_like(c_i).to(device)# double the batch
c_i = c_i.repeat(2)
context_mask = context_mask.repeat(2)
context_mask[n_sample:]=1.# makes second half of batch context free
z = torch.randn(n_sample,*size).to(device)if i >1else0
x_i =(
self.oneover_sqrta[i]*(x_i - eps * self.mab_over_sqrtmab[i])+ self.sqrt_beta_t[i]* z
)
这是一个从高斯分布采样,z采样自标准高斯分布,前半部分是均值,后半部分是方差。
前半部分是均值,显然遵从下方公式
μ
~
t
=
1
a
t
(
x
t
−
β
t
1
−
a
ˉ
t
ϵ
t
)
\tilde{\boldsymbol{\mu}}_{t}=\frac{1}{\sqrt{a_{t}}}\left(x_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{a}_{t}}} \epsilon_{t}\right)
μ~t=at1(xt−1−aˉtβtϵt) 后半部分是方差,应该使用了简化版本的
β
t
\sqrt \beta_t
βt。
应该没有遵从如下公式:
1
σ
2
=
1
β
~
t
=
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
)
;
β
~
t
=
1
−
α
ˉ
t
−
1
1
−
α
ˉ
t
⋅
β
t
\frac{1}{\sigma^{2}}=\frac{1}{\tilde{\beta}_{t}}=\left(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) ; \quad \tilde{\beta}_{t}=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}} \cdot \beta_{t}
σ21=β~t1=(βtαt+1−αˉt−11);β~t=1−αˉt1−αˉt−1⋅βt