AntiAliasInterpolation2d代码解读
注记
最近在看一些视频驱动的代码时,常见一种特殊的下采样方法,故在这里记录一下。
Class AntiAliasInterpolation2d(nn.Module):
## 初始化
def __init__(self, channels, scale):
## channels
## scale: 下采样比例 取 < 1, 这里假设取0.5
super(AntiAliasInterpolation2d, self).__init__()
sigma = (1 / scale - 1) / 2 ## 0.5
kernel_size = 2* round(sigma * 4) + 1 ## 5
self.ka = kernel_size // 2 ## 2
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka ## 2
kernel_size = [kernel_size, kernel_size] ##[5, 5]
sigma = [sigma, sigma] ## [0.5, 0.5]
kernel = 1
meshgrids = torch.mershgrid( ## size: (2, 5, 5)
[
torch.arange(size, dtype=torc.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2 ## 2
## 高斯处理(标准化,使得数据在处理之后,符合正态分布,符合模型推理要求)
kernel *= torch.exp(-(mgrid - mean)**2 / (2 * std **2)) ## shape: [5, 5]
## 保证卷积核中元素和为1
kernel = kernel / torch.sum(kernel) ## shape: [5, 5]
kernel = kernel.view(1, 1, *kernel.size()) ##shape[1, 1, 5, 5]
## 数组广播
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) ## shape:[channels, 1, 5, 5],这里第二个维度由(in_channel/group)计算得到
## 向模块添加持久缓冲区(将kennel加入到“weight”中,weight不参与训练)
self.register_buffer("weight", kernel)
self.groups = channels ## group=channel,为deepwise
self.scale = scale
inv_scale = 1 / scale
self.int_inv_scale = int(inv_scale)
## 推理
def forward(self, input):
if self.scale == 1.0:
return input
## 填充(conv 操作不改变size)
out = F.pad(input, (self.ka, self.kb, self.ka, self,kb))
## 卷积处理
out = F.conv2d(out, weight=self.weight, groups=self.groups)
## 下采样,int_inv_scale为scale 倒数的整数
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
return out