Arxiv 2307

2023-11-08

Retentive Network: A Successor to Transformer for Large Language Models

image.png

image.png

本文从序列建模的角度,构建了一种类似Transformer且更加高效的结构。在语言任务上展现出了良好的效率和性能。

  • 利用类似于Transformer的并行组件实现了对于GPU并行能力的利用。
  • 利用循环机制确保了 O ( 1 ) O(1) O(1)级别的存储和计算复杂度。
  • 利用分块循环策略从而执行有效的长序列建模。

实际中,并行编码每个局部的块来加速计算,同时循环编码全局块来节省显存。

序列建模

对于输入的长度为 N N N的文本嵌入序列,由于其本身信息的前后依赖关系和因果关系的需求,所以本文是从循环模型的角度开始构建模型的。

基础的迭代形式:

对于第n次迭代的输入 X n X_n Xn,有

Q n = X n ⋅ W Q , K n = X n ⋅ W K , V n = X n ⋅ W V ∈ R 1 × d Q_n = X_n \cdot W_Q, K_n = X_n \cdot W_K, V_n = X_n \cdot W_V \in \mathbb{R}^{1 \times d} Qn=XnWQ,Kn=XnWK,Vn=XnWVR1×d

将序列建模认为成通过状态 S n S_n Sn,将 V ( n ) V(n) V(n)映射为 O ( n ) O(n) O(n)**的过程。**于是可以得到下式:

S n = A s n − 1 + K n ⊤ V n = A n − 1 K 1 ⊤ V 1 + A n − 2 K 2 ⊤ V 2 + ⋯ + K n ⊤ V n = ∑ m = 1 n A n − m K m ⊤ V m S_n = As_{n-1} + K^{\top}_n V_n = A^{n-1} K^{\top}_1 V_1 + A^{n-2} K^{\top}_2 V_2 + \dots + K^{\top}_n V_n = \sum^{n}_{m=1} A^{n-m} K^{\top}_m V_m Sn=Asn1+KnVn=An1K1V1+An2K2V2++KnVn=m=1nAnmKmVm

这里的 A ∈ R d × d A \in \mathbb{R}^{d \times d} ARd×d描述了各个位置之间的相对关系。

O n = Q n S n = ∑ m = 1 n Q n A n − m K m ⊤ V m , Q n ∈ R 1 × d O_n = Q_n S_n = \sum^{n}_{m=1}Q_n A^{n-m} K^{\top}_m V_m, Q_n \in \mathbb{R}^{1 \times d} On=QnSn=m=1nQnAnmKmVm,QnR1×d

Parallel Retention

通过设置一个特殊的矩阵 A A A,将其对角化处理获得 A = Λ ( γ e i θ ) Λ − 1 A = \Lambda (\gamma e^{i \theta}) \Lambda^{-1} A=Λ(γeiθ)Λ1,这里的两个矩阵 Λ \Lambda Λ由于在公式中紧邻 Q n , K n Q_n, K_n Qn,Kn,所以可以将其合并到二者各自的权重矩阵 W Q , W K W_Q, W_K WQ,WK中一同随着网络去学习,从而上式可以改写:

O n = Q n S n = ∑ m = 1 n Q n ( γ e i θ ) n − m K m ⊤ V m = ∑ m = 1 n [ Q n ( γ e i θ ) n ] [ K m ( γ e i θ ) − m ] ⊤ V m = ∑ m = 1 n γ n − m ( Q n e i n θ ) ( K m e i m θ ) † V m O_n = Q_n S_n = \sum^{n}_{m=1} Q_n (\gamma e^{i \theta})^{n-m} K^{\top}_m V_m = \sum^{n}_{m=1} [Q_n (\gamma e^{i \theta})^{n}] [K_m (\gamma e^{i \theta})^{-m}]^{\top} V_m = \sum^{n}_{m=1} \gamma^{n-m} (Q_n e^{i n \theta}) (K_m e^{i m \theta})^{\dagger} V_m On=QnSn=m=1nQn(γeiθ)nmKmVm=m=1n[Qn(γeiθ)n][Km(γeiθ)m]Vm=m=1nγnm(Qneinθ)(Kmeimθ)Vm

这里将指数与转置融合和获得共轭转置。这里的复数系数实际上可以看做是一种位置嵌入,由于这里的计算反映出了与位置n和m的关联,所以可以认为是一种相对位置关系的表示。

