大语言模型浅探一

2023-10-26

目录

1 前言

2 GPT模型解码

3 InstructGPT

4 基于RWKV微调模型

4.1 RWKV简介

4.2 增量预训练

4.3 SFT微调

4.4 RM和PPO

5 测试

6 总结


1 前言

近来,人工智能异常火热,ChatGPT的出现极大的推动了自然语言处理的发展,在推出仅两个月后,月活跃用户已达1亿,成为历史上增长最快的消费应用。OpenAI一直在研究生成式模型,在2018年6月发布了GPT,在2020年5月发布了GPT3,GPT3的模型参数也达到了1750 亿。但是为什么到现在ChatGPT才突然间火起来呢?

众所周知,模型只有达到一定的参数量才会出现涌现能力,但是两年强模型参数已经达到了1750亿,说明单纯的堆模型参数并不能达到想要的效果。或许在OpenAI的一篇论文InstructGPT(Training language models to follow instructions with human feedback)中可以找到答案。接下来本文会简单介绍InstructGPT,然后基于RWKV去复现,之所以选择RWKV是因为其速度快,占用GPU显存低,便于快速实验。

2 GPT模型解码

这里我感觉有必要把生成式模型的解码输出单独拿出来讲一下,只有理解了答案生成的原理才能更好理解InstructGPT所做的工作。

GPT属于生成式预训练语言模型,只采用了Transformers的Decoder结构,并对Decoder进行了一些改动,去掉了第二个Multi-Head Attention。推理解码过程就是利用当前token和前面输入所有token的状态矩阵去预测下一个token的过程,直到输出位终止符。例如输入tokens序列为[u1,u2,u3,u4,u5],词典大小为20000,则输出

y_{n},state_{n}=f(state_{n-1}, u_{n})

 上面公式中,state为前面输入所有词的状态矩阵,即保存的是每个词的词向量,y_{n}是一个1*20000的矩阵,此时的模型若是没有经过任何微调,此时的分布是比较分散的,如何从20000个词中挑选合适的词有几种方案:

第一种:贪心搜索(greedy search),每次都取概率最大的。贪心搜索是局部最优,但是不能保证全局最优。

第二种:集束搜索(beam search),每次取top-b个得分最高的句子,得分的计算很关键,但是这里不是我要讲的重点,感兴趣的可以自己去了解一下。这种方法比贪心搜索好一点儿。

第三种:随机采样,每次按照概率随机取一个,概率越高越有可能被采到。基于此又有了temperature sampling,top-k sampling和top-p sampling等改进方案。随机采用是目前使用的比较频繁的方式,这里简单说一下。

temperature sampling:设置一个temperature参数来控制概率分布的弥散程度,temperature越接近0,则数值间差异被指数级放大,被采样的数值范围就越小,体现在回答上面就是多次解码结果不会有太大变化。

top-k sampling:取概率最高的top-k个词作为候选采样词,剩余的置为0。

top-p sampling:将词按照概率从大到小排序,从第1个词开始累加,直到累加和大于等于top-p为止,作为候选采样词,剩余的置为0。

随机采用一般是几种方法组合使用,但是有时不管如何组合调整,模型还是会遇到不能生成终止符的情况,即产生一个概率为零的无限长序列,实际上也会添加最长长度限制或者随着输出长度的增加,提高终止符的概率等等。

基于以上问题,学术界一直在探索新的方法可以让正确的答案(人们想要的答案)生成的概率更高,从而更容易被解码出来,而InstructGPT利用人工反馈强化学习进行了探索。

3 InstructGPT

InstructGPT论文的地址:https://arxiv.org/pdf/2203.02155.pdf

总结一下就是:大语言模型可能会生成不真实、有害和对用户没有帮助的简单答案。InstructGPT通过在多种任务上面指令微调以及人类反馈强化学习(RLHF)来降低这种无效输出的概率。

从图中可以看到,两个过程是分三步来进行的。监督微调SFT(supervised fine-tuning),训练reward model(RM)和强化学习(RL)微调。

SFT:也有人把这个阶段叫做指令微调,还有叫做模型精调,不过这都不重要,这个阶段的目的只有一个,就是让模型能识别指令。训练完成后体现在解码上就是指令可能有很多种回答,这些回答的概率都很高,通过多次随机采用解码,每次解码基本上都不一致。若是没有经过微调的模型,只能通过prompt利用语言模型的能力来让模型理解指令,这时答非所问的情况就更容易出现。

RM:训练一个得分模型,即可以对模型的输出打分,便于后面训练PPO。得分模型需要人工标注数据,同一个prompt需要有一条好的回答,一条差的回答构成一条训练语料(也可以理解为一条比另一条好就行)。

