NanoDet上海人工智能实验室RangiLyu202010月份开源的轻量级检测项目,取得了很好的效果,广受关注。202112月份,作者又更新发布了NanoDetPlus,在coco val上的map提升了7个百分点。

2.辅助训练模块Assign Guidance Module

Nanodet Plus的检测头只使用了2个深度可分离卷积以减少模型的参数,但同时也导致其学习能力有限,对于从零开始学习预测分类和标签
匹配有一定困难。作者使用了同WACV上一篇paperLAD:Improving Object Detection by Label Assignment Distillation一样的做法,通过教师学生模型训练多了一个网络来引导NanodetPlus检测头的训练,同知识蒸馏的思想。


Nanodet Plus中,辅助训练模块在整个网络中组成中的所占的部分如下图:



      name: NanoDetPlusHead
      num_classes: 80
      input_channel: 96
      feat_channels: 96
      stacked_convs: 2
      kernel_size: 5
      strides: [8, 16, 32, 64]
      activation: LeakyReLU
      reg_max: 7
        type: BN
    # Auxiliary head, only use in training time.
      name: SimpleConvHead
      num_classes: 80
      input_channel: 192
      feat_channels: 192
      stacked_convs: 4
      strides: [8, 16, 32, 64]
      activation: LeakyReLU
      reg_max: 7

可以看到检测头headfeat_channels: 96, stacked_convs: 2,辅助训练检测头aux_headfeat_channels: 192, stacked_convs: 4,且检测头中使用的还是深度可分离卷积,因此参数比辅助训练头少很多,因此辅助训练头的学习能力更强。辅助训练分支只在网络的训练过程中起作用,训练时,backbone输出的特征同时送入检测分支和辅助训练分支,因辅助训练分支有更多的参数,故其更容易从初始状态学习判断如何划分正负样本并实现标签匹配。辅助训练分支和检测分支的输出是相同维度的预测框和类别数,因辅助训练分支训练学习的更快更好,因此可以使用辅助训练分支预测框输出结果来做标签匹配,将匹配的结果当成检测分支预测框的匹配结果来计算训练loss


