大模型微调代码解析,哪些方法可以加速训练?

2023-10-27

近期大模型层出不穷,大家对于大模型的微调也在跃跃欲试,像Lijia的BELLE,斯坦福的Alpaca[1], 清华的ChatGLM[2],中文的Chinese-Vicuna[3],让我这样的普通玩家也能训练自己的微调模型。

在微调和推理的时候仍然需要加速,有哪些方法可以加速微调呢?

Part1LoRA

低秩矩阵分解 LoRA[4]原理:冻结预训练模型权重,并将可训练的秩分解矩阵注入到Transformer层的每个权重中,大大减少了下游任务的可训练参数数量。LoRA 开源代码[5]见文末。

原理图:

公式:

结合原理图和公式,我们可以很容易明白LoRA了:

左侧是预训练模型的权重,输入输出维度都是d,在训练期间被冻结,不接受梯度更新。

右侧,对A使用随机的高斯初始化,B在训练开始时为零,r是秩,会对△Wx做缩放 α/r。

HuggingFace的包peft[6]对LoRA做了封装支持,几步即可使用:

from peft import get_peft_model, LoraConfig, TaskType

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    inference_mode=False, 
    r=8, 
    lora_alpha=32, 
    lora_dropout=0.1,
    target_modules=['query_key_value']
)

model = "加载的模型"
model = get_peft_model(model, peft_config)
# 打印参数情况
model.print_trainable_parameters()
接下来和正常训练模型一样

论文中提到了LoRA的诸多优点:

Part2Accelerate 和 deepspeed

Accelerate[7]库提供了简单的 API,使我们可以在任何类型的单节点或分布式节点(单CPU、单GPU、多GPU 和 TPU)上运行,也可以在有或没有混合精度(fp16)的情况下运行。

这里是我用Accelerator和DeepSpeedPlugin做个示例:

需要提前知道梯度累积步骤 gradient_accumulation_steps 和 梯度累积计算

from accelerate import Accelerator, DeepSpeedPlugin
import tqdm

model = ...

deepspeed_plugin = DeepSpeedPlugin(
    zero_stage=2, 
    gradient_accumulation_steps=2)

accelerator = Accelerator(
    mixed_precision='fp16', 
    gradient_accumulation_steps=2, 
    deepspeed_plugin=deepspeed_plugin)

device = accelerator.device
... ...
optimizer = ...
lr_scheduler = ...

model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)

for epoch in range(epochs):
    total_loss = 0
    for step, batch in enumerate(t:=tqdm.tqdm(train_dataloader)):
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss_detach = outputs.loss.detach().cpu().float()
            t.set_description(f"loss: {loss_detach}")
            total_loss += loss_detach
            loss = outputs.loss
            # 不再是 loss.backward()
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
    # 每个epoch 保存
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        accelerator.save(model.state_dict(accelerator.unwrap_model(model), '/saved/model.pt')

# 其他参考保存方法
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(save_dir, 
                                save_function=accelerator.save, 
                                state_dict=accelerator.get_state_dict(model))

Part3Autocast 自动混合精度

autocast是在GPU上训练时一种用于降低显存消耗的技术。原理是用更短的总位数来保存浮点数,能够有效将显存消耗降低,从而设置更大的batch来加速训练。但会造成精度的损失,导致收敛效果也会变差。

PyTorch的AMP有2种精度是torch.FloatTensor和torch.HalfTensor。

使用方法:

from torch.cuda.amp import autocast as autocast, GradScaler

dataloader = ...
model = model.cuda()
optimizer = ...
scheduler = ...
# scaler的大小在每次迭代中动态估计,为了尽可能减少梯度,scaler应该更大;
# 但太大,半精度浮点型又容易 变成inf或NaN.
# 动态估计原理就是在不出现if或NaN梯度的情况下,尽可能的增大scaler值。 
scaler = GradScaler()

for epoch in range(epochs):
    for batch_idx, (data, targets) in enumerate(train_dataloader):
        optimizer.zero_grad()
        data = data.cuda(0)
        with autocast(dtype=torch.bfloat16): # 自动混精度
            logits = model(data)
            loss = loss(logits, targets)
        # 反向传播梯度放大
        scaler.scale(loss).backward()
        # 首先 把梯度值unscale回来, 优化器中的值也需要放缩
        # 如果梯度值不是inf或NaN, 则调用optimizer.step()来更新权重,否则,忽略step调用,从而保证权重不更新。
        scaler.step(optimizer)
        # 看是否要增大scaler, 更新scaler
        scaler.update()

Part4单机多GPU、多机多卡

如果条件允许的话,可以使用单机多卡和多机多卡分布式训练。

那么:

  • 模型怎么同步参数与梯度?

  • 数据怎么划分到多个GPU中?

pytorch框架给我们封装了对应的接口函数:

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel

PyTorch提供的torchrun命令以及一些API封装了多进程的实现。 我们只要在普通【单进程程序前后】加入: 开头 setup()和 结尾 cleanup()

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost' # ip
    os.environ['MASTER_PORT'] = '8848'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

就能用多个进程来运行训练程序,每个进程分配一个GPU,我们可以用dist.get_rank()来查看当前进程的GPU号的。

setup()

rank = dist.get_rank()
print(f'Current rank {rank}')
pid = os.getpid()
print(f'current pid: {pid}')
device_id = rank % torch.cuda.device_count()

1数据并行:

只要在生成Dataloader时,把DistributedSampler的实例传入sampler参数就行了,DistributedSampler会自动对数据采样,并放到不同的进程中。这里需要注意的是:sampler自动完成了打乱数据集的作用,所以在定义DataLoader时,不用再开启shuffle选项

dataset = MyDataset()
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)