r为打分模型,可以在SFT的模型基础之上添加全连接层改造成一个分类器得到,x为prompt,yw和yl为RM模型的两个输出,一个是好的回答,一个是不好的回答,最小化loss即将二者的得分差扩大。人工标注的数据训练,则会让符合人类回答的答案得分更高。

RL:主要采用的是近端策略优化模型( reinforcement learning via proximal policy optimization)。架构图参考ColossalAI。强化学习微调,会让得分高的答案,概率更高,得分低的答案,概率更低,提高差异,采样时即使不怎么调整参数,也会只采样到正确的答案。

4 基于RWKV微调模型

我们基于RWKV1.5B模型作为基模型,使用CSDN博客100W篇和问答30W条数据进行中文垂直领域增量训练,使用BELLE开源的50K指令数据进行指令微调。这里将训练后的模型命名为ChatCSDN,作为入门大模型的Hello World来参考学习使用。

项目地址:https://gitcode.net/csdn/ai/chatcsdn

模型参数地址:zxm2023/ChatCSDN · Hugging Face

4.1 RWKV简介

RWKV(Receptance Weighted Key Value)的 介绍除了作者的项目外,目前已出,更详细的可以参考论文:https://arxiv.org/pdf/2305.13048.pdf

想了解RWKV必须先看苹果的AFT(An Attention Free Transformer)论文,它和标准的注意力算法结构相同,同样包含QKV结构,K和V首先与一组学习得到的位置偏差(position bias)结合,然后再进行同位元素Q对应相乘(element-wise multiplication)。

 从上图可以看出,QKV的计算转换成了线性计算,速度得到了极大提升。但是wt是一个需要训练的矩阵,大小为ctx_len*ctx_len,即模型的参数量随输入长度的增加而指数级增加,模型能处理的句子长度严重受限。

而RWKV主要是针对AFT进行了改进,从AFT的attention计算公式可以看出,其时间复杂度为O(N^{2}),wt改成了通过公式计算出来,不再是训练得到的,于是时间复杂度就变成了O(N),同时也能扩展到任意长度。主要是引入了Position Matrix、Time-shit、TimeMix和ChannelMix 等结构,相较于原始的GPT结构,RWKV将self-attention替换为Position Encoding和TimeMix,将FFN替换为ChannelMix。

 RWKV项目地址:RWKV-LM/RWKV-v4neo at main · BlinkDL/RWKV-LM · GitHub 

1.5B模型参数地址:BlinkDL/rwkv-4-pile-1b5 · Hugging Face

4.2 增量预训练

RWKV模型基本上都是采用英文进行的训练,这里使用中文的CSDN博客数据和问答数据进行垂直领域增量式训练,同时也会引入部分代码。如果大家也想尝试一下增量式训练,可以参考如下的步骤。

数据预处理。直接参考原作者的数据处理步骤需要按照很多包,这里已经将需要的包都移植到tools里面了。首先使用clean_data.py中的clean_ask_data和clean_blog_data方法可以将从数据仓库中拉取的数据转换成jsonl文件。之后进入tools文件夹下使用如下命令将数据转换成idx和bin文件:

python preprocess_data.py \
            --input ../data/data.txt \
            --output-prefix ../data/blog \
            --vocab ../20B_tokenizer.json \
            --dataset-impl mmap \
            --tokenizer-type HFTokenizer \
            --append-eod

增量预训练。使用前面n-1个词预测第n个词。

x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)

从上面断面可以看出x和y只是错开了一个token,例如:原始句子的tokens为[u1,u2,u3,u4,u5,u6],x为[u1,u2,u3,u4,u5],y为[u2,u3,u4,u5,u6]。

loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

 使用如上的代码计算loss。

python train.py  --load_model "RWKV-4-Pile-1B5-EngChn-test4-20230115.pth" --wandb "" --proj_dir "out" \
--data_file "data/blog_text_document" --data_type "binidx" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 10 \
--micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1

使用上面的命令开始训练。 load_model代表要加载的模型,data_type是数据类型,data_file是训练数据,precision是数据的精度,epoch_save是多久保存一次,ctx_len是窗口长度,每条数据多少个token,micro_bsz是批大小,gpu显存不够的可以设置的小一些。

4.3 SFT微调

构造训练用的指令数据,这里添加了相应的角色指令。

data_list = []

for index, row in pf.iterrows():
    input = row["input"]
    target = row["target"]
    input_instruction = f"Below is an instruction that describes a task. Write a response that appropriately " \
                        f"completes the request.\n\n### Instruction: {input}\n\n### Response:"
    input_tokens = self.tokenizer.tokenizer.encode(input_instruction)
    target_tokens = self.tokenizer.tokenizer.encode(target)
    if len(input_tokens) + len(target_tokens) > self.args.ctx_len:
        continue
    data_list.append((input_tokens, target_tokens))

