deepseek最近比较出圈,本人也一直关注deepseek发布的一些技术报告。在模型训练、推理性能和计算成本上一直能给大家惊喜。读了deepseek的技术报告,我个人有两个比较强的感受。第一:deepseek在模型细节上扣的比较极致,魔改了一些模型框架(比如模型优化方面: MLA, GRPO,MTP);第二:工程能力上确实比较强,对于主流的一些框架和技术点能敏捷地整合到自己的系统内(比如:在Infra方面,能看到deepspeed, Megatron,DistServer、vLLM等框架的核心技术点)。后面准备用几篇笔记学习和整理下deepseek的技术。
本文重点讲解下MLA(Multi-Head Latent Attention)
注:我在学习的过程中,通常会有些知识盲点,或掌握不精确的地方,我会递归学习一些扩展的脉络。本文也是沿着一些必要的背景知识,逐层解读下MLH的提出背景、要解决的问题和最终的效果。
MLA主要通过优化KV-cache来减少显存占用,从而提升推理性能。直接抛出这个结论可能不太好理解。首先我们来看下,对于生成模型,一个完整的推理阶段是什么样的,推理性能上有什么问题。
1. LLM模型推理过程
LLM推理分为两个阶段:prefill阶段和 decode阶段
- prefill阶段:是模型对全部的Prompt tokens一次性并行计算,最终会生成第一个输出token
- decode阶段:每次生成一个token,直到生成EOS(end-of-sequence)token,产出最终的response
在推理过程中,由于模型堆叠了多层transformer,所以核心的计算消耗在Transformer内部,包括MHA,FFN等操作,其中MHA要计算Q,K ,V 矩阵,来做多头注意力的计算。
在LLM生成过程中,是一个基于前向序token列预测下一个token的过程,序列中的token(无论是prefill阶段,还是decode阶段)只与它前面的token交互来计算attention,我们也称这种Attention为Causal Attention。矩阵计算上通过一个下三角的Causal Attention Mask来实现token交互只感知前向序列。如图1所示,展现的Transformer内部的细节:

图1、Transformer 内部的计算细节

接下来我们再详细看看对于一个典型的推理架构有几级访存速率,模型推理过程中又有哪些数据要做存储下来,应该如何分配存储。
2. LLM推理阶段显存使用情况
2.1 访存速率分级
为了直观理解访存的速率,我们以一个分布式推理架构为例。
比如2台机器,每台机器有8张A100, 那么在这样一个系统内,卡内,单机卡间,机器之间的数据访问效率如图3所示。
注:我们的例子中,只描述了一种访存介质HBM (也就是我们常说的显卡的显存),我们知道通常GPU的存储介质除了显存,还有SRAM和DRAM。SRAM也被成为片上存储,是GPU计算单元上即时访问更快的存储,所有的计算都要先调度到片上存储SRAM才能做计算,一般只有几十M大小,带宽可达到20T/s左右,SRAM是跟计算单元强绑定的,推理阶段一般不考虑将SRAM作为存储单元使用。而DRAM是我们常说的CPU的内存,由于访问速率较慢,推理阶段一般也不考虑使用。所以我们讨论的推理存储介质,一般就指的是HBM(显存)

图3、分布式推理架构卡内、卡间、跨机存储和带宽
由上图的访存带宽可知,卡内的带宽是单机卡间的带宽的3倍,是跨机带宽的20倍,所以我们对于存储的数据应该优先放到卡内,其次单机内,最后可能才考虑跨机存储。
接下来我们再看下,推理过程中,有哪些数据要存储到显存上。
2.2. 模型推理阶段显存分配
下面我画了一张图,如图4所示,推理阶段主要有三部分数据会放到显存里。


图4. 推理阶段显存占用
由上述可知,推理阶段主要存储消耗是两部分: 模型参数和 KV Cache。那么模型参数占多少,KV Cache又占多少?
首先我们先以一个token的计算过程为例,看下一个token计算要存储多少KV?为了方便理解,我们以Qwen-72B模型为例,模型配置详见: Qwen-72B-Chat。