2模型并行

在并行训练时,各个进程并行,每个模型使用同一份模型参数 weights。在梯度下降时,各个进程会同步一次,致使每个进程的模型都更新相同的梯度。

做法也很简单,只需要把Model套一层DistributedDataParallel,就可以实现backward的自动同步梯,其他的操作都照旧,把新模型ddp_model当成旧模型model调用就行。

model = MyModel().to(device_id)
ddp_model = DistributedDataParallel(model, device_ids=[device_id])
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

训练流程照常:

在每个新epoch中,要用sampler.set_epoch(epoch)更新sampler打乱数据集。训练流程和普通深度学习训练流程一样。

# 这里根据自己的数据格式修改一下
for epoch in range(2):
    sampler.set_epoch(epoch)
    for data in dataloader:
        print(f'epoch {epoch}, rank {rank} data: {data}')
        data = data.to(device_id)
        y = ddp_model(data)
        optimizer.zero_grad()
        loss = loss_fn(data, y)
        loss.backward()
        optimizer.step()

3模型保存和读取:

在保存的时候,我们只需要保存一个进程下的模型即可,另外使用barrier()确保进程1在进程0保存模型之后加载模型。

存储参数时会保存设备信息。由于刚刚只保存了0号GPU进程的模型,所有参数的device都是cuda:0。而读取模型时,每个设备上都要去加载这个模型,device要做一个调整。

# 保存模型。
# 由于每个进程的模型都是一样的,我们只需要保存一个进程下的模型即可。
if rank == 0:
    torch.save(ddp_model.state_dict(), ckpt_path)
dist.barrier()

cleanup()

map_location = {'cuda:0': f'cuda:{device_id}'}
state_dict = torch.load(ckpt_path, map_location=map_location)
print(f'rank {rank}: {state_dict}')
ddp_model.load_state_dict(state_dict)

使用DistributedDataParallel把model封装成ddp_model后,模型的参数名称里多了一个module,这是因为原来的模型model被保存到了ddp_model.module这个成员变量中(model == ddp_model.module)。

在混用单GPU和多GPU的训练代码时,要注意这个参数名不兼容的问题,包括上面我们使用LoRA加载模型的时候,也会出现模型层名称变换了的情况。最好的做法是每次存取ddp_model.module,这样单GPU和多GPU的checkpoint可以轻松兼容。

END

大模型快速微调和训练是我们做自然语言处理必备技能之一,尤其现在大语言模型及其微调模型不断涌现,只有掌握了这些技能才能跟上AI的浪潮。

大模型微调代码解析,哪些方法可以加速训练? - mdnice 墨滴 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

