Drill · 手撕

AdamW 从零实现学习练习 / AdamW from Scratch Study Drill

本项目旨在通过从零实现 AdamW 优化器的一个步骤,深入理解其数学原理和实现细节。实现基于纯 PyTorch,专注于核心更新逻辑。

1. 数学原理 / Mathematics

AdamW 是 Adam 优化器(Kingma & Ba, 2014)的变体,引入了解耦权重衰减(decoupled weight decay)(Loshchilov & Hutter, 2017)。对于每个参数 pp 在步骤 tt 的梯度 gg,关键更新方程如下:

第一步:更新有偏矩估计(Biased Moment Estimates)

mt=β1mt1+(1β1)g(一阶矩,梯度的指数移动平均)vt=β2vt1+(1β2)g2(二阶矩,梯度平方的指数移动平均)\begin{aligned} m_t &= \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g \quad &\text{(一阶矩,梯度的指数移动平均)} \\ v_t &= \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g^2 \quad &\text{(二阶矩,梯度平方的指数移动平均)} \end{aligned}

第二步:偏差校正(Bias Correction)

m^t=mt1β1t(校正一阶矩)v^t=vt1β2t(校正二阶矩)\begin{aligned} \hat{m}_t &= \frac{m_t}{1 - \beta_1^t} \quad &\text{(校正一阶矩)} \\ \hat{v}_t &= \frac{v_t}{1 - \beta_2^t} \quad &\text{(校正二阶矩)} \end{aligned}

第三步:参数更新(Parameter Update)

pt=pt1αm^tv^t+ϵαλpt1p_t = p_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} - \alpha \cdot \lambda \cdot p_{t-1}

其中:

注意:权重衰减项 αλp-\alpha \cdot \lambda \cdot p 直接施加于参数 pp,而不通过梯度,这与 L2 正则化不同,称为“解耦”。

2. 直觉与复杂度 / Intuition and Complexity

直觉:AdamW 结合了 Adam 的自适应学习率(通过一阶和二阶矩调整每个参数的更新幅度)和解耦的权重衰减。权重衰减直接惩罚参数值,促进模型稀疏性和泛化能力,而不干扰 Adam 的自适应机制。偏差校正项解决了初始化时矩估计为零导致的偏差问题。

复杂度

3. 文件 / Files

本项目目录包含仅有以下三个文件:

4. 运行 / Run

仅支持以下两个运行命令:

5. 追问分层 / Stratified follow-ups

以下问题旨在分层深入理解 AdamW 优化器步骤的从零实现: