GradScaler的数学原理简述
引言¶
什么是GradScaler?
目的: 在半精度训练(float16)时,许多梯度值过小会下溢到 0,导致网络无法学习。GradScaler 通过把 损失 \(L\) 放大 再反向传播,把梯度整体抬高到 float16 可表示的动态范围内;反向之后再按比例缩回,从而 既避免下溢,又不改变更新方向和大小。
损失放大¶
给定放大因子\(S(S\gg1)\)(通常取 2 的幂,便于位移实现)
\[L_s = S \, L\]
在后向传播时,链式法则保证梯度同样被放大:
\[\nabla_w L_s \;=\; \frac{\partial\, (S L)}{\partial w} \;=\; S \,\frac{\partial L}{\partial w}\;=\; S \,\nabla_w L\]
- 若\(\nabla_w L\)原本落在\(\bigl[10^{-7},10^{-3}\bigr]\),而 float16 最小正数约\(6\times10^{-5}\),则乘以\(S=2^{10}=1024\)后便能落入可表示区间,避免被截成 0。
溢出检测¶
反向结束后,将 放大后的梯度\(\nabla_w L_s\)检查是否含有 Inf/NaN。
- 若发现溢出 → 说明\(S\)太大;本次迭代 整步丢弃,并将\(S\)乘以衰减系数\(\alpha<1\)(如 0.5)。
- 若无溢出 → 继续下一步。
梯度缩放回原尺度¶
未溢出的情况下,在执行优化器更新前把梯度除回\(S\):
\[\nabla_w L \;=\; \frac{\nabla_w L_s}{S}\]
因为微分线上性,\(\nabla_w L\)与常规 FP32 训练得到的梯度数值完全一致;随后再做梯度裁剪、Adam 动量更新等,数值就已安全地处在 FP32 域。
动态调整放大因子¶
GradScaler 维护一个计数器\(n_{\text{good}}\):
- 成功步(无溢出)累计\(n_{\text{good}}\)。
- 当\(n_{\text{good}} \ge k\)(如 2000 步)时,将\(S\)增长:\(S \gets S \times \beta\)(常取\(\beta=2\)),然后清零计数器。
- 溢出步 则立即下降放大因子:\(S \gets S \times \alpha\),并重置\(n_{\text{good}}=0\)。
这样\(S\)在 稳定区间 内自动寻找“尽可能大但不溢出”的值,兼顾数值安全与下溢缓解。
数值等价性与实现细节¶
- 放大与缩回互为乘除,故 参数更新与不使用 GradScaler 时完全一致。
- 选取\(S=2^{k}\)能用指针移位而非乘法完成缩放,避免额外舍入误差。
- BF16 动态范围与 FP32 相同,通常 可略过损失缩放;FP16 则几乎必需。
总结
GradScaler 通过“放大损失→检测溢出→缩回梯度→动态调\(S\)”四步,使训练过程在低精度下 最大化利用硬件加速,同时保持与全精度训练 数值等价 的梯度更新,是自动混合精度 (AMP) 的核心组件。