SOLO代码阅读解析

2023-10-26

SOLO是一种直接预测instance mask的范式,摒弃了之前top-down和bottom-up两种主流的实例分割方法,从而pipeline更加简洁直观。这篇文章以官方代码中的demo为例,简单梳理一下SOLO在inference时的流程。整个代码基于mmdetection。

首先是demo.inference_demo.py

config_file = '../configs/solo/decoupled_solo_r50_fpn_8gpu_3x.py'
# download the checkpoint from model zoo and put it in `checkpoints/`
checkpoint_file = '../checkpoints/DECOUPLED_SOLO_R50_3x.pth'

# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:0')

# test a single image
img = 'demo.jpg'
result = inference_detector(model, img)

show_result_ins(img, result, model.CLASSES, score_thr=0.25, out_file="demo_out.jpg")

上述代码很简单,init_detector创建model,inference_detector做正向inference,并且show出最后的result。核心在于init_detector和inference_detector。这两个function存在于mmdet.apis中,下面看下这个模块:

mmdet.apis.inferece.py

def init_detector(config, checkpoint=None, device='cuda:0'):
    """Initialize a detector from config file.

    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.

    Returns:
        nn.Module: The constructed detector.
    """
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        'but got {}'.format(type(config)))
    config.model.pretrained = None
    model = build_detector(config.model, test_cfg=config.test_cfg)
    if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint)
        if 'CLASSES' in checkpoint['meta']:
            model.CLASSES = checkpoint['meta']['CLASSES']
        else:
            warnings.warn('Class names are not saved in the checkpoint\'s '
                          'meta data, use COCO classes by default.')
            model.CLASSES = get_classes('coco')
    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model


def inference_detector(model, img):
    """Inference image(s) with the detector.

    Args:
        model (nn.Module): The loaded detector.
        imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
            images.

    Returns:
        If imgs is a str, a generator will be returned, otherwise return the
        detection results directly.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
    test_pipeline = Compose(test_pipeline)
    # prepare data
    data = dict(img=img)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]
    # forward the model
    with torch.no_grad():
        result = model(return_loss=False, rescale=True, **data)
    return result

对于init_detector,其核心函数是build_detector,根据config文件信息创建模型,并将checkpoint加载进来;而inference_detector更简单了,首先做一系列augmentation,然后调用model做inference即可。

那么接下来仍然是两个分支,build_detector是如何创建模型的,以及该模型如何做inference,分开来说。

build_detector

build_detector方法存在于mmdet.model.builder.py

from mmdet.utils import build_from_cfg
from .registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
                       ROI_EXTRACTORS, SHARED_HEADS)


def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)


def build_backbone(cfg):
    return build(cfg, BACKBONES)


def build_neck(cfg):
    return build(cfg, NECKS)


def build_roi_extractor(cfg):
    return build(cfg, ROI_EXTRACTORS)


def build_shared_head(cfg):
    return build(cfg, SHARED_HEADS)


def build_head(cfg):
    return build(cfg, HEADS)


def build_loss(cfg):
    return build(cfg, LOSSES)


def build_detector(cfg, train_cfg=None, test_cfg=None):
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

build_detector方法又调用了build方法,而build方法中调用了build_from_cfg。注意:在调用build方法中传入了DETECTORS这个注册器(Registry,一个类,传入的参数该class的一个实例,每一个部分i.e. backbone,FPN etc. 都对应一个Registry实例),可以先理解为创建这些module以及分开进行管理。

接着看mmdet.utils.registry.py中的build_from_cfg:

def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        obj: The constructed object.
    """
    assert isinstance(cfg, dict) and 'type' in cfg
    assert isinstance(default_args, dict) or default_args is None
    args = cfg.copy()
    obj_type = args.pop('type')
    if mmcv.is_str(obj_type):
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError('{} is not in the {} registry'.format(
                obj_type, registry.name))
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError('type must be a str or valid type, but got {}'.format(
            type(obj_type)))
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_cls(**args)

