一个用于学习数值稳定实现的 PyTorch 练习,包含 softmax、带 label smoothing 的交叉熵损失函数。
1. 数学原理
Softmax 函数: 给定 logits 向量 , softmax 输出概率分布 : 为了数值稳定性,计算时会减去最大值 :
Label Smoothing: 给定真实类别 one-hot 向量 (仅第 维为1) 和 smoothing 因子 ,平滑后的分布 为:
交叉熵损失: 对于单个样本,计算预测分布 与目标分布 的交叉熵 : 最终,对一个 batch 的 个样本取平均得到总损失 :
2. 直觉与复杂度
- 数值稳定性: 核心技巧是在
exp前减去logits的最大值,防止浮点数上溢。 - Label Smoothing: 一种正则化技术,防止模型对预测过于自信。它将真实标签的 one-hot 分布(尖锐)变为更平滑的分布,鼓励模型输出更低的熵(即更不确定)。
- 时间复杂度: 两个函数的核心操作(
max,sum,exp)都在最后一个维度(类别维度)进行。对于形状为(N, K)的输入,单次操作的时间复杂度为 。 - 空间复杂度: 需要存储与输入同形状的中间结果(如
shifted,exp_shifted),空间复杂度为 。
3. 文件
本学习目录仅包含以下三个文件:
from_scratch.py: 核心实现,包含stable_softmax和label_smoothing_cross_entropy两个函数。test_cross_entropy.py: 单元测试文件,用于验证实现的正确性和数值稳定性。README.md: 本说明文档。
4. 运行
执行内置的自检演示:
python from_scratch.py
运行完整的单元测试套件:
python test_cross_entropy.py
5. 追问分层 / Stratified follow-ups
L1 基础
stable_softmax函数中减去max_logits的作用是什么?label_smoothing_cross_entropy函数中,smoothed变量是如何从one_hot和epsilon计算得到的?- 最终损失值是如何从
loss_per_sample计算得出的?
L2 中级
- 如果
epsilon = 0,label_smoothing_cross_entropy的计算结果会等价于标准的交叉熵损失吗?为什么? - 代码中计算
log_softmax的方式 (shifted - log_sum_exp) 与直接使用torch.log(torch.softmax(logits, dim=-1))相比,为何在数值上更稳定? scatter_函数在此处的作用是什么?它完成了一项什么关键的数据转换?
L3 深度
- 假设某个样本的
logits值非常大(例如[1000, 1000, 0]),请追踪代码执行过程,解释为什么stable_softmax和label_smoothing_cross_entropy能够避免产生NaN或Inf。 - 从梯度的角度,分析 label smoothing () 如何影响模型对正确类别的参数更新。
- 此实现与 PyTorch 内置的
torch.nn.CrossEntropyLoss在功能和计算路径上有何主要异同?(提示:考虑内置函数是否集成了 softmax、是否支持 label smoothing 参数)。