【pytorch】固定(freeze)住部分网络

2023-11-10

前言

最好、最高效、最简洁的,是 “方案一” 。

方案一

步骤一、固定基本网络

代码模板:

# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')

# 导入之(记得strict=False):
model.load_state_dict(pre_state_dict, strict=False)
print('Load model from %s success.' % model_path)

# 固定基本网络:
model = freeze_model(model=model, to_freeze_dict=pre_state_dict)
其中 freeze_model 函数如下: 
def freeze_model(model, to_freeze_dict, keep_step=None):

    for (name, param) in model.named_parameters():
        if name in to_freeze_dict:
            param.requires_grad = False
        else:
            pass

    # # 打印当前的固定情况(可忽略):
    # freezed_num, pass_num = 0, 0
    # for (name, param) in model.named_parameters():
    #     if param.requires_grad == False:
    #         freezed_num += 1
    #     else:
    #         pass_num += 1
    # print('\n Total {} params, miss {} \n'.format(freezed_num + pass_num, pass_num))

    return model

Note:

  • 如果预加载模型是在 model = nn.DataParallel(model) 模式下训练出来的分布式模型,那么每个参数名称会默认加上 .module 前缀。
  • 相应地,会导致无法对号导入单机模型。此时需要将如下语句:
# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
改为: 
# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
pre_state_dict = {k.replace('module.', ''): v for k, v in pre_state_dict.items()}

步骤二、让optimizer回避要freeze的参数

代码模板:

optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)

步骤三、train时通过.eval()来freeze

因为:即使对bn设置了 requires_grad = False ,一旦 model.train() ,bn还是会偷偷开启update( model.eval()模式下就又停止update )。(详见【pytorch】bn
所以:train每个epoch之前都要统一重新定义一下这块,否则容易出问题。

model.eval()
model.stage4_xx.train()
model.pred_xx.train()

方案二

pytorch下进行freeze操作,一般需要经过以下四步。

步骤一、固定基本网络

代码模板:

# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')

# 导入之(记得strict=False):
model.load_state_dict(pre_state_dict, strict=False)
print('Load model from %s success.' % model_path)

# 固定基本网络:
model = freeze_model(model=model, to_freeze_dict=pre_state_dict)
其中 freeze_model 函数如下: 
def freeze_model(model, to_freeze_dict, keep_step=None):

    for (name, param) in model.named_parameters():
        if name in to_freeze_dict:
            param.requires_grad = False
        else:
            pass

    # # 打印当前的固定情况(可忽略):
    # freezed_num, pass_num = 0, 0
    # for (name, param) in model.named_parameters():
    #     if param.requires_grad == False:
    #         freezed_num += 1
    #     else:
    #         pass_num += 1
    # print('\n Total {} params, miss {} \n'.format(freezed_num + pass_num, pass_num))

    return model

步骤二、让optimizer回避要freeze的参数

代码模板:

optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)

步骤三、固定bn

即使通过步骤一对bn设置了 requires_grad = False ,一旦 model.train() ,bn还是会偷偷开启update( model.eval()模式下就又停止update )。
所以还需要额外地深入固定bn:

  • 固定 momentum :momentum=0.0
  • 掐灭 track_running_stats :track_running_stats=False

举例:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)

修改为:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)

但是 track_running_stats=False 会带来副作用:受波及的每个bn都会在state_dict中丢失三个对应的键值对(每组对应的key都为xx.xx.bn.running_mean、xx.xx.bn.running_var 和 xx.xx.bn.num_batches_tracked)

步骤四、正常训练

训练过程中,记得定时check一下被固定部分是否恒定不变:

  • 比如每次eval的时候,顺便check一下被固定部分的预测精度。

步骤五、后处理

4.1 重启track_running_stats

举例:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)

修改为:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0)

此时,之前受波及的每个bn,都会在state_dict中恢复所丢失三个对应的键值对(但是value为空,待填充)。

Note:

  • 线上训练虽然用freeze过的网络,但线下测试时,还是要老老实实换回未被freeze的网络。否则结果不仅会对不齐,被freeze和未被freeze的task都会表现更差!
4.2 复原缺失的value

为了克服 track_running_stats=False 带来的副作用,最终模型需要依赖 “原始state_dict” 和 “训好的state_dict” 合并。前者为后者补充缺失的value。

# 原始state_dict:
origin_state_dict = torch.load(origin_model_path, map_location=torch.device('cpu'))
# 训好的state_dict:
new_state_dict = torch.load(new_model_path, map_location=torch.device('cpu'))

# 后者从前者中补充缺失的键值对:
final_dict = new_state_dict.copy()
for (key, val) in origin_state_dict.items():
    if key not in final_dict:
        final_dict[key] = val

