Drill · 手撕

LoRA 前向与权重合并 · 从零实现学习钻

低秩适配(LoRA) 的单文件最小实现:前向传播 + 权重合并/拆分。 从零手写每一行,用于理解核心原理,不依赖任何框架。


数学公式 / Math

前向传播 Forward

给定冻结预训练权重 WRdout×dinW \in \mathbb{R}^{d_{out} \times d_{in}},可训练低秩矩阵 ARr×dinA \in \mathbb{R}^{r \times d_{in}}(下投影)、BRdout×rB \in \mathbb{R}^{d_{out} \times r}(上投影),缩放因子 scaling=α/r\text{scaling} = \alpha / r,LoRA 前向为:

y=Wx+αrBAxy = Wx + \frac{\alpha}{r} \cdot B A x

分步实现(代码中的 forward):

hdown=AxRrh_{\text{down}} = Ax \quad \in \mathbb{R}^{r}

hup=BhdownRdouth_{\text{up}} = B h_{\text{down}} \quad \in \mathbb{R}^{d_{out}}

y=Wx+αrhupy = Wx + \frac{\alpha}{r} \cdot h_{\text{up}}

批量维度:输入 xR(,din)x \in \mathbb{R}^{(*, d_{in})},输出 yR(,dout)y \in \mathbb{R}^{(*, d_{out})}* 表示任意前导 batch 维度。

权重增量 Delta Weight

ΔW=αrBARdout×din\Delta W = \frac{\alpha}{r} \cdot BA \quad \in \mathbb{R}^{d_{out} \times d_{in}}

合并 Merge

推理前一次性折叠,消除额外矩阵乘法开销:

WW+ΔW=W+αrBAW \leftarrow W + \Delta W = W + \frac{\alpha}{r} \cdot BA

拆分 Unmerge

逆操作,恢复原始冻结权重:

WWΔWW \leftarrow W - \Delta W

初始化 Initialization


直觉与复杂度 / Intuition & Complexity

直觉 Intuition

复杂度 Complexity

可训练参数量 Trainable Params 前向 FLOPs(每 token)
LoRA(未合并) r(din+dout)r \cdot (d_{in} + d_{out}) dindout+r(din+dout)d_{in} \cdot d_{out} + r \cdot (d_{in} + d_{out})
LoRA(已合并) 同上(推理时为零额外开销) dindoutd_{in} \cdot d_{out}(与原模型相同)
全量微调 Full FT dindoutd_{in} \cdot d_{out}

rmin(din,dout)r \ll \min(d_{in}, d_{out}) 时,LoRA 参数量远小于全量微调。


文件 / Files

本钻(drill)目录下仅有以下三个文件

文件 说明
from_scratch.py LoRALinear 类的完整从零实现(含 forwardmerge_weightsunmerge_weights)及主函数自测
test_lora_forward.py 单元测试,验证前向传播、合并、拆分的正确性
README.md 本说明文件

运行 / Run

# 演示与自测 —— 打印前向结果、手动复核、合并后一致性、往返恢复
python from_scratch.py

# 单元测试
python test_lora_forward.py

追问分层 / Stratified Follow-ups

L1 · 基础 Basic

  1. 为什么 BB 初始化为零而不是 AA 如果两者都随机初始化,ΔW=BA\Delta W = BA 初始值非零,会破坏预训练权重,导致训练初期性能剧烈下降。B=0B=0 确保 ΔW=0\Delta W=0,实现"零初始化身份起点"。

  2. scaling = α / r 的作用是什么?改变 α 会怎样? α\alpha 控制 LoRA 更新的总强度。固定 α\alpha 后增大 rr,每条秩的贡献自动缩小(1/r1/r),使不同 rank 之间的学习率等效。实际训练中常用 α=2r\alpha = 2r 作为默认值,使 scaling = 1。

  3. merge_weights() 后为什么 forward 中的 LoRA 分支不再执行? 代码通过 self.merged 标志位判断。合并后 self.merged = Trueforward 跳过 LoRA 分支,直接用 F.linear(x, W) 计算。此时 WW 已包含 ΔW\Delta W,结果与未合并时数学等价。

L2 · 进阶 Intermediate

  1. 合并与拆分是精确可逆的吗?有没有数值陷阱? 数学上精确可逆(加法逆元),但浮点运算存在舍入误差。每次合并/拆分累积的误差约为 ϵ107\epsilon \approx 10^{-7}(float32)。频繁反复合并/拆分可能累积误差;实际使用中建议只合并一次用于推理,或切换到新权重前重新加载。

  2. 为什么用 F.linear 而不是手动 x @ A.T?能否用一个 F.linear 完成两步投影? F.linear(input, weight) 内部执行 input @ weight.T,代码更简洁且利用了 PyTorch 内部的 fused kernel 优化。两步投影(down → up)必须拆成两个 F.linear,因为中间维度从 dind_{in} 变为 rr 再变为 doutd_{out},无法用单一矩阵乘法表达。

  3. self.weight 设置了 requires_grad=False,但 lora_Alora_B 没有显式设置 requires_grad=True——为什么它们仍是可训练的? nn.Parameter 默认 requires_grad=True。冻结权重通过显式设置 requires_grad=False 来"关闭梯度"。这正是 LoRA 的核心:只训练低秩矩阵,冻结原始权重。

L3 · 深入 Deep

  1. 本实现的 compute_delta_weight() 显式构造了完整的 dout×dind_{out} \times d_{in} 矩阵。在真正的推理框架(如 vLLM)中,合并大权重矩阵的内存和计算瓶颈在哪里?如何优化? 对于 70B 模型,单个线性层的 ΔW\Delta W 可能是 8192×8192256MB8192 \times 8192 \approx 256\text{MB}(float16)。合并操作需要分配临时张量并执行加法。优化手段包括:(a) 就地加法(add_,本实现已使用)避免额外分配;(b) 对多 LoRA adapter 使用运行时矩阵乘法而非合并(避免 OOM);(c) 量化合并(如 GPTQ 后量化时再融合 LoRA)。

  2. 如果要支持多个 LoRA adapter 的动态切换(如多租户推理),合并/拆分策略还适用吗? 不适用。反复合并/拆分会累积浮点误差,且无法同时服务多个 adapter。正确做法是不合并,保持 WW 冻结,在前向时动态计算 ΔWx=B(Ax)\Delta W \cdot x = B(Ax)。进一步的优化是 Punica/S-LoRA 等方案,用 batch 内分组的稀疏矩阵乘法(BGMV)在一次 kernel launch 中处理多个 adapter。

  3. 当前实现中 AA 使用 Kaiming 初始化,这在数学上等价于什么假设?与原始论文(Hu et al., 2021)中使用高斯初始化有何区别? Kaiming Uniform 的设计目标是保持 ReLU 激活的方差在前向传播中恒定,隐含假设是 AA 后有非线性激活。原始 LoRA 论文中 AN(0,σ2)A \sim \mathcal{N}(0, \sigma^2)B=0B=0,不依赖非线性假设。实际差异很小:两种初始化的方差量级相同,最终效果基本由 B=0B=0 的零初始化和训练过程主导。