如何从大型模型(BART)fine tune一个小模型及代码实现

2023-11-10

系列文章

  1. 如何从大型模型(BART)fine tune一个小模型及代码实现

  2. 文本自动摘要评价方法-金字塔方法

  3. pytorch 使用BART模型进行中文自动摘要

摘要

本文目的是从上游大型模型进行知识蒸馏以应用于下游自动摘要任务,主要总结了自动摘要目前面临的难题,BART模型的原理,与fine tune 模型的原理。对模型fine tune部分进行了代码复现,通过fine tune使得student模型能够在一块8G显存的GPU上进行训练。

论文标题:

  1. PRE-TRAINED SUMMARIZATION DISTILLATION
    url: https://arxiv.org/pdf/2010.13002.pdf
  2. BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
    url:https://arxiv.org/abs/1910.13461

自动摘要目前的问题

  • 自动摘要的输入长度不定,输出长度也不定
  • 需要被摘要的文本无统一结构,因此难以根据文章结构学习出合适的模型。
  • 传统的抽取式摘要不能很好概括文本信息

seq2seq

在推理模式中,即当我们想解码未知的输入序列时,我们会经历一个过程:
1)将输入序列编码为状态向量
2)从大小为1的目标序列开始(仅是序列开始字符)
3)将状态向量和1个字符的目标序列提供给解码器,以生成下一个字符的预测。
4)使用这些预测来采样下一个字符(argmax)。
5)将采样的字符追加到目标序列
6)重复上述过程直到生成序列结束字符或达到字符数限制。

模型

BART

BART是一种用于序列到序列模型预处理的去噪自编码器。它的训练方法是:
(1)用任意的噪声函数破坏文本,
(2)学习一个模型来重建原始文本。
在这里插入图片描述

  • 双向编码(类似BERT),单向解码
  • 训练前的任务包括随机打乱原始句子的顺序和一个新的填充方案,其中文本的范围被一个单一的掩码标记取代
    在这里插入图片描述

Fine-Tune

从一个训练好大型任务中直接迁移部分参数,再使用下移任务的训练集数据进行微调

在这里插入图片描述

  • Fine-Tune从一个模型(teacher model)进行参数的迁移得到新模型(student model)
  • 这篇论文student model复制teacher model的全部层,通过实验选取效果最好的3层decode:0, 5, 11.

Fine-Tune另外测评方法

Pseudo-labels
Fine-Tune的目的是获得和teacher模型一样的预测结果,即最小化损失函数:
在这里插入图片描述
Direct Knowledge Distillation (KD)
理想条件是student模型对下一个词的预测所产生的概率分布和teacher模型相同。

在这里插入图片描述
或者使得decode层输出隐状态相同

在这里插入图片描述
最终衡量标准
在这里插入图片描述
其中三个α基于不同损失函数以不同权重,其最终目的是使得加权损失函数最小

实验方案

beam search
预测的时候,假设词表大小为3,内容为a,b,c。beam size是2,decoder解码的时候:
1: 生成第1个词的时候,选择概率最大的2个词,假设为a,c,那么当前的2个序列就是a和c。
2:生成第2个词的时候,我们将当前序列a和c,分别与词表中的所有词进行组合,得到新的6个序列aa ab ac ca cb cc,计算每个序列的得分并选择得分最高2个序列,作为新的当前序列,假如为aa cb。
3:后面会不断重复这个过程,直到遇到结束符或者达到最大长度为止。最终输出得分最高的2个序列。
在这里插入图片描述
上图解释参考:
https://blog.csdn.net/weixin_43718786/article/details/116991489
提前停止
我们在任何满足满足以下条件的时间点停止训练:第五阶段结束,或者连续四次评估的分数不增加(一个完整的阶段)。
在使用全尺寸(完全复制)encod的实验中,我们在训练期间不改变它的参数。初步实验表明,这不会影响性能,但微调速度提高了5.6倍。
我们还冻结了位置和文字的embedding。

实验结果

在这里插入图片描述
从结果来看,保持encode层不变,从12个decode层中抽取三层是非常有效的解决方案。(兼顾成绩与训练时间)

在这里插入图片描述
摘要数据集(前4行)的得分为Rouge-2
成本测量为运行该方法所需的GPU小时数
Size为teacher和student两种模型中,decode层数的对比