这里其实就是对注册器进行注册的部分,也就是说通过config中的字典来对模型进行搭建。obj_cls就是要创建的module,如SOLO,ResNet,FPN等等,只有某个注册器中有配置文件中存在的type时,才会对该注册器进行register,通过args中的dict得到相应的module。这里一开始obj_cls返回的是SOLO(可以refer下配置文件),所以我们要找到SOLO这个模型的文件:

mmdet.models.detectors.solo.py

@DETECTORS.register_module
class SOLO(SingleStageInsDetector):

    def __init__(self,
                 backbone,
                 neck,
                 bbox_head,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(SOLO, self).__init__(backbone, neck, bbox_head, None, train_cfg,
                                   test_cfg, pretrained)

可见第一行用了一个装饰器,也就是说在创建SOLO实例的时候,首先就自动调用装饰器中的方法,并且把SOLO这个类作为参数,注册到注册器DETECTORS里面。而SOLO又是继承自SingleStageInsDetector,所以接下来重点是SingleStageInsDetector类:

mmdet.models.detectors.single_stage_ins.py

@DETECTORS.register_module
class SingleStageInsDetector(BaseDetector):

    def __init__(self,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 mask_feat_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(SingleStageInsDetector, self).__init__()
        self.backbone = builder.build_backbone(backbone)
        if neck is not None:
            self.neck = builder.build_neck(neck)
        if mask_feat_head is not None:
            self.mask_feat_head = builder.build_head(mask_feat_head)

        self.bbox_head = builder.build_head(bbox_head)
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.init_weights(pretrained=pretrained)

上面是SingleStageInsDetector的核心代码,之前是将args作为参数传入作为这里的初始化。根据之前的config,依次创建模型的backbone,neck,bbox_head以及test_config(这里是inference),这些部分的创建又对应到builder中的函数,每一个module对应一个Registry,然后根据相应的config文件中的参数建立不同的module,最后都作为类内部变量,集中在这一个SingleStageInsDetector中。具体每一个module创建的代码就不贴了,无非是将args传递进去,根据现有的代码创建相应的模块。

至此模型的创建工作大致如此,下面来看Inference的过程。

Inference

SOLO类的forward继承自BaseDetector,其forward方法如下:

    def forward_test(self, imgs, img_metas, **kwargs):
        。。。。。。

        if num_augs == 1:
            return self.simple_test(imgs[0], img_metas[0], **kwargs)
        else:
            return self.aug_test(imgs, img_metas, **kwargs)

    @auto_fp16(apply_to=('img', ))
    def forward(self, img, img_meta, return_loss=True, **kwargs):
        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs)

以单gpu为例,调用的是simple_test,这个函数在SingleStageInsDetector中被重写过,如下:

    def extract_feat(self, img):
        x = self.backbone(img)
        if self.with_neck:
            x = self.neck(x)
        return x
        
    def simple_test(self, img, img_meta, rescale=False):
        x = self.extract_feat(img)
        outs = self.bbox_head(x, eval=True)

        if self.with_mask_feat_head:
            mask_feat_pred = self.mask_feat_head(
                x[self.mask_feat_head.
                  start_level:self.mask_feat_head.end_level + 1])
            seg_inputs = outs + (mask_feat_pred, img_meta, self.test_cfg, rescale)
        else:
            seg_inputs = outs + (img_meta, self.test_cfg, rescale)
        seg_result = self.bbox_head.get_seg(*seg_inputs)
        return seg_result  

这里Inference的顺序依次是backbone->neck->bbox_head,backbone为ResNet50,neck为FPN,bbox_head为(decoupled)solo_head。所以前面特征提取部分的代码很简单,就不做过多赘述。主要来看下bbox_head:

mmdet.models.anchor_heads.decoupled_solo_head.py

@HEADS.register_module
class DecoupledSOLOHead(nn.Module):
    def __init__(self,
                 num_classes,
                 in_channels,
                 seg_feat_channels=256,
                 stacked_convs=4,
                 strides=(4, 8, 16, 32, 64),
                 base_edge_list=(16, 32, 64, 128, 256),
                 scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
                 sigma=0.4,
                 num_grids=None,
                 cate_down_pos=0,
                 with_deform=False,
                 loss_ins=None,
                 loss_cate=None,
                 conv_cfg=None,
                 norm_cfg=None):
        super(DecoupledSOLOHead, self).__init__()
        self.num_classes = num_classes
        self.seg_num_grids = num_grids
        self.cate_out_channels = self.num_classes - 1
        self.in_channels = in_channels
        self.seg_feat_channels = seg_feat_channels
        self.stacked_convs = stacked_convs
        self.strides = strides
        self.sigma = sigma
        self.cate_down_pos = cate_down_pos
        self.base_edge_list = base_edge_list
        self.scale_ranges = scale_ranges
        self.with_deform = with_deform
        self.loss_cate = build_loss(loss_cate)
        self.ins_loss_weight = loss_ins['loss_weight']
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self._init_layers()

    def _init_layers(self):
        norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
        self.ins_convs_x = nn.ModuleList()
        self.ins_convs_y = nn.ModuleList()
        self.cate_convs = nn.ModuleList()

        for i in range(self.stacked_convs):
            #第一层+1表示采用coordconv concat上的position(如果非decouple则+2)
            chn = self.in_channels + 1 if i == 0 else self.seg_feat_channels
            # ins_x分支几个卷积+norm模块
            self.ins_convs_x.append(
                ConvModule(
                    chn,
                    self.seg_feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    bias=norm_cfg is None))
            # ins_y分支几个卷积+norm模块
            self.ins_convs_y.append(
                ConvModule(
                    chn,
                    self.seg_feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    bias=norm_cfg is None))

            chn = self.in_channels if i == 0 else self.seg_feat_channels
            # cate分支几个卷积+norm模块
            self.cate_convs.append(
                ConvModule(
                    chn,
                    self.seg_feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    bias=norm_cfg is None))

        self.dsolo_ins_list_x = nn.ModuleList()
        self.dsolo_ins_list_y = nn.ModuleList()
        #每一个level对应的num_grid不同,针对所有level的feature设计对应维度的卷积
        for seg_num_grid in self.seg_num_grids:
            self.dsolo_ins_list_x.append(
                nn.Conv2d(
                    self.seg_feat_channels, seg_num_grid, 3, padding=1))
            self.dsolo_ins_list_y.append(
                nn.Conv2d(
                    self.seg_feat_channels, seg_num_grid, 3, padding=1))
        self.dsolo_cate = nn.Conv2d(
            self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
            
     def forward(self, feats, eval=False):
#        for i in feats:
#            print(i.shape)
#        torch.Size([1, 256, 200, 304])
#        torch.Size([1, 256, 100, 152])
#        torch.Size([1, 256, 50, 76])
#        torch.Size([1, 256, 25, 38])
#        torch.Size([1, 256, 13, 19])

        new_feats = self.split_feats(feats)      
#        for i in new_feats:
#            print(i[0].shape)
#		torch.Size([256, 100, 152])
#		torch.Size([256, 100, 152])
#		torch.Size([256, 50, 76])
#		torch.Size([256, 25, 38])
#		torch.Size([256, 25, 38])

            
        featmap_sizes = [featmap.size()[-2:] for featmap in new_feats]
        upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2)
#        print(upsampled_size)   (200, 304)
        ins_pred_x, ins_pred_y, cate_pred = multi_apply(self.forward_single, new_feats,
                                                        list(range(len(self.seg_num_grids))),
                                                        eval=eval, upsampled_size=upsampled_size)
        return ins_pred_x, ins_pred_y, cate_pred

    def split_feats(self, feats):
        return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'), 
                feats[1], 
                feats[2], 
                feats[3], 
                F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))

    def forward_single(self, x, idx, eval=False, upsampled_size=None):
        ins_feat = x
        cate_feat = x
        # ins branch
        # concat coord
        x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
        y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
        y, x = torch.meshgrid(y_range, x_range)
        y = y.expand([ins_feat.shape[0], 1, -1, -1])
        x = x.expand([ins_feat.shape[0], 1, -1, -1])
