Drill · 手撕

Rotary Position Embedding (RoPE) 从零实现学习演练

本演练旨在通过一个纯净的 PyTorch 实现,深入理解旋转位置编码 (RoPE) 的数学原理与工程实现。

1. 数学原理

RoPE 的核心思想是为序列中的每个 token 的查询 (q) 和键 (k) 向量施加一个依赖于其绝对位置 mm 的旋转变换,使得变换后两个向量的点积仅依赖于它们的相对位置 (mn)(m - n)

对于一个 dd 维的向量 x\mathbf{x},将其视为 d/2d/2 个二维子空间。对于第 ii 个子空间,其旋转角度 θm,i\theta_{m,i} 定义为:

θm,i=mbase2i/d\theta_{m,i} = m \cdot \text{base}^{-2i/d}

其中 base\text{base} 是一个预定义的常数(通常为 10000.0),mm 是 token 在序列中的绝对位置。

对于向量 x\mathbf{x} 中的每一对元素 (x2i,x2i+1)(x_{2i}, x_{2i+1}),应用旋转后的结果为:

[x2ix2i+1]=[cosθm,isinθm,isinθm,icosθm,i][x2ix2i+1]\begin{bmatrix} x'_{2i} \\ x'_{2i+1} \end{bmatrix} = \begin{bmatrix} \cos \theta_{m,i} & -\sin \theta_{m,i} \\ \sin \theta_{m,i} & \cos \theta_{m,i} \end{bmatrix} \begin{bmatrix} x_{2i} \\ x_{2i+1} \end{bmatrix}

上述公式可以写成更紧凑的元素操作形式:

x=xcosθm+rotate_half(x)sinθm\mathbf{x}' = \mathbf{x} \odot \cos\boldsymbol{\theta}_m + \text{rotate\_half}(\mathbf{x}) \odot \sin\boldsymbol{\theta}_m

其中 \odot 表示元素乘法,rotate_half\text{rotate\_half} 操作将 (x2i,x2i+1)(x_{2i}, x_{2i+1}) 变换为 (x2i+1,x2i)(-x_{2i+1}, x_{2i})cosθm\cos\boldsymbol{\theta}_msinθm\sin\boldsymbol{\theta}_m 是预先计算并重复交错以匹配原始维度 dd 的查找表。

其核心性质(在 from_scratch.py 中已验证)是 相对位置不变性RoPE(q,m),RoPE(k,n)=RoPE(q,m+Δ),RoPE(k,n+Δ)\langle \text{RoPE}(\mathbf{q}, m), \text{RoPE}(\mathbf{k}, n) \rangle = \langle \text{RoPE}(\mathbf{q}, m + \Delta), \text{RoPE}(\mathbf{k}, n + \Delta) \rangle

2. 直觉与复杂度

3. 文件说明

本演练目录包含以下三个文件:

  1. from_scratch.py:RoPE 的核心实现文件。包含频率预计算、旋转辅助函数、RoPE 应用函数、nn.Module 封装类以及自测试代码。运行此文件可执行简单的功能验证。
  2. test_rope.py:针对 from_scratch.py 中实现的单元测试文件,用于更系统地验证 RoPE 的数学性质和边界条件。
  3. README.md:本说明文档。

4. 运行指南

5. 追问分层 / Stratified follow-ups

L1 - 基础 (Basic):

  1. 旋转位置编码 (RoPE) 相对于学习的绝对位置编码 (Learned Absolute PE) 的主要优势是什么?
  2. 为什么 RoPE 的实现需要将 cossin 表进行“交错重复”操作(例如 [cos0, cos0, cos1, cos1, ...])?
  3. from_scratch.pyapply_rope 函数中,rotate_half 操作的具体作用是什么?

L2 - 中级 (Intermediate):

  1. RoPE 的频率公式 θi=base2i/d\theta_i = \text{base}^{-2i/d} 中,base 参数的作用是什么?如果改变它(例如从 10000 变为 500000),可能会对模型的长序列建模能力产生什么影响?
  2. 代码中使用 register_buffer 而不是 nn.Parameter 来存储 cos_tablesin_table,这是出于什么考虑?这对模型的保存和加载有什么影响?
  3. RoPE 是如何实现对相对位置的“软性”建模的?它与使用相对位置偏置(Relative Position Bias)的方法有何异同?

L3 - 深入 (Deep):

  1. 外推性 (Extrapolation):在推理时,如果遇到比训练时最大序列长度更长的序列,此基础 RoPE 实现会遇到什么问题?社区提出了哪些改进技术(如 NTK-aware Scaling, YaRN)来增强 RoPE 的外推能力?
  2. 梯度分析:从反向传播的角度分析,RoPE 的旋转操作对梯度流有什么特性?它是否有助于缓解深层 Transformer 中的梯度消失问题?
  3. 与其他位置编码的融合:如何将 RoPE 与卷积 (Convolution) 或循环 (Recurrence) 等局部操作结合使用?例如,在混合架构(如 RWKV 或 State Space Models)中,RoPE 可以扮演什么角色?