Drill · 手撕

DPO 偏好优化损失函数 — 从零实现学习演练

Direct Preference Optimization (DPO) Loss — From-Scratch Study Drill


1. 数学原理 / Mathematical Formulation

隐式奖励 / Implicit Reward

DPO 将策略模型 πθ\pi_\theta 与参考模型 πref\pi_{\text{ref}} 的对数概率比定义为隐式奖励:

r(yx)  =  β[logπθ(yx)    logπref(yx)]r(y \mid x) \;=\; \beta \,\bigl[\log \pi_\theta(y \mid x) \;-\; \log \pi_{\text{ref}}(y \mid x)\bigr]

DPO 损失 / Loss (per sample)

给定偏好对 (x,yw,yl)(x, y_w, y_l),其中 ywy_w 为优选(chosen)响应,yly_l 为拒绝(rejected)响应:

LDPO=logσ ⁣(β[logπθ(ywx)logπref(ywx)logπθ(ylx)+logπref(ylx)])\mathcal{L}_{\text{DPO}} = -\log \sigma\!\Bigl(\beta \cdot \bigl[\log \pi_\theta(y_w \mid x) - \log \pi_{\text{ref}}(y_w \mid x) - \log \pi_\theta(y_l \mid x) + \log \pi_{\text{ref}}(y_l \mid x)\bigr]\Bigr)

等价地,令 z=β(chosen_log_ratiorejected_log_ratio)z = \beta(\text{chosen\_log\_ratio} - \text{rejected\_log\_ratio}),则 L=logσ(z)\mathcal{L} = -\log\sigma(z)

带 Label Smoothing 的保守变体 / Conservative Variant

ε>0\varepsilon > 0 时(参见 Rafailov et al. 2023 §5):

L=(1ε)logσ(z)    εlogσ(z)\mathcal{L} = -(1 - \varepsilon)\,\log\sigma(z) \;-\; \varepsilon\,\log\sigma(-z)

逐序列对数概率 / Per-Sequence Log-Probability

标准 next-token cross-entropy,对序列长度求和并 mask 掉 ignore_index=-100 的 padding 位置:

logP(yx)=t1[yt+1-100]logπθ(yt+1yt,x)\log P(y \mid x) = \sum_{t} \mathbf{1}[y_{t+1} \neq \texttt{-100}] \cdot \log \pi_\theta(y_{t+1} \mid y_{\le t}, x)


2. 直觉与复杂度 / Intuition & Complexity

直觉 / Intuition

DPO 的核心洞察是:无需训练单独的 reward model,可以直接用 logπθπref\log\frac{\pi_\theta}{\pi_{\text{ref}}} 作为隐式奖励。损失函数本质上是一个二分类目标——让优选回答的隐式奖励高于被拒绝回答的奖励。β\beta 控制 KL 散度约束的强度:β\beta 越大,策略越贴近参考模型。

复杂度 / Complexity

BB 为 batch size,SS 为序列长度,VV 为词表大小:

操作 时间复杂度 备注
compute_log_probs_from_logits O(BSV)O(B \cdot S \cdot V) log_softmax + gather
dpo_loss O(B)O(B) 仅逐元素算术

总计算量约 4×O(BSV)4 \times O(BSV)(四个 logits 张量各过一次 log-prob 计算)。内存主要受 logits 张量 (B,S,V)(B, S, V) 支配。


3. 文件清单 / Files

本演练目录下仅有以下三个文件:

文件 说明
from_scratch.py 核心实现:compute_log_probs_from_logitsdpo_lossDPOLossModule,以及自测入口
test_dpo_loss.py 单元测试,验证损失函数数值正确性与梯度传播
README.md 本文件

4. 运行 / Run

# 演示与自测 / Demo & smoke-test
python from_scratch.py

# 运行单元测试 / Run tests
python test_dpo_loss.py

5. 追问分层 / Stratified Follow-ups

L1 — 基础 / Basic

  1. 什么是 DPO? 为什么 DPO 可以替代 RLHF 中的 reward model + PPO 流程?
  2. 隐式奖励的含义: r(y|x) = β · (log π_θ − log π_ref) 这个量在直觉上代表什么?为什么它与 Bradley-Terry 偏好模型一致?
  3. beta 的作用:beta 从 0.1 改为 1.0 或 0.01,loss 行为会如何变化?什么情况下 beta=0 会导致问题?
  4. ignore_index 的用途: 为什么需要在 compute_log_probs_from_logits 中 mask 掉 labels == -100 的位置?如果不 mask 会怎样?

L2 — 进阶 / Intermediate

  1. Label Smoothing 的效果: 代码中 label_smoothing > 0 时,损失函数变成了什么形式?这如何防止对偏好数据的过拟合(overfitting to preference data)?
  2. Reference model 冻结: 为什么 reference_chosen_logpsreference_rejected_logps 在训练中不应当有梯度流过?如果参考模型也被更新,会出现什么问题?
  3. Reward margin 的监控: 代码返回了 reward_margin = chosen_rewards - rejected_rewards。训练过程中这个值应如何变化?持续上升是好事吗?为什么?
  4. 数值稳定性: 直接用 torch.log(torch.sigmoid(...)) 代替 F.logsigmoid(...) 可能在哪些场景下导致数值溢出(overflow/underflow)?

L3 — 深入 / Deep

  1. DPO 的理论假设: DPO 推导假设偏好服从 Bradley-Terry 模型。如果真实偏好数据违反了 BT 假设(例如存在循环偏好或不可传递偏好),DPO 的优化目标会出现什么偏差?label smoothing 能在多大程度上缓解这个问题?
  2. 与 RLHF 的等价性边界: DPO 论文证明了在 BT 模型下 DPO 与 RLHF 等价。但当策略 πθ\pi_\theta 偏离 πref\pi_{\text{ref}} 较远时,这种等价性是否仍然成立?离线(offline)DPO 的主要局限是什么?
  3. β\beta 与 KL 约束的关系: 从带 KL 约束的 RL 目标 maxπE[r(x,y)]βDKL(ππref)\max_\pi \mathbb{E}[r(x,y)] - \beta \, D_{\text{KL}}(\pi \| \pi_{\text{ref}}) 出发,推导 DPO 的闭式最优策略 π\pi^*,并解释 β\beta 如何在奖励最大化和 KL 约束之间做权衡。
  4. 从 DPO 到后续改进: SimPO、IPO、KTO 等方法分别解决了 DPO 的哪些已知问题?如果要在本代码基础上实现 IPO(Identity Preference Optimization),需要修改哪些部分?