大模型微调代码解析,哪些方法可以加速训练? 的相关文章

  • 【路径规划】基于A*算法路径规划研究(Matlab代码实现)

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码实现
  • 机器学习:基于python微博舆情分析系统+可视化+Django框架 K-means聚类算法(源码)✅

    博主介绍 全网粉丝10W 前互联网大厂软件研发 集结硕博英豪成立工作室 专注于计算机相关专业 毕业设计 项目实战6年之久 选择我们就是选择放心 选择安心毕业 感兴趣的可以先收藏起来 点赞 关注不迷路 毕业设计 2023 2024年计算机毕业
  • 问CHAT很繁琐的问题会不会有答案呢?

    问CHAT 什么已有的基于极值理论的极端温度重现期主要针对极端高温事件 对极端低温事件研究较少 CHAT 回复 为这主要可能是由于以下几个原因 1 气候变化与全球变暖 当前 全球变暖和气候变化的问题备受关注 这导致科研者更加关注极端高温事件
  • Soul App:年轻人的社交状态,还有多少种可能?

    查尔斯 狄更斯在 双城记 的开篇写下 这是最好的时代 这是最坏的时代 这是智慧的时代 这是愚蠢的时代 这是信仰的时期 这是怀疑的时期 人们面前有着各样事物 人们面前一无所有 既然万事万物都和狄更斯所说般 好坏参半 那又何必执着于过去 苦恼于
  • 利用CHAT上传文件的操作

    问CHAT autox js ui 上传框 CHAT回复 上传文件的操作如果是在应用界面中的话 由于Android对于文件权限的限制 你可能不能直接模拟点击选择文件 一般来说有两种常见的解决方案 一种是使用intent来模拟发送一个文件路径
  • 用CHAT分析高校体育智慧教学体系构建与探索研究现状

    CHAT回复 现阶段 高校体育智慧教学体系的构建与探索研究还处于初级阶段 但全球数字化转型大潮的推动下 一些较为前沿的研究和实践已经开始出现 1 教学平台的建设 很多高校已经开始尝试使用在线教育平台进行体育教学 把传统的面对面授课模式转变为
  • 毕业设计:基于卷积神经网络的验证码识别系统 机器视觉 人工智能

    目录 前言 设计思路 一 课题背景与意义 二 算法理论原理 2 1 字符分割算法 2 2 深度学习 三 检测的实现 3 1 数据集 3 2 实验环境搭建 3 3 实验及结果分析 最后 前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实
  • 无人机视角、多模态、模型剪枝、国产AI芯片部署

    无人机视角 多模态 模型剪枝 国产AI芯片部署是当前无人机技术领域的重要研究方向 其原理和应用价值在以下几个方面进行详细讲述 一 无人机视角 无人机视角是指在无人机上搭载摄像头等设备 通过航拍图像获取环境信息 并进行图像处理和分析 这种技术
  • 性能大减80%,英伟达芯片在华“遇冷”,我方霸气回应:不强求

    中国这么大一块市场 谁看了不眼馋 在科技实力大于一切的今天 高端芯片的重要性不言而喻 作为半导体产业发展过程中不可或缺的一环 芯片技术也一直是我国技术发展的一大 心病 在美西方等国的联手压制下 我国芯片技术发展处处受阻 至今也未能在高端芯片
  • 2024 人工智能与大数据专业毕业设计(论文)选题指导

    目录 前言 毕设选题 选题迷茫 选题的重要性 更多选题指导 最后 前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实习为毕业后面临的就业升学做准备 一边要为毕业设计耗费大量精力 近几年各个学校要求的毕设项目越来越难 有不少课题是研究生
  • 如何快速申请GPT账号?

    详情点击链接 如何快速申请GPT账号 一OpenAI 1 最新大模型GPT 4 Turbo 2 最新发布的高级数据分析 AI画图 图像识别 文档API 3 GPT Store 4 从0到1创建自己的GPT应用 5 模型Gemini以及大模型
  • 机器学习算法实战案例:LSTM实现多变量多步负荷预测

    文章目录 1 数据处理 1 1 数据集简介 1 2 数据集处理 2 模型训练与预测 2
  • AI在保护环境、应对气候变化中的作用

    对于AI生命周期数据领域的全球领导者而言 暂时搁置我们惯常的AI见解和AI生命周期数据内容产出 来认识诸如世界地球日这样的自然环境类活动日 似乎是个奇怪的事情 我们想要知道 数据是否真的会影响我们的地球环境 简而言之 是 确实如此 但作为一
  • AI帮助终结全球饥饿问题

    全球饥饿问题是牵动人心的头等大事 5月28日是 世界饥饿日 这一问题更值得关注 让人人都能吃饱的想法不仅令人向往 而且很快就会变成现实 与大多数新事物引进一样 对于在控制世界粮食供应这样复杂的任务中AI究竟应该发挥多大的作用 人们还踟蹰不前
  • 【固定翼飞机】基于最优控制的固定翼飞机着陆控制器设计研究(Matlab代码实现)

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码及文章
  • CorelDRAW2024官方中文版重磅发布更新

    35年专注于矢量设计始于1988年并不断推陈出新 致力为全球设计工作者提供更高效的设计工具 CorelDRAW 滋养并见证了一代设计师的成长 在最短的时间内交付作品 CorelDRAW的智能高效会让你一见钟情 CorelDRAW 全称 Co
  • 15天学会Python深度学习,我是如何办到的?

    陆陆续续有同学向我们咨询 Python编程如何上手 深度学习怎么学习 如果有人能手把手 一对一帮帮我就好了 我们非常理解初学者的茫然和困惑 大量视频 书籍 广告干扰了大家的判断 学习Python和人工智能 成为内行人不难 为此 我们推出了
  • 自动驾驶离不开的仿真!Carla-Autoware联合仿真全栈教程

    随着自动驾驶技术的不断发展 研发技术人员开始面对一系列复杂挑战 特别是在确保系统安全性 处理复杂交通场景以及优化算法性能等方面 这些挑战中 尤其突出的是所谓的 长尾问题 即那些在实际道路测试中难以遇到的罕见或异常驾驶情况 这些问题暴露了实车
  • 两个月进口猛增10倍,买近百台光刻机,难怪ASML不舍中国市场

    据统计数据显示 2023年11月和12月 中国从荷兰进口的光刻机设备同比猛增10倍 进口金额超过19亿美元 让ASML赚得盆满钵满 ASML早前表示中国客户在2023年订购的光刻机全数交付 2023年11月中国进口的光刻机达到42台 进口金
  • Making Large Language Models Perform Better in Knowledge Graph Completion论文阅读

    文章目录 摘要 1 问题的提出 引出当前研究的不足与问题 KGC方法 LLM幻觉现象 解决方案 2 数据集和模型构建

