【自然语言处理】大模型高效微调:PEFT 使用案例

2023-11-11

一、PEFT介绍

PEFT(Parameter-Efficient Fine-Tuning,参数高效微调),是一个用于在不微调所有模型参数的情况下,高效地将预训练语言模型(PLM)适应到各种下游应用的库。

PEFT方法仅微调少量(额外的)模型参数,显著降低了计算和存储成本,因为对大规模PLM进行完整微调的代价过高。最近的最先进的PEFT技术实现了与完整微调相当的性能。

代码:

https://github.com/huggingface/peft

文档:

https://huggingface.co/docs/peft/index

二、PEFT 使用

接下来将展示 PEFT 的主要特点,并帮助在消费设备上通常无法访问的情况下训练大型预训练模型。您将了解如何使用LoRA来训练1.2B参数的bigscience/mt0-large模型,以生成分类标签并进行推理。

2.1 PeftConfig

每个 PEFT 方法由一个PeftConfig类来定义,该类存储了用于构建PeftModel的所有重要参数。

由于您将使用LoRA,您需要加载并创建一个LoraConfig类。在LoraConfig中,指定以下参数:

  • task_type,在本例中为序列到序列语言建模
  • inference_mode,是否将模型用于推理
  • r,低秩矩阵的维度
  • lora_alpha,低秩矩阵的缩放因子
  • lora_dropout,LoRA层的dropout概率
from peft import LoraConfig, TaskType

peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)

有关您可以调整的其他参数的更多详细信息,请参阅LoraConfig参考。

2.2 PeftModel

使用 get_peft_model() 函数可以创建PeftModel。它需要一个基础模型 - 您可以从 Transformers 库加载 - 以及包含配置特定 PEFT 方法的PeftConfig。

首先加载您要微调的基础模型。

from transformers import AutoModelForSeq2SeqLM

model_name_or_path = "bigscience/mt0-large"
tokenizer_name_or_path = "bigscience/mt0-large"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)

使用get_peft_model函数将基础模型和peft_config包装起来,以创建PeftModel。要了解您模型中可训练参数的数量,可以使用print_trainable_parameters方法。在这种情况下,您只训练了模型参数的0.19%!

from peft import get_peft_model

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# 输出示例: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282

至此,我们已经完成了!现在您可以使用Transformers的Trainer、 Accelerate,或任何自定义的PyTorch训练循环来训练模型。

2.3 保存和加载模型

在模型训练完成后,您可以使用save_pretrained函数将模型保存到目录中。您还可以使用push_to_hub函数将模型保存到Hub(请确保首先登录您的Hugging Face帐户)。

model.save_pretrained("output_dir")

# 如果要推送到Hub
from huggingface_hub import notebook_login

notebook_login()
model.push_to_hub("my_awesome_peft_model")

这只保存了已经训练的增量PEFT权重,这意味着存储、传输和加载都非常高效。例如,这个在RAFT数据集的twitter_complaints子集上使用LoRA训练的bigscience/T0_3B模型只包含两个文件:adapter_config.json和adapter_model.bin,后者仅有19MB!

使用from_pretrained函数轻松加载模型进行推理:

from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig

peft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)

三、PEFT支持任务

3.1 Models support matrix

3.1.1 Causal Language Modeling

在这里插入图片描述

3.1.2 Conditional Generation

在这里插入图片描述

3.1.3 Sequence Classification

在这里插入图片描述

3.1.4 Token Classification

在这里插入图片描述

3.1.5 Text-to-Image Generation

在这里插入图片描述

3.1.6 Image Classification

在这里插入图片描述

3.1.7 Image to text (Multi-modal models)

在这里插入图片描述

四、PEFT原理

4.1 LoRA

LoRA(Low-Rank Adaptation)是一种技术,通过低秩分解将权重更新表示为两个较小的矩阵(称为更新矩阵),从而加速大型模型的微调,并减少内存消耗。

