使用NNI对BERT模型进行粗剪枝、蒸馏与微调

2023-11-18

前言

  • 模型剪枝(Model Pruning)是一种用于减少神经网络模型尺寸和计算复杂度的技术。通过剪枝,可以去除模型中冗余的参数和连接,从而减小模型的存储需求和推理时间,同时保持模型的性能。
  • 模型剪枝的一般步骤:
    • 训练初始模型:训练一个初始的神经网络模型
    • 评估参数重要性:计算每个参数的重要性指标来评估参数的贡献程度
    • 剪枝冗余参数:根据阈值将参数置零,或者直接将对应的连接删除,从而减小模型的大小。
    • 微调(Feine-tune):以恢复或提高模型的性能。
  • 模型剪枝策略可以大致分为两种:粗粒度剪枝(Coarse-Grained Pruning)和细粒度剪枝(Fine-Grained Pruning)
    • 粗粒度剪枝:被选择剪枝的单元是以较大的粒度进行的。在神经网络模型中,这通常意味着选择剪枝整个通道、卷积核、层或其他结构化的模块。
    • 细粒度剪枝:被选择剪枝的单元是以较小的粒度进行的。在神经网络模型中,这意味着选择剪枝单个参数或连接。
  • 本文仅针对模型粗剪枝进行示例,如果这篇博客看的人多的话,后面会推出细剪枝教程。
  • 需要NNI库的支持,请先安装好(pip install nni

简单示例

训练初始模型

  • 针对一个简单的MNIST分类神经网络进行剪枝。
import torch
import torch.nn.functional as F
from torch.optim import SGD

from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device

# 定义模型
model = TorchModel().to(device)

# 打印模型框架
print(model)
  • 模型框架——输出:
TorchModel(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)
  • 定义优化器和损失函数预训练模型
# 定义优化器和损失函数

optimizer = SGD(model.parameters(), 1e-2)
criterion = F.nll_loss

# 预训练模型在MNIST数据集上评估模型
for epoch in range(3):
    trainer(model, optimizer, criterion)
    evaluator(model)
  • 模型一共跑了3个epoch,输出:
Average test loss: 0.6671, Accuracy: 8105/10000 (81%)
Average test loss: 0.2894, Accuracy: 9108/10000 (91%)
Average test loss: 0.1751, Accuracy: 9459/10000 (95%)

修剪冗余参数

  • 使用L1NormPruner方法评估参数重要性并修剪冗余参数
  • L1NormPruner(L1 范数剪枝器):支持线性层和卷积层。L1 范数剪枝器计算第一个维度上层权重的 l1 范数,然后用较小的 l1 范数值修剪该维度上的权重块。即,计算卷积层中滤波器的l1范数作为度量值,计算线性层中行权重的l1范数作为度量值。
  • 剪枝方法有很多,这里使用最简单的一种,后面微调Bert模型会使用Taylor FO Weight PrunerMovement Pruner
  • L1NormPruner策略参数L1NormPruner(model, config_list, mode='normal', dummy_input=None)
参数 含义
model 要修剪的模型
config_list 策略设定
mode 'normal''dependency_aware'
dummy_input 用于分析拓扑约束的虚拟输入
config_list 含义
sparsity 指定要压缩的配置中每一层稀疏性
sparsity_per_layer 等于sparsity
op_types L1NormPruner 支持 Conv2d 和 Linear
op_partial_names 要修剪的操作名称
exclude 设置为Trueop_typesop_names 图层将被排除在修剪之外
from nni.compression.pytorch.pruning import L1NormPruner
# 稀疏性为0.1,针对线性层和Conv2d层。不对fc3进行剪枝
config_list = [{
    'sparsity_per_layer': 0.1,
    'op_types': ['Linear', 'Conv2d']
}, {
    'exclude': True,
    'op_names': ['fc3']
}]

pruner = L1NormPruner(model, config_list)
# 显示被包裹的模型结构,`PrunerModuleWrapper`已经包裹了config_list中配置的层。
print(model)

输出:

TorchModel(
  (conv1): PrunerModuleWrapper(
    (module): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  )
  (conv2): PrunerModuleWrapper(
    (module): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  )
  (fc1): PrunerModuleWrapper(
    (module): Linear(in_features=256, out_features=120, bias=True)
  )
  (fc2): PrunerModuleWrapper(
    (module): Linear(in_features=120, out_features=84, bias=True)
  )
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)

压缩模型

# 压缩模型并产生掩码(masks矩阵)
_, masks = pruner.compress()
# 展示掩码的稀疏性
for name, mask in masks.items():
    print(name, ' sparsity : ', '{:.2}'.format(mask['weight'].sum() / mask['weight'].numel()))

输出:

conv1  sparsity :  1.0
conv2  sparsity :  0.94
fc1  sparsity :  0.9
fc2  sparsity :  0.9
  • 解开模型包装,压缩模型
# 在加速之前需要解开模型包装
pruner._unwrap_model()

# 加速模型
from nni.compression.pytorch.speedup import ModelSpeedup
ModelSpeedup(model, torch.rand(3, 1, 28, 28).to(device), masks).speedup_model()
# 打印模型
print(model)

输出:

TorchModel(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 15, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=240, out_features=108, bias=True)
  (fc2): Linear(in_features=108, out_features=76, bias=True)
  (fc3): Linear(in_features=76, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)
  • 可以看到LinearConv2层参数量减少。重新定义优化器,并再次训练模型。
optimizer = SGD(model.parameters(), 1e-2)
for epoch in range(3):
    trainer(model, optimizer, criterion)
    evaluator(model)

输出:

Average test loss: 0.1495, Accuracy: 9539/10000 (95%)
Average test loss: 0.1148, Accuracy: 9662/10000 (97%)
Average test loss: 0.0962, Accuracy: 9715/10000 (97%)

对Bert模型剪枝、蒸馏和微调

模型原理图

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

修剪过程

  • 整个剪枝过程可以分为以下几个步骤:
    • 在下游任务上微调预训练模型。根据经验,在微调模型上剪枝的最终性能优于直接在预训练模型上剪枝。同时,这一步得到的Finetuned模型也将作为后面蒸馏训练的teacher模型。
    • 首先修剪注意力层。在这里,我们对注意力层权重应用块稀疏,如果头部被完全遮掩(mask),则直接修剪头部(压缩权重)。如果头部被部分遮盖,我们将不会修剪它并恢复其权重。
    • 通过蒸馏重新训练头部修剪模型。在修剪FFN层前恢复模型精度
    • 修剪FFN层。这里我们在第一层FFN层上应用输出通道修剪,由于第一层输出通道修剪,第二层FFN输入通道将被修剪。
    • 通过蒸馏重新训练最终修剪后的模型。
  • 在修剪Transformer过程中,我们获得了以下经验:
    • 我们在注意力层上使用Movement Pruner策略。在FFN层上使用Taylor FO Weight Pruner策略。这两个剪枝策略都是一些基于梯度的剪枝算法。
    • L1 Norm Pruner这样的基于权重的剪枝算法,但在这种场景下似乎效果不佳。
    • 蒸馏是恢复模型精度的好方法。从结果来看,在MNLI任务上对Bert进行剪枝通常可以实现1~2%的精度提升。
    • 有必要逐渐增加稀疏度,而不是一下子达到非常高的稀疏度。

实验

  • 在一台A100上,完整的修剪过程大约需要8小时。

准备

  • 在下游任务上得到一个微调模型(如果你熟悉如何在GLUE数据集上微调Bert,则可以跳过本节)
  • 进行一些基本设置,详情请看代码注释
from pathlib import Path
from typing import Callable, Dict

# 是否生成文档,这里不生成False
dev_mode = False

# 预训练模型名称
pretrained_model_name_or_path = 'bert-base-uncased'
# 任务名称mnli(自然语言推理)
task_name = 'mnli'
# 实验名称
experiment_id = 'pruning_bert_mnli'

# head_num和layer_num参数应该与预训练模型一致
heads_num = 12
layers_num = 12

# 保存实验记录
log_dir = Path(f'./pruning_log/{pretrained_model_name_or_path}/{task_name}/{experiment_id}')
log_dir.mkdir(parents=True, exist_ok=True)

# 保存微调后的模型,并在具有相同预训练模型名称和任务名称实验之间共享。
model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')
model_dir.mkdir(parents=True, exist_ok=True)

# 保存GLUE数据集
data_dir = Path(f'./data')
data_dir.mkdir(parents=True, exist_ok=True)

# 固定随机数种子
from transformers import set_seed
set_seed(1024)

import torch
# 获取运行设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  • mnli(自然语言推理)数据的验证集有两个,通过将模型在这两个子集上的性能进行比较,可以获得更全面的评估结果,以便更好地了解模型的泛化能力和推理能力:
    • validation_matched:与训练数据类似的子集,其中包含了与训练集中的句子相似的句子对。模型在这个子集上进行评估时,要对给定的前提和假设进行推理判断,并将其分类为"蕴含"、"矛盾"或"中立"的关系。
    • validation_mismatched:与训练数据不同的子集,它包含了来自不同领域或流派的句子对。这个子集的目的是测试模型在面对与训练数据不同的文本类型时的泛化能力。模型在这个子集上的表现可以帮助评估其在不同上下文中的推理能力。
from torch.utils.data import DataLoader

from datasets import load_dataset
from transformers import BertTokenizerFast, DataCollatorWithPadding

# 不同任务关键词
task_to_keys = {
    'cola': ('sentence', None),
    'mnli': ('premise', 'hypothesis'),
    'mrpc': ('sentence1', 'sentence2'),
    'qnli': ('question', 'sentence'),
    'qqp': ('question1', 'question2'),
    'rte': ('sentence1', 'sentence2'),
    'sst2': ('sentence', None),
    'stsb': ('sentence1', 'sentence2'),
    'wnli': ('sentence1', 'sentence2'),
}

def prepare_dataloaders(cache_dir=data_dir, train_batch_size=32, eval_batch_size=32):
    # 加载预训练模型
    tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)
    # 获取sentence1、sentence2关键词
    sentence1_key, sentence2_key = task_to_keys[task_name]
    # DataCollatorWithPadding将具有不同长度序列数据组合成相同长度的batch,以便进行模型训练和推理
    data_collator = DataCollatorWithPadding(tokenizer)

    # 预处理原始数据
    def preprocess_function(examples):
        args = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*args, padding=False, max_length=128, truncation=True)

        if 'label' in examples:
            # 将列重命名为标签,因为模型模型输入格式要求
            result['labels'] = examples['label']
        return result
    
    # 下载GLUE数据
    raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
    
    # 如果test关键字出现,删除
    for key in list(raw_datasets.keys()):
        if 'test' in key:
            raw_datasets.pop(key)
    
    # 将函数preprocess_function依次作用在raw_datasets的每个元素上
    processed_datasets = raw_datasets.map(preprocess_function, batched=True,
                                          remove_columns=raw_datasets['train'].column_names)

    train_dataset = processed_datasets['train']
    
    # 如果任务为mnli
    if task_name == 'mnli':
        # validation_matched:与训练数据类似的子集,其中包含了与训练集中的句子相似的句子对
        # validation_mismatched:与训练数据不同的子集,它包含了来自不同领域或流派的句子对
        validation_datasets = {
            'validation_matched:': processed_datasets['validation_matched'],
            'validation_mismatched': processed_datasets['validation_mismatched']
        }
    else:
        validation_datasets = {
            'validation': processed_datasets['validation']
        }
        
    # 训练数据加载器
    train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
    # 验证数据加载器
    validation_dataloaders = {
        val_name: DataLoader(val_dataset, collate_fn=data_collator, batch_size=eval_batch_size) \
            for val_name, val_dataset in validation_datasets.items()
    }

    return train_dataloader, validation_dataloaders