随机推荐

  • Hadoop分布式集群时间同步(ntp)配置

    目录 时间服务器配置 必须root用户 1 查看所有节点ntpd 时间服务器 服务状态和开机自启状态 2 修改hadoop102的ntp conf配置文件 3 重新启动ntpd服务并设置开机自启 配置其他服务器 1 关闭所有节点上的ntpd
  • vm options什么意思_什么是锂电池保护板,保护板的基础知识和不良分析!

    点击上面 电动知家 可以订阅哦 锂电池保护板是对串联锂电池组的充放电保护 在充满电时能保证各单体电池之间的电压差异小于设定值 一般 20mV 实现电池组各单体电池的均充 有效地改善了串联充电方式下的充电效果 同时检测电池组中各个单体电池的过
  • SiriKit 新变化:让 Intent 更强大

    Python实战社群 Java实战社群 长按识别下方二维码 按需求添加 扫码关注添加客服 进Python社群 扫码关注添加客服 进Java社群 作者 wiilen iOS 开发者 来源丨老司机技术周报 ID LSJCoding Sessio
  • mysql 视图的作用

    转自 http blog csdn net fm0517 article details 5625949 视图是从一个或几个基本表 或视图 导出的表 它与基本表不同 是一个虚表 数据库只存放视图的定义 而不存放视图对应的数据 这些数据仍存放
  • JS赋值运算符详解

    赋值运算符左侧的操作数必须是变量 对象属性或数组元素 也称为左值 例如 下面的写法是错误的 因为左侧的值是一个固定的值 不允许操作 1 100 返回错误 赋值运算有以下两种形式 简单的赋值运算 把等号右侧操作数的值直接复制给左侧的操作数 因
  • [下载演讲稿]数字藏品与元宇宙存储—数字新世界的“土壤”

    和上次 下载 元宇宙存储 演讲稿 相比 增加了 1 两厅印发的 关于推进实施国家文化数字化战略的意见 对数字藏品的发展有积极促进作用 2 NFT和数字藏品的分类 新玩法 高质量体验 守诺 受朱嘉明老师 朱嘉明 数字经济和非同质时代 NFT
  • Java BigInteger的使用

    前言 在Java中 由CPU原生提供的整型最大范围是64位 long 型整数 使用 long 型整数可以直接通过CPU指令进行计算 速度非常快 但是如果我们使用的整数范围超过了 long 型怎么办 这个时候 就只能用软件来模拟一个大整数 j
  • unity 3D RPG高级教程(十)

    目录 声明 1 Action Button 快捷栏按键 2 Stats Info 显示 Player 相关信息 3 Change Animator 切换动画控制器 4 Item Tooltip 物品信息显示栏 5 Loot Items 掉落
  • misc.func.php,完美解决 discuz 您的管理面板已经锁定!

    出现 对不起 由于您多次输入错误密码 所以管理面板暂时锁定 您现在无法进入管理面板 15 分钟以后 锁定会自动解除 的提示 是出于安全的考虑 在您连续输入五次密码 仍然没有成功登陆的情况下所提示的 并且会在 15 分钟内禁止此 IP 再次登
  • mysql怎样设置默认值约束_MySQL默认值约束怎么用

    本篇文章将介绍default 默认约束 如何使用和改动后的效果 常用数据库约束 default 默认约束 not null 非空约束 指定某列不为NULL unique 唯一约束 指定某列和几列组合的数据不能重复 primary key 主
  • 【C/C++多线程编程之九】pthread读写锁

    多线程编程之读写锁 Pthread 是 POSIX threads 的简称 是POSIX的 线程标准 pthread读写锁把对共享资源的访问者分为读者和写者 读者只对共享资源进行读访问 写者只对共享资源进行写操作 在互斥机制 读者和写者都需
  • 详解Nodejs中的模块化

    Nodejs是一个基于Chrome V8引擎的JavaScript运行时环境 它允许开发者使用JavaScript在服务器端运行代码 在Nodejs中 模块化是一种组织和重用代码的重要方式 模块化允许我们将代码拆分成小块 使得代码结构更清晰
  • Windows10访问Ubuntu子系统(WSL)的桌面环境

    Windows10访问Ubuntu子系统 WSL 的桌面环境 文章目录 Windows10访问Ubuntu子系统 WSL 的桌面环境 Why Linux Why WSL 开启WSL Ubuntu换源 更新与升级 安装桌面环境xubuntu
  • Pytorch-YOLOV4-火焰目标检测

    首先感谢大佬提供的代码bubbliiiing 0 效果展示 1 所需环境 torch 1 2 0 2 注意事项 代码中的yolo4 weights pth是基于608x608的图片训练的 代码中的默认anchors是基于608x608的图片
  • SystemC在Ubuntu16.04上安装测试

    使用SystemC进行硬件仿真 环境 linux x86 64 bash g 下载解压SystemC SystemC下载地址 解压下载的包 tar zxvf systemc 2 3 3 tar gz 进入解压出来的目录 准备编译安装 cd
  • 自动控制原理快速入门+理解

    用最简单的话认识全貌 PS 默认都是线性系统 即输入和输出之间是线性的 默认你知道什么是线性 初步认识控制 假设你在推箱子 你推的力气是 f f f 箱子位移是 x x x 质量是
  • 09-----检测对方网络是否正常在线的方法

    1 使用ping ping是我们经常检测对方能否正常通信的方法 它使用的协议是传输层中的ICMP 下面是我自己抓包的内容 2 使用telnet telnet一般是检测某个端口是否开放 格式为 telnet ip port 3 使用nc命令
  • 汇编语言DW、DB和DD的区别

    DW 是定义2字节空间的意思 DW属于汇编的一个伪指令 DW定义字类型变量 一个字数据占2个字节单元 读完一个 偏移量加2 DB定义字节类型变量 一个字节数据占1个字节单元 读完一个 偏移量加1 DD定义双字类型变量 一个双字数据占4个字节
  • 7.26 二进制练习题

    给你个礼物你能收到吗 打开这个exe文件后 我们看到了它让我们输入礼物提取码 我们先随便输入数据 按回车显示提取码错误还有输错的次数 我们发现这里存在着一个循环 然后我们在IDA里面打开这个文件 int cdecl main int arg
  • 大模型微调代码解析,哪些方法可以加速训练?

    近期大模型层出不穷 大家对于大模型的微调也在跃跃欲试 像Lijia的BELLE 斯坦福的Alpaca 1 清华的ChatGLM 2 中文的Chinese Vicuna 3 让我这样的普通玩家也能训练自己的微调模型 在微调和推理的时候仍然需要