纯 PyTorch 手写 Transformer 自回归解码 + KV Cache 实现,无任何外部推理框架依赖。
1. 数学原理 / Math
Scaled Dot-Product Causal Attention(带因果遮罩的缩放点积注意力):
Attention(Q,K,V)=softmax(dkQK⊤+M)V
其中因果遮罩 M 为上三角(对角线除外):
Mij={0−∞if i≥jif i<j
KV Cache 核心操作: Prefill 阶段处理全部 prompt 建立缓存,Decode 阶段每步仅输入一个 token,将新投影的 Kcur,Vcur 与缓存拼接:
Knew=concat(Kcache,Kcur),Vnew=concat(Vcache,Vcur)
注意分数维度为 (B,H,Tcur,Tkv),其中 Tkv=Tcache+Tcur。因果遮罩取全长矩阵的最后 Tcur 行:
M^=M[Tkv−Tcur:Tkv,:]
Pre-Norm Decoder Block:
x′=x+CausalMHA(LN(x)),x^=x′+FFN(LN(x′))
FFN: FFN(x)=W2⋅GELU(W1x),其中 W1∈Rdff×dmodel,W2∈Rdmodel×dff。
Position Encoding: 可学习的绝对位置嵌入,解码时通过 position_offset 保持位置索引连续:
x=TokEmb(t)+PosEmb(offset+t)
采样策略: temperature scaling + top-k filtering
pi=∑jexp(zj/τ)exp(zi/τ),zi′={zi/τ−∞if zi/τ∈top-kotherwise
当 τ=0 时退化为 greedy(argmax)。
2. 直觉与复杂度 / Intuition & Complexity
无 Cache vs 有 Cache 解码对比:
|
无 Cache |
有 KV Cache |
| 第 t 步 Attention 计算 |
O(t⋅d) |
O(1⋅d)(仅新 token 作 query) |
| 生成 n token 总量 |
O(n2⋅d) |
O(n⋅d)(缓存线性增长) |
直觉: KV Cache 本质是用 内存换计算——将之前所有 step 的 Key/Value 向量缓存起来,避免重复计算。Prefill 一次处理全部 prompt 填充缓存,之后每步 decode 只需处理一个 token,Attention 的 query 维度恒为 1。
因果遮罩的精妙之处: 当 decode 阶段 Tcur=1 时,遮罩从 Tkv×Tkv 矩阵中取最后 1 行,确保新 token 只能看到它自己及之前所有位置。
3. Files
| 文件 |
说明 |
from_scratch.py |
核心实现:CausalMultiHeadAttention、TransformerDecoderBlock、MiniGPT、generate 解码循环、_sample 采样函数,以及 __main__ 自测试 |
test_kv_cache.py |
单元测试:验证 KV Cache 的正确性(缓存拼接、形状、与无缓存结果的一致性等) |
README.md |
本文件 |
4. Run
python from_scratch.py
python test_kv_cache.py
5. 追问分层 / Stratified Follow-ups
L1 — 基础 / Basic
- Prefill 阶段和 Decode 阶段分别输入模型的 token 数是多少?为什么要区分这两个阶段?
- 因果遮罩(causal mask)的作用是什么?如果没有它会怎样?
temperature 参数如何影响生成的多样性?temperature=0 意味着什么?
L2 — 中级 / Intermediate
- 为什么 KV Cache 只缓存 Key 和 Value 而不缓存 Query?从计算图角度解释。
- 代码中
position_offset 的作用是什么?如果去掉它会导致什么问题?
- 解码阶段因果遮罩的切片
causal[T_kv - T_cur : T_kv, :] 为什么不能直接用完整的 T_kv × T_kv 矩阵?
- Pre-norm(先 LayerNorm 再 Attention/FFN)相比 Post-norm 有什么训练稳定性上的优势?
L3 — 深度 / Deep
- 如果将
cached_k 的 torch.cat 操作替换为预分配 buffer + 原地写入,具体会减少哪些开销?在什么场景下收益最大?
- 本实现中 KV Cache 的内存占用为 O(L⋅nheads⋅T⋅dk),当序列长度 T 极大时有哪些经典的压缩策略(如 GQA、MQA、Sliding Window)?它们各自牺牲了什么?
- 当前采样使用的是独立的 top-k + temperature;如果引入 nucleus sampling(top-p),概率分布的截断逻辑有何本质区别?在什么分布特征下 top-p 优于 top-k?