Llama2 指令精调脚本

2023-11-08

指令精调脚本

⚠️重要提示⚠️

  • 该代码仅适用于特定PEFT版本,运行脚本前请从源码安装commit id为13e53fc的Peft

  • 如果使用其他版本的PEFT或修改部分训练参数设置(如不使用deepspeed),不能保证模型可以正常训练。

  • 运行前确保拉取仓库最新版代码:git pull

训练步骤

进入项目的scripts/training目录,运行bash run_sft.sh进行指令精调,默认使用单卡。运行前用户应先修改脚本并指定相关参数,脚本中的相关参数值仅供调试参考。run_sft.sh的内容如下:

########参数部分########
lr=1e-4
lora_rank=64
lora_alpha=128
lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj"
modules_to_save="embed_tokens,lm_head"
lora_dropout=0.05

pretrained_model=path/to/hf/llama-2/or/merged/llama-2/dir/or/model_id
chinese_tokenizer_path=path/to/chinese/llama-2/tokenizer/dir
dataset_dir=path/to/sft/data/dir
per_device_train_batch_size=1
per_device_eval_batch_size=1
gradient_accumulation_steps=1
output_dir=output_dir
peft_model=path/to/peft/model/dir
validation_file=validation_file_name
max_seq_length=1024

deepspeed_config_file=ds_zero2_no_offload.json

########启动命令########
torchrun --nnodes 1 --nproc_per_node 1 run_clm_sft_with_peft.py \
    --deepspeed ${deepspeed_config_file} \
    --model_name_or_path ${pretrained_model} \
    --tokenizer_name_or_path ${chinese_tokenizer_path} \
    --dataset_dir ${dataset_dir} \
    --validation_split_percentage 0.001 \
    --per_device_train_batch_size ${per_device_train_batch_size} \
    --per_device_eval_batch_size ${per_device_eval_batch_size} \
    --do_train \
    --do_eval \
    --seed $RANDOM \
    --fp16 \
    --num_train_epochs 2 \
    --lr_scheduler_type cosine \
    --learning_rate ${lr} \
    --warmup_ratio 0.03 \
    --weight_decay 0 \
    --logging_strategy steps \
    --logging_steps 10 \
    --save_strategy steps \
    --save_total_limit 3 \
    --evaluation_strategy steps \
    --eval_steps 250 \
    --save_steps 500 \
    --gradient_accumulation_steps ${gradient_accumulation_steps} \
    --preprocessing_num_workers 8 \
    --max_seq_length ${max_seq_length} \
    --output_dir ${output_dir} \
    --overwrite_output_dir \
    --ddp_timeout 30000 \
    --logging_first_step True \
    --lora_rank ${lora_rank} \
    --lora_alpha ${lora_alpha} \
    --trainable ${lora_trainable} \
    --modules_to_save ${modules_to_save} \
    --lora_dropout ${lora_dropout} \
    --torch_dtype float16 \
    --validation_file ${validation_file} \
    --peft_path ${peft_model} \
    --gradient_checkpointing \
    --ddp_find_unused_parameters False

其中一些参数的含义不言自明。部分参数的解释如下:

  • --tokenizer_name_or_path: Chinese-LLaMA-2 tokenizer所在的目录。⚠️ 本项目中LLaMA-2模型与Alpaca-2模型使用相同的tokenizer,不再进行区分。
  • --dataset_dir: 指令精调数据的目录,包含一个或多个以json结尾的Stanford Alpaca格式的指令精调数据文件
  • --validation_file: 用作验证集的单个指令精调文件,以json结尾,同样遵循Stanford Alpaca格式
  • --flash_attn: 启用FlashAttention-2加速训练

Stanford Alpaca格式如下:

[
  {"instruction" : ...,
   "input" : ...,
   "output" : ...},
  ...
]

该脚本支持以下训练模式。不支持未在表格中的模式,如要修改请自行debug。

