Proximal Policy Optimization(PPO)和文本生成

2023-11-20

ChatGPT的RLHF步使用了强化学习PPO算法。
PPO是一种策略梯度方法,其交替地进行与环境交互采样数据和使用随机梯度上升优化“代理”目标函数。标准策略梯度方法对每个数据样本执行一次梯度更新,而PPO可以采样一批数据后,对模型进行多次梯度更新。

策略梯度

策略梯度(Policy Gradient)方法梯度的计算如下:
E ( a t , s t ) ∈ π θ [ A ^ t ∇ θ log ⁡ π θ ( a t ∣ s t ) ] \mathbb E_{(a_t,s_t) \in \pi_\theta}[\hat A_t \nabla_ \theta \log \pi_\theta(a_t | s_t)] E(at,st)πθ[A^tθlogπθ(atst)] A ^ t \hat A_t A^t是优势函数(advantage function) A t A_t At的估计。
A t = Q ( s t , a t ) − V ( s t ) A_t=Q(s_t, a_t)-V(s_t) At=Q(st,at)V(st)优势函数计算的是,在该状态下采取这个行动的奖励与在该状态下的平均奖励的差值。
上面的导数可以通过对下面的目标求导获得:
L P G ( θ ) = E ( a t , s t ) ∈ π θ [ A ^ t log ⁡ π θ ( a t ∣ s t ) ] L^{PG}(\theta)=\mathbb E_{(a_t,s_t) \in \pi_\theta}[\hat A_t \log \pi_\theta(a_t | s_t)] LPG(θ)=E(at,st)πθ[A^tlogπθ(atst)]

PPO(Proximal Policy Optimization)

PPO有两个形式,其中一种形式PPO_CLIP的优化目标函数是:
L C L I P ( θ ) = E ( a t , s t ) ∈ π θ o l d [ min ⁡ ( r t ( θ ) A ^ t , c l i p ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ t ) ] (1) L^{CLIP}(\theta)=\mathbb E_{(a_t,s_t) \in \pi_{\theta_{old}}}[\min(r_t(\theta)\hat A_t, clip(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat A_t)] \tag{1} LCLIP(θ)=E(at,st)πθold[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)](1)其中 r t ( θ ) = π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) r_t(\theta)=\frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{old}}(a_t | s_t)} rt(θ)=πθold(atst)πθ(atst)
PPO算法中的advantage用下面的公式估计:
A ^ t = δ t + ( γ λ ) δ t + 1 + ⋯ + ( γ λ ) T − t + 1 δ T − 1 \hat A_t = \delta_t + (\gamma \lambda)\delta_{t+1} + \cdots+ (\gamma \lambda)^{T-t+1}\delta_{T-1} A^t=δt+(γλ)δt+1++(γλ)Tt+1δT1其中 δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) δt=rt+γV(st+1)V(st) r t r_t rt是reward, V ( s t ) V(s_t) V(st)是状态价值。
通常情况下,我们用一个网络同时学习策略和价值函数,这样策略和价值函数能共享参数,那么就需要结合策略代理和价值函数误差项的损失函数。再加上熵奖励(entropy bonus)来确保足够的探索,优化目标变为:
L C L I P + V F + S ( θ ) = E ( a t , s t ) ∈ π θ [ L t C L I P ( θ ) − c 1 L t V F ( θ ) + c 2 S [ π θ ] ( s t ) ] L^{CLIP+VF+S}(\theta)=\mathbb E_{(a_t,s_t) \in \pi_\theta}[L_t^{CLIP}(\theta) - c_1 L_t^{VF}(\theta) + c_2 S[\pi_\theta](s_t)] LCLIP+VF+S(θ)=E(at,st)πθ[LtCLIP(θ)c1LtVF(θ)+c2S[πθ](st)]其中 L t V F ( θ ) = ( V θ ( s t ) − V t t a r g ) 2 L_t^{VF}(\theta)=(V_\theta(s_t)-V_t^{targ})^2 LtVF(θ)=(Vθ(st)Vttarg)2是价值函数的误差项, S [ π θ ] S[\pi_\theta] S[πθ]是entropy bonus。