由于这里 Q , K Q, K Q,K索引上的独立性,所以很容易改为并行的基于矩阵运算的结构。将复数矩阵系数极其共轭形式分别合并到 Q , K Q, K Q,K计算过程中,从而可以得到:

Q = ( X W Q ) ⊙ Θ , K = ( X W K ) ⊙ Θ ˉ , V = X W V , Θ n = e i n θ D n m = γ n − m  if  n ≥ m  else  0 Q=(XW_Q) \odot \Theta, K=(XW_K) \odot \bar{\Theta}, V=XW_V, \Theta_{n} = e^{i n \theta} \\ D_{nm}=\gamma^{n-m} \text{ if } n \ge m \text{ else } 0 Q=(XWQ)Θ,K=(XWK)Θˉ,V=XWV,Θn=einθDnm=γnm if nm else 0

从而得到整体模块的计算过程:

R e t e n t i o n ( X ) = ( Q K ⊤ ⊙ D ) V , D ∈ R N × N Retention(X) = (QK^{\top} \odot D) V, D \in \mathbb{R}^{N \times N} Retention(X)=(QKD)V,DRN×N

def ParallelRetention(
    q, # bsz ∗ num_head ∗ len ∗ qk_dim
    k, # bsz ∗ num_head ∗ len ∗ qk_dim
    v, # bsz ∗ num_head ∗ len ∗ v_dim
    decay_mask # num_head ∗ len ∗ len
 ):
     retention = q @ k.transpose(1,2)
     retention = retention ∗ decay_mask
     output = retention @ v
     output = group_norm(output)
     return output

这一形式实际上与Transformer的带mask的计算形式非常类似。

这里由于有 Q K ⊤ QK^\top QK,使用了三种归一化方式来提升数值精度,这些归一化策略实际上都是在GN输入上乘以了一个常数,而由于GN本身的尺度不变性,所以必不会影响GN的输出和反向的梯度。

  • 使用特征维度归一化 Q K ⊤ / d Q K^\top / \sqrt{d} QK/d
  • 设置 D = { D n m ∑ i = 1 n D n i } D = \{\frac{D_{nm}}{\sqrt{\sum^n_{i=1}D_{ni}}}\} D={i=1nDni Dnm}
  • 假定 R = Q K ⊤ ⊙ D R = Q K^{\top} \odot D R=QKD,设置 R = { R n m max ⁡ ( ∣ ∑ i = 1 n R n i ∣ , 1 ) } R = \{ \frac{R_{nm}}{\max(|\sum^{n}_{i=1} R_{ni}|, 1)} \} R={max(i=1nRni,1)Rnm}

Recurrent Retention

但是,如果从序列形式的角度来看,前面的最一开始的建模过程也可以改写成另外一种类似于RNN的形式。
先将状态参数写成迭代形式:

S n = γ S n − 1 + K n ⊤ V n ∈ R d × d S_n = \gamma S_{n-1} + K^{\top}_n V_n \in \mathbb{R}^{d \times d} Sn=γSn1+KnVnRd×d

最终可以得到整体迭代计算过程:

R e t e n t i o n ( X n ) = Q n S n , n ∈ { 1 , … , N } Retention(X_n) = Q_n S_n, n \in \{1,\dots,N\} Retention(Xn)=QnSn,n{1,,N}

def RecurrentRetention(
    q, k, v, # bsz ∗ num_head ∗ len ∗ qkv_dim
    past_kv, # bsz ∗ num_head ∗ qk_dim ∗ v_dim
    decay # num_head ∗ 1 ∗ 1
):
    current_kv = decay ∗ past_kv + k.unsqueeze(1) ∗ v.unsqueeze(2)
    output = torch.sum(q.unsqueeze(1) ∗ current_kv, dim=2)
    output = group_norm(output)
    return output, current_kv

实际上这里的形式与线性Attention先计算KV的思路颇有相通之处。

Chunkwise Recurrent Retention

作者也提出了一种将上述两种形式进行混合的形式,通过将序列划分为连续的块,块内部执行并行形式的处理,块之间执行循环处理,实际的,对于第 i i i个块,处理形式如下:

R e t e n t i o n ( X [ i ] ) = ( Q [ i ] K [ i ] ⊤ ⊙ D ) V [ i ] ⏟ 块内并行 + ( Q [ i ] S i ) ⊙ ξ ⏟ 块间循环 , ξ i j = γ i + 1 Retention(X_{[i]})=\underbrace{(Q_{[i]} K^{\top}_{[i]} \odot D)V_{[i]}}_{块内并行} + \underbrace{(Q_{[i]} S_i) \odot \xi}_{块间循环}, \xi_{ij} = \gamma^{i+1} Retention(X[i])=块内并行 (Q[i]K[i]D)V[i]+块间循环 (Q[i]Si)ξ,ξij=γi+1