模型 model_name_or_path peft_path lora params
基于Chinese-LLaMA-2 LoRA进行指令精调 原版HF格式的LLaMA-2 Chinese-LLaMA-2 LoRA 无需指定
基于Chinese-Alpaca-2 LoRA进行指令精调 原版HF格式的LLaMA-2 Chinese-Alpaca-2 LoRA 无需指定
基于Chinese-LLaMA-2训练全新的指令精调LoRA权重 完整(合并Chinese-LLaMA-2-LoRA后)的HF格式Chinese-LLaMA-2模型 勿提供此参数,并且从脚本中删除 --peft_path 需设置--lora_rank--lora_alpha--lora_dropout--trainable--modules_to_save参数
基于Chinese-Alpaca-2训练全新的指令精调LoRA权重 完整(合并Chinese-Alapca-2-LoRA后)的HF格式Chinese-Alpaca-2模型 勿提供此参数,并且从脚本中删除 --peft_path 需设置--lora_rank--lora_alpha--lora_dropout--trainable--modules_to_save参数

这里列出的其他训练相关超参数(尤其是学习率,以及和total batch size大小相关的参数)仅供参考。请在实际使用时根据数据情况以及硬件条件进行配置。

节省显存小提示

  • 如果机器的显存比较紧张,可以删去脚本中的--modules_to_save ${modules_to_save} \, 即不训练embed_tokens和lm_head(这两部分参数量较大),只训练LoRA参数。
    • 如果是在已有LoRA基础上继续微调,需要修改peft_path下的adapter_config.json文件,改为"modules_to_save": null
    • 如果执行修改后程序报错,请删除--gradient_checkpointing \再尝试
  • 减小max_seq_length也可降低训练时显存占用,如可将block_size设置为512。

使用多机多卡

请参考以下启动方式:

torchrun \
  --nnodes ${num_nodes} \
  --nproc_per_node ${num_gpu_per_node} 
  --node_rank ${node_rank} \
  --master_addr ${master_addr} \
  --master_port ${master_port} \
  run_clm_sft_with_peft.py \
    ...

训练后文件整理

训练后的LoRA权重和配置存放${output_dir}/sft_lora_model,可用于后续的合并流程。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Llama2 指令精调脚本 的相关文章