“### Instruction”和“### Response”问答对话的两个角色。这里写什么后续推理时最好也要用同样的方式构造prompt。

def __getitem__(self, index):

    ctx_len = self.args.ctx_len
    req_len = ctx_len + 1

    i = np.random.randint(0, len(self.data))
    question, answer = self.data[i]
    text = question + answer
    text = text[:req_len]

    text = text + [0] * (req_len - len(text))
    x = torch.tensor(text[:-1], dtype=torch.long)
    y = torch.tensor(text[1:], dtype=torch.long)

    z = [0] * len(question) + [1] * (ctx_len - len(question))
    z = torch.tensor(z, dtype=torch.long)

构造x,y和z,z代表mask,后面计算loss使用,会将前面的instruction给屏蔽掉。

loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
# loss_raw = loss
loss = torch.sum(loss * mask) / sum_mask

4.4 RM和PPO

RM和PPO代码部分已经初步完善,待测试训练。

5 测试

将训练好的模型使用chat.py运行,解码方式采用的是随机采用解码。部分结果如下:

### Instruction: 你好
### Response:Hello! Nice to meet you!

### Instruction: 用java实现冒泡排序
### Response:以下是使用Java实现冒泡排序的代码:

```java
public void bubbleSort(int[] arr) {
    for (int i = 0; i < arr.length - 1; i++) {
        for (int j = 0; j < arr.length - i - 1; j++) {
            if (arr[j] > arr[j+1]) {
                int temp = arr[j];
                arr[j] = arr[j+1];
                arr[j+1] = temp;
            }
        }
    }
}
```

该方法的工作方式是,首先将数组排序,然后比较相邻的两个数字是否相等,如果相等,则交换它们的位置,并重新计算下一个数字的位置

### Instruction: +++
### Response:。

在实际应用中,您可以将此方法替换为任何其他适用的排序算法,例如快速排序、归并排序或基数排序等。

