ColossalAI-Chat训练手册(RLHF)

2023-11-05

目录

1、什么的RLHF流程?

2、环境安装

3、训练&运行

3.1、模型下载

3.1、SFT(supervised fine-tuning)

3.2、训练奖励模型(Training reward model)

3.3、RL(Training model using prompts with RL)

3.4、使用模型进行应答

3.5、playground

3.6、应答效果

4、异常记录

4.1 llama爆显存

4.2 bloom模型报Error while deserializing header: HeaderTooLarge

4.3 wandb异常

4.4 RL 训练爆显存

4.5 模型加载应答



源码:GitHub - hpcaitech/ColossalAI: Making large AI models cheaper, faster and more accessible

官网:快速演示 | Colossal-AI

官方demo:ColossalChat

ColossalAI-Chat是一款基于人工智能技术的智能聊天机器人,是由Colossal AI开发的一款聊天机器人。该机器人使用了最先进的自然语言处理技术和深度学习算法,可以回答各种问题、提供建议、提供娱乐和与用户进行轻松对话。ColossalAI-Chat可以在多种平台上使用,例如Facebook Messenger、Slack、WeChat等。

ColossalAI-Chat通过使用自然语言处理技术和深度学习算法,机器人可以理解人类语言的含义,从而生成更加自然和准确的回答。在聊天过程中,机器人可以不断学习和优化自己的回答能力,提高其整体的智能水平。

随着ChatGPT的火爆,业界内也有很多机构开始着手训练自己的大语言模型,比如百度的文心一言,阿里的通义千问等。那么训练自己的模型,需要做些什么呢?RLHF流程又该如何复现?ColossalAI开源了一套方案,但是在复现过程中也有很多坑,接下来看看如何复现吧。

1、什么的RLHF流程?

在大语言模型的训练过程中,RLHF通常指的是“Reinforcement Learning based Heuristic Fine-tuning”(基于强化学习的启发式微调)。RLHF是指在训练大型语言模型时,使用强化学习算法对模型进行微调,以进一步提高其性能。RLHF的主要目标是通过引入额外的语言模型内部评估指标,使得语言模型在生成文本时更加准确和流畅。

RLHF可以分为以下几个阶段:

  • 预训练阶段(Pre-training):在此阶段中,使用大量的未标注文本数据来训练初始的语言模型,通常使用无监督学习算法,如BERT、GPT等。
  • 微调阶段(Fine-tuning):在此阶段中,使用有标注的任务数据对语言模型进行微调,使其能够完成具体的任务。此阶段的任务可以是文本分类、命名实体识别、问答等。
  • 强化学习微调阶段(RLHF Fine-tuning):在此阶段中,使用强化学习算法对语言模型进行微调,以进一步提高其性能。强化学习算法可以根据所生成的文本序列的整体质量,对语言模型进行反馈和调整。
  • 启发式微调阶段(Heuristic Fine-tuning):在此阶段中,通过设计一些启发式规则,对语言模型进行微调,以进一步提高其性能。启发式规则可以是语言学知识、常识知识等。

这些阶段在语言模型训练中通常是相互关联的,且不一定是线性的顺序,可能会进行多次迭代和交叉训练。RLHF在语言模型训练中扮演了重要的角色,能够帮助语言模型更好地理解和生成自然语言,提高其在各种任务上的表现。

2、环境安装

git clone https://github.com/hpcaitech/ColossalAI

# 创建环境
conda create -n ColossalAI-Chat python=3.10

conda activate ColossalAI-Chat

# 安装依赖
pip install . -i https://mirrors.aliyun.com/pypi/simple/

cd applications/Chat

pip install . -i https://mirrors.aliyun.com/pypi/simple/

git clone https://github.com/hpcaitech/transformers
cd transformers
pip install . -i https://mirrors.aliyun.com/pypi/simple/

pip install pytest -i https://mirrors.aliyun.com/pypi/simple/

数据集:

InstructionWild/data at main · XueFuzhao/InstructionWild · GitHub

3、训练&运行

3.1、模型下载

这一步可以不做,不做的话默认会在将模型下载到 ~/.cache 目录

模型文件较大,需要安装git lfs,否则模型可能损坏

curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
sudo apt-get install git-lfs
git lfs install

例如我将模型下载到

/data/chenhao/train/ColossalAI/models目录下,那么需要操作