随机推荐

  • Python爬虫可以干什么?Python入门必看!

    在爬虫领域 Python几乎是霸主地位 虽然C Java GO等编程语言也可以写爬虫 但Python更具优势 不仅拥有优秀的第三方库 还可以为我们做很多的事情 那么Python爬虫可以干什么 Python爬虫有什么用 想必很多人都比较好奇
  • 【机考】华为OD2022.11.01机考题目思路与代码

    题目一 描述 输入一个长度为4的倍数的字符串 字符串中仅包含WASD四个字母 将这个字符串中的连续子串用同等长度的仅包含WASD的字符串替换 如果替换后整个字符串中WASD四个字母出现的频数相同 那么我们称替换后的字符串是 完美走位 求子串
  • keil5如何打开智能提示

    在使用keil中需要敲上许多重复代码 并且经常需要调用别人写好的包 这时候我们总不能每句代码都重复的敲一遍 这样不仅没有效率 还要去花时间记住许多自己或许不常用的代码 这时候就需要智能提示来帮助我们了 第一步 打开编辑Edit 目录里找到设
  • Kubernetes(k8s)安装和搭建集群时kubeadm init失败

    Kubernetes k8s 按官方文档描述安装和搭建集群遇到kubelet状态异常 环境 Cenots 7 9 2009 adm64 我在搭建master节点时通过以下命令安装了docker kubelet kubectl kubeadm
  • 建立实体-关系模型(案例)

    一 标识实体 通常有用户 角色这两个实体 二 标识关系 用户与角色间为多对多的互相拥有关系 三 标识实体 关系的属性 不仅仅是实体有属性 关系同样也有属性 这些属性在实体间建立关系时才会存在 有时属性太多 无法在图上一一列出 可以用表格 在
  • AndroidStudio运行项目时的Run/debug configurations问题

    今天遇到的问题一个接一个 在调试项目时突然不能调试 但并没有报代码出错 看Logcat提示的是Android SDK没配置 还有一个明显不同之处 就是右上角那个显示当前项目名称的地方 显示的是app还有一个红叉 根据提示是配置Android
  • Spring Cloud Bus消息总线

    目录 一 概述简介 1 1 Bus是什么 1 2 Bus能干嘛 1 3 为何被称为总线 二 RabbitMQ环境配置 2 1 windows下载与安装 2 2 使用RabbitMQ 三 Bus动态刷新全局广播 3 1 Bus设计思想 3 2
  • PHP 获取当天凌晨时间戳

    总结几种PHP 获取当天凌晨时间戳方法 首先设置时区 header Content type text html charset utf 8 设置北京时间为默认时区 date default timezone set PRC 方法一 当天的
  • Django Error——Requested setting INSTALLED_APPS, but settings are not configured.

    django core exceptions ImproperlyConfigured Requested setting INSTALLED APPS but settings are not configured You must ei
  • jupyter notebook主题、字体、字号管理工具

    jupyter notebook编写 调试代码非常方便 但是其默认主题和字体实在是太难看了 因此大家一般都有修改主题的想法 感谢GitHub上的大神提供了一款主题管理工具 网上已经有文章提出其使用方法 如 jupyter notebook
  • Servlet基础_0500_Application

    一 application概念 application即ServletContext 能够被所有的客户端页面共享 不同的浏览器 不同电脑上的浏览器 演示 ServletContextTest java package com servlet
  • docker下使用apt install报错E: Unable to locate package

    解决方法 方法1 方法2 问题背景 由于docker环境是独立的 gcc vim等需要重新安装 输入安装命令 sudo apt install gcc 7 报错 E Unable to locate package gcc 7 原因是软件源
  • airpods固件更新方法_AirPods Pro迎来首个固件更新,检查耳机版本及更新方法

    airpods pro AirPods Pro推出了一段时间 获得一致好评 但有不少bug存在 针对此 苹果推出了airpods Pro的Firmware 固件 更新 早前购买的AirPods Pro都是 2B584 版本 在11月15日
  • Linux网络发送流程概述

    Linux网络的数据发送 本文主要是学习一下有关Linux 基于Linux3 10 网络层数据写入的流程 在Linux中通过网络写入的数据是如何发送到设备层 socket数据写入 在应用层一般写入的往已经创建好的连接进行数据发送的都会使用s
  • ubuntu20.04下载谷歌浏览器

    第一步 打开终端输入 wget https dl google com linux direct google chrome stable current amd64 deb 第二步 在终端中输入 sudo apt install goog
  • 关于临时表空间问题总结

    oracle经常需要查数据库临时表空间大小 使用率 加表空间等 这里总结临时表空间相关的语句 0 查看实例的临时表空间 SELECT FROM dba tablespaces t where t CONTENTS TEMPORARY SEL
  • 移动端通用404页面

  • 【分享NVIDIA GTC 23大会干货】在 GPU 上使用 Video Codec SDK,CV-CUDA 和 TensorRT 加速现代云上视频应用 [SE51229]

    在 GPU 上使用 Video Codec SDK CV CUDA 和 TensorRT 加速现代云上视频应用 前言 基于现代 的视频流水线架构与运用场景 NVIDIA 视频处理的工具集 1 视频编解码工具 2 前后处理部分 CV CUDA
  • Python中的魔术方法详解

    介绍 在Python中 所有以 双下划线包起来的方法 都统称为 Magic Method 中文称 魔术方法 例如类的初始化方法 init Python中所有的魔术方法均在官方文档中有相应描述 但是对于官方的描述比较混乱而且组织比较松散 很难
  • Llama2 指令精调脚本

    指令精调脚本 重要提示 该代码仅适用于特定PEFT版本 运行脚本前请从源码安装commit id为13e53fc的Peft 如果使用其他版本的PEFT或修改部分训练参数设置 如不使用deepspeed 不能保证模型可以正常训练 运行前确保拉