#        print(ins_feat.shape)
#        print(x.shape)
        ins_feat_x = torch.cat([ins_feat, x], 1)
        ins_feat_y = torch.cat([ins_feat, y], 1)
#        print(ins_feat_x.shape)  (1, 256 + 1, ?, ?)

        for ins_layer_x, ins_layer_y in zip(self.ins_convs_x, self.ins_convs_y):
            ins_feat_x = ins_layer_x(ins_feat_x)
            ins_feat_y = ins_layer_y(ins_feat_y)

        ins_feat_x = F.interpolate(ins_feat_x, scale_factor=2, mode='bilinear')
        ins_feat_y = F.interpolate(ins_feat_y, scale_factor=2, mode='bilinear')

        ins_pred_x = self.dsolo_ins_list_x[idx](ins_feat_x)
        ins_pred_y = self.dsolo_ins_list_y[idx](ins_feat_y)
#        print(ins_pred_x.shape)   对应到每个feat_map对应的grid (1,256,?,?)->(1,40/36/24/16/12,?,?)

        # cate branch
        for i, cate_layer in enumerate(self.cate_convs):
            if i == self.cate_down_pos:
                seg_num_grid = self.seg_num_grids[idx] 	# idx对应特征图的level,不同level的num_grid不同
                cate_feat = F.interpolate(cate_feat, size=seg_num_grid, mode='bilinear')
            cate_feat = cate_layer(cate_feat)

        cate_pred = self.dsolo_cate(cate_feat)