cd /data/chenhao/train/ColossalAI/models
git lfs install
git clone https://huggingface.co/bigscience/bloom-560m

如何按照这种方式,那么下面模型的路径也需要改,否则他还是会下载模型到~/.cache 目录

3.1、SFT(supervised fine-tuning)

torchrun --standalone --nproc_per_node=4 train_sft.py \
    --pretrain "bigscience/bloom-560m" \
    --model 'bloom' \
    --strategy colossalai_zero2 \
    --log_interval 10 \
    --save_path  /data/chenhao/train/ColossalAI/Coati-7B \
    --dataset /data/chenhao/train/ColossalAI/data.json \
    --batch_size 4 \
    --accimulation_steps 8 \
    --lr 2e-5 \
    --max_datasets_size 512 \
    --max_epochs 1 

3.2、训练奖励模型(Training reward model

torchrun --standalone --nproc_per_node=4 train_reward_model.py \
    --pretrain "/data/chenhao/train/ColossalAI/Coati-7B/" \
    --model 'bloom' \
    --strategy colossalai_zero2 \
    --loss_fn 'log_exp'\
    --save_path "/data/chenhao/train/ColossalAI/rmstatic.pt"

这里面 --pretrain 参数,从官方文档上看不明白是第一步的产出模型还是原模型,希望有大佬解答。

资源占用情况

3.3、RL(Training model using prompts with RL)

torchrun --standalone --nproc_per_node=4 train_prompts.py \
         --pretrain "bigscience/bloom-560m" \
         --model 'bloom' \
         --strategy colossalai_zero2 \
         --prompt_path /data/chenhao/train/ColossalAI/prompt_dataset/data.json \
         --pretrain_dataset /data/chenhao/train/ColossalAI/pretrain_dataset/data.json \
         --rm_pretrain /data/chenhao/train/ColossalAI/Coati-7B \
         --rm_path /data/chenhao/train/ColossalAI/rmstatic.pt \
         --train_batch_size 4 \
         --experience_batch_size 4 \
         --max_epochs 1 \
         --num_episodes 1

为了快速走完流程,我这里的数据集和第一步数据集实际上是同一份数据集,这里应该是需要自行准备数据集的。
 

3.4、使用模型进行应答

ColossalAI/inference.py at main · hpcaitech/ColossalAI · GitHub

python chat.py --model=bloom --pretrain="bigscience/bloom-560m" --model_path="/data/chenhao/train/ColossalAI/rmstatic.pt" --input  你好


回复

你好,谢谢。\n我有一个问题:我有一个小公司,公司里有员工需要去银行开户,银行需要通过公司的网站查询客户信息,银行需要审核客户的资料和要求,如果银行审核成功,那么需要客户在规定的时间内给银行支付一定的利息。银行也知道公司的客户需要这些资料,所以银行可能会给客户转账,那么是不是银行可以不审核或者审核时不给转账,而只是给一些现金呢?\n这样我们就要问银行,如果公司要

3.5、playground

import gradio as gr

import torch
from coati.models.bloom import BLOOMActor
from transformers import AutoTokenizer

MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2

# 这里换成自己模型的路径
model_path_dict = {
    'SFT': '/data/chenhao/train/ColossalAI/Coati-7B/pytorch_model.bin',
    'RM': '/data/chenhao/train/ColossalAI/rmstatic.pt',
    'RL': '/data/chenhao/train/ColossalAI/actor_checkpoint_prompts/pytorch_model.bin',
}


def predict(model, input, max_length, history):
    updates = []
    actor = BLOOMActor(pretrained='bigscience/bloom-560m').to(torch.cuda.current_device())
    state_dict = torch.load(model_path_dict[model])
    actor.model.load_state_dict(state_dict, strict=False)

    tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
    tokenizer.pad_token = tokenizer.eos_token
    actor.eval()
    question = f'Question: {input} ? Answer:'
    input_ids = tokenizer.encode(question, return_tensors='pt').to(torch.cuda.current_device())
    outputs = actor.generate(input_ids,
                             max_length=max_length,
                             do_sample=True,
                             top_k=50,
                             top_p=0.95,
                             num_return_sequences=1)
    output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
    for i in history:
        if not i.get('visible'):
            continue
        print(i)
        value = i.get('value')
        updates.append(gr.update(visible=True, value=value))

    updates.append(gr.update(visible=True, value="提问:" + input))
    updates.append(gr.update(visible=True, value=f"{model}:" + output[0].replace(question, '').replace(question.replace(' ', ''), '')))
    if len(updates) < MAX_BOXES:
        updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
    history.extend(updates)
    return [history] + updates


with gr.Blocks() as demo:
    state = gr.State([])
    text_boxes = []

    with gr.Row():
        with gr.Column(scale=1):
            model = gr.Radio(["SFT", "RM", "RL"], label="model",
                             interactive=True, value='SFT')
            max_length = gr.Slider(0, 200, value=100, step=1.0, label="max_length", interactive=True)
            button = gr.Button("Generate")

        with gr.Column(scale=4):
            for i in range(MAX_BOXES):
                if i % 2 == 0:
                    text_boxes += [gr.Markdown(visible=False, label="提问:")]
                else:
                    text_boxes += [gr.Markdown(visible=False, label="回复:")]
            input = gr.Textbox(show_label=True, placeholder="input", lines=5, label='input').style(container=False)

    button.click(predict, [model, input, max_length, state],
                 [state] + text_boxes)
demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0')

3.6、应答效果

这里就完全是调用自己的模型进行应答,我这里因为基座模型bloom-560m就是小模型,加上训练的超参数我都调整到最小跑,因此效果一般。

4、异常记录

4.1 llama爆显存

[BUG]: ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -9) local_rank: 3 (pid: 812917) of binary · Issue #3514 · hpcaitech/ColossalAI · GitHub


 