# in nanodet_plus.py
def forward_train(self, gt_meta):
    img = gt_meta["img"]
    feat = self.backbone(img)
    fpn_feat = self.fpn(feat)
    if self.epoch >= self.detach_epoch:
        aux_fpn_feat = self.aux_fpn([f.detach() for f in feat])
        dual_fpn_feat = (
            torch.cat([f.detach(), aux_f], dim=1)
            for f, aux_f in zip(fpn_feat, aux_fpn_feat)
        aux_fpn_feat = self.aux_fpn(feat)
        dual_fpn_feat = (
            torch.cat([f, aux_f], dim=1) for f, aux_f in zip(fpn_feat, aux_fpn_feat)
    head_out = self.head(fpn_feat)
    aux_head_out = self.aux_head(dual_fpn_feat)
    loss, loss_states = self.head.loss(head_out, gt_meta, aux_preds=aux_head_out)
    return head_out, loss, loss_states


Nanodet Plus参考了Generalized Focal Loss中的Distributed Bounding Boxes方法,在特征图尺度上回归检测框距特征grid cell中心距离时,采用离散化的方法,将回归范围分成特征图尺度上的reg_max份,并计算落在0,1,...,reg_max上的概率。因此Nanodet Plus除了检测的分类和box IoU损失外,还加多了一个DistributionFocalLoss。类别评价使用的是QualityFocalLoss,box评价使用的是Generalized Intersection over Union,GIoU

# in nanodet_plus_head.py
class NanoDetPlusHead:
    def _get_loss_from_assign(self, cls_preds, reg_preds, decoded_bboxes, assign):
        device = cls_preds.device
        labels, label_scores, bbox_targets, dist_targets, num_pos = assign
        num_total_samples = max(
            reduce_mean(torch.tensor(sum(num_pos)).to(device)).item(), 1.0

        labels = torch.cat(labels, dim=0)
        label_scores = torch.cat(label_scores, dim=0)
        bbox_targets = torch.cat(bbox_targets, dim=0)
        cls_preds = cls_preds.reshape(-1, self.num_classes)
        reg_preds = reg_preds.reshape(-1, 4 * (self.reg_max + 1))
        decoded_bboxes = decoded_bboxes.reshape(-1, 4)
        loss_qfl = self.loss_qfl(
            cls_preds, (labels, label_scores), avg_factor=num_total_samples

        pos_inds = torch.nonzero(
            (labels >= 0) & (labels < self.num_classes), as_tuple=False

        if len(pos_inds) > 0:
            weight_targets = cls_preds[pos_inds].detach().sigmoid().max(dim=1)[0]
            bbox_avg_factor = max(reduce_mean(weight_targets.sum()).item(), 1.0)

            loss_bbox = self.loss_bbox(

            dist_targets = torch.cat(dist_targets, dim=0)
            loss_dfl = self.loss_dfl(
                reg_preds[pos_inds].reshape(-1, self.reg_max + 1),
                weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
                avg_factor=4.0 * bbox_avg_factor,
            loss_bbox = reg_preds.sum() * 0
            loss_dfl = reg_preds.sum() * 0

        loss = loss_qfl + loss_bbox + loss_dfl
        loss_states = dict(loss_qfl=loss_qfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl)
        return loss, loss_states

    def loss(self, preds, gt_meta, aux_preds=None):
        """Compute losses.
            preds (Tensor): Prediction output.
            gt_meta (dict): Ground truth information.
            aux_preds (tuple[Tensor], optional): Auxiliary head prediction output.

            loss (Tensor): Loss tensor.
            loss_states (dict): State dict of each loss.
        gt_bboxes = gt_meta["gt_bboxes"]
        gt_labels = gt_meta["gt_labels"]
        device = preds.device
        batch_size = preds.shape[0]
        input_height, input_width = gt_meta["img"].shape[2:]
        featmap_sizes = [
            (math.ceil(input_height / stride), math.ceil(input_width) / stride)
            for stride in self.strides
        # get grid cells of one image
        mlvl_center_priors = [
            for i, stride in enumerate(self.strides)
        center_priors = torch.cat(mlvl_center_priors, dim=1)

        cls_preds, reg_preds = preds.split(
            [self.num_classes, 4 * (self.reg_max + 1)], dim=-1
        dis_preds = self.distribution_project(reg_preds) * center_priors[..., 2, None]
        decoded_bboxes = distance2bbox(center_priors[..., :2], dis_preds)

        if aux_preds is not None:
            # use auxiliary head to assign
            aux_cls_preds, aux_reg_preds = aux_preds.split(
                [self.num_classes, 4 * (self.reg_max + 1)], dim=-1
            aux_dis_preds = (
                self.distribution_project(aux_reg_preds) * center_priors[..., 2, None]
            aux_decoded_bboxes = distance2bbox(center_priors[..., :2], aux_dis_preds)
            batch_assign_res = multi_apply(
            # use self prediction to assign
            batch_assign_res = multi_apply(

        loss, loss_states = self._get_loss_from_assign(
            cls_preds, reg_preds, decoded_bboxes, batch_assign_res

        if aux_preds is not None:
            aux_loss, aux_loss_states = self._get_loss_from_assign(
                aux_cls_preds, aux_reg_preds, aux_decoded_bboxes, batch_assign_res
            loss = loss + aux_loss
            for k, v in aux_loss_states.items():
                loss_states["aux_" + k] = v
        return loss, loss_states


class DistributionFocalLoss
    def forward(
        self, pred, target, weight=None, avg_factor=None, reduction_override=None
        """Forward function.

            pred (torch.Tensor): Predicted general distribution of bounding
                boxes (before softmax) with shape (N, n+1), n is the max value
                of the integral set `{0, ..., n}` in paper.
            target (torch.Tensor): Target distance label for bounding boxes
                with shape (N,).
            weight (torch.Tensor, optional): The weight of loss for each
                prediction. Defaults to None.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Defaults to None.
        assert reduction_override in (None, "none", "mean", "sum")
        reduction = reduction_override if reduction_override else self.reduction
        loss_cls = self.loss_weight * distribution_focal_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor
        return loss_cls

def distribution_focal_loss(pred, label):
    r"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning
    Qualified and Distributed Bounding Boxes for Dense Object Detection

        pred (torch.Tensor): Predicted general distribution of bounding boxes
            (before softmax) with shape (N, n+1), n is the max value of the
            integral set `{0, ..., n}` in paper.
        label (torch.Tensor): Target distance label for bounding boxes with
            shape (N,).

        torch.Tensor: Loss tensor with shape (N,).
    dis_left = label.long()
    dis_right = dis_left + 1
    weight_left = dis_right.float() - label
    weight_right = label - dis_left.float()
    loss = (
        F.cross_entropy(pred, dis_left, reduction="none") * weight_left
        + F.cross_entropy(pred, dis_right, reduction="none") * weight_right
    return loss


NanodetPlus使用了DynamicSoftLabelAssigner,DSLA,参考YoloX中的SimOTA算法来做标签匹配,SimOTA是一种动态标签匹配算法,基于dynamic k来实现,先计算cost matrix,再将其当作任务分配问题,关于YoloX中的


虽然NanoDetPlus作者将模型最终的输出concat为了一个输出,从下图可以看到NanoDetPlus有四个输出头,对应的stride分别为[8, 16, 32, 64]


上图中四个输出头特征图的shape为[1, 33, 80, 80]/[1, 33, 40, 40]/[1, 33, 20, 20]/[1, 33, 10, 10],shape分别对应的含义是[batch_size, num_class+4*(reg_max+1), feature_map_height, feature_map_width]batch_size,num_class,feature_map_{height,width}都好理解,reg_max却是新引入的一个超参数,值得介绍一下。



中心点就是通过meshgrid(range(feature_width), range(feature_height))*stride得到的从特征图映射到输入图像尺度中的点,而(left,top,right,bottom)的预测作者使用的是Generalized Focal Loss(GFL)中提出的离散化回归的方法。

Generalized Focal Loss(GFL)是南开大学的李翔在2020年6月发表的论文中提出的。该方法是离散化检测框回归的范围,选取range(0, reg_max+1)上的离散值作为回归目标,reg_max是最大回归范围。


class Integral(nn.Module):
    """A fixed layer for calculating integral result from distribution.
    This layer calculates the target location by :math: `sum{P(y_i) * y_i}`,
    P(y_i) denotes the softmax vector that represents the discrete distribution
    y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max}
        reg_max (int): The maximal value of the discrete set. Default: 16. You
            may want to reset it according to your new dataset or related

    def __init__(self, reg_max=16):
        super(Integral, self).__init__()
        self.reg_max = reg_max
            "project", torch.linspace(0, self.reg_max, self.reg_max + 1)

    def forward(self, x):
        """Forward feature from the regression head to get integral result of
        bounding box location.
            x (Tensor): Features of the regression head, shape (N, 4*(n+1)),
                n is self.reg_max.
            x (Tensor): Integral result of box locations, i.e., distance
                offsets from the box center in four directions, shape (N, 4).
        shape = x.size()
        x = F.softmax(x.reshape(*shape[:-1], 4, self.reg_max + 1), dim=-1)
        x = F.linear(x, self.project.type_as(x)).reshape(*shape[:-1], 4)
        return x

除了Integral处理外,其余的就是常规的后处理操作了,distance2Box然后做multiclass_nms。还有一点就是作者计算分类的评分时使用的sigmoid函数,一个detection box有可能分配多个标签,直观上NanodetPlus应该对不同类别的物体遮挡有相对好的检测效果。具体可以参考nanodet/model/module/nms.pymulticlass_nms函数的下面部分代码:

def multiclass_nms(...args):
    num_classes = multi_scores.size(1) - 1
    # exclude background category
    if multi_bboxes.shape[1] > 4:
        bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
        bboxes = multi_bboxes[:, None].expand(multi_scores.size(0), num_classes, 4)
    scores = multi_scores[:, :-1]

    # filter out boxes with low scores
    valid_mask = scores > score_thr

    # We use masked_select for ONNX exporting purpose,
    # which is equivalent to bboxes = bboxes[valid_mask]
    # we have to use this ugly code
    bboxes = torch.masked_select(
        bboxes, torch.stack((valid_mask, valid_mask, valid_mask, valid_mask), -1)
    ).view(-1, 4)
    if score_factors is not None:
        scores = scores * score_factors[:, None]
    scores = torch.masked_select(scores, valid_mask)
    labels = valid_mask.nonzero(as_tuple=False)[:, 1]




其实部分特征图完全可以通过一次线性变换 Φ i \Phi_{i} Φi来实现,因此卷积层输出的通道就部分来自于卷积,部分通过对卷积结果线性变换得到,concat后得到最终的输出。