#        print(cate_pred.shape)    (1, 80, num_grid, num_grid)

        if eval:
            ins_pred_x = F.interpolate(ins_pred_x.sigmoid(), size=upsampled_size, mode='bilinear')
            ins_pred_y = F.interpolate(ins_pred_y.sigmoid(), size=upsampled_size, mode='bilinear')
            cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
        return ins_pred_x, ins_pred_y, cate_pred

上面的代码是solo_head正向传播以后得到的结果:ins_pred_x, ins_pred_y, cate_pred。但并不是完整的Inference,最终的maks生成还需要进行下面两个函数的操作:

    def get_seg(self, seg_preds_x, seg_preds_y, cate_preds, img_metas, cfg, rescale=None):
        assert len(seg_preds_x) == len(cate_preds)
        num_levels = len(cate_preds)
#        print(num_levels)     5
        featmap_size = seg_preds_x[0].size()[-2:]
#        print(featmap_size)   [200, 304]

#        for i in range(5):
#            print(seg_preds_x[i].shape)
#            print(cate_preds[i].shape)
#       torch.Size([1, 40, 200, 304])
#		torch.Size([1, 40, 40, 80])
#		torch.Size([1, 36, 200, 304])
#		torch.Size([1, 36, 36, 80])
#		torch.Size([1, 24, 200, 304])
#		torch.Size([1, 24, 24, 80])
#		torch.Size([1, 16, 200, 304])
#		torch.Size([1, 16, 16, 80])
#		torch.Size([1, 12, 200, 304])
#		torch.Size([1, 12, 12, 80])

        result_list = []
        #由于是demo,这里只有一张img
        for img_id in range(len(img_metas)):
            cate_pred_list = [
                cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
            ]
#            print(cate_pred_list[0].shape)  (num_grid*num_grid, 80)
            seg_pred_list_x = [
                seg_preds_x[i][img_id].detach() for i in range(num_levels)
            ]
#            print(seg_pred_list_x[0].shape)    #(num_grid, 200, 304)
            seg_pred_list_y = [
                seg_preds_y[i][img_id].detach() for i in range(num_levels)
            ]
            img_shape = img_metas[img_id]['img_shape']
            scale_factor = img_metas[img_id]['scale_factor']
            ori_shape = img_metas[img_id]['ori_shape']

            cate_pred_list = torch.cat(cate_pred_list, dim=0)    #(3872, 80) == (40^2+36^2+24^2+16^2+12^2, 80)
            seg_pred_list_x = torch.cat(seg_pred_list_x, dim=0)    #(128, 200, 304) == (40+36+24+16+12, 200, 304)
