分享

麻省理工(MIT) | 提出跨层Attention,减少Transformer大模型键值(KV)缓存,加快LLM推理!

 netouch 2024-05-28 发布于北京
更多干货,第一时间送达

引言

键值 (KV) 缓存能够显著提升Transformer大模型的解码速度。但是当面对长序列的时候,键值 (KV) 缓存需要大量的内存资源。当前减少键值 (KV) 缓存的两个主要方法分别为:Multi-Query Attention(MQA)和Grouped-Query Attention (GQA)。这两种方法主要是修改了Attention块,使得多头请求头共享单个KV头,从而大大减少了不同KV的数量。

而本文作者受前人启发,提出了一种新的Attention设计方法:跨层注意力(Cross-Layer Attention, CLA),即通过在不同层之间共享KV头,减少了KV缓存的大小。对比其它方法,在相同准确性的情况下,可以将KV缓存的大小缩小2倍!图片https:///pdf/2405.12981

背景介绍

大模型实际部署应用的时候,键值 (KV) 缓存的内存可能会成为其应用的瓶颈。由于 KV 缓存的大小与序列长度以及批处理大小存在比例关系,因此在长序列长度上操作时,KV 缓存内存的大小可能会限制批处理大小,并且当面对设备内存不足时,也有会采用成本较高的方法,例如:offloading策略。但为了提升推理速度减少冗余计算,人们也非常希望能够长期保存键值 (KV) 。然而,KV 缓存的大小直接决定了存储和检索此类持久缓存的成本

随着LLM 新应用的出现,需要更长的序列长度,KV 缓存的内存占用的挑战越来越受到研究人员的关注。并且当前研究人员也提出了多种减少KV缓存内存占用的方法,例如:采用低精度来缓存KV、驱逐不重要的KV缓存条目以及跨请求头共享KV等。

与之前方法不同,本文提出了一种新的方法:跨层Attention(Cross-Layer Attention,CLA),简单来说,该方法主要是通过减少KV缓存中唯一层的数量来减小KV缓存的大小

MQA and GQA

在介绍跨层Attention之前,先带大家简单的了解一下多头Attention(MHA)多请求Attention(MQA)分组请求Attention(GQA)

最初的Transformer架构主要是用多头Attention(MHA),其中每个请求头主要关注不同KV头生成的KV。在MHA中,每个KV头的KV激活必须单独存储在KV缓存中,这对于每个token来说,它的存储开销为:,其中表示每个请求头数量,表示每个头的嵌入维度。

为了减少减少 Transformer 解码期间存储和访问 KV 缓存相关的开销,有研究人员提出了多请求Attention,并逐渐的将其推广至分组请求Attention。分组查询Attention通过将每个Attention层的请求头编制成组来修改Transformer架构,其中每组请求头共享单个KV头。由于 KV 缓存的大小仅随着不同KV头的数量而变化,而不是请求头的数量,因此 GQA 将 KV 缓存的存储开销降低到 ,其中 表示 GQA的组数,且很明显:。另外,MQA 可以看作是 GQA 的特例,其中

研究发现,与具有相同头尺寸的 MHA 架构相比,MQA 和 GQA 能够显着减少 KV 缓存大小和Transformer解码延迟,但精度会有略微下降。所以在模型设计过程中,需要平衡Attention架构的准确性和KV缓存大小之间的关系

跨层Attention

受MQA 和 GQA的启发,本文作者提出了跨层共享KV头,并将这种Attention架构称为:跨层Attention(CLA),如下图所示:图片可以看到在CLA中,只有模型中的一部分层会计算KV投影,而没有计算KV投影的层的Attention块会重新使用之前层的KV激活值。这意味着只有计算了KV投影的那些层会使用KV缓存,从而与传统架构相比,后者在每一层都应用了独立的KV投影,对比之下,CLA能够减少对内存的使用。

除此之外,CLA可以与MQA、GQA、MHA 进行组合。此外,与 GQA 允许不同的 访问一系列不同的Attention配置一样,CLA 可以改变共享每个 KV 投影输出的层数,作者将其称为共享因子。通过共享因子来引用 CLA 的不同配置,从而产生了 CLA2,它在一对相邻层之间共享每个 KV 投影,CLA3,它在一组 3 层之间共享每个 KV 投影,依此类推。如下图所示:图片另外,作者还在系统工程的角度总结了 CLA 对相关关键指标的影响:

  • KV 缓存内存:CLA 显着减少了 KV 缓存内存占用量,减少的倍数等于共享因子
  • 训练内存占用:CLA 减少了训练期间具体化的中间 KV 激活张量的内存占用,尽管对于 GQA 和 MQA 模型,此类 KV 张量与模型的隐藏状态和 MLP 激活相比通常很小。
  • 模型并行性:CLA 与标准完全兼容 张量并行技术,可用于跨多个加速器分片模型权重。
  • 参数和FLOP:由于CLA 减少了模型中KV投影块的总数,因此CLA 略微减少了模型中参数的数量以及前向或后向传递期间所需的FLOP 数量。
  • 解码延迟:在完整的LLM 服务堆栈的背景下,CLA 可以实现比其他方式更大的批量大小和更长的KV 缓存持久时间,可以减少推理延迟。
  • 核心Attention延迟:与MQA和GQA不同,CLA对每个解码步骤中Attention机制消耗的内存带宽没有直接影响。

实验结果

下图展示了CLA在准确性/内存权衡上的影响。可以看到在1B和3B参数规模的模型上,CLA结合MQA相比于单纯的MQA基线,KV缓存所需内存缩小了2倍,同时仅造成微小的困惑度(perplexity)增加。同时作者也展示了不同CLA共享模式的实验结果,可以发现CLA2在性能上一致优于其他配置。图片图片

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多