在上一篇博客中介绍了数据处理的整体结构:Faster RCNN代码详解(三):数据处理的整体结构。这一篇博客介绍数据处理的细节——关于anchor的前世今生,代码在脚本的:~/mx-rcnn/rcnn/io/rpn.py的assign_anchor函数中。
这一部分也是你想要深入了解Faster RCNN算法细节的重要部分,因为anchor是Faster RCNN算法的核心之一。具体而言,在这篇博客中我将为你介绍:anchor是什么?怎么生成的?anchor的标签是怎么定义的?bbox(bounding box)的回归目标是怎么定义的?bbox和anchor是什么区别?
def assign_anchor(feat_shape, gt_boxes, im_info, feat_stride=16,
scales=(8, 16, 32), ratios=(0.5, 1, 2), allowed_border=0):
"""
assign ground truth boxes to anchor positions
:param feat_shape: infer output shape
:param gt_boxes: assign ground truth
:param im_info: filter out anchors overlapped with edges
:param feat_stride: anchor position step
:param scales: used to generate anchors, affects num_anchors (per location)
:param ratios: aspect ratios of generated anchors
:param allowed_border: filter out anchors with edge overlap > allowed_border
:return: dict of label
'label': of shape (batch_size, 1) <- (batch_size, num_anchors, feat_height, feat_width)
'bbox_target': of shape (batch_size, num_anchors * 4, feat_height, feat_width)
'bbox_inside_weight': *todo* mark the assigned anchors
'bbox_outside_weight': used to normalize the bbox_loss, all weights sums to RPN_POSITIVE_WEIGHT
"""
def _unmap(data, count, inds, fill=0):
"""" unmap a subset inds of data into original data of size count """
if len(data.shape) == 1:
ret = np.empty((count,), dtype=np.float32)
ret.fill(fill)
ret[inds] = data
else:
ret = np.empty((count,) + data.shape[1:], dtype=np.float32)
ret.fill(fill)
ret[inds, :] = data
return ret
im_info = im_info[0]
scales = np.array(scales, dtype=np.float32)
# base_anchors是anchor的初始化结果,输入中base_size=16,表示输入图像到该层
# feature map的尺寸缩小倍数,对于resnet网络的conv4_x而言缩小倍数是16;ratios默认是[0.5,1,2];
# scales默认是[8,16,32]。base_anchors默认是9*4的numpy array,表示9个anchor的4个坐标值,
# 4个坐标值用框的左上角坐标和右下角坐标。这9个anchor有一个共同点是中心坐标点一样,
# 这正是和RPN网络的滑窗操作对应(第一个3*3的卷积层),滑窗每滑到一个3*3区域,
# 则以该区域中心点为坐标就会生成9个anchor。
base_anchors = generate_anchors(base_size=feat_stride, ratios=list(ratios), scales=scales)
num_anchors = base_anchors.shape[0]
# feat_height和feat_width表示该层feature map的size,比如对于resnet的res4而言,
# 缩放系数是16,所以如果输入图像是600*900,则feat_height=600/16,feat_width=900/16
feat_height, feat_width = feat_shape[-2:]
logger.debug('anchors: %s' % base_anchors)
logger.debug('anchor shapes: %s' % np.hstack((base_anchors[:, 2::4] - base_anchors[:, 0::4],
base_anchors[:, 3::