Drill · 手撕

从头实现温度 / top-k / top-p (nucleus) 采样学习演练 README

本演练基于纯 PyTorch 从头实现温度缩放、top-k 和 top-p(nucleus)采样,用于从 logits 中采样 token。目录中仅包含三个文件,无其他依赖。

数学基础 / Mathematical Foundations

所有操作基于模型输出的 logits 向量 zRVz \in \mathbb{R}^{V},其中 VV 为词表大小。采样过程如下:

  1. 温度缩放 / Temperature Scaling
    给定温度 τ>0\tau > 0,缩放后的 logits 为 zi/τz_i / \tau,然后通过 softmax 转换为概率:
    pi=ezi/τj=1Vezj/τp_i = \frac{e^{z_i / \tau}}{\sum_{j=1}^V e^{z_j / \tau}}

    • τ0\tau \to 0:分布趋近贪心解码(argmax)。
    • τ\tau \to \infty:分布趋近均匀分布。
    • τ=1.0\tau = 1.0:无变化。
  2. Top-k 过滤 / Top-k Filtering
    对每个样本,保留 logits 中最大的 kk 个值,其余设为 -\infty
    z~i={zi,if zitop-k values,otherwise\tilde{z}_i = \begin{cases} z_i, & \text{if } z_i \in \text{top-}k \text{ values} \\ -\infty, & \text{otherwise} \end{cases}
    等价于 mask 掉非 top-k 的 token。

  3. Top-p(Nucleus)过滤 / Top-p Filtering
    先将 logits 转换为概率 pi=softmax(zi)p_i = \text{softmax}(z_i),并按降序排列得到 p(1)p(2)p(V)p_{(1)} \geq p_{(2)} \geq \dots \geq p_{(V)}。计算累积概率:
    cumulativem=j=1mp(j)\text{cumulative}_m = \sum_{j=1}^m p_{(j)}
    找到最小的 mm 使得 cumulativemp\text{cumulative}_m \geq ppp 为 top_p 阈值),则保留前 mm 个 token(至少保留 min_tokens_to_keep 个),其余设为 -\infty。在代码中,通过移位累积和实现 mask。

  4. 采样 / Sampling
    应用上述过滤后,对 filtered logits 计算 softmax 概率,并从多项分布中采样一个 token:
    next_tokenMultinomial(softmax(z~))\text{next\_token} \sim \text{Multinomial}(\text{softmax}(\tilde{z}))
    处理顺序为:温度缩放 → top-k → top-p → 多项采样。

直觉与复杂度 / Intuition and Complexity

文件 / Files

演练目录包含 EXACTLY 以下三个文件:

运行 / Run

仅支持以下两个命令:

追问分层 / Stratified follow-ups

L1 基础 / Basic

  1. 温度缩放(temperature scaling)如何影响概率分布?当温度接近 0 时会发生什么?
  2. Top-k 采样中,参数 kk 的作用是什么?为什么需要限制 kvocab_sizek \leq \text{vocab\_size}
  3. 在 top-p 采样中,什么是“nucleus”?如何通过阈值 pp 控制候选 token 数量?

L2 中级 / Intermediate

  1. 比较 top-k 和 top-p 采样:在什么场景下 top-p 可能比 top-k 更优?
  2. 解释为什么在采样管道中,温度缩放通常先于 top-k 和 top-p 应用?
  3. 如何调整温度、top-k 和 top-p 参数来平衡文本生成的多样性和连贯性?

L3 深入 / Deep

  1. 分析温度缩放与 softmax 的数学关系:为什么缩放 logits 等价于调整分布熵?
  2. 在 top-p 过滤中,移位累积和(shifted cumulative sum)的算法设计如何确保正确 mask?讨论边界条件。
  3. 从信息论角度,讨论温度、top-k 和 top-p 如何影响生成文本的多样性和困惑度(perplexity)。在实际应用中,如何联合调参以优化模型性能?