在onnx opset 12下转以下模型时因不支持hardswish激活函数而报错
- GhostNet
- MobileNetv3Small
- EfficientNetLite0
- PP-LCNet
解决方案是找到对应的nn.Hardswish
层,将其替换为自己覆写的Hardswish
实现:
class Hardswish(nn.Module):
@staticmethod
def forward(x):
return x * F.hardtanh(x + 3, 0., 6.) / 6.
以PP-LCNet
为例,找到哪些层是Hardswish
层,替换方法为
def _set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
sub_tokens = tokens[:-1]
cur_mod = model
for s in sub_tokens:
cur_mod = getattr(cur_mod, s)
setattr(cur_mod, tokens[-1], module)
for k, m in model.named_modules():
if 'dw_sp.2' in k or 'dw_sp.6' in k:
_set_module(model, k, Hardswish())
当然也可以根据m
来判断是否为nn.Hardswish
的实例,
for k, m in model.named_modules():
if isinstance(m, nn.Hardswish):
_set_module(model, k, Hardswish())
参考
YOLOv5-Multibackbone-Compression
Pytorch替换model对象任意层的方法
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)