跳转至

GradScaler的数学原理简述

标签: 模型量化LLM

引言

什么是GradScaler?

目的: 在半精度训练(float16)时,许多梯度值过小会下溢到 0,导致网络无法学习。GradScaler 通过把 损失 \(L\) 放大 再反向传播,把梯度整体抬高到 float16 可表示的动态范围内;反向之后再按比例缩回,从而 既避免下溢,又不改变更新方向和大小


损失放大

给定放大因子\(S(S\gg1)\)(通常取 2 的幂,便于位移实现)

\[L_s = S \, L\]

在后向传播时,链式法则保证梯度同样被放大:

\[\nabla_w L_s \;=\; \frac{\partial\, (S L)}{\partial w} \;=\; S \,\frac{\partial L}{\partial w}\;=\; S \,\nabla_w L\]
  • \(\nabla_w L\)原本落在\(\bigl[10^{-7},10^{-3}\bigr]\),而 float16 最小正数约\(6\times10^{-5}\),则乘以\(S=2^{10}=1024\)后便能落入可表示区间,避免被截成 0。

溢出检测

反向结束后,将 放大后的梯度\(\nabla_w L_s\)检查是否含有 Inf/NaN。

  • 若发现溢出 → 说明\(S\)太大;本次迭代 整步丢弃,并将\(S\)乘以衰减系数\(\alpha<1\)(如 0.5)。
  • 若无溢出 → 继续下一步。

梯度缩放回原尺度

未溢出的情况下,在执行优化器更新前把梯度除回\(S\)

\[\nabla_w L \;=\; \frac{\nabla_w L_s}{S}\]

因为微分线上性,\(\nabla_w L\)与常规 FP32 训练得到的梯度数值完全一致;随后再做梯度裁剪、Adam 动量更新等,数值就已安全地处在 FP32 域。


动态调整放大因子

GradScaler 维护一个计数器\(n_{\text{good}}\)

  1. 成功步(无溢出)累计\(n_{\text{good}}\)
  2. \(n_{\text{good}} \ge k\)(如 2000 步)时,将\(S\)增长\(S \gets S \times \beta\)(常取\(\beta=2\)),然后清零计数器。
  3. 溢出步 则立即下降放大因子:\(S \gets S \times \alpha\),并重置\(n_{\text{good}}=0\)

这样\(S\)稳定区间 内自动寻找“尽可能大但不溢出”的值,兼顾数值安全与下溢缓解。


数值等价性与实现细节

  • 放大与缩回互为乘除,故 参数更新与不使用 GradScaler 时完全一致
  • 选取\(S=2^{k}\)能用指针移位而非乘法完成缩放,避免额外舍入误差。
  • BF16 动态范围与 FP32 相同,通常 可略过损失缩放;FP16 则几乎必需。

总结
GradScaler 通过“放大损失→检测溢出→缩回梯度→动态调\(S\)”四步,使训练过程在低精度下 最大化利用硬件加速,同时保持与全精度训练 数值等价 的梯度更新,是自动混合精度 (AMP) 的核心组件。