文章目录
- 1 解码是什么意思
- 2 代码解读
- 3 生成网格中心 代码详解
- 4 按照网格格式生成先验框的宽高 代码详解
- 5 感谢链接
1 解码是什么意思
在利用YOLOv3网络结构提取到out0、out1、out2之后,不同尺度下每个网格点上均有先验框,网络训练过程会对先验框的参数进行调整,继而得到预测框,从不同尺度下预测框还原到原图输入图像上,同时包括该框内目标预测的结果情况(预测框位置、类别概率、置信度分数),这个过程称之为解码。
2 代码解读
注释主要以VOC数据集,YOLOv3 net最后一层输出进行解读。
import torch
import numpy as np
class DecodeBox():
def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
super(DecodeBox, self).__init__()
self.anchors = anchors
self.num_classes = num_classes
self.bbox_attrs = 5 + num_classes
self.input_shape = input_shape
self.anchors_mask = anchors_mask
def decode_box(self, inputs):
outputs = []
for i, input in enumerate(inputs):
batch_size = input.size(0)
input_height = input.size(2)
input_width = input.size(3)
stride_h = self.input_shape[0] / input_height
stride_w = self.input_shape[1] / input_width
scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in self.anchors[self.anchors_mask[i]]]
prediction = input.view(batch_size, len(self.anchors_mask[i]),
self.bbox_attrs, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()
x = torch.sigmoid(prediction[..., 0])
y = torch.sigmoid(prediction[..., 1])
w = prediction[..., 2]
h = prediction[..., 3]
conf = torch.sigmoid(prediction[..., 4])
pred_cls = torch.sigmoid(prediction[..., 5:])
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_height, 1).repeat(
batch_size * len(self.anchors_mask[i]), 1, 1).view(x.shape).type(FloatTensor)
grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_width, 1).t().repeat(
batch_size * len(self.anchors_mask[i]), 1, 1).view(y.shape).type(FloatTensor)
anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
anchor_w = anchor_w.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(w.shape)
anchor_h = anchor_h.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(h.shape)
pred_boxes = FloatTensor(prediction[..., :4].shape)
pred_boxes[..., 0] = x.data + grid_x
pred_boxes[..., 1] = y.data + grid_y
pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
_scale = torch.Tensor([input_width, input_height, input_width, input_height]).type(FloatTensor)
output = torch.cat((pred_boxes.view(batch_size, -1, 4) / _scale,
conf.view(batch_size, -1, 1), pred_cls.view(batch_size, -1, self.num_classes)), -1)
outputs.append(output.data)
return outputs
if __name__ == '__main__':
anchors = [10.0, 13.0, 16.0, 30.0, 33.0, 23.0, 30.0, 61.0, 62.0, 45.0, 59.0, 119.0, 116.0, 90.0, 156.0, 198.0, 373.0, 326.0]
anchors = np.array(anchors).reshape(-1,2)
num_classes = 20
anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
input_shape = [416,416]
bbox_util = DecodeBox(anchors, num_classes, (input_shape[0], input_shape[1]), anchors_mask)
net = YoloBody(anchors_mask, num_classes)
outputs = net(images)
outputs = bbox_util.decode_box(outputs)
3 生成网格中心 代码详解
先验框中心=网格左上角,下面这行代码到底如何理解呢?
grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_height, 1).repeat(
batch_size * len(self.anchors_mask[i]), 1, 1).view(x.shape).type(FloatTensor)
以宽为5,高为5, batch_size为1为例,详细解读见下方代码及输出。
import torch
if __name__ == "__main__":
input_width = 5
input_height = 5
batch_size = 1
anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]
a = torch.linspace(0, input_width - 1, input_width)
print(a)
"""
tensor([0., 1., 2., 3., 4.])
"""
b = a.repeat(input_height, 1)
print(b)
"""
tensor([[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]])
"""
c = b.repeat(batch_size * 3, 1, 1)
print(c)
"""
tensor([[[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]],
[[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]],
[[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]]])
"""
d = c.view(batch_size, 3, input_height, input_width)
print(d)
"""
tensor([[[[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]],
[[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]],
[[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]]]])
"""
e = d.type(FloatTensor)
4 按照网格格式生成先验框的宽高 代码详解
按照网格格式生成先验框的宽高,其代码如下:
anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
anchor_w = anchor_w.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(w.shape)
anchor_h = anchor_h.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(h.shape)
对于上面这四行代码,我们以最小特征层为例,详细理解:
import torch
if __name__ == "__main__":
scaled_anchors = [(3.625,2.8125), (4.875,6.1875), (11.65625, 10.1875)]
x_is_cuda = False
FloatTensor = torch.cuda.FloatTensor if x_is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x_is_cuda else torch.LongTensor
a = LongTensor([0])
print(a)
b = FloatTensor(scaled_anchors)
print(b)
"""
tensor([[ 3.6250, 2.8125],
[ 4.8750, 6.1875],
[11.6562, 10.1875]])
"""
anchor_w = b.index_select(1, a)
print(anchor_w)
"""
tensor([[ 3.6250],
[ 4.8750],
[11.6562]])
"""
anchor_h = b.index_select(1, LongTensor([1]))
"""
tensor([[ 2.8125],
[ 6.1875],
[10.1875]])
"""
batch_size = 1
input_height = 13
input_width = 13
c = anchor_w.repeat(batch_size, 1)
print(c)
"""
tensor([[ 3.6250],
[ 4.8750],
[11.6562]])
若batch_size = 2, c 的结果:
tensor([[ 3.6250],
[ 4.8750],
[11.6562],
[ 3.6250],
[ 4.8750],
[11.6562]])
毕竟有几张图片,先验框的宽,参数个数就应该有几倍,每张图片都有
"""
d = c.repeat(1, 1, input_height * input_width)
print(d.shape)
anchor_w = d.view(1,3,13,13)
print(anchor_w.shape)
5 感谢链接
https://www.bilibili.com/video/BV1Hp4y1y788?p=6&spm_id_from=pageDriver
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)