# 获得训练、验证数据加载器
train_dataloader, validation_dataloaders = prepare_dataloaders()

训练、评估模型

  • 定义训练函数
import functools
import time

import torch.nn.functional as F
from datasets import load_metric
from transformers.modeling_outputs import SequenceClassifierOutput

def training(model: torch.nn.Module,
             optimizer: torch.optim.Optimizer,
             criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
             lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
             max_steps: int = None,
             max_epochs: int = None,
             train_dataloader: DataLoader = None,
             distillation: bool = False,
             teacher_model: torch.nn.Module = None,
             distil_func: Callable = None,
             log_path: str = Path(log_dir) / 'training.log',
             save_best_model: bool = False,
             save_path: str = None,
             evaluation_func: Callable = None,
             eval_per_steps: int = 1000,
             device=None):
     
    # 检查训练加载器是否存在
    assert train_dataloader is not None
    
    # 将模型置为训练模式
    model.train()
    # 如果teacher_model不为空
    if teacher_model is not None:
        # 将teacher_model置为评估模式
        teacher_model.eval()
    
    # 当前step
    current_step = 0
    # 最佳结果
    best_result = 0

    # 总迭代次数,如果存在max_steps,则为max_steps // len(train_dataloader) + 1
    # 如果存在max_epochs,则为max_epochs,否则,默认为 3。
    total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3
    # 如果存在max_steps,则将max_steps赋值给total_steps。否则,为每个训练轮数迭代次数乘以总训练轮数
    total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)

    print(f'Training {total_epochs} epochs, {total_steps} steps...')

    for current_epoch in range(total_epochs):
        # 取batch
        for batch in train_dataloader:
            # 如果当前step >= 总step
            if current_step >= total_steps:
                # 结束
                return
            # 将batch放入设备
            batch.to(device)
            # 得到输出
            outputs = model(**batch)
            # 得到loss
            loss = outputs.loss
            
            # 如果执行蒸馏操作
            if distillation:
                # 检查teacher_model是否存在
                assert teacher_model is not None
                # 不计算梯度
                with torch.no_grad():
                    # 得到teacher网络的输出
                    teacher_outputs = teacher_model(**batch)
                # 计算蒸馏loss    
                distil_loss = distil_func(outputs, teacher_outputs)
                # loss等于0.1 * model_loss + 0.9 * distil_loss
                loss = 0.1 * loss + 0.9 * distil_loss
                
            
            loss = criterion(loss, None)
            # 梯度清零
            optimizer.zero_grad()
            # 误差反向传播
            loss.backward()
            # 更新梯度
            optimizer.step()

            # 学习率策略更新
            if lr_scheduler:
                lr_scheduler.step()
            
            # 当前step自增
            current_step += 1
            
            # 如果达到检查节点
            if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:
                # 返回模型在验证集上的结果
                result = evaluation_func(model) if evaluation_func else None
                # 输出日志
                with (log_path).open('a+') as f:
                    msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)
                    f.write(msg)
                # 如果是最优模型,则保存
                if save_best_model and (result is None or best_result < result['default']):
                    # 检查save_path是否为空
                    assert save_path is not None
                    torch.save(model.state_dict(), save_path)
                    best_result = None if result is None else result['default']
  • 定义蒸馏模型loss
  • 首先函数定义了一个空列表encoder_hidden_state_loss,用于存储编码器隐藏状态的损失值。它使用一个循环来遍历encoder_layer_idxs的索引(除了最后一个索引),并计算学生模型输出的第i个隐藏状态与教师模型输出的第idx个隐藏状态之间的均方误差损失(MSE),并将其添加到encoder_hidden_state_loss列表中。
  • 然后计算logits损失。并使用 KL 散度(Kullback-Leibler divergence)来度量学生模型输出的logits与教师模型输出的logits之间的差异。具体来说,先对学生模型logits和教师模型logits进行对数softmax处理(缩放因子为2),然后使用F.kl_div函数计算它们之间的KL散度。reduction='batchmean'指定了计算KL散度的方式,并将结果乘以2 ** 2。
  • 最后将distil_loss初始化为0。然后,通过循环将encoder_hidden_state_loss列表中的每个损失值累加到distil_loss中。最后,将logits_loss加到distil_loss中,并返回最终蒸馏损失值。