完整的PPO算法如下。可以看到每个循环中,先采样N个T个时间步的数据,然后用采样的数据进行K个epoch的优化。
在这里插入图片描述

文本生成

在文本生成的情况下,给一个prompt,生成完整的response,是一个episode。动作空间是vocabulary。每生成一个词是一个时间步。

公式(1)需要advantage的估计,为了计算advantage,我们需要定义奖励(reward) r r r和估计状态价值函数 V ( s ) V(s) V(s)

用于强化学习的reward计算如下:
R ( x , y ) = r ( x , y ) − β log ⁡ π ( y ∣ x ) ρ ( y ∣ x ) R(x,y) = r(x,y) - \beta\log\frac{\pi(y|x)}{\rho(y|x)} R(x,y)=r(x,y)βlogρ(yx)π(yx)x是prompt,y是response, r ( x , y ) r(x,y) r(x,y)是reward model的输出,也就是下面代码中的score。注意这里reward model的输出称之为score,送入强化学习部分的才称为reward。 π ( y ∣ x ) \pi(y|x) π(yx)是要学习的生成模型, ρ ( y ∣ x ) \rho(y|x) ρ(yx)是参数固定的原始生成模型。

在trl库中reward的计算如下。只将reward model的score添加到最后一个token的reward上,其他token的reward来自当前模型和 原始生成模型之间KL散度。这么做是为了减轻奖励模型的过度优化问题。

   def compute_rewards(
       self,
       scores: torch.FloatTensor,
       logprobs: torch.FloatTensor,
       ref_logprobs: torch.FloatTensor,
       masks: torch.LongTensor,
   ):
       """
       Compute per token rewards from scores and KL-penalty.

       Args:
           scores (`torch.FloatTensor`):
               Scores from the reward model, shape (`batch_size`)
           logprobs (`torch.FloatTensor`):
               Log probabilities of the model, shape (`batch_size`, `response_length`)
           ref_logprobs (`torch.FloatTensor`):
               Log probabilities of the reference model, shape (`batch_size`, `response_length`)
       """
       rewards, non_score_rewards = [], []
       for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
           # compute KL penalty (from difference in logprobs)
           kl = self._kl_penalty(logprob, ref_logprob)
           non_score_reward = -self.kl_ctl.value * kl
           non_score_rewards.append(non_score_reward)
           reward = non_score_reward.clone()
           last_non_masked_index = mask.nonzero()[-1]

           # reward is preference model score + KL penalty
           reward[last_non_masked_index] += score
           rewards.append(reward)
       return torch.stack(rewards), torch.stack(non_score_rewards)

在trl库中用一个网络AutoModelForCausalLMWithValueHead学习策略 π θ ( s ) \pi_\theta(s) πθ(s)和状态价值函数 V ( s ) V(s) V(s)。AutoModelForCausalLMWithValueHead在普通AutoModelForCausalLM模型上了一个线性层nn.Linear(hidden_size, 1),用于估计状态价值函数 V ( s ) V(s) V(s)
普通AutoModelForCausalLM模型估计token概率即可作为策略 π θ ( s ) \pi_\theta(s) πθ(s)

在trl库中advantage的计算如下。compute_advantage函数返回的returns作为 V t t a r g V_t^{targ} Vttarg用于学习状态价值函数 V ( s ) V(s) V(s)

    def compute_advantages(
        self: torch.FloatTensor,
        values: torch.FloatTensor, # AutoModelForCausalLMWithValueHead输出的状态价值估计V
        rewards: torch.FloatTensor, # compute_rewards函数计算得到的rewards
        mask: torch.FloatTensor,
    ):
        lastgaelam = 0
        advantages_reversed = []
        gen_len = rewards.shape[-1]

        values = values * mask
        rewards = rewards * mask

        for t in reversed(range(gen_len)):
            nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
            delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
            lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)

        returns = advantages + values
        advantages = masked_whiten(advantages, mask)
        advantages = advantages.detach()
        return values, advantages, returns