[BUG]: LlamaRM model has no attribute 'resize_token_embeddings' · Issue #3389 · hpcaitech/ColossalAI · GitHub


 

方案:跑小一点的模型(bloom bloom-560m)

4.2 bloom模型报Error while deserializing header: HeaderTooLarge

方案:使用transformers加载预训练模型
 

torchrun --standalone --nproc_per_node=4 train_sft.py \
    --pretrain "bigscience/bloom-560m" \
    --model 'bloom' \
    --strategy colossalai_zero2 \
    --log_interval 10 \
    --save_path  /data/chenhao/train/ColossalAI/Coati-7B \
    --dataset /data/chenhao/train/ColossalAI/data.json \
    --batch_size 4 \
    --accimulation_steps 8 \
    --lr 2e-5 \
    --max_datasets_size 512 \
    --max_epochs 1

4.3 wandb异常

需要我们一直重复选择

直接禁用

wandb disabled

4.4 RL 训练爆显存

按照最小规格跑
 

4.5 模型加载应答

actor.model.load_state_dict(state_dict)

# 改为

actor.model.load_state_dict(state_dict, strict=False)

欢迎关注我们的微信公众号IT一氪,我们将不定期更新AI、大数据相关的高质量文章。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

ColossalAI-Chat训练手册(RLHF) 的相关文章