# 蒸馏模型loss
def distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):
    # 存储编码器隐藏状态的损失值
    encoder_hidden_state_loss = []
    # 遍历encoder_layer_idxs的索引(除了最后一个)
    for i, idx in enumerate(encoder_layer_idxs[:-1]):
        # 计算学生模型输出的第i个隐藏状态与教师模型输出的第idx个隐藏状态之间的均方误差损失
        encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))
    # 计算logits_loss损失,使用KL散度度量学生模型输出的logits与教师模型输出的logits之间的差异
    logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)

    distil_loss = 0
    # 将encoder_hidden_state_loss中loss累加
    for loss in encoder_hidden_state_loss:
        distil_loss += loss
    # 再加上KL散度损失
    distil_loss += logits_loss
    return distil_loss
  • 定义评估模型
  • 值得注意的是预测结果的获取,如果is_regression值为 False,即任务是分类任务,那么outputs.logits是一个张量,其形状为[batch_size, num_classes],表示模型对每个类别的预测分数。使用argmax(dim=-1)函数可以找到每个样本预测最大值所在的类别索引,这样可以得到模型对每个样本的分类预测。因此,predictions的形状是[batch_size],每个元素是一个整数,表示每个样本的预测类别。
  • 如果is_regression的值为True,即任务是回归任务,那么outputs.logits是一个张量,其形状为[batch_size, 1],表示模型对每个样本的回归预测值。使用squeeze()函数可以去除维度为1的维度,将outputs.logits形状变为[batch_size],这样得到每个样本的回归预测值。因此,predictions的形状是[batch_size],每个元素是一个实数,表示每个样本的回归预测值。
  • 本示例任务类型为MNLI所以is_regression值为 False,分类任务