完整的loss计算可以参看ppo_trainer.py中的loss函数,因为函数较长,这里不粘贴了。
因为在reward的计算中可以考虑了entropy bonus,所以在更新网络的时候,没有再用 S [ π θ ] S[\pi_\theta] S[πθ]的loss。

    def loss(
        self,
        old_logprobs: torch.FloatTensor,
        values: torch.FloatTensor,
        logits: torch.FloatTensor,
        vpreds: torch.FloatTensor,
        logprobs: torch.FloatTensor,
        mask: torch.LongTensor,
        advantages: torch.FloatTensor,
        returns: torch.FloatTensor,
    )

Reference

《Proximal Policy Optimization Algorithms》
《Fine-Tuning Language Models from Human Preferences》
《Training language models to follow instructions with human feedback》
https://github.com/huggingface/trl

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Proximal Policy Optimization(PPO)和文本生成 的相关文章

  • 小白入门C#初探Web简易页面显示信息小案例

    1 创建新项目 选择ASP NET Core Web应用 模型 视图 控制器 然后点击下一步 然后在项目名称里面填写CSharpDemo 点击下一步 直至创建即可 目录结构 Connected Services 是Visual Studio
  • CentOS 8安装并配置NFS服务

    先决条件 我们假设您有一台运行CentOS 8的服务器 我们将在该服务器上设置NFS服务器和其他充当NFS客户端的计算机 服务器和客户端应该能够通过专用网络相互通信 如果您的托管服务提供商不提供私有IP地址 则可以使用公共IP地址并配置服务
  • Python学习第十二天——logging

    1 日志级别 CRITICAL 50 FATAL CRITICAL ERROR 40 WARNING 30 WARN WARNING INFO 20 DEBUG 10 NOTSET 0 不设置 日志的设置是自下而上的 如果等级为ERROR