为了使微调更加高效,LoRA的方法是通过低秩分解,使用两个较小的矩阵(称为更新矩阵)来表示权重更新。这些新矩阵可以通过训练适应新数据,同时保持整体变化的数量较少。原始的权重矩阵保持冻结,不再接收任何进一步的调整。为了产生最终结果,同时使用原始和适应后的权重进行合并。

4.2 Prompt tuning

训练大型预训练语言模型是非常耗时且计算密集的。随着模型尺寸的增长,越来越多的人对更高效的训练方法产生了兴趣,例如提示(Prompting)。提示通过包括描述任务的文本提示或甚至演示任务示例的文本提示来为特定的下游任务准备一个冻结的预训练模型。通过使用提示,您可以避免为每个下游任务完全训练单独的模型,而是使用相同的冻结预训练模型。这更加方便,因为您可以将同一模型用于多个不同的任务,而训练和存储一小组提示参数要比训练所有模型参数要高效得多。

提示方法可以分为两类:

  • 硬提示(Hard Prompts):手工制作的具有离散输入标记的文本提示;缺点是需要花费很多精力来创建一个好的提示。
  • 软提示(Soft Prompts):可与输入嵌入连接并进行优化以适应数据集的可学习张量;缺点是它们不太易读,因为您不是将这些“虚拟标记”与实际单词的嵌入进行匹配。

4.3 IA3

为了使微调更加高效,IA3(通过抑制和放大内部激活来注入适配器)使用学习向量对内部激活进行重新缩放。这些学习向量被注入到典型的基于Transformer架构中的注意力和前馈模块中。这些学习向量是微调过程中唯一可训练的参数,因此原始权重保持冻结。处理学习向量(而不是像LoRA一样对权重矩阵进行学习的低秩更新)可以大大减少可训练参数的数量。

与LoRA类似,IA3具有许多相同的优点:

  • IA3通过大大减少可训练参数的数量使微调更加高效(对于T0模型,IA3模型仅具有约0.01%的可训练参数,而即使是LoRA也有超过0.1%)。
  • 原始的预训练权重保持冻结,这意味着您可以在其之上构建多个轻量级和便携的IA3模型,用于各种下游任务。
  • 使用IA3进行微调的模型性能与完全微调的模型性能相当。
  • IA3不会增加任何推理延迟,因为适配器权重可以与基础模型合并。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【自然语言处理】大模型高效微调:PEFT 使用案例 的相关文章

