Drill · 手撕

GQA/MQA 从零实现学习练习

这是一个基于纯 PyTorch 从零实现的 Grouped-Query Attention (GQA) 和 Multi-Query Attention (MQA) 的学习练习。

1. 数学原理

核心是 缩放点积注意力,其输入为 Query QQ, Key KK, Value VV,输出为:

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V

Grouped-Query Attention 中,假设总查询头数为 HH,KV头数为 HkvH_{kv},则每组查询头数 G=H/HkvG = H / H_{kv}。投影过程为:

Q=XWQRB×H×S×dhead,K=XWKRB×Hkv×S×dhead,V=XWVRB×Hkv×S×dheadQ = XW_Q \in \mathbb{R}^{B \times H \times S \times d_{head}}, \quad K = XW_K \in \mathbb{R}^{B \times H_{kv} \times S \times d_{head}}, \quad V = XW_V \in \mathbb{R}^{B \times H_{kv} \times S \times d_{head}}

GQA 核心操作:将 KKVV 沿头维度重复 GG 次,以匹配查询头数:

K=repeat_interleave(K,G,dim=1)RB×H×S×dheadK' = \text{repeat\_interleave}(K, G, \text{dim}=1) \in \mathbb{R}^{B \times H \times S \times d_{head}} V=repeat_interleave(V,G,dim=1)RB×H×S×dheadV' = \text{repeat\_interleave}(V, G, \text{dim}=1) \in \mathbb{R}^{B \times H \times S \times d_{head}}

之后计算标准注意力(带因果掩码):

scores=Q(K)dheadRB×H×S×S\text{scores} = \frac{Q (K')^\top}{\sqrt{d_{head}}} \in \mathbb{R}^{B \times H \times S \times S} causal_maskij={0if ijif i<j\text{causal\_mask}_{ij} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases} weights=softmax(scores+causal_mask)\text{weights} = \text{softmax}(\text{scores} + \text{causal\_mask}) output=weightsVRB×H×S×dhead\text{output} = \text{weights} \cdot V' \in \mathbb{R}^{B \times H \times S \times d_{head}}

最终投影回模型维度:

Out=Concat(head1,...,headH)WO\text{Out} = \text{Concat}(\text{head}_1, ..., \text{head}_H) W_O

特例

2. 直觉与复杂度

直觉: GQA 是 MHA 和 MQA 之间的折中方案。它通过让多个查询头共享一组 KV 头(即一个“组”),在保持较高模型表达能力的同时,显著减少了 KV 缓存的内存占用和计算量。MQA 是其极限情况(所有查询头共享同一组 KV)。

计算复杂度(以 FLOPs 计,忽略激活函数等): 对于序列长度 SS,头维度 dheadd_{head}

  1. Q/K/V 投影: O(Sdmodel(H+2Hkv)dhead)O(S \cdot d_{model} \cdot (H + 2H_{kv}) \cdot d_{head})
  2. 注意力分数计算 QKQK^\top: O(BHS2dhead)O(B \cdot H \cdot S^2 \cdot d_{head})
  3. 注意力加权求和 weightsV\text{weights} \cdot V': O(BHS2dhead)O(B \cdot H \cdot S^2 \cdot d_{head})
  4. 输出投影: O(SHdheaddmodel)O(S \cdot H \cdot d_{head} \cdot d_{model})

主要优势在于推理时的 KV 缓存 大小从 MHA 的 O(HSdhead)O(H \cdot S \cdot d_{head}) 降低为 GQA 的 O(HkvSdhead)O(H_{kv} \cdot S \cdot d_{head})

3. 文件说明

本练习目录包含 EXACTLY 三个文件:

4. 运行命令

  1. 运行演示/自我测试

    python from_scratch.py
    

    这将实例化 GQA、MQA 和 MHA 模块,进行前向传播并检查输出形状和梯度流。

  2. 运行单元测试

    python test_gqa_mqa.py
    

5. 追问分层 / Stratified follow-ups

L1 基础

  1. Grouped-Query Attention 与标准的 Multi-Head Attention 在架构上的主要区别是什么?
  2. 什么是 Multi-Query Attention?它与 Grouped-Query Attention 有什么关系?
  3. 为什么说 Grouped-Query Attention 能降低推理时的内存消耗?具体影响了哪一部分的内存?

L2 中间

  1. 在 GQA 的实现中,n_headsn_kv_heads 参数需要满足什么约束条件?为什么?
  2. 代码中如何将 KV 头的数量“扩展”以匹配查询头的数量?请描述具体操作(repeat_interleave 的作用)。
  3. 除了降低内存,GQA 对模型训练的计算量有直接影响吗?与 MHA 相比是增加、减少还是基本不变?

L3 深入

  1. 在 GQA 中,一个组内的多个查询头共享同一组 KV 投影(WKW_K, WVW_V)。从表示学习的角度看,你认为这种共享可能会带来什么优势或潜在问题?
  2. 如果将 GQA 应用于超长序列(例如上下文长度超过 2162^{16}),除了 KV 缓存外,还有哪些性能或计算瓶颈可能会变得更加突出?
  3. 在本实现的因果注意力掩码中,我们使用了一个固定的上三角布尔矩阵。在分布式训练或序列并行中,这个掩码可能需要如何调整?