#            print(seg_pred_list_x.shapes)
            seg_pred_list_y = torch.cat(seg_pred_list_y, dim=0)

            result = self.get_seg_single(cate_pred_list, seg_pred_list_x, seg_pred_list_y,
                                         featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
            result_list.append(result)
        return result_list

    def get_seg_single(self,
                       cate_preds,
                       seg_preds_x,
                       seg_preds_y,
                       featmap_size,
                       img_shape,
                       ori_shape,
                       scale_factor,
                       cfg,
                       rescale=False, debug=False):


        # overall info.
        h, w, _ = img_shape
        upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)   # 原图大小

        # trans trans_diff.
        trans_size = torch.Tensor(self.seg_num_grids).pow(2).cumsum(0).long()    # [1600, 2896, 3472, 3728, 3872]
        trans_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
        num_grids = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
        seg_size = torch.Tensor(self.seg_num_grids).cumsum(0).long()
        seg_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
        strides = torch.ones(trans_size[-1].item(), device=cate_preds.device)	# [1, 1, ..., 1]

        n_stage = len(self.seg_num_grids)
        trans_diff[:trans_size[0]] *= 0
        seg_diff[:trans_size[0]] *= 0
        num_grids[:trans_size[0]] *= self.seg_num_grids[0]
#        print(self.strides)	[8, 8, 16, 32, 32]
        strides[:trans_size[0]] *= self.strides[0]

        for ind_ in range(1, n_stage):
            trans_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= trans_size[ind_ - 1]
            seg_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= seg_size[ind_ - 1]
            num_grids[trans_size[ind_ - 1]:trans_size[ind_]] *= self.seg_num_grids[ind_]
            strides[trans_size[ind_ - 1]:trans_size[ind_]] *= self.strides[ind_]	# [0-1599:8, 1600-2895:8, 2896-3471: 16, 2372-3871:32]

        # process.
        inds = (cate_preds > cfg.score_thr)
#        print(inds.shape)    # [3872, 80]布尔矩阵  
        cate_scores = cate_preds[inds]
#        print(cate_scores)    # [3872, 80]

        inds = inds.nonzero()
#        print(inds.shape)  # (n, 2) n表示有多少个分数>thres
        trans_diff = torch.index_select(trans_diff, dim=0, index=inds[:, 0])
        seg_diff = torch.index_select(seg_diff, dim=0, index=inds[:, 0])
        num_grids = torch.index_select(num_grids, dim=0, index=inds[:, 0])
        strides = torch.index_select(strides, dim=0, index=inds[:, 0])

        y_inds = (inds[:, 0] - trans_diff) // num_grids
        x_inds = (inds[:, 0] - trans_diff) % num_grids
        y_inds += seg_diff
        x_inds += seg_diff

        cate_labels = inds[:, 1]
#        print(cate_labels)	# n维向量,表示类别num
        seg_masks_soft = seg_preds_x[x_inds, ...] * seg_preds_y[y_inds, ...]	# [n, 200, 304]
        seg_masks = seg_masks_soft > cfg.mask_thr
        sum_masks = seg_masks.sum((1, 2)).float()		# [n, 1]
        keep = sum_masks > strides		# 进一步筛除,总的mask之和小于stride就筛掉
#        print(keep)

        seg_masks_soft = seg_masks_soft[keep, ...]
        seg_masks = seg_masks[keep, ...]
        cate_scores = cate_scores[keep]
        sum_masks = sum_masks[keep]
        cate_labels = cate_labels[keep]
        # maskness
        seg_score = (seg_masks_soft * seg_masks.float()).sum((1, 2)) / sum_masks
        cate_scores *= seg_score

        if len(cate_scores) == 0:
            return None

        # sort and keep top nms_pre
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.nms_pre:
            sort_inds = sort_inds[:cfg.nms_pre]
        seg_masks_soft = seg_masks_soft[sort_inds, :, :]
        seg_masks = seg_masks[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        sum_masks = sum_masks[sort_inds]
        cate_labels = cate_labels[sort_inds]
#        print(cate_scores)

        # Matrix NMS
        cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
                                 kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)
#        print(cate_scores)		#维度并没变,只是将IOU高的部分的score降低,类似于soft-NMS

        keep = cate_scores >= cfg.update_thr
        seg_masks_soft = seg_masks_soft[keep, :, :]
        cate_scores = cate_scores[keep]