随机推荐

  • This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.

    困难 是智者的机遇 是人与人差距所在 疑惑 D pytools anaconda PyCharm 2018 3 5 helpers pycharm matplotlib backend backend interagg py 62 User
  • Vue项目中grid布局的应用

    Vue项目中grid布局的应用 一 使用背景 二 常见属性 1 grid template 属性 1 1 columns列相关配置 1 1 1 指定列的个数 1 1 2 auto fill属性 自动填充 1 1 3 fr 比例关系 1 1
  • windows向linux传送文件

    windows与Linux之间传送文件 1 用putty的内置小组件PSCP exe 此法可行 pscp exe 可从putty官方下载 然后放到 windows 的c windows system32目录下 这样cmd 命令提示符窗口 输
  • linux下使用ffmpeg录屏

    linux系统中 使用ffmpeg进行录屏与截图 把 dev fb0设备的framebuffer显示图像录制为视频 ffmpeg f fbdev framerate 10 i dev fb0 out avi 编码帧率默认值为25fps 把
  • Android查看应用签名方法

    查看keystore文件签名 查看keystore文件签名信息 前提要有keystore文件和密钥 才能够获取keystore文件的签名信息 打开 AS工具窗口栏右边的 Gradle gt Project gt app gt Tasks g
  • QtCreator设置多个qmake

    qt Creator 有时候需要设置不同qt库文件 也就是不同qmake 我们可以设置 1 Tools gt KIts 然后选择Manual gt add 然后添加Name写5 15或者其它名字 然后点击Qt Version gt Manu
  • PID算法(没办法完全理解的东西)

    快速 P 准确 I 稳定 D P Proportion 比例 就是输入偏差乘以一个常数 I Integral 积分 就是对输入偏差进行积分运算 D Derivative 微分 对输入偏差进行微分运算 输入偏差 读出的被控制对象的值 设定值
  • 24. 二叉搜索树的最近公共祖先

    题目链接 235 二叉搜索树的最近公共祖先 大概思路 题目要求 给定一颗二叉搜索树 两个确定值q p 要求q p的最近公共祖先 思路 利用搜索树的特性 当q p的值均小于遍历的节点值的时候 可以判断q p均在根节点的左子树上 小于则在右子树
  • DUKE大学BOE数据集 OCT图像积液分割数据集

    使用此数据集用来做积液分割研究 地址 http people duke edu sf59 Chiu BOE 2014 dataset htm 使用python将 mat转换为图片格式 对BOE MAT格式文件处理成图片 import cv2
  • 数据生成

    数据生成 MATLAB实现MCMC马尔科夫蒙特卡洛模拟的数据生成 目录 数据生成 MATLAB实现MCMC马尔科夫蒙特卡洛模拟的数据生成 生成效果 基本描述 模型描述 程序设计 参考资料 生成效果 基本描述 1 MATLAB实现MCMC马尔
  • java常见轮询算法

    轮询算法 轮询算法就是通过一个算法 对提供的一组列表进行计算 按照一定规则取出列表中的元素 常见的有顺序模式 随机模式 加权模式 加权平滑模式 定义轮询算法的接口 轮询算法接口 public interface Balance
  • 计费服务器不响应,按小时计费的服务器不开机会计费吗

    按小时计费的服务器不开机会计费吗 内容精选 换一换 按需付费是后付费方式 可以随时开通 删除弹性云服务器 支持秒级计费 系统会根据云服务器的实际使用情况每小时出账单 并从账户余额里扣款 按需付费的弹性云服务器关机再次开机时 可能会出现由于资
  • NMOS作为开关的两种接法

    NMOS作为开关的两种接法 1 左边电路负载是接在S极对地 如果R1很小且Q1 G极一直为High 那么流过Q1的电流可能将会非常大 MOS管容易烧 2 R1 I Us VGS Vg Vs 此时VGS不一定会大于Vgs th MOS会不完全
  • html抽奖概率,求一个可挑概率的html5抽奖 圆盘的

    该楼层疑似违规已被系统折叠 隐藏此楼查看此楼圆盘抽奖 margin 0 padding 0 elm1 height 40px background color a00 elm2 height 50px background color 0a
  • mysql库的安装

    编译文件时找不到mysql库 使用以下命令查看是否安装mysql库 dpkg l grep libmysqlclient dev 安装 sudo apt get install libmysqlclient dev 安装完成可以正常编译
  • Parallels Desktop 17 发布 针对M1大幅优化

    今天 Parallels 公司发布了 Parallels Desktop 17 它对 Windows 11 和 macOS Monterey 进行了适配优化 同时为基于Apple M1 和Intel 芯片的Mac进行图形 性能提升和生产力的
  • 【.NET8】访问私有成员新姿势UnsafeAccessor(上)

    前言 前几天在 NET性能优化群里面 有群友聊到了 NET8新增的一个特性 这个类叫 UnsafeAccessor 有很多群友都不知道这个特性是干嘛的 所以我就想写一篇文章来带大家了解一下这个特性 其实在很早之前我就有关注到这个特殊的特性
  • Windows 常用运行库下载 (DirectX、VC++、.Net Framework等)

    经常听到有朋友抱怨他的电脑运行软件或者游戏时提示缺少什么 d3dx9 xx dll 或 msvcp71 dll msvcr71 dll又或者是 Net Framework 初始化之类的错误而无法正常使用 其实很多时候 只是因为你的电脑没有安
  • kettle8 新插件开发 调试

    参考 https blog csdn net u013468915 article details 82629810 https blog csdn net zougen article details 80825751 基于eclipse
  • 【自然语言处理】大模型高效微调:PEFT 使用案例

    文章目录 一 PEFT介绍 二 PEFT 使用 2 1 PeftConfig 2 2 PeftModel 2 3 保存和加载模型 三 PEFT支持任务 3 1 Models support matrix 3 1 1 Causal Langu