# 定义评估函数
def evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):
    # 检查验证集数据加载器是否为空
    assert validation_dataloaders is not None
    # 指示当前模型是否处于训练模式,若为训练模式返回True
    training = model.training
    # 将模型置为评估模式
    model.eval()
    
    # 是否为回归
    is_regression = task_name == 'stsb'
    # 评估文本间文本相似性
    metric = load_metric('glue', task_name)

    result = {}
    default_result = 0
    for val_name, validation_dataloader in validation_dataloaders.items():
        # 取batch
        for batch in validation_dataloader:
            batch.to(device)
            outputs = model(**batch)
            #得到预测值
            predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
            metric.add_batch(
                predictions=predictions,
                references=batch['labels'],
            )
        result[val_name] = metric.compute()
        # 计算F1值和准确率
        default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))
    result['default'] = default_result / len(result)
    
    # 将模型回归初始状态
    model.train(training)
    return result
  • 使用functools创建偏函数(Partial function),偏函数是指固定一个函数的部分参数,从而创建一个新的函数。这样可以将原始函数的部分参数值预先设置,以便在后续调用时只需提供剩余的参数值。
evaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)

def fake_criterion(loss, _):
    return loss

预训练模型并微调

  • 生成预训练模型。由于MNLI数据集包含三个类别的标签,分别是entailment(蕴含),contradiction(矛盾),neutral(中性),需要将num_labels设置为 3,以确保模型的输出层能够适应这三个类别的分类任务。
  • 设置模型output_hidden_statesTrue,调用 BERT 模型,除了返回预测结果之外,还会返回一个包含所有隐藏层隐藏状态的列表。
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from transformers import BertForSequenceClassification


def create_pretrained_model():
    # 判断是否为回归任务
    is_regression = task_name == 'stsb'
    # 若为回归任务设为1,若不为回归任务并且task_name为mnli设为3,其他情况设为2
    num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
    # 加载预训练模型,设定输出
    model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
    model.bert.config.output_hidden_states = True
    return model
  • 生成微调模型
# 创建微调模型
def create_finetuned_model():
    # 将预训练模型赋给finetuned_model
    finetuned_model = create_pretrained_model()
    # 设置微调模型保存路径
    finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'
    
    # 如果微调模型存在
    if finetuned_model_state_path.exists():
        # 直接加载模型,并指定模型在加载时应该被映射到的设备
        finetuned_model.load_state_dict(torch.load(finetuned_model_state_path, map_location='cpu'))
        # 转移模型到对应设备
        finetuned_model.to(device)
    elif dev_mode:
        pass
    # 如果微调模型不存在
    else:
        # 统计一个epoch中的steps
        steps_per_epoch = len(train_dataloader)
        # 训练epochs
        training_epochs = 3
        # 优化器为adam
        optimizer = Adam(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
        
        # 学习率更新函数
        def lr_lambda(current_step: int):
            return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))
        # 学习率更新策略
        lr_scheduler = LambdaLR(optimizer, lr_lambda)
        # 微调模型为model,无teacher模型
        training(finetuned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,
                 max_epochs=training_epochs, train_dataloader=train_dataloader, log_path=log_dir / 'finetuning_on_downstream.log',
                 save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func, device=device)
    return finetuned_model

# 得到微调模型
finetuned_model = create_finetuned_model()

修剪

  • 根据经验,分阶段剪枝attention部分和FFN部分更容易取得好的效果。当然一起剪枝也可以达到类似的效果,只是需要更多的参数调整尝试。在本节中,我们分阶段进行剪枝。