随机推荐

  • 面试宝典----数据库(总结来自知乎路人甲)

    一 什么是存储过程 有哪些优缺点 存储过程是一些预编译的SQL语句 更加直白的理解 存储过程可以说是一个记录集 它是由一些T SQL语句组成的代码块 这些T SQL语句代码像一个方法一样实现一些功能 对单表或多表的增删改查 然后再给这个代码
  • svn 打patch

    patch patch 即 补丁 的意思 当代码有改动的时候 svn会产生diff 可以查看diff和打patch 使用Mac终端来打patch也是非常方便的 首先查看本地的修改 确认无误后 使用 svn diff gt PATCH 命令可
  • C++ main函数中参数argc和argv含义及用法

    argc 是 argument count的缩写 表示传入main函数的参数个数 argv 是 argument vector的缩写 表示传入main函数的参数序列或指针 并且第一个参数argv 0 一定是程序的名称 并且包含了程序所在的完
  • 2022美赛C题思路分享

    美赛c题 比特币和金子投资分析 问题翻译 下附思路 1 问题分析 本题题目理解较为简单 就是利用历史数据对于投资策略的分析 每一天的决策只能使用之前的历史数据 求解最佳的投资回报 并分析模型的可行性 2模型准备 时间序列分析模型选择 以及模
  • 学习实践-Whisper语音识别模型实战(部署+运行)

    1 Whisper内容简单介绍 OpenAI的语音识别模型Whisper Whisper 是一个自动语音识别 ASR Automatic Speech Recognition 系统 OpenAI 通过从网络上收集了 68 万小时的多语言 9
  • 【matplotlib】饼图+legend()、loc、color位置颜色图例中文显示(一个饼图的例子)

    博客已经搬家到 捕获完成 https www v2python com 1 原来自己做的饼图 http mp blog csdn net postedit 79222127 见文章 matplotlib 中文显示 负号显示 统计微信好友性别
  • 《再也不怕elasticsearch》Spring Boot集成Elasticsearch

    大家好我是迷途 一个在互联网行业 摸爬滚打的学子 热爱学习 热爱代码 热爱技术 热爱互联网的一切 再也不怕elasticsearch系列 帅途会慢慢由浅入深 为大家剖析一遍 各位大佬请放心 虽然这个系列帅途有时候更新的有点慢 但是绝对不会烂
  • django获取某一个字段的列表,values/values_list/flat

    django获取某一个字段的列表 values values list flat 2017年11月01日 11 43 28 阅读数 2241 python view plain copy class Building models Mode
  • C语言实现邻接矩阵(无向图的顺序表示)

    文章目录 有向 无向不带权图 带权图 定义图的结构体 初始化 分析 分配堆空间 对矩阵的行开辟空间 对矩阵 即二维数组 进行初始化 edge 0 0 edge 0 9 edge 1 0 edge 1 9 edge 2 0 edge 2 9
  • 全国物流快递查询网址大全

    http www kiees cn default htm
  • VMware&Linux详细安装步骤

    VMware Linux详细安装步骤 一 VmWare虚拟机的安装 1 安装虚拟机 注意 虚拟机安装完成后会在网络连接中多出两个虚拟网卡 二 在虚拟机上安装CentOS 1 创建新虚拟机 文件 新建虚拟机 或 直接点击 创建新的虚拟机 图标
  • python1_2列表(2)

    列表增删改查 1 增 all in list 0 3 hello True all in list append hello world 新增元素 print all in list 运行结果 2 插入 all in list 0 3 he
  • 机器学习-基础

    欢迎来到机器学习的世界 博客主页 卿云阁 欢迎关注 点赞 收藏 留言 本文由卿云阁原创 本阶段属于练气阶段 希望各位仙友顺利完成突破 首发时间 2021年5月5日 希望可以和大家一起完成进阶之路 作者水平很有限 如果发现错误 请留言轰炸哦
  • Java学习心得4——Java中的包是什么

    Java中的包完全可以理解成一个文件夹 如果你不信 我们可以做一些测试 1 我们先在eclipse中创建一个项目 java Project 命名为test 2 我们可以在文件资源管理器的中找到这个项目 3 我们双击进入test文件夹 再进入
  • 系统架构设计专业技能 · 信息安全技术

    点击进入系列文章目录 现在的一切都是为将来的梦想编织翅膀 让梦想在现实中展翅高飞 Now everything is for the future of dream weaving wings let the dream fly in re
  • 聊聊 cookie 管理那些事

    1 前言 在浏览内核加载网络资源的过程中我们离不开 HTTP 协议 它是在 Web 上进行数据交换的基础 同时也是一种无状态的 client server 协议 这种无状态的属性促使许多端存储技术产生 其中最重要的技术之一就是 cookie
  • 深入理解自增自减运算符,看懂表达式不糊涂

    自增运算符 和自减运算符 在算术表达式中容易造成使用上的错误 主要原因有两点 一是自增运算符和自减运算符在变量前后的位置不一样 其内部逻辑不一样 二是自增运算符和自减运算符只能用于变量 不能用于常量 首先讲解一下自增自减运算符的概念 自增自
  • lcm in qcom

    文章目录 lcm需要生产的相关文件 lcm in lk lcm in kernel 一些注意的事项 其他平台 sdm845 in kernel in uefi lcm需要生产的相关文件 根据fae提供的相关资料去配置自己的 xml文件 如下
  • 电信资源管理系统性能测试总结

    1 电信资源管理系统性能测试总结 陈建慧 2007 7 30 1 1 技术问题与解决方法 1 1 1 Loadrunner JAVA脚本 唯一参数问题 最初采用JNI 调用delphi的DLL JAVA脚本中未使用static synchr
  • ColossalAI-Chat训练手册(RLHF)

    目录 1 什么的RLHF流程 2 环境安装 3 训练 运行 3 1 模型下载 3 1 SFT supervised fine tuning 3 2 训练奖励模型 Training reward model 3 3 RL Training m