超参数相关

在这里插入图片描述
在这里插入图片描述

代码复现

为了保持风格的一致,下面的部分解释用英语写

These codes below run in colab
first we will install transformer from hugging face

! pip install datasets transformers rouge-score nltk

在这里插入图片描述
success !

Loading the dataset

We will use the

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何从大型模型(BART)fine tune一个小模型及代码实现 的相关文章

  • pytorch 的 IDE 自动完成

    我正在使用 Visual Studio 代码 最近尝试了风筝 这两者似乎都没有 pytorch 的自动完成功能 这些工具可以吗 如果没有 有人可以推荐一个可以的编辑器吗 谢谢你 使用Pycharmhttps www jetbrains co
  • 快速 shell 命令删除文本文件中的停用词

    我有一个 2GB 的文本文件 我正在尝试从此文件中删除经常出现的英语停用词 我有 stopwords txt 包含这样的 a an the for and I 使用 shell 命令 例如 tr sed 或 awk 执行此操作的快速方法是什
  • 使用正则表达式标记化进行 NLP 词干提取和词形还原

    定义一个函数 名为performStemAndLemma 它需要一个参数 第一个参数 textcontent 是一个字符串 编辑器中给出了函数定义代码存根 执行以下指定任务 1 对给出的所有单词进行分词textcontent 该单词应包含字
  • openNLP 与 Solr 集成时出现异常

    我正在尝试将 openNLP 与 Solr 6 1 0 集成 我配置了架构和 solrconfig 文件 详细信息请参见 wiki 链接 https wiki apache org solr OpenNLP https wiki apach
  • PyTorch 中的连接张量

    我有一个张量叫做data形状的 128 4 150 150 其中 128 是批量大小 4 是通道数 最后 2 个维度是高度和宽度 我有另一个张量叫做fake形状的 128 1 150 150 我想放弃最后一个list array从第 2 维
  • 如何检测文本是否可读?

    我想知道是否有一种方法可以告诉给定的文本是人类可读的 我所说的人类可读的意思是 它有一些含义 格式就像某人写的文章 或者至少是由软件翻译器生成的供人类阅读的文章 这是背景故事 最近我正在制作一个应用程序 允许用户将短文本上传到数据库 在部署
  • AttributeError:类型对象“Word2Vec”没有属性“load_word2vec_format”

    我正在尝试实现 word2vec 模型并收到属性错误 AttributeError 类型对象 Word2Vec 没有属性 load word2vec format 下面是代码 wv Word2Vec load word2vec format
  • 保存具有自定义前向功能的 Bert 模型并将其置于 Huggingface 上

    我创建了自己的 BertClassifier 模型 从预训练开始 然后添加由不同层组成的我自己的分类头 微调后 我想使用 model save pretrained 保存模型 但是当我打印它并从预训练上传时 我看不到我的分类器头 代码如下
  • 使用 PyTorch 分布式 NCCL 连接失败

    我正在尝试使用 torch distributed 将 PyTorch 张量从一台机器发送到另一台机器 dist init process group 函数正常工作 但是 dist broadcast 函数中出现连接失败 这是我在节点 0
  • 如何使用动词时态/语气制作稀疏匹配器模式?

    我一直在尝试使用动词时态和情绪为 spacy 匹配器创建一个特定的模式 我发现了如何使用 model vocab morphology tag map token tag 访问使用 spacy 解析的单词的形态特征 当动词处于虚拟语气模式
  • 对 FastAI 中的数据应用图像增强转换时出错

    我正在尝试复制这个 Kaggle 笔记本https www kaggle com tanlikesmath diabetic retinopathy with resnet50 oversampling https www kaggle c
  • 在requirements.txt中包含.whl安装

    如何将其包含在requirements txt 文件中 对于Linux pip install http download pytorch org whl cu75 torch 0 1 12 post2 cp27 none linux x8
  • 如何使用FeatureUnion转换PipeLine中的多个特征?

    我有一个 pandas 数据框 其中包含有关用户发送的消息的信息 对于我的模型 我感兴趣的是预测消息的缺失收件人 即给定消息的收件人 A B C 我想预测还有谁应该成为收件人的一部分 我正在使用 OneVsRestClassifier 和
  • Google Colab 使用 Transformers 和 PyTorch 微调 BERT Base Case 时出现间歇性“RuntimeError: CUDA out of memory”错误

    我正在运行以下代码来微调 Google Colab 中的 BERT Base Cased 模型 有时代码第一次运行良好 没有错误 其他时候 相同的代码使用相同的数据 会导致 CUDA 内存不足 错误 以前 重新启动运行时或退出笔记本 返回笔
  • 如何同时有效地运行多个 Pytorch 进程/模型? Traceback:分页文件太小,无法完成此操作

    背景 我有一个非常小的网络 我想用不同的随机种子进行测试 该网络几乎只使用了我的 GPU 计算能力的 1 因此理论上我可以同时运行 50 个进程来同时尝试许多不同的种子 Problem 不幸的是我什至无法在多个进程中导入 pytorch 当
  • Pytorch - 推断线性层 in_features

    我正在构建一个玩具模型来获取一些图像并进行分类 我的模型看起来像 conv2d gt pool gt conv2d gt linear gt linear 我的问题是 当我们创建模型时 我们必须计算第一个线性层的大小in features基
  • 如何将 35 类城市景观数据集转换为 19 类?

    以下是我的代码的一小段 使用它 我可以在城市景观数据集上训练名为 lolnet 的模型 但数据集包含 35 个类别 标签 0 34 imports trainloader torch utils data DataLoader datase
  • 无法在 Windows 10 上构建 Detectron2

    尽管 Windows 上的 Detectron2 没有官方支持 但有很多可用的说明 我尝试按照这些说明进行操作 但最终出现了相同的错误 这是我的设置 OS Windows 10 专业版 19043 1466 微软视觉工作室 2019 CUD
  • 对产品列表进行分类的算法? [关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我有一个代表或多或少相同的产品的列表 例如 在下面的列表中 它们都是希捷硬盘 希捷硬盘 500Go 适用于笔记本电脑的希捷硬盘 120
  • 如何将句子或文档转换为向量?

    我们有将单词转换为向量的模型 例如 word2vec 模型 是否存在类似的模型 可以使用为单个单词学习的向量将句子 文档转换为向量 1 跳克法 以及使用它的工具 谷歌 word2vec https code google com p wor

随机推荐

  • Div点击显示再次点击隐藏

    1 先上效果 这是默认显示的时候 这是再点击隐藏的时候 下方代码贴出 有需要的C V直接浏览器查看
  • 【Unity基础】Input.GetAxis()函数

    根据输入设备 参数分为两类 一 触屏类 1 Mouse X 鼠标沿屏幕X移动时触发 2 Mouse Y 鼠标沿屏幕Y移动时触发 3 Mouse ScrollWheel 鼠标滚轮滚动是触发 二 键盘类 1 Vertical 键盘按上或下键时触
  • windows系统下,如何使用win+R快速打开安装的应用

    windows系统下 如何使用win R快速打开安装的应用 随着工作学习时间的增加 我们的桌面就会出现越来越多的文件和应用快捷方式 使得桌面变得很杂乱 有时候需要打开某个应用的时候就可能需要花费时间来找 那我们如何快速打开我们需要的应用呢
  • Layout的放大和缩小效果例子(ScaleAnimation)

    个Layout从中心放大和缩小的例子 直接上代码 1 ScaleDialog java文件 Java代码 package cn com import android app Activity import android graphics
  • TypeError: ‘function‘ object is not subscriptable

    关于错误 TypeError function object is not subscriptable 错误原因 get dummies函数 写成了
  • https网络编程——如何建立利用根证书(凭证)签发建立中继证书(凭证)详解

    参考 如何建立利用根证书 凭证 签发建立中继证书 凭证 详解 地址 https qingmu blog csdn net article details 108221568 spm 1001 2014 3001 5502 目录 在建立中继之
  • oracle 写入 权限设置,改变用户组文件的读写和执行权限

    网上找来一篇关于linux权限修改方式文章 对于我脑子记性不好的人有非常大的帮助 1 更改档案拥有者 命令 chown cfhvR help version user group file 功能 更改文件或者文件夹的拥有者 参数格式 use
  • c++读写文件

    目录 1 写文件 2 读文件 3 二进制方式写文件 4 3 二进制方式读文件 文件类型分为两种 文本文件 文件以文本的ASCII码形式存储在计算机中 二进制文件 文件以文本的二进制形式存储在计算机中 用户一般不能直接读懂它们 操作文件的三大
  • 小型机 PC服务器 性能,pc服务器小型机

    pc服务器小型机 内容精选 换一换 业务测试完成后或不再需要克隆服务器 您可参考本章节删除克隆服务器 删除克隆服务器后 请到弹性云服务器Console界面检查 使用主机迁移服务迁移Windows系统的源端服务器时 要求目的端服务器的磁盘大小
  • 进程间通讯的7种方式

    1 常见的通信方式 管道pipe 管道是一种半双工的通信方式 数据只能单向流动 而且只能在具有亲缘关系的进程间使用 进程的亲缘关系通常是指父子进程关系 命名管道FIFO 有名管道也是半双工的通信方式 但是它允许无亲缘关系进程间的通信 消息队
  • [架构之路-213]- 架构 - 架构设计过程快速概览与在线画图工具

    目录 第一步 业务系统 1 收集目标系统的用户需求 2 定义用例图 第二步 领域建模 1 业务流程定义 2 业务功能分解 3 非功能性架构 支撑架构 第三步 高层架构设计 1 应用展现层 2 业务功能层 3 框架支撑层 第四部 详解架构设计
  • 如何查gmail发件人ip_如何在Gmail中阻止来自特定发件人的电子邮件

    如何查gmail发件人ip There are some email senders from which you never want to hear You can t stop them from sending you emails
  • 瞳孔特征值提取,blink frequency,fixation frequency,saccad extent, pupil diameter等

    进行的分析有 滤波分析 fft psd database py 下面展示一些 内联代码片 import pandas as pd import numpy as np def read file raw path data pd DataF
  • unity 3d 原创制作射击游戏(一)

    目录 实验一 4 1 设计如下UI界面 其中包含了canvas Panel Text Button Image RawImage等UI元素 4 2 实现点击Play按钮转换场景 点击Exit退出游戏的功能 5 3 主界面添加音量滑动杆 静音
  • Flink1.11.0 SQL与hive整合

    一 前言 此次flink sql 整合 hive 主要是能在flink sql中读写hive数据 为flink实时写数据进入hive 构建实时数仓做准备工作 flink 1 11 0 hive 2 3 4 hadoop 2 7 2 主要步骤
  • 使用Python,OpenCV制作不同风格的素描图(正常,漫画,写实风格)

    使用Python OpenCV制作不同风格的素描图 正常 漫画 写实风格 这篇博客将介绍如何使用Python OpenCV制作不同风格的素描图 正常风格 漫画风格 写实风格 1 效果图 原始图 VS 正常风格素描图 VS 漫画风格素描图 V
  • 软件测试缺陷的定义、产生原因、缺陷报告格式、缺陷报告

    软件缺陷的定义 错误 静态存在于说明文档中的表述或编码错误 缺陷 存在于代码中或硬件系统中的错误 BUG 被测对象实际表现与用户显性需求或隐性需求中的差异 功能实现错误 功能实现遗漏 功能实现多余 功能实现不好 失效 因缺陷激发后导致功能的
  • 递归求斐波那契数列

    斐波那契数列 题目描述 编写一个函数 求斐波那契数列的第n项的值 首先 对于斐波那契数列 我们是非常熟悉了 对斐波那契定义为如下 f 0 0 f 1 0 f 2 1 f n f n 1 f n 2 其中n gt 1 对于这种求斐波那契数列第
  • Mockito(三)--完整功能介绍

    强烈建议不熟悉Mockito的同学先看看我写的Mockito 一 入门篇和 二 实例篇之后再来看这篇文章 因为只有看了前两篇文章才明白mockito的本质以及该如何使用它 下面是对Mockito全部功能的介绍 1 使用mockito验证行为
  • 如何从大型模型(BART)fine tune一个小模型及代码实现

    系列文章 如何从大型模型 BART fine tune一个小模型及代码实现 文本自动摘要评价方法 金字塔方法 pytorch 使用BART模型进行中文自动摘要 目录 系列文章 摘要 自动摘要目前的问题 seq2seq 模型 BART Fin