### Instruction: 帮我写一段JavaScript数组去重代码
### Response:```JavaScript
function uniqueArray(arr) {
    // 去重
    let unique = [];
    for (let i = 0; i < arr.length; i++) {
        if (arr[i] === arr[i]) {
            unique.push(arr[i]);
        }
    }
    return unique;
}
```

此方法将一个数组转化为一个新的数组,在此过程中,会判断每个元素是否已经在原数组中,如果是则不会重复元素,否则则会添加到新数组中去。这样就可以将原始数组转化为新数组,避免重复元素。

### Instruction: How about JOYI Sports Company? 翻译成中文
### Response:科基力学(JOYI)是一个怎样的公司?

6 总结

1、rwkv速度已经很快了,单卡可达到450token/s,但是在亿级甚至更大的tokens面前,也显得有点儿力不从心(不,应该叫望洋兴叹更形象)。

2、没有经过强化学习微调的模型,多次随机采用的结果都会不一致,要想得到理想的回答,可能需要多试几次。

3、1.5B受限于语言模型本身的限制,语言能力较弱,用来做大模型的入门学习还是挺不错的。后续我们会将代码和模型参数开源出来供大家一起学习研究。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

大语言模型浅探一 的相关文章

随机推荐

  • C++从0到1(9):指针

    目录 1 指针的基本概念 2 指针变量的定义和使用 3 指针所占内存空间 4 空指针和野指针 5 const修饰指针 6 指针和数组 7 指针和函数 8 指针 数组和函数 1 指针的基本概念 作用 通过指针间接访问内存 内存编号从0开始记录
  • ubuntu22.04安装podman及cockpit并在WEB中管理容器

    目录 前言 一 准备工具 二 安装步骤 1 更新系统到最新版本 2 使用以下命令安装podman 3 使用以下命令安装cockpit及相关插件 三 启动服务 四 登录管理界面 五 使用podman容器管理 1 创建容器 2 管理容器 六 总
  • sqli-labs————Less-33

    Less 33 查看源代码
  • QProcess处理带管道的shell

    代码中需要调用shell 原写法为 QProcess proc new QProcess QString qCmd find name so print0 xargs 0 objdump x grep oE T 0 9 a f A F 4
  • 护网

    在HVV期间 蓝队主要就是通过安全设备看告警信息 后续进行分析研判得出结论及处置建议 在此期间要注意以下内容 内网攻击告警需格外谨慎 可能是进行内网渗透 1 攻击IP是内网IP 攻击行为不定 主要包括 扫描探测行为 爆破行为 命令执行等漏扫
  • 笑脸工具COORD批量转换2000大地到空间坐标

    数据格式txt 1 31 48 14 118687N 119 38 07 130943E 2 32 3 19 06731100008N 119 31 20 422269001200302E 3 31 50 31 89348499992000
  • 变频调速系统c语言编程,基于8098单片机的SPWM变频调速系统

    数字控制的交流调速系统所选用的微处理器 功率器件及产生PWM波的方法是影响交流调速系统性能好坏的直接因素 在介绍了正弦脉宽调制 SPWM 技术的基础上 设计了一种以8098单片机作为控制器 以智能功率模块IPM为开关器件的变频调速系统 通过
  • 小样本学习(Few-shot Learning)综述

    作者丨耿瑞莹 李永彬 黎槟华 单位丨阿里巴巴智能服务事业部小蜜北京团队 分类非常常见 但如果每个类只有几个标注样本 怎么办呢 笔者所在的阿里巴巴小蜜北京团队就面临这个挑战 我们打造了一个智能对话开发平台 Dialog Studio 以赋能第
  • [Flutter]封装了个Toast组件

    Flutter官方插件市场上已经有了很多成熟的Toast组件 如 fluttertoast 等等 使用了一年多的Flutter框架 一时兴起 自己封装了一个简单的Toast组件 注 本人觉得 自动关闭的时候 不宜使用 Navigator p
  • 西门子PLC S7-1200的硬件中断组织块简介

    西门子PLC S7 1200系列是一款中小型西门子PLC 可以在各种自动化项目中进行应用 S7 1200系列设计较为紧凑 经济性较好 而且指令功能较为强大 因此在各种自动化控制解决方案中有较广泛的应用 作为西门子PLC S7 200系列的升
  • [1218]hive之Map Join使用方法

    文章目录 介绍 mapjoin的使用方法 介绍 MAPJION会把小表全部加载到内存中 在map阶段直接拿另外一个表的数据和内存中表数据做匹配 由于在map端是进行了join操作 省去了reduce运行的时间 算是hive中的一种优化 如上
  • 开放原子训练营(第三季)inBuilder低代码开发实验室之探秘

    一 活动介绍 以开放原子训练营为主办方的inBuilder低代码实验室活动现已开启 参与者无论身居计算机业界 偏好低代码开发抑或是普通用户 均可在社区版inBuilder低代码开发平台 一款基于UBML开源项目的广泛适用的发行版 中尝试向导
  • ECMAScript2020 可选链操作符(?.)的应用

    一 前言 const programmer user lin department name 技术部 getSite return 在以前的语法中 想要获得深层次的属性或方法 如果不做前置校验的话 那么就很容易出现这种错误 这可能会导致你整
  • MFC 之 重绘按键Cbutton

    上次我们学习了如何美化对话框的界面 这次我们为上次的对话框添加两个按钮 一个是关闭按钮 另一个是最小化按钮 好 现在我们先看一下效果 是不是很难看 因为我们的对话框美化了 所以我们的按钮也要美化 因为采用贴图的方式来美化 所以 我先给出这两
  • 笔试面试算法经典--矩阵的最短路径和(Java)

    题目 给定一个矩阵m 从左上角开始每次只能向右或者向下走 最后到达右下角的位置 路径上所有的数字累加起来就是路径和 返回所有路径中最小的路径和 例子 给定m如下 1 3 5 9 8 1 3 4 5 0 6 1 8 8 4 0 路径1 3 1
  • 信号去噪 - 基于SVD实现数字信号降噪含Matlab源码

    信号去噪 基于SVD实现数字信号降噪含Matlab源码 介绍 信号处理中的一个重要问题是如何降噪 这在各种应用领域中都有非常重要的作用 奇异值分解 SVD 是一种广泛使用的信号处理技术 可以用于有效地降低信号噪声 本文将介绍如何使用SVD进
  • Elasticsearch 安装及启动【Windows】

    一 下载 Elasticsearch 官网下载地址 https www elastic co cn downloads past releases elasticsearch 选择自己所需版本进行下载 这里以Elasticsearch 8
  • 【操作系统】王道考研 p64-66 IO软件层次结构、IO核心子系统、假脱机技术(SPOOLing技术)

    IO软件层次结构 IO核心子系统 假脱机技术 SPOOLing技术 以下是IO软件层次结构的内容 知识总览 用户层软件 实现了与用户交互的接口 将用户的请求翻译为格式化的IO请求 并通过 系统调用 请求操作系统内核的服务 设备独立性软件 又
  • PyQt5 QTableWidget内容复制功能

    为了更快速的将QTableWidget的内容复制到剪贴板 只需重写这个控件的keyPressEvent event 废话不多说 直接上代码 复制功能 def keyPressEvent self event Ctrl C复制表格内容 if
  • 大语言模型浅探一

    目录 1 前言 2 GPT模型解码 3 InstructGPT 4 基于RWKV微调模型 4 1 RWKV简介 4 2 增量预训练 4 3 SFT微调 4 4 RM和PPO 5 测试 6 总结 1 前言 近来 人工智能异常火热 ChatGP