修剪注意力层

  • 使用MovementPruner修剪策略修剪注意力层。MovementPruner是动态(Movement)修剪的一种实现。这是一种“微调”算法,这意味着掩码可能会在每个微调步骤中发生变化。在每一步中,每个权重元素将根据权重与其梯度的乘积之和的相反数进行评分。这意味着向零移动的权重元素将累积负分,远离零的权重元素将累积正分。得分较低的权重元素将在推理过程中被屏蔽。
  • MovementPruner函数参数:
参数 含义
model 要修剪的模型
config_list 策略设定
evaluator 评估器
warm_up_step 在开始修剪前预热step,该参数小于 cool_down_beginning_step
cool_down_beginning_step 稀疏度停止增长的步数,请注意,稀疏度停止增长并不意味着掩码没有改变
training_epochs 训练模型epoch,若同时设置了training_epochstraining_steps,则修剪将在达到其中一个时停止,该参数大于 cool_down_beginning_step
training_steps 训练模型steps,若同时设置了training_epochstraining_steps,则修剪将在达到其中一个时停止,该参数大于 cool_down_beginning_step
regular_scale 缩放运动分数常规损失在“软(soft)”模式下,更高的规则尺度意味着更高的最终稀疏度。
movement_mode “硬(hard)”或“软(soft)”
config_list 含义
sparsity 指定要压缩的配置中每一层稀疏性
sparsity_per_layer 等于sparsity
op_types L1NormPruner 支持 Conv2d 和 Linear
op_partial_names 要修剪的操作名称
exclude 设置为Trueop_typesop_names 图层将被排除在修剪之外
  • 初始化MovementPruner使用的评估器。
import nni
from nni.algorithms.compression.v2.pytorch import TorchEvaluator

movement_training = functools.partial(training, train_dataloader=train_dataloader,
                                      log_path=log_dir / 'movement_pruning.log',
                                      evaluation_func=evaluation_func, device=device)
traced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)

def lr_lambda(current_step: int):
    if current_step < warmup_steps:
        return float(current_step) / warmup_steps
    return max(0.0, float(total_steps - current_step) / float(total_steps - warmup_steps))

traced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)
evaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)
  • 设定MovementPruner参数
from nni.compression.pytorch.pruning import MovementPruner

# 统计一个epoch中的steps
steps_per_epoch = len(train_dataloader)

# 设定修剪steps或epochs
if not dev_mode:
    total_epochs = 4
    total_steps = total_epochs * steps_per_epoch
    warmup_steps = 1 * steps_per_epoch
    cooldown_steps = 1 * steps_per_epoch
else:
    total_epochs = 1
    total_steps = 3
    warmup_steps = 1
    cooldown_steps = 1

# 修剪,每一层encoder的attention中的Linear,稀疏率为0.1
config_list = [{
    'op_types': ['Linear'],
    'op_partial_names': ['bert.encoder.layer.{}.attention'.format(i) for i in range(layers_num)],
    'sparsity': 0.1
}]

pruner = MovementPruner(model=finetuned_model,
                        config_list=config_list,
                        evaluator=evaluator,
                        training_epochs=total_epochs,
                        training_steps=total_steps,
                        warm_up_step=warmup_steps,
                        cool_down_beginning_step=total_steps - cooldown_steps,
                        regular_scale=10,
                        movement_mode='soft',
                        sparse_granularity='auto')

# 执行剪枝
_, attention_masks = pruner.compress()
# 展示剪枝权重
pruner.show_pruned_weights()

# 保存attention_masks权重
torch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')
  • 加载一个新的微调模型来做加速(speedup),可以认为这是使用微调状态来初始化剪枝后的模型权重。注意nni speedup不支持替换attention module,所以这里我们手动替换attention module
  • 如果头部(head)被整个屏蔽,则对其进行物理修剪并为FFN修剪创建config_list
# 重新创建一个新的微调模型
attention_pruned_model = create_finetuned_model().to(device)
# 取上面通过MovementPruner得到的attention_masks矩阵
attention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')

# 创建空列表
ffn_config_list = []
layer_remained_idxs = []
module_list = []

# 在模型初始设定部分,layers_nums = 12
for i in range(0, layers_num):
    prefix = f'bert.encoder.layer.{i}.'
    # 取各attention层mask矩阵权重
    value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
    # 观察头部是否被完全掩盖
    head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
    # 取head_mask下标
    head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()
    print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')
    if len(head_idxs) != heads_num:
        attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)
        # 加入module_list
        module_list.append(attention_pruned_model.bert.encoder.layer[i])
        # 最后的FFN剩余权重比率是attention剩余权重比率的一半
        # 这只是一个经验性的配置,你可以使用任何其他方法来确定这个稀疏度
        sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5
        # 稀疏性计划,我们将在12次迭代中修剪FFN,每次迭代修剪sparsity_per_iter
        sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)
        # FFN修剪计划配置:修剪intermediate.dense
        ffn_config_list.append({
            'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],
            'sparsity': sparsity_per_iter
        })
        layer_remained_idxs.append(i)

# 将bert.encoder.layer替换为model_list
attention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)
# 创建蒸馏损失偏函数
distil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)
  • 通过蒸馏重新训练注意力修剪模型。
if not dev_mode:
    total_epochs = 5
    total_steps = None
    distillation = True
else:
    total_epochs = 1
    total_steps = 1
    distillation = False

# teacher模型
teacher_model = create_finetuned_model()
# Adam优化器
optimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)

# 学习率更新函数
def lr_lambda(current_step: int):
    return max(0.0, float(total_epochs * steps_per_epoch - current_step) / float(total_epochs * steps_per_epoch))
