Cheatsheet · 题解

持续 / 终身 Post-training / Continual & Lifelong(只收生产验证方法)

模型是迭代更新的(加新数据/新能力/新对齐轮)→ 灾难性遗忘与能力回归是规模化时的真问题。 ⚠️ 本页只收在大规模生产中验证过的方法;经典学术 CL 算法单列、明确标注「未经生产验证」,面试别当工业标准答。

1. 生产里「持续」长什么样

不是教科书式的在线流式持续学习,而是:周期性从 base / checkpoint 重训 + 调数据配比。目标 = 加新能力/新对齐,同时不退化已有能力(避免 alignment tax / 回归)。

1.1 遗忘的机制与度量

机制:在新数据上做梯度下降时,更新方向 θLnew-\nabla_\theta \mathcal{L}_{\text{new}} 常与旧任务的下降方向冲突(gradient interference:Lold,Lnew<0\langle \nabla\mathcal{L}_{\text{old}},\, \nabla\mathcal{L}_{\text{new}}\rangle < 0),于是权重朝"对新好、对旧差"的方向漂移(weight drift),旧能力随之损失。漂移越大、LR 越高、训得越久,遗忘越重——这正是"低 LR + 少 epoch + PEFT"奏效的原因(限制漂移)。

度量:用 BWT(backward transfer) 量化遗忘。设学完任务 ii 后在任务 jj 上的指标为 Ri,jR_{i,j},训完最后任务 TT 后:

BWT=1T1j=1T1(RT,jRj,j)\mathrm{BWT} = \frac{1}{T-1}\sum_{j=1}^{T-1}\big(R_{T,j} - R_{j,j}\big)

BWT<0\mathrm{BWT}<0 即遗忘(越负越严重)。生产里常配合保持率(retention)与对旧 benchmark 的回归探针(regression probe)一起监控。

2. ✅ 生产验证过的工具箱

2.1 数据回放 / 混合(replay / rehearsal)—— 最主力

持续微调时混入一定比例的旧/通用数据(指令数据配比)。最朴素也最有效的防遗忘手段;工程重点是配比、去重、质量过滤。

2.2 低学习率 + 少 epoch + PEFT

小步微调限制权重漂移;LoRA / adapter 做廉价增量适配 + 改动隔离(改坏了可丢弃 adapter)。

2.3 对 base 的 KL 正则

RLHF 的 βKL(πθπref)\beta\,\mathrm{KL}(\pi_\theta\,\|\,\pi_{\mathrm{ref}}) 本质就是把策略锚在 base 附近、防漂移与遗忘。

2.4 模型合并 / 权重平均

2.5 蒸馏 consolidation

把多个专家 / 更新后的 teacher 蒸馏成一个模型,巩固能力、压缩多轮迭代。

2.6 多阶段顺序遗忘(SFT → DPO → RL)

后阶段会侵蚀前阶段:DPO / RL 阶段的策略漂移会抹掉部分 SFT 习得的能力与格式,通常最后的 RL 步最严重(无标注约束、只追奖励,易过优化)。缓解:RL 阶段保留对 SFT-ref 的 KL、混入 SFT replay、对关键能力加 verifier 约束;并在每阶段后跑回归探针。

3. ❌ 未经生产验证(学术——别当工业标准)

4. 把你的 CL 背景诚实地用上

Fed-TaLoRA(联邦持续微调)、Continual Agent → 可迁移的洞察(遗忘度量、保持率视角、聚合一致性)。

5. 代码:replay 混合 + 权重合并

39 行 / lines
import torch, itertools

# (1) replay 混合:按比例把旧/通用数据交错进新数据,防遗忘
def make_replay_stream(new_data, old_data, replay_ratio=0.3, seed=0):
    """每条新数据后,以 replay_ratio 概率插入一条循环复用的旧数据。"""
    g = torch.Generator().manual_seed(seed)
    old_cycle = itertools.cycle(old_data)
    stream = []
    for x in new_data:
        stream.append(("new", x))
        if torch.rand(1, generator=g).item() < replay_ratio:
            stream.append(("old", next(old_cycle)))
    return stream

# (2) model soup:等权平均多个同构 checkpoint(需同一 init)
def model_soup(state_dicts):
    avg = {k: torch.zeros_like(v) for k, v in state_dicts[0].items()}
    for sd in state_dicts:
        for k, v in sd.items():
            avg[k] += v / len(state_dicts)
    return avg

# (3) task arithmetic:θ0 + Σ scale_i·(θ_ft_i − θ0),加得能力 / 减则遗忘
def task_arithmetic(theta0, finetuned, scales):
    merged = {k: v.clone() for k, v in theta0.items()}
    for sd, s in zip(finetuned, scales):
        for k in merged:
            merged[k] += s * (sd[k] - theta0[k])     # τ_i = θ_ft_i − θ0
    return merged

# --- 玩具验证 ---
t0 = {"w": torch.zeros(3)}
a  = {"w": torch.tensor([1., 0., 0.])}
b  = {"w": torch.tensor([0., 2., 0.])}
print("soup:", model_soup([a, b])["w"])                                  # [0.5, 1.0, 0.0]
print("θ0+τa+τb:", task_arithmetic(t0, [a, b], [1.0, 1.0])["w"])         # [1., 2., 0.]
print("forget b (−τb):", task_arithmetic(t0, [a, b], [1.0, -1.0])["w"])  # [1., -2., 0.]
print("replay stream:", [tag for tag, _ in make_replay_stream(range(4), range(100, 103), 0.5)])

分层面试题 / Stratified follow-ups

L1 基础

L2 进阶

L3 深挖

§A 核心论文时间线 / Key Papers Timeline