# 载入合并好的 state_dict,这时候一定是可以通过 strict=True 的:
model.load_state_dict(final_dict, strict=True)
这时重新再save一遍model,就是可最终直接用的model文件了。 
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【pytorch】固定(freeze)住部分网络 的相关文章

  • Python 按文件夹模块导入

    我有一个目录结构 example py templates init py a py b py a py and b py只有一个类 名称与文件相同 因为它们是猎豹模板 纯粹出于风格原因 我希望能够在中导入和使用这些类example py像
  • WTForms 中的小数字段舍入

    我有一个包含价格小数字段的表单 如下所示 from flask ext wtf import Form import wtforms from wtforms validators import DataRequired from deci
  • ValueError:在 R 中使用 keras 模型时在用户代码中

    我正在尝试使用 R 在 R 中运行一维 CNNkeras包裹 我正在使用以下代码 library MASS library keras Create some data data Boston data lt Boston create a
  • 如何最好地将包含列表或元组的 Pandas 列提取到多个列中[重复]

    这个问题在这里已经有答案了 我不小心用错误重复的链接关闭了这个问题 这是正确的 Pandas 将列表的列拆分为多列 https stackoverflow com questions 35491274 pandas split column
  • PyPI 项目页面中的“Py 版本”是什么意思?这有关系吗?

    我注意到 大多数在 PyPI 上发布的项目在其项目页面中都包含 Py 版本 元数据 但它们的值各不相同 如果包不是通用包或不是纯 python 包 那么它们的值是不同的 这是可以理解的 以便表示它们的目标平台 例如鼻页 https pypi
  • 图像堆栈的最大强度投影

    我正在尝试重新创建该功能 max array 3 来自 MatLab 它可以获取 N 个图像的 300x300px 图像堆栈 我在这里说 图像 因为我正在处理图像 实际上这只是一个大的双数组 300x300xN 并创建一个 300x300
  • 为图例中的点设置固定大小

    我正在制作一些散点图 我想将图例中的点的大小设置为固定的相等值 现在我有这个 import matplotlib pyplot as plt import numpy as np def rand data return np random
  • 对于 pygtk 应用程序来说,什么是好的嵌入式浏览器?

    我计划在我的 pygtk 应用程序中使用嵌入式浏览器 并且我正在 gtkmozembed 和 pywebkitgtk 之间进行辩论 两者之间有什么引人注目的区别吗 还有我不知道的第三种选择吗 应该注意的是 我不会使用它来访问网络上的内容 我
  • 监控单个文件

    我需要监控 使用watchdog http pythonhosted org watchdog index html 单个文件 而不是整个目录 避免监视整个目录的最佳方法是什么 我想this http pythonhosted org wa
  • 多线程写入文件

    前几天刚开始使用 python 对多线程的整个概念还很陌生 我在多线程时写入文件时遇到问题 如果我按照常规方式执行此操作 它会不断覆盖正在写入的内容 使用 5 个线程写入文件的正确方法是什么 不降低性能的最佳方法是在所有线程之间使用队列 每
  • 具有条件的重复行 pandas dataframe python

    我的数据框有问题 我的 df 是 product power brand product 1 3 x 1500W brand A product 2 2x1000W 1x100W product 3 1x1500W 1x500W brand
  • 收到的标签值 1 超出了 [0, 1) 的有效范围 - Python、Keras

    我正在使用具有张量流背景的 keras 开发一个简单的 cnn 分类器 def cnnKeras training data training labels test data test labels n dim print Initiat
  • 避免在列表理解中计算相同的表达式两次[重复]

    这个问题在这里已经有答案了 我在列表理解中使用一个函数和一个 if 函数 new list f x for x in old list if f x 0 令我恼火的是这个表达f x 在每个循环中计算两次 有没有办法以更清洁的方式做到这一点
  • 从 sublime_plugin.WindowCommand 获取当前文件名

    我开发插件sublime text 3 并想要获取当前打开的文件路径 absolute1 self window view file name 在哪里self is sublime plugin WindowCommand 但失败了 Att
  • 如何从python导入路径中删除当前目录

    我想使用 Mercurial 存储库hg本身 也就是说 我克隆了 Mercurialhttps www mercurial scm org repo hg https www mercurial scm org repo hg并想运行一些h
  • 在 python 中使用递归替代 len()

    作为 CS1301 问题的一部分 我正在尝试使用递归编写一个函数 该函数将执行与 len 完全相同的操作 但是 我有两个问题 我正在使用全局变量 但我在课程中还没有学到这一点 cs1301 自动评分器告诉我 我的函数返回 26 而不是 13
  • 用于桌面数据库应用程序的 Python 框架

    是否有一个框架可以为Python开发桌面数据库应用程序 一些带有CRUD屏幕的屏幕 我正在寻找类似于 Windows 窗体的东西 能够将 TextField Combos 和其他 UI 隐喻与datasets连接到关系数据库例如 MySQL
  • python pandas如何在多个条件下过滤字符串

    我有以下数据框 import pandas as pd data 5Star FiveStar five star fiv estar data pd DataFrame data columns columnName 当我尝试用一 种条件
  • Pandas 替换特定列上的值

    我知道这两个类似的问题 熊猫替换值 https stackoverflow com questions 27117773 pandas replace values Pandas 替换数据框中的列值 https stackoverflow
  • PyQt QFileDialog exec_ 很慢

    我正在使用自定义QFileDialog因为我想选择多个目录 但是exec 功能非常慢 我不明白为什么 我正在使用最新版本的 PyQt 代码片段 from PyQt4 import QtGui QtCore QtNetwork uic cla

随机推荐