# 学习率动态更新
lr_scheduler = LambdaLR(optimizer, lr_lambda)
# 模型保存路径
at_model_save_path = log_dir / 'attention_pruned_model_state.pth'
# student模型为attention_pruned_model,teacher模型为微调模型
training(attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=total_epochs,
         max_steps=total_steps, train_dataloader=train_dataloader, distillation=distillation, teacher_model=teacher_model,
         distil_func=distil_func, log_path=log_dir / 'retraining.log', save_best_model=True, save_path=at_model_save_path,
         evaluation_func=evaluation_func, device=device)

if not dev_mode:
    attention_pruned_model.load_state_dict(torch.load(at_model_save_path))
  • 使用TaylorFOWeightPruner策略在 12 次迭代中迭代修剪 FFN。每次修剪迭代后微调 3000 步,然后在修剪完成后微调 2 个 epoch。
# 设定修剪steps或epochs
if not dev_mode:
    total_epochs = 7
    total_steps = None
    taylor_pruner_steps = 1000
    steps_per_iteration = 3000
    total_pruning_steps = 36000
    distillation = True
else:
    total_epochs = 1
    total_steps = 6
    taylor_pruner_steps = 2
    steps_per_iteration = 2
    total_pruning_steps = 4
    distillation = False

# 创建TaylorFOWeightPruner评估器
from nni.compression.pytorch.pruning import TaylorFOWeightPruner
from nni.compression.pytorch.speedup import ModelSpeedup

distil_training = functools.partial(training, train_dataloader=train_dataloader, distillation=distillation,
                                    teacher_model=teacher_model, distil_func=distil_func, device=device)
traced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
evaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)

# 当前step
current_step = 0
# 最好结果
best_result = 0
# 初始学习率
init_lr = 3e-5

dummy_input = torch.rand(8, 128, 768).to(device)