#        print(cate_scores.shape)		#筛掉一部分
        cate_labels = cate_labels[keep]
        # sort and keep top_k
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.max_per_img:		# coco数据集最大一张img100个instance
            sort_inds = sort_inds[:cfg.max_per_img]
        seg_masks_soft = seg_masks_soft[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        # 将mask的resolution还原到original图像大小
        seg_masks_soft = F.interpolate(seg_masks_soft.unsqueeze(0),
                                    size=upsampled_size_out,
                                    mode='bilinear')[:, :, :h, :w]
        seg_masks = F.interpolate(seg_masks_soft,
                               size=ori_shape[:2],
                               mode='bilinear').squeeze(0)
        seg_masks = seg_masks > cfg.mask_thr

        return seg_masks, cate_labels, cate_scores

最后在demo中在Matrix NMS之后,选择的筛除阈值为0.05,这个值有点小导致很多有小目标的img筛出来100个,最后demo在展示结果的时候又采用了0.25的阈值,这里会不会有些矛盾。

清明过后补上训练部分的代码解读。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

SOLO代码阅读解析 的相关文章

  • FISCO-BCOS如何把WEBASE部署通过的合约方法由api在前端调用

    参考文章 fisco bcos官方文档第五章部分 通过POST请求 数据格式要为json 调用hello合约中的get方法 按要求填写需要的信息
  • 决策树的学习

    决策树 从名字上看 就知道其模型的结构为树结构 决策树既可以用于分类 也可以用于回归之中 在分类问题中 我们可以认为其是if then规则的集合 也可以认为是定义在特征空间与类空间上的条件概率分布 在学习过程中 利用训练数据和损失函数最小化

随机推荐

  • 因果推理的do算子

    在因果推理中 我们一般都需要首先构建一个因果图 这是后续进行因果分析的基础 但是在现阶段笔者的知识看来 因果图的构建其实是一个比较主观的过程 但偏偏又是后续分析的基础 所以略感头疼 在构建因果图前 我们有必要明白 什么是因果关系 通俗来说
  • 【JUC并发编程】

    本笔记内容为狂神说JUC并发编程部分 目录 一 什么是JUC 二 线程和进程 1 概述 2 并发 并行 3 线程有几个状态 4 wait sleep 区别 三 Lock锁 重点 四 生产者和消费者问题 五 八锁现象 六 集合类不安全 七 C
  • 统计字符串中,中文字符、英文字符和数字字符的数量

    package com suanfa public class ZYSTotal 统计字符串中 中文字符 英文字符和数字字符的数量 public static void main String args int englishCount 0
  • 指针和数组的相关练习题

    目录 一 一维数组 二 字符数组 三 二维数组 注意 假设本练习题所用的VS编译器是64位平台下的 首先要明白数组名的意义 1 sizeof 数组名 这里的数组名表示整个数组 计算的是整个数组的大小 2 数组名 这里的数组名表示整个数组 取
  • 帆软之图表详解

    帆软之图表详解 饼图 饼图 玫瑰图 玫瑰图和饼图类似 仅选择不同的图例即可 多分类饼图 注 标题居中不是直接显示标题居中 而是隐藏标题偶按照下面的方法将标题加上去 柱状图 柱状图设置柱子宽度 boby 样式 系列 固定柱宽 注意事项 问题描
  • 4.3寸串口屏在智能炒菜机上应用分享

    现代人追求高效品质生活的美好愿望以及社会科技的不断发展持续推动着一种新兴经济形态的出现 即懒人经济 懒人经济的崛起也成为智能家电行业新的增长引擎 自动炒菜机便是这一经济形态下的产物 对于很多居住于快节奏生活的一二线城市人来说 在辛苦工作一整
  • vue3 递归无限分类树型菜单+搜索功能

    我们先来看一下大致实现效果 数据可以无限向下增加 搜索关键字会自动展开数据 vue3树形结构菜单 搜索 首先我这个需要自己设计数据源 一定要先搞清楚数据是什么结构才能顺利开展下一步 有接口的同学可以忽略这一步 其中children顾名思义
  • 区块链是如何做到交易记录不可被篡改的

    区块链是如何做到交易记录不可被篡改的 星目 关注 2017 07 19 23 03 字数 1912 阅读 1654评论 4喜欢 1 BlockChain 比特币前一阵子一度超过2万元一枚 而且长期来看这远远不是它的极限 假如你手里有比特币
  • Python实现队列

    Python实现队列 关于队列的介绍 请参考 https blog csdn net weixin 43790276 article details 104033337 队列的数据存储结构可以是顺序表 也可以是链表 本篇使用 Python
  • Keil中工程文件编译后没有显示.h文件

    一 第一种解决方法 打开Keil软件 重新打开试试 二 第二种解决方法 查看是否点击了Show include File Dependencies 1 右击源组 记住 一定是右击 不是双击 就可以看到如下画面 如果你没有打勾的话 那就是如下
  • 开源项目哪家强?Github年终各大排行榜超级盘点(内附开源项目学习资源)

    整理 Jane 出品 AI科技大本营 导语 提到开源项目 2018 年注定是不平凡的一年 据 Octoverse 报告数据 仅在 2018 年 Github 上的新用户就比过去六年的用户总数还要多 存储库数量近一亿 这些增长都要归功于开源社
  • Linux 系统中kill命令杀死进程常用技巧

    目录 前言 基础 进阶 1 查找进程号的方式进行改进 2 将常规的两步杀死进程合并为一步 3 强制踢掉登陆用户 kill的注意事项 前言 在Linux的系统中 kill是我们最常见的命令之一 kill 英语中为杀死的意思 顾名思义 就是用来
  • pytorch中使用detach()

    import torch nn as nn import torch class net nn Module def init self super init self conv nn Conv2d 3 6 3 stride 2 paddi
  • GOOGLE地图基站定位-Google Mobile Maps API

    如果你在你的手机装过Google Mobile Maps 你就可以发现只要你的手机能连接GPRS 即使没有GPS功能 也能定位到你手机所在的位置 只是精度不够准确 在探讨这个原理之前 我们需要了解一些移动知识 了解什么是MNC LAC Ce
  • Spark SQL 基本操作

    将下列JSON格式数据复制到Linux系统中 并保存命名为employee json id 1 name Ella age 36 id 2 name Bob age 29 id 3 name Jack age 29 id 4 name Ji
  • 【财富空间】一场史无前例的白领破产潮,正在来袭!

    来源 水木然 ID smr8700 最近 我们在上海做一个项目 准备招一个部门经理 于是发了招聘启示 应聘的人符合条件的很少 这不算什么 最令我吃惊的是 他们基本上个个都是要求年薪百万以上 放眼四望 诺大的上海 除去垄断国企 大牌的外企 再
  • xshell及xftp更新提示:Xshell出现要继续使用此程序必须应用到最新的更新或使用新版本

    一 前言 java开发者或者linux运维都肯定会用到xshell及xftp工具 说实话这两个工具真心是好用 但是有两个问题一直困扰这我 1 每次打开xshell或xftp总是会提示更新 2 今天打开xftp的时候 突然提示 要继续使用此程
  • Arduino基础 — Arduino 字符串

    Arduino 字符串 在Arduino编程中有两位字符串 1 字符数组 与C语言编程使用相同 2 Arduino 字符串 它允许我们在代码中使用字符对象 字符串数组 字符串是一个特殊的数组 在字符串的末尾有一个额外的元素 其值总是为0 零
  • 面试鹅厂,我三面被虐的体无完肤……

    戳蓝字 CSDN云计算 关注我们哦 作者 codegoose 来源 https segmentfault com a 1190000017864721 经过半年的沉淀 加上对MySQL redis和分布式这块的补齐 终于重拾面试信心 再次出
  • SOLO代码阅读解析

    SOLO是一种直接预测instance mask的范式 摒弃了之前top down和bottom up两种主流的实例分割方法 从而pipeline更加简洁直观 这篇文章以官方代码中的demo为例 简单梳理一下SOLO在inference时的