def ChunkwiseRetention(
    q, k, v, # bsz ∗ num_head ∗ chunk_size ∗ qkv_dim
    past_kv, # bsz ∗ num_head ∗ qk_dim ∗ v_dim
    decay_mask, # num_head ∗ chunk_size ∗ chunk_size
    chunk_decay, # num_head ∗ 1 ∗ 1
    inner_decay, # num_head ∗ chunk_size
):
    retention = q @ k.transpose(1,2)
    retention = retention ∗ decay_mask
    inner_retention = retention @ v
    cross_retention = (q @ past_kv) ∗ inner_decay
    retention = inner_retention + cross_retention
    output = group_norm(retention)
    current_kv = chunk_decay ∗ past_kv + k.transpose(1,2) @ v
    return output, current_kv

Gated Multi-Scale Retention

引入“头”的概念,为不同的头使用不同的变换权重,同时为不同的头分配不同的 γ \gamma γ

image.png

image.png

由于不同的头引入了不同的参数 γ \gamma γ,所以实际的方差统计量会有所差异。所以这里使用GroupNorm独立归一化不同的头。

实际效果

仅在文本任务上进行了实验。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Arxiv 2307 的相关文章

  • 蓝桥杯备赛Day8——队列

    大家好 我是牛哥带你学代码 本专栏详细介绍了蓝桥杯备赛的指南 特别适合迎战python组的小白选手 专栏以天作为单位 定期更新 将会一直更新 直到所有数据结构相关知识及高阶用法全部囊括 欢迎大家订阅本专栏 队列也属于基础数据结构 队列概念
  • C#串口通信三步走

    第一步 实例化串口通讯类 SerialPort sp new SerialPort 第二步 设置串口信息并打开串口 串口设置 public void SetSP string PortName string BaudRate string
  • 项目开发总结报告(GB8567——88)(转载)

    项目开发总结报告 GB8567 88 1引言1 1编写目的说明编写这份项目开发总结报告的目的 指出预期的阅读范围 1 2背景说明 a 本项目的名称和所开发出来的软件系统的名称 b 此软件的任务提出者 开发者 用户及安装此软件的计算中心 1
  • unity3D 巡逻兵

    游戏要求 创建一个地图和若干巡逻兵 使用动画 每个巡逻兵走一个3 5个边的凸多边型 位置数据是相对地址 即每次确定下一个目标位置 用自己当前位置为原点计算 巡逻兵碰撞到障碍物 则会自动选下一个点为目标 巡逻兵在设定范围内感知到玩家 会自动追
  • UPC思维题--移动

    题目描述 考虑333的立方体 有六个面 每个面有九个正方形 染色方法如下 角上的方格是red 中心是green 其他为blue 初始有一个机器人站在立方体顶面中心 面朝一个blue方格 它将接受到一系列如下指令 L 左转90度 R 右转90