# 将模型置为训练模式
attention_pruned_model.train()
for current_epoch in range(total_epochs):
    for batch in train_dataloader:
        if total_steps and current_step >= total_steps:
            break
        # 用TaylorFOWeightPruner进行修剪,并重新初始化优化器
        if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:
            # 模型权重
            check_point = attention_pruned_model.state_dict()
            # 修剪
            pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)
            # 得到FFN掩码(mask)矩阵
            _, ffn_masks = pruner.compress()
            renamed_ffn_masks = {}
            # 重新命名掩码键,因为只是针对bert.encoder的速度
            for model_name, targets_mask in ffn_masks.items():
                renamed_ffn_masks[model_name.split('bert.encoder.')[1]] = targets_mask
            # 将模型解包
            pruner._unwrap_model()
            # 加载模型权重
            attention_pruned_model.load_state_dict(check_point)
            # speedup操作
            ModelSpeedup(attention_pruned_model.bert.encoder, dummy_input, renamed_ffn_masks).speedup_model()
            # 重新定义优化器
            optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)

        batch.to(device)
        # 学习率更新策略
        for params_group in optimizer.param_groups:
            params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr
            
        # 得到输出
        outputs = attention_pruned_model(**batch)
        # 计算loss
        loss = outputs.loss

        # 蒸馏
        if distillation:
            assert teacher_model is not None
            with torch.no_grad():
                teacher_outputs = teacher_model(**batch)
            distil_loss = distil_func(outputs, teacher_outputs)
            loss = 0.1 * loss + 0.9 * distil_loss
        
        # 梯度清零
        optimizer.zero_grad()
        # 误差反向传播
        loss.backward()
        # 梯度更新
        optimizer.step()

        current_step += 1
        
        # 输出模型中间精度,保存最佳模型
        if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
            result = evaluation_func(attention_pruned_model)
            with (log_dir / 'ffn_pruning.log').open('a+') as f:
                msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())),
                                                            current_epoch, current_step, result)
                f.write(msg)
            if current_step >= total_pruning_steps and best_result < result['default']:
                torch.save(attention_pruned_model, log_dir / 'best_model.pth')
                best_result = result['default']
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用NNI对BERT模型进行粗剪枝、蒸馏与微调 的相关文章

  • 【论文精读ACL_2021】Prefix-Tuning: Optimizing Continuous Prompts for Generation

    论文精读ACL 2021 Prefix Tuning Optimizing Continuous Prompts for Generation 0 前言 Abstract 1 Introduction 2 Related Work 2 1
  • LLM微调

    下面我只是分析讲解下这些方法的原理以及具体代码是怎么实现的 不对效果进行评价 毕竟不同任务不同数据集效果差别还是挺大的 文章目录 0 hard prompt soft prompt区别 1 Prefix Tuning 2 Prompt Tu
  • c语言六子棋(Alpha-Beta剪枝算法)

    c语言Alpha Beta剪枝算法六子棋 介绍 Alpha Beta剪枝算法是一种用于优化博弈树搜索的算法 可以在搜索过程中减少不必要的计算 从而提高搜索效率 该算法常用于博弈游戏 如六子棋 六子棋是一种类似于五子棋的棋类游戏 在一个六边形
  • 毕业设计-基于 BERT 的中文长文本分类系统

    目录 前言 课题背景和意义 实现技术思路 一 文本分类的相关技术 二 文本表示模型 三 文本分类模型 实现效果图样例 最后 前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实习为毕业后面临的就业升学做准备 一边要为毕业设计耗费大量精力
  • Bert的NSP任务的loss原理

    Bert的NSP任务是预测上句和下句的关系 对一个句子的表征可以用CLS的embedding bert的NSP任务 NSP 是一个预测两段文本是否在原文本中连续出现的二元分类损失 NSP 是一种二进制分类损失 用于预测原始文本中是否有两个片
  • 一文读懂BERT(原理篇)

    一文读懂BERT 原理篇 2018年的10月11日 Google发布的论文 Pre training of Deep Bidirectional Transformers for Language Understanding 成功在 11
  • 【ACCV2022】论文阅读笔记Lightweight Alpha Matting Network Using Distillation-Based Channel Pruning

    Lightweight Alpha Matting Network Using Distillation Based Channel Pruning 使用基于蒸馏通道裁剪的轻量Alpha抠图网络 https arxiv org pdf 22
  • 【自然语言处理】ChatGPT 相关核心算法

    ChatGPT 相关核心算法 ChatGPT 的卓越表现得益于其背后多项核心算法的支持和配合 本文将分别介绍作为其实现基础的 Transformer 模型 激发出其所蕴含知识的 Prompt Instruction Tuning 算法 其涌
  • 机器学习——决策树剪枝

    目录 一 决策树剪枝策略 1 1剪枝目的 1 2剪枝策略 1 3判断决策树泛化性能是否提升的方法 二 预剪枝 prepruning 2 1概述 2 2预剪枝优缺点 2 3代码实现 三 后剪枝 postpruning 3 1概述 3 2后剪枝
  • 全排列 Ⅱ--回溯算法

    LeetCode 全排列 给定一个可包含重复数字的序列 返回所有不重复的全排列 示例 输入 1 1 2 输出 1 1 2 1 2 1 2 1 1 解法 回溯法 解题思路 思路很简单 因为要全排列 所以每一个数字都可能选择 即选择区间为 0
  • 【使用 BERT 的问答系统】第 1 章 : 自然语言处理简介

    大家好 我是Sonhhxg 柒 希望你看完之后 能对你有所帮助 不足请指正 共同学习交流 个人主页 Sonhhxg 柒的博客 CSDN博客 欢迎各位 点赞 收藏 留言 系列专栏 机器学习 ML 自然语言处理 NLP 深度学习 DL fore
  • 【NLP】1、BERT

    文章目录 一 背景 二 方法 论文 BERT Pre training of Deep Bidirectional Transformers for Language Understanding 出处 Google 一 背景 在 BERT
  • Bert: Pre-training of Deep Bidirectional Transformers for Language Understanding

    Abstract 我们介绍了一种语言表达模型称为BERT 也就是Transformer的双边编码表示 与当前语言表达模型不同 Peters et al 2018a Radford et al 2018 BERT设计通过考虑所有层左右上下文对
  • Bert机器问答模型QA(阅读理解)

    Github参考代码 https github com edmondchensj ChineseQA with BERT https zhuanlan zhihu com p 333682032 数据集来源于DuReader Dataset
  • 从浅到深理解bert

    更多查看https github com B C WANG AI Storage 4 2 4 2从浅到深理解bert 4 2 1 理解Attention 参考https www cnblogs com robert dlut p 86382
  • 【BERT类预训练模型整理】

    BERT类预训练模型整理 1 BERT的相关内容 1 1 BERT的预训练技术 1 1 1 掩码机制 1 1 2 NSP Next Sentence Prediction 1 2 BERT模型的局限性 2 RoBERTa的相关内容 2 1
  • 【论文翻译】【剪枝】Filter Pruning via Geometric Medianfor Deep Convolutional Neural Networks Acceleration

    目录 摘要 1 介绍 2 相关工作 3 方法 3 1 前置内容 3 2 基于规范的准则分析 3 3 真实场景中的Norm统计 3 4 通过几何中值进行过滤修剪 3 5 理论和现实加速 3 5 1理论加速 3 5 2 实际加速 4 实验 4
  • 【YOLOv7/YOLOv5系列算法改进NO.49】模型剪枝、蒸馏、压缩

    文章目录 前言 一 解决问题 二 基本原理 三 剪枝操作 四 知识蒸馏操作 前言 作为当前先进的深度学习目标检测算法YOLOv7 已经集合了大量的trick 但是还是有提高和改进的空间 针对具体应用场景下的检测难点 可以不同的改进方法 此后
  • BERT、BART、T5 等法学硕士的比较分析

    探索语言模型 介绍 在这篇博文中 我将讨论 BERT BART 和 T5 等大型语言模型 到 2020 年 法学硕士领域取得的重大进展包括这些模型的开发 BERT和T5是Google开发的 BART是Meta开发的 我将根据这些型号的发布日
  • YOLOv8改进开源

    大致介绍一下AI全栈技术社区的相关内容 主要涵盖了YOLO全系列模型的改进 量化 蒸馏 剪枝以及不同工具链的使用 同时也涵盖多目标跟踪 语义分割 3D目标检测 AI模型部署等内容 具体内容小伙伴们可以参考下面的目录部分 所有内容均有答疑服务

