基于检测代码库detectron2和蒸馏代码库RepDistiller,完成将蒸馏方法应用在目标检测的代码库,完整代码已开源。
1. 参数添加
在config/defaults.py
里面添加蒸馏参数的默认值,同时类似于定义了变量
_C.DISTILL = CN()
_C.DISTILL.DO = False
_C.DISTILL.PATH_T = None
_C.DISTILL.CFG_T = None
_C.DISTILL.DISTILL = 'kd'
_C.DISTILL.B = None
在实际使用的yaml
文件中添加实际使用的蒸馏参数
DISTILL:
ENABLE: True
PATH_T: 'detectron2/teacher_models/FasterRCNN-R101-FPN-lr3x.pkl'
CFG_T: 'configs/COCO-Detection/teacher_models/faster_rcnn_R_101_FPN_3x.yaml'
DISTILL: 'hint'
TRIAL: 1.0
B: 1.0
2. 导入教师模型
根据detectron2官方提供的下载地址下载教师模型,并放置在detectron2/teacher_models
文件夹下:
配置教师模型对应的参数,放置在configs/COCO-Detection/teacher_models
文件夹下:
3. 计算蒸馏损失
在GeneralizedRCNN
的foward
函数中添加蒸馏损失,初始化和调用时添加蒸馏参数
from detectron2.distiller_helper.distill import Distill
class GeneralizedRCNN(nn.Module):
@configurable
def __init__(
self,
...
distill_cfg: dict = None
):
super().__init__()
...
self.distill_cfg = distill_cfg
if distill_cfg.ENABLE:
self.distill_model = Distill(self.distill_cfg)
@classmethod
def from_config(cls, cfg):
backbone = build_backbone(cfg)
return {
...
"distill_cfg": cfg.DISTILL
}
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
...
if self.distill_cfg.ENABLE:
distill_losses = self.distill_model.compute_distill_loss(batched_inputs, features)
losses.update(distill_losses)
return losses
创建get_features
函数,用来输出教师特征图
def get_features(self, batched_inputs):
'''
return: dict
keys: 'p2'-'p6'
values: 分别对应FPN第n层的特征层
'''
images = self.preprocess_image(batched_inputs)
features = self.backbone(images.tensor)
return features
创建distill.py
文件,用来构建教师模型和计算蒸馏损失,放置在detectron2/distiller_helper
文件夹下。
from detectron2.config import get_cfg
from detectron2.modeling.meta_arch import build_model
from detectron2.distiller_zoo import HintLoss, Attention, Similarity, NSTLoss, RKDLoss, PKT
from detectron2.checkpoint import DetectionCheckpointer
import torch
class Distill():
def __init__(self, distill_cfg):
self.opt = distill_cfg
self.model_t = self.build_teacher_model()
def build_teacher_model(self):
'''构建教师模型'''
teacher_cfg = get_cfg()
teacher_cfg.merge_from_file(self.opt.CFG_T)
teacher_cfg['student_identity'] = False
model_t = build_model(teacher_cfg)
DetectionCheckpointer(model_t).resume_or_load(self.opt.PATH_T, resume=False)
return model_t
def compute_distill_loss(self, batched_inputs, logit_s):
'''计算蒸馏损失'''
with torch.no_grad():
logit_t = self.model_t.get_features(batched_inputs)
if self.opt.DISTILL == 'hint':
criterion_kd = HintLoss()
elif self.opt.DISTILL == 'attention':
criterion_kd = Attention()
elif self.opt.DISTILL == 'nst':
criterion_kd = NSTLoss()
elif self.opt.DISTILL == 'similarity':
criterion_kd = Similarity()
elif self.opt.DISTILL == 'rkd':
criterion_kd = RKDLoss()
elif self.opt.DISTILL == 'pkt':
criterion_kd = PKT()
elif self.opt.DISTILL == 'kdsvd':
criterion_kd = KDSVD()
else:
raise NotImplementedError(self.opt.DISTILL)
loss_kd = 0
for t,s in zip(logit_t.values(), logit_s.values()):
loss_kd += criterion_kd(s, t.detach())
loss_kd = loss_kd/5 * self.opt.B
distill_losses = {'loss_kd': loss_kd}
return distill_losses
还有一些其它细节,就不赘述了,完整代码见我的Github
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)