随机推荐

  • gzip 命令

    NAME gzip compression decompression tool using Lempel Ziv coding LZ77 SYNOPSIS gzip cdfhkLlNnqrtVv S suffix file file gu
  • SQL Server连接字符串句法

    Application Name 应用程序名称 应用程序的名称 如果没有被指定的话 它的值为 NET SqlClient Data Provider 数据提供程序 AttachDBFilename extended properties 扩
  • ts总结 之 ts中的类型

    其他内容 ts中的类型 编译选项 webpack打包 类 文章目录 ts是什么 ts增加了什么 TypeScript中的基本类型 字面量 number boolean string any unknown 类型断言 void never o
  • (一)(C语言)实现顺序表(静态分配)的基本操作(初始化、判断是否为空,打印表,插入和删除等)讲解(含相关C语言代码讲解及运行结果)

    一 C语言 实现顺序表 静态分配 的基本操作 初始化 查找 打印表 插入和删除等 讲解 含C语言完整代码讲解及运行结果 文章目录 一 顺序表 二 顺序表相关操作 1 初始化 2 插入 3 删除 4 打印表 5 查找 三 完整代码讲解 C语言
  • 如何在chrome浏览器调试JS代码

    文章目录 资源 Sources 面板 控制台 Console 断点 Breakpoints debugger 命令 暂停并查看 日志记录 总结 参考文献 在编写更复杂的代码前 让我们先来聊聊调试吧 调试是指在一个脚本中找出并修复错误的过程
  • 如何解决merge conflict的方法

    如何解决merge conflict的方法 首先在pull的时候加上rebase 解决conflict 最后push git pull rebase origin remote if there is conflict clean it a
  • 3月份的字节跳动面经

    本人2本毕业 目前工作四年 一直是Android 做的都是些二线公司 没做过一线 四年跳了三家公司 在家休息了几个月 今年3月份开始面试 由于跳槽过多而且已经是现在Android市场的原因 内推的我的字节哥们儿 推了不知道多少个部门 才把我
  • Python轻松搞定免费语音合成,利用百度AI为短视频配音

    1 创建百度AI账号 1 1 点击进入百度AI 左上角 开放能力 gt 语音合成 gt 立即使用 如果是试用 可以直接点击在线语音合成 不过语音不能下载 要下载还得用下面方式 调用百度AI的API 1 2 然后登录百度云账户 进入管理中心
  • qemu-virtio基本原理

    virtio是相当复杂的 网上写virtio原理解析的文章也不少 这里我想通过最简练易懂的方式来解释一下virtio的原理 一方面也完善一下自己对virtio的理解 文中含有大量个人理解 如果发现有错误的地方欢迎与我交流 virtio整体流
  • 掌财社:掌握CCI指标捕捉爆发牛股

    什么是CCI指标 CCI指标又叫顺势指标 其英文全名为 Commodity Channel Index 是由美国股市分析家唐纳德R 兰伯特 Donald r Lambert 于20世纪80年代所创 是指导股市投资的一种中短线指标 CCI指标
  • linuxas3+apache2+mysql5+php5+discuz5+zend3.3+supesite.docx

    最近领导要装个supesite discuz 方便公司内部用 对于公司内部用来说是大了点 感觉有些大财小用了 但如果考虑以后做成门户 还是很值得的 于是就动手配置 出于linux系统的稳定与安全 选择linux作为平台 本配置所用系统与软件
  • 认识glBegin

    初学OpenGL的时候总有很多函数或者函数的参数不会用 不明白其作用 今天主要总结一下关于glBegin 中的参数用法 一 glBegin glBegin表示一组用于定义一个或者多个图元的顶点的开始 此函数通常与glEnd函数联用 在glB
  • 深度学习中常见的loss函数汇总

    损失函数 Loss Function 分为经验风险损失函数和结构风险损失函数 经验风险损失函数反映的是预测结果和实际结果之间的差别 结构风险损失函数则是经验风险损失函数加上正则项 L1或L2 深度学习中的损失函数被用于模型参数的估计 通常作
  • SpringBoot (6)- 自定义Starter

    SpringBoot 6 自定义Starter 1 简介 1 1启动器starter命名 1 2什么是SpringBoot starter机制 1 3为什么要自定义starter 1 4什么时候需要创建自定义starter 1 5自动加载核
  • Matlab——图像缩放(插值法)

    实验内容 用双线性内插法实现位深度为8的灰度图像的缩放 思路 输入原图像以及缩放后图像的像素要求 宽度 高度 处理后输出新图像 我是用matlab来实现scale input img scale size 函数的 输入图像路径以及要求实现的
  • 排序函数c++函数模板实现

    冒泡排序 插入排序 选择排序 归并排序 快排 堆排序 冒泡排序 插入排序 选择排序 这种简单的时间复杂度是O n2 归并排序 快排 堆排序时间复杂度O nlogn include
  • 最短路的应用(G - Easy Glide );

    题意 就是到达一个点可以进行加速 加速时间3s给出了他的加速后的速度 计算从起点终点的最短距离 思路 首先把每个点都加上2这可以进行点的设置 起点为1 第一个加速点就是2 下一个就是3 设置点后进行计算 一个点都另一个点的全部时间进行计算
  • 如何评估加解密代码?

    在不深入研究代码的具体实现的情况下 如何评估加解密代码的有效性 强度 背景 迫于无赖 项目组只能安排1位新手设计一系列的加密算法 用于对本地文件和二进制代码的加密 幸运的是 对加密强度并没有过高的要求 但也希望能够有效的评估代码 并实现自动
  • Enter passphrase for key提示

    Enter passphrase for key提示 找了很多博客 都没有解决 后来通过删除密码解决了 删除方法 ssh keygen p enpty new passphrase后按回车
  • Arxiv 2307

    Retentive Network A Successor to Transformer for Large Language Models 论文 https arxiv org abs 2307 08621 代码 https github