随机推荐

  • TensorRT部署YOLOv5(02)-环境介绍

    本文对TensorRT部署YOLOv5模型的整体环境配置及软件包进行介绍 实验环境主要从主机和JestonNano两方面进行介绍 在主机端完成模型训练并转换为onnx中间模型表示 在JestonNano进行onnx模型转换为TensorRT
  • cocos2dx 3.0 新建工程

    打开终端 转到cocos2d 3 0的这个目录下tools cocos2d console bin cocos py cocos py 运行命令 会显示该命令的帮助 根据帮助提示进行操作 有以下几个命令可用 compile Compiles
  • 游戏开发unity编译和调试系列:Unsafe code may only appear if compiling with /unsafe

    问题 Unsafe code may only appear if compiling with unsafe Enable Allow unsafe code in Player Settings to fix this error 解决
  • MySQL中的锁机制和MVCC

    MySQL中的锁和MVCC 概述 InnoDB的MVCC MVCC锁相关 SQL语句的加锁分析 RR隔离级别是如何解决幻读的 死锁 模拟死锁 在学习MySQL中的锁机制相关时搜集了几篇写得非常不错的博客 这里就不再花时间详细介绍 本篇仅做总
  • 批量异步更新策略及 nextTick 原理?

    在Vue中 当需要对多个响应式数据进行异步更新时 可以使用 批量异步更新策略 Vue在更新组件状态时 会将所有的异步更新合并成一个批量更新 从而避免不必要的重复渲染和提高性能 Vue中的异步更新策略是基于 事件循环 Event Loop 的
  • numpy 索引和切片

    目录 1 索引元素 2 切片 省略参数写法 3 应用举例 4 多维数组 二维数组获取一行 二维数组获取一列 5 多维数组的切片 如果要获取 第一行的第四和第五个元素 如果得到数组的最后两行和最后两列 如果得到数组的第三列 如果取出第3 5行
  • (局部特征)HOG+SVM,LBP,Haar

    在利用多幅二维图像进行三维重建 恢复场景三维结构的应用中 其基本出发点是要有一个可靠的图像对应点集合 而自动地建立图像之间 点与点之间的可靠对应关系通常都依赖于一个优秀的局部图像特征描述子 在物体识别中 目前非常流行以及切实可行的方法之一是
  • 记录QT4键盘无法置顶问题

    键盘类设置 QWidget setWindowFlags windowFlags Qt WindowStaysOnTopHint Qt FramelessWindowHint Qt X11BypassWindowManagerHint 莫名
  • bash脚本,自动输入sudo的密码

    解决方法 echo 管道 开始使用 echo admin sudo service tomcat7 stop 始终提示输入密码 后来查看了下sudo命令的使用 man sudo 发现有如下的解释 S The S stdin option c
  • go并发模型GPM

    线程模型的三种实现方式 用户级线程 M 1对应关系 多个用户态线程对应着一个内核线程 用户态线程的创建 终止 切换 同步等线程工作必须由自身来完成 内核级线程 1 1对应关系 直接调用操作系统的内核线程 所有线程的创建 终止 切换 同步等操
  • mysql数据库丢库记

    2021 1 29 记 环境 Windows10 mysql5 6 Navicat12 原因 电脑突然关机 启动后通过Navicat连接发现少了一个库 让人无比迷茫 为啥会出现丢库现象 解决方案 无比庆幸的是在几个小时之前 我对丢失的库做过
  • 赋值运算符的错误使用举例

    和 运算符的错误使用 运算符在C语言中为赋值运算符 运算符为检查符号两边的值是否是相等的操作 当某些程序的本意是比较运算符的时候 可能会无疑写成了赋值运算符 比如下面的程序 本意是要检查x是否等于y 但是实际上是把y的值的赋值给了x 然后去
  • android studio 编译时出现 Build completed with 1 failures错误

    okhttp keep class okhttp3 keep interface okhttp3 dontwarn okhttp3 出现 Build completed with 1 failures 这个问题上面只是导致这个错误的一种 出
  • 嵌入式linux基础学习全套精品视频教程

    嵌入式linux基础学习全套精品视频教程 在给大家分享教程之前 首先给大家简要的介绍一下嵌入式linux 嵌入式linux 是将日益流行的Linux操作系统进行裁剪修改 使之能在嵌入式计算机系统上运行的一种操作系统 嵌入式linux既继承了
  • nginx中自动在二级目录后加斜杠

    有时候在Nginx的伪静态中 二级目录无法自动加斜杠 导致目录无法打开 可以加 optimize server names off server name in redirect off 但是又时候可能不通用 所以在虚拟主机中加入一段万能的
  • kuangbin的模板

    直接链接 间接链接
  • 使用DbHelperSQL调用存储过程的方法

    下面代码是个调用存储过程的例子 对于学习怎么使用DbHelperSQL调用存储过程很有帮助
  • ceph 维护系列(二)--卸载osd

    一 摘要 本文主要介绍从ceph 某台节点上卸载一块或者多块osd 硬盘 二 环境信息 2 1 操作系统版本 root proceph05 cat etc centos release CentOS Linux release 7 6 18
  • SSM框架搭建,及遇到的问题

    SSM框架搭建 及遇到的问题 1 基本概念 1 1 Spring Spring是一个开源框架 Spring是于2003 年兴起的一个轻量级的Java 开发框架 由Rod Johnson 在其著作Expert One On One J2EE
  • 使用NNI对BERT模型进行粗剪枝、蒸馏与微调

    前言 模型剪枝 Model Pruning 是一种用于减少神经网络模型尺寸和计算复杂度的技术 通过剪枝 可以去除模型中冗余的参数和连接 从而减小模型的存储需求和推理时间 同时保持模型的性能 模型剪枝的一般步骤 训练初始模型 训练一个初始的神