Drill · 手撕

Drill: SFT loss masking from scratch

可运行的 from-scratch 实现 + 测试。目标:每一行都能在面试里推导和辩护。 Runnable from-scratch implementation with tests — derive and defend every line.

做什么 / What this covers

监督微调(SFT)的损失函数有两个分开的问题需要解决:

Supervised fine-tuning (SFT) involves two distinct sub-problems:

  1. Label masking — 只在 assistant 回答的 token 上计算损失;prompt / user turn 被设成 ignore_index=-100,对梯度没有贡献。 Only compute loss on assistant-response tokens; prompt / user tokens are set to ignore_index=-100 and contribute no gradient.

  2. Masked cross-entropy — 在含有 ignore_index 的 labels 上安全地算交叉熵;索引越界问题用 clamp-then-mask 模式规避。 Safely compute cross-entropy over a labels tensor containing ignore_index; index-out-of-bounds is avoided with the clamp-then-mask pattern.

数学 / The math

Label masking

给定序列 x1:Lx_{1:L} 和若干 assistant spans {[si,ei)}\{[s_i, e_i)\},构造标签:

Given sequence x1:Lx_{1:L} and assistant spans {[si,ei)}\{[s_i, e_i)\}, construct:

yt={xtif ti[si,ei)100otherwisey_t = \begin{cases} x_t & \text{if } t \in \bigcup_i [s_i, e_i) \\ -100 & \text{otherwise} \end{cases}

Masked cross-entropy loss

Per-token 负对数似然(只对 active 位置求和):

Per-token negative log-likelihood summed only over active positions:

L=1AtAlogpθ(ytx<t)\mathcal{L} = \frac{1}{|\mathcal{A}|} \sum_{t \in \mathcal{A}} -\log p_\theta(y_t \mid x_{<t})

其中 A={t:yt100}\mathcal{A} = \{t : y_t \neq -100\}pθp_\theta 是模型的 softmax 输出。

where A={t:yt100}\mathcal{A} = \{t : y_t \neq -100\} and pθp_\theta is the model's softmax output.

两种归一化 / Two normalisation conventions:

mode denominator 适用场景
"token" A\|\mathcal{A}\|(非 mask token 数) HuggingFace Trainer / TRL 默认;每个 token 等权
"sample" LL(序列总长度) 部分 RL trainer;loss scale 随 batch size 稳定

Clamp-then-mask 模式 / The clamp-then-mask pattern

直接用 -100torch.gather 的索引会触发越界错误(CUDA backend 尤甚)。正确做法:

Directly using -100 as an index into torch.gather raises an out-of-bounds error on most backends. The safe pattern:

safe_labels = labels.clamp(min=0)         # 1. make every index legal
nll = -log_probs.gather(1, safe_labels)   # 2. gather — all indices in [0, V-1]
nll = nll * (labels != ignore_index)      # 3. zero out the masked positions

步骤 3 的零乘保证 -100 位置的"错误"gather 结果对 loss 贡献为 0,数学上等价于跳过这些位置。

The multiply-by-zero in step 3 ensures the "wrong" gather result at -100 positions contributes 0 to the loss — mathematically identical to skipping those positions entirely.

为什么重要 / Why it matters

文件 / Files

python test_sft_loss_mask.py        # 或 python -m pytest test_sft_loss_mask.py

追问分层 / Stratified follow-ups