随机推荐

  • vs2008常用操作汇总

    1 OpenCV2 1环境配置 1 Tools gt Options gt Projects and Solutions gt VC Drectories Show directories for选择include files 加入目录 D
  • Android-给RecyclerView添加分隔线

    RecyclerView和ListView不同 是不自带分隔线的 如此 在讲为Item加入分割线本质的前 先来介绍 认识一下ChildView 也就是平时我们用到的ListView RecyclerView中的getChildAt int
  • 【解决】Win 10+Visual Studio community 2017,许可证到期,不能登录问题

    Win 10 Visual Studio community 2017 许可证到期 不能登录问题 试了很多种方式 会出现很多问题 最终尝试成功 1 在打开vs之后 第一时间点击帮助 发送反馈 报告问题 2 在弹出的对话框中点击发现新的许可证
  • UCI提供给shell和lua使用的配置接口

    转自 http m blog csdn net article details id 47989493 1 uci提供给shell使用的配置借口有两套 1 config get用来读取一个config值 命令格式如下 config getv
  • 【Python爬虫】Python 爬虫的学习和案例,一篇文章带你了解爬虫的密码

    爬虫基础 我们可以把互联网比作一张大网 而爬虫 即网络爬虫 便是在网上爬行的蜘蛛 把网的节点比作一个个网页 爬虫爬到这就相当于访问了该页面 获取了其信息 可以把节点的连线比作网页与网页之间的链接关系 这样蜘蛛通过一个节点后 可以顺着节点连线
  • Linux系统下载并安装Redis

    Linux上下载并安装Redis 下面是下载安装过程 如果只是想快速安装 那就直接看图中命令 全部下载安装命令都在图中 1 在home目录下下载Redis安装包 下载Redis安装包命令 wget http download redis i
  • html5取消了哪些标签,html5删除的标签有哪些

    html5删除的标签 basefont big center font s strike tt u frame noframes frameset bgsound blink marquee applet isindex listing等
  • 使用函数打印无符号整形的二进制表达式

    目录 目录 目录 1 问题描述 输入两个非负整数a b 并输出这两个整数的二进制形式以及这两个数的反码执行逻辑或和逻辑与操作后的二进制形式 2 三个函数作用的详细解释 2 1第一个函数 2 2第二个函数 2 3第三个函数 3 结语 请多多指
  • 一招解决Tomcat闪退

    tomcat的运行需要JRE 一般启动闪退都是因为找不到JRE 也就是说环境安装JDK时环境变量没有配置好 首先检查JDK配置是否正确 确认JDK配置好了以后开始检查错误 在Tomcat的安装目录下的bin文件夹里面找到startup ba
  • 在Windows11系统中安装Anaconda

    1 在电脑自带的应用商店下载或者去Anaconda官网下载 Anaconda官网 2 打开Anaconda官网 如下图 3 点击Download 选择自己电脑对应的版本 这里选择Windows 4 将下载的Anaconda 放在电脑的某个地
  • Python手册(Scientific Computing)--numpy

    文章目录 NumPy的ndarray 创建ndarray ndarrary索引和切片 ndarrary属性 ndarrary方法 numpy函数 NumPy的random随机库 生成n维随机数组 Numba NumPy Numerical
  • 大话数据结构读书笔记 1---线性表

    大话数据结构读书笔记 编程基础 数据结构 算法 1 线性表 顺序储存结构的结构代码 define MAXSIZE 20 储存空间的起始分配量 typedef int ElemType ElemType类型根据实际类型而定 这里假设是int
  • framebuffer驱动详解

    1 framebuffer介绍 1 什么是framebuffer 1 裸机中如何操作LCD 2 OS下操作LCD的难点 显存放在内核中还是应用中是个问题 之前讲的应用和内核之间传递数据用的是copy from usr copy to usr
  • dell灵越笔记本后盖怎么拆_戴尔灵越5584笔记本按键拆卸、安装教程

    最近一直用笔记本 用着用着我发现U键变得迟钝 不灵敏 虽然这是小问题 但对于我的打字造成较大影响 去维修站修又有点浪费 所以就萌生了自己修的念头 发现网上笔记本键帽拆卸的教程不好用 便决定写篇教程 方便他人 第一步 关机 在拆卸笔记本任何部
  • 2023蓝桥杯python 组试题A:2023

    题目 请求出在 12345678 至 98765432 中 有多少个数中完全不包含 2023 完全不包含 2023 是指无论将这个数的哪些数位移除都不能得到 2023 例如 20322175 33220022 都完全不包含 2023 而 2
  • 记录一次生产环境MySQL死锁以及解决思路

    一 背景 1 业务背景 这里因为涉及到公司的业务问题不进行深入讨论 下面换成通用的一些业务场景就是举例 2 技术背景 众所周知 所谓锁的产生本质上是想解决资源竞争问题 在MySQL的前提下 MySQL为了解决事务并发独写的问题 在进行ins
  • Ubuntu查看cuda版本号 cudnn版本号

    cuda版本号 nvcc V nvcc version 若遇到 nvcc command not found 添加环境变量 打开 bashrc 添加环境变量如下 export LD LIBRARY PATH usr local cuda l
  • 常用系统命令

    重定向 cat aa txt gt bbb txt 将输出定向到bbb txt cat aaa txt gt gt bbb txt 输出并追加 查看进程 ps ps ef 显示所有进程 例 ps ef grep mysql 管道符 kill
  • webpack从此不再是我们的痛点 — 核心基础

    webpack一直是前端工程师的痛点 因为他的复杂 分散 loader plugin这些第三方 让我们的学习成本陡然上升 使我们一直对他的配置模棱两可 今天带大家彻底明白他如何配置 摆脱困扰我们很久的痛点 本篇主要是webpack基础配置详
  • Proximal Policy Optimization(PPO)和文本生成

    ChatGPT的RLHF步使用了强化学习PPO算法 PPO是一种策略梯度方法 其交替地进行与环境交互采样数据和使用随机梯度上升优化 代理 目标函数 标准策略梯度方法对每个数据样本执行一次梯度更新 而PPO可以采样一批数据后 对模型进行多次梯