Drill · 手撕

Drill: Sequence Packing from scratch

可运行的 from-scratch 实现 + 测试。目标:每一行都能在面试里推导和辩护。 Runnable from-scratch implementation with tests — derive and defend every line.

背景 / Background

大模型训练时一个 batch 通常含多条不等长文本。Padding 的做法是把所有序列补齐到最长者,浪费 O(LmaxLi)O(L_{max} - L_i) 计算;更严重的是在同一个 padded batch 里做 self-attention 时,普通 padding mask 无法阻止不同文档之间的 attention 泄漏——需要额外 per-sample causal mask 或牺牲批次效率。

Sequence Packing 把多条序列直接拼接成一条长序列,用 cu_seqlens(cumulative sequence lengths)记录边界,然后通过 block-diagonal attention mask(或 flash-attn varlen 内核)保证跨文档 attention 权重严格为 0。零填充、零泄漏。

为什么不能用普通 padding mask 替代 cu_seqlens?

方面 Padding mask cu_seqlens + block-diagonal
空间 (B,Lmax)(B, L_{max}) 标记哪些是 pad (N+1,)(N+1,) 整数,记文档边界
计算浪费 O(LmaxLi)O(L_{max} - L_i) 个无效 token 仍参与 QK 矩阵乘法 T=LiT = \sum L_i,零浪费
跨文档隔离 无法阻止同 batch 内不同文档间的 attention(padding mask 只标记 pad 位,不标记文档边界) block-diagonal mask 把跨文档位置填 -\infty,softmax 后严格为 0
内存 需要 (B,Lmax,Lmax)(B, L_{max}, L_{max}) 的注意力矩阵 同理 (T,T)(T, T),但 TBLmaxT \ll B \cdot L_{max}(当序列长短不均时)
生产实践 适合序列等长 / 小 batch flash-attn varlen 接口直接消费 cu_seqlens,O(T) 索引避免构造完整 mask

一句话:padding mask 标记的是"哪里是 pad",而 cu_seqlens 标记的是"哪里是文档边界"——这是两个不同的概念,前者无法推导后者。

数学 / The math

NN 条文档长度为 l1,,lNl_1, \dots, l_N,总长 T=i=1NliT = \sum_{i=1}^{N} l_i

cu_seqlens: cu[0]=0,cu[i]=j=1ilj\text{cu}[0] = 0,\quad \text{cu}[i] = \sum_{j=1}^{i} l_j

Block-diagonal mask: M[s,t]=1[doc(s)=doc(t)]    (causalst)M[s, t] = \mathbf{1}[\text{doc}(s) = \text{doc}(t)] \;\wedge\; (\text{causal} \Rightarrow s \geq t)

其中 doc(s)=i\text{doc}(s) = i 当且仅当 cu[i]s<cu[i+1]\text{cu}[i] \leq s < \text{cu}[i+1]

Attention over packed sequence: Attn(Q,K,V)s=t:M[s,t]exp ⁣(qsktdk)vtt:M[s,t]exp ⁣(qsktdk)\mathrm{Attn}(Q, K, V)_s = \frac{\sum_{t: M[s,t]} \exp\!\left(\frac{q_s^\top k_t}{\sqrt{d_k}}\right) v_t}{\sum_{t: M[s,t]} \exp\!\left(\frac{q_s^\top k_t}{\sqrt{d_k}}\right)}

跨文档位置 M[s,t]=0M[s,t]=0 → scores 填 -\inftyexp()=0\exp(-\infty)=0 → 权重严格为 0。

Position IDs 在每条文档内从 0 重置,确保 RoPE 等位置编码在文档级别正确。

文件

python test_sequence_packing.py        # 或 python -m pytest test_sequence_packing.py

追问分层 / Stratified follow-ups