图5、单token kv缓存数据


这里还要多啰嗦几句,推理阶段根据离线、在线的业务场景,到底组多大的Batch,其实是一个Balance的过程,Batch选择比较小,虽然并发度不高,但可能单卡就能装下完整模型参数和KV Cache,这时候卡内带宽会比较高,性能可能依然出众,可以考虑适当增加Batch把单卡显存用满,进一步提升性能。但当Batch再增大,超出单卡范围、甚至超出单机范围,此时并发会比较大,但跨卡或跨机访存性能会降低,导致访存成为瓶颈,GPU计算资源使用效率不高,可能实际导致整体推理性能不高。所以单从推理Batch设置角度来看,要实测找到性能最佳的平衡点。
当前LLM都比较大,而访存的容量和访存速率有分级的特点。所以推理过程中,减少跨卡、卡机的访存读写是优化推理性能的一个有效路径。一方面单次读写的数据越少,整体速度会越快;另一方面整体显存占用越少,就能尽量把数据放到单卡或单机上,能使用更高的带宽读写数据。
本文要学习的MLA就是通过减少KV Cache来压缩显存占用,从而优化推理速度。我们在展开了解MLA之前,先看看当前有哪些优化KV Cache的方法。
3. 减小KV cache的方法
3.1. KV Cache 优化方法汇总
业界针对KV Cache的优化,衍生出很多方法,这里我根据自己的积累,稍微总结下,只简单描述优化的思路,不过多展开。
方法主要有四类:
- 共享KV:多个Head共享使用1组KV,将原来每个Head一个KV,变成1组Head一个KV,来压缩KV的存储。代表方法:GQA,MQA等
- 窗口KV:针对长序列控制一个计算KV的窗口,KV cache只保存窗口内的结果(窗口长度远小于序列长度),超出窗口的KV会被丢弃,通过这种方法能减少KV的存储,当然也会损失一定的长文推理效果。代表方法:Longformer等
- 量化压缩:基于量化的方法,通过更低的Bit位来保存KV,将单KV结果进一步压缩,代表方法:INT8等
- 计算优化:通过优化计算过程,减少访存换入换出的次数,让更多计算在片上存储SRAM进行,以提升推理性能,代表方法:flashAttention等
本文要讨论的MLA是共享KV分支下的一种优化方法,下面我们先展开看看共享KV方法有哪些,这些方法也是MLA拿来对比的方法。
3.2. 共享KV优化显存方法
共享KV主要有两种方法,MQA和GQA都是Google提出的,详见: MQA(2019),GQA(2023),如图6所示。

图6、KV Cache优化方法- 共享KV方法


图7、MHA,MQA,GQA KVcache对比图

4. MLA
4.1. MLA KV优化速览
我们先走马观花看看MLA的计算方式和与MQA、GQA的压缩KV的效果对比。
首先我们看看MLA计算Attention的完整公式,如下图8所示

图8、MLA Attention计算公式









图9、MLA,MHA,GQA,MQA对比图

图10、MLA与其他方法压缩性能和效果对比
注:图中能力的比较上,描述比MHA更强我比较存疑,并没看到有消融的实验对比,也不太好从原理上解释。
5. 总结
本文试图通过引入更多基础知识和辅助信息,来深入理解MLA。内容比较长,可能觉得比较啰嗦。这是本人在理解MLA过程递归总结的一些扩展信息,最终整理了一个系统的脉络,发出来供大家参考。
6.参考文献
- deepseek-v1:https://arxiv.org/pdf/2401.02954
- deepseek-v2:https://arxiv.org/pdf/2405.04434
- deepseek-v3:https://arxiv.org/pdf/2412.19437
- 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA – 科学空间|Scientific Spaces
- https://zhuanlan.zhihu.com/p/659770503
- GQA:https://arxiv.org/pdf/2305.13245
- MQA:https://arxiv.org/pdf/1911.02150
个人水平有限,欢迎指正~
— 文章来源 知乎 链接:https://zhuanlan.zhihu.com/p/16730036197
Comments NOTHING