训练神经网络的每一步都要算「损失对所有参数的梯度」。原理是初等微积分里就有的链式法则,但深度网络的参数动辄 10810^8 起步,手算推导既容易出错也无法扩展;数值微分(差分法)则要为每个参数单独扰动一遍,代价随参数量线性增长,并且扰动步长两头不讨好——大有近似误差、小有浮点精度损失。

本文介绍 反向传播(backpropagation)+ 自动微分(automatic differentiation, AD)——前者把链式法则按计算图的反向拓扑顺序套用一遍,后者把这套递推抽象成一个不需要显式雅可比的算子。两者合起来让一次梯度计算的代价与一次前向同阶(O(params)O(\text{params}) 而不是 O(params2)O(\text{params}^2)),并构成 PyTorch、JAX 这类可微编程框架的算法核心

这篇文章是 从感知机到反向传播再到深层训练 § 3.4 反向传播:计算图、链式法则、雅可比 的深度展开——主文章给「为什么这么做」的高层视角,本文按 § 0 ~ § 5 顺序展开「梯度具体怎么算」:先讲清楚需求与方案选择(§ 0),铺垫多维求导与雅可比矩阵(§ 1)、矩阵链式法则与 Y=WXY = WX 的实战推导(§ 2);再过渡到计算图与反向传播(§ 3)、VJP 与反向模式 AD(§ 4)、内存账本与动态/静态图的工程取舍(§ 5);§ 6 收束 takeaways。

0. 引言:为什么需要自动微分?

把神经网络看成一个标量函数 L(θ)L(\theta)θRP\theta \in \mathbb{R}^P 是所有参数的拼接。训练每一步都要拿到 θL\nabla_\theta L,量级 PP 经常是 107101110^7 \sim 10^{11}。一般有如下三种方案。

手算解析导数:对每一层、每一个参数手工写出偏导。链式法则保证结果正确,但在 10 层以上的网络里,写错一个转置就会让整个梯度系统性偏移;模型一改架构,公式重推一遍。可读、可教学、不可扩展。

数值微分(finite differences):把每个参数 θi\theta_i 单独扰动 ε\varepsilon,估计 L/θi(L(θ+εei)L(θ))/ε\partial L / \partial \theta_i \approx (L(\theta + \varepsilon e_i) - L(\theta)) / \varepsiloneie_i 是第 ii 个标准基向量,只在第 ii 位为 1,因此 θ+εei\theta + \varepsilon e_i 只给第 ii 个参数加 ε\varepsilon、其余不动)。优点是与模型解耦、写两行就能用。第一个缺点是计算量——每个参数都要一次额外前向,10810^8 参数就是 10810^8 次前向,单这一条就排除了用于训练的可能

第二个缺点更隐蔽,在于步长 ε\varepsilon 怎么取都不对。先看 ε\varepsilon 偏大的一端:把 L(θ+εei)L(\theta + \varepsilon e_i) 沿 eie_i 方向做 Taylor 展开代回差分公式,

L(θ+εei)L(θ)ε=Lθi+ε22Lθi2+O(ε2),\frac{L(\theta + \varepsilon e_i) - L(\theta)}{\varepsilon} = \frac{\partial L}{\partial \theta_i} + \frac{\varepsilon}{2} \frac{\partial^2 L}{\partial \theta_i^2} + O(\varepsilon^2),

真梯度之外多出一个正比于 ε\varepsilon 的余项。这就是截断误差(truncation error)——差分本质是一阶近似,ε\varepsilon 越大、近似越偏

直觉上 ε\varepsilon 取越小、截断误差越小,但 ε\varepsilon 过小会触发浮点运算另一端的误差。浮点数精度有限(fp32 约 7 位有效十进制位、即 24 位尾数),任何一个数从写进内存起,第 8 位以后就被舍入、带上量级约 εmachine107\varepsilon_{\text{machine}} \approx 10^{-7} 的相对误差——这些低位不再对应真实值,本身就是噪声。把真值记为 a=L(θ)a = L(\theta)b=L(θ+εei)b = L(\theta + \varepsilon e_i),存进内存是 a^=a+na\hat a = a + n_ab^=b+nb\hat b = b + n_bna,nb107n_a, n_b \sim 10^{-7} 为各自的舍入噪声)。ε\varepsilon 很小时 aba \approx b、两者高位完全相同,相减把可靠的高位精确抵消:b^a^=(ba)+(nbna)\hat b - \hat a = (b - a) + (n_b - n_a)信号 baεb - a \sim \varepsilon 被抵消到只剩这么点,噪声 nbna107n_b - n_a \sim 10^{-7} 原地不动——绝对误差没变,相对误差却放大成 107/ε10^{-7} / \varepsilon

举具体数(fp32、L(θ)=1.0L(\theta) = 1.0、真梯度为 1):ε=108\varepsilon = 10^{-8}L(θ+εei)L(\theta + \varepsilon e_i) 数学上是 1.000000011.00000001,但 fp32 只存 7 位有效数字、被舍入回 1.01.0,相减得 0、除以 ε\varepsilon 仍是 0(真梯度 1 估成了 0);ε=104\varepsilon = 10^{-4} 时表面上 (1.00011.0)/104=1(1.0001 - 1.0)/10^{-4} = 1 像是算对了,但 1.00011.0001 存进 fp32 本身带着约 6×1086 \times 10^{-8}Lεmachine\approx |L|\,\varepsilon_{\text{machine}})的舍入误差,除以 ε\varepsilon 后放大成约 6×1046 \times 10^{-4} 的相对误差——中心值确实接近 1,可信的只有约 3-4 位有效数字。这就是浮点相减灾难(catastrophic cancellation)——ε\varepsilon 越小、信号越弱、噪声占比越高,相减后剩的有效位越少

把两类误差写成步长 ε\varepsilon 的函数,总误差是一增一减两项之和:

E(ε)12Lε+Lεmachineε.E(\varepsilon) \approx \frac{1}{2}\,|L''|\,\varepsilon + \frac{|L|\,\varepsilon_{\text{machine}}}{\varepsilon}.

右边第一项是截断误差(上面 Taylor 展开的余项,L|L''| 是二阶导 2L/θi2\partial^2 L / \partial \theta_i^2 的量级、衡量函数弯曲程度),随 ε\varepsilon 线性增大;第二项是舍入误差:εmachine\varepsilon_{\text{machine}} 是相对误差,乘上数值量级 L|L| 才是存一个数时的绝对噪声 Lεmachine\sim |L|\,\varepsilon_{\text{machine}},它在相减后留存、再除以 ε\varepsilon 被放大(即上一段的 107/ε10^{-7}/\varepsilon),随 ε\varepsilon 缩小越发严重。形如 aε+b/εa\varepsilon + b/\varepsilon 的和对 ε\varepsilon 求导置零,在 ε=b/a\varepsilon = \sqrt{b/a} 处取最小、最小值 2ab2\sqrt{ab};这里 b/a=2LLεmachineb/a = \tfrac{2|L|}{|L''|}\,\varepsilon_{\text{machine}}。在神经网络中,LLLL'' 通常都在 O(1)O(1) 量级,略去常数可得最优步长 εεmachine\varepsilon \approx \sqrt{\varepsilon_{\text{machine}}}、对应的最小总误差也 εmachine\sim \sqrt{\varepsilon_{\text{machine}}}——两端相抵,总误差降不到这个下限以下

代入数字:fp64(机器精度 εmachine1016\varepsilon_{\text{machine}} \approx 10^{-16})最优步长约 10810^{-8}、最小误差 108\sim 10^{-8}(约 8 位有效数字),fp32(εmachine1.19×107\varepsilon_{\text{machine}} \approx 1.19 \times 10^{-7})最优步长约 3.4×1043.4 \times 10^{-4}、只剩 3-4 位。数值微分因此只够做梯度检查(gradient checking),不能用来训练

解析自动微分(analytic AD):把网络拆成基本算子(matmul、add、σ\sigma、loss 等),每个算子手写解析的局部偏导一次,框架按计算图反向将局部偏导组合还原成 θL\nabla_\theta L。每个算子的局部导数在该算子的 OO-级代价内完成,总代价与一次前向同阶。

自动微分不是「算得更准」,而是用解析雅可比 + 反向拓扑顺序,让每次梯度计算的代价从「参数次前向」降到「一次前向」。本文剩下的篇幅就是把这套机制从数学到算法到工程拆开。

1. 多维求导基础:从标量到雅可比矩阵

1.1 标量、向量与梯度的直觉回顾

标量函数 y=f(x)y = f(x) 的导数是一个数 dy/dx\mathrm{d}y/\mathrm{d}x,几何上是曲线在 xx 点的斜率。把输入升到 nnf:RnRf: \mathbb{R}^n \to \mathbb{R},标量对每个分量都有偏导 f/xi\partial f / \partial x_i,将这 nn 个偏导按分量排成一个向量,即 梯度向量(gradient) fRn\nabla f \in \mathbb{R}^n

梯度的几何意义不变:f\nabla f 指向 ff 在当前点增长最快的方向,大小是该方向上的瞬时增长率。约定上 f\nabla f 是列向量(与 xx 同形状);这是后续讨论「分子 vs 分母布局」前的默认起点。

1.2 向量值函数与雅可比矩阵

输出也升到向量 f:RnRmf: \mathbb{R}^n \to \mathbb{R}^m 时,每个输出分量 fif_i 对每个输入分量 xjx_j 都有一个偏导 fi/xj\partial f_i / \partial x_j,共 m×nm \times n 个;把这些偏导排成一个矩阵,就是 ff 在某点的一阶信息——雅可比矩阵(Jacobian matrix)。不过「排成矩阵」有两种约定,互为转置。

  • 分子布局(numerator layout):行数 = 输出维数 mm,列数 = 输入维数 nn(f/x)ij=fi/xj(\partial f / \partial x)_{ij} = \partial f_i / \partial x_j
  • 分母布局(denominator layout):把分子布局转置。行 = 输入维数,列 = 输出维数。

两种约定数学上都对,但 链式法则在它们下面写法不同:分子布局下,链式法则是「左乘上游、右乘本层」L/x=L/yy/x\partial L / \partial x = \partial L / \partial y \cdot \partial y / \partial x(矩阵右乘);分母布局下要先转置才能凑出乘法。一旦在同一份实现里前后混用,转置 bug 会沿着梯度公式系统性地传染。本文剩下的所有推导统一用分子布局,链式法则按右乘排列

按分子布局,神经网络中常见三类层的雅可比:

  • 线性层 y=Wx+by = Wx + byRmy \in \mathbb{R}^mxRnx \in \mathbb{R}^nWRm×nW \in \mathbb{R}^{m \times n}):y/x=W\partial y / \partial x = Wm×nm \times n);对参数 WW 的偏导按元素写为 yi/Wij=xj\partial y_i / \partial W_{ij} = x_j
  • 激活层 y=σ(x)y = \sigma(x)(element-wise):y/x=diag(σ(x))\partial y / \partial x = \mathrm{diag}(\sigma'(x)),是一个对角矩阵。
  • 损失层 L=(y,ytrue)L = \ell(y, y_{\text{true}})(标量输出):L/yR1×m\partial L / \partial y \in \mathbb{R}^{1 \times m} 是行向量。

线性层的雅可比就是权重矩阵本身、element-wise 激活层的雅可比是对角阵——这两个观察让大多数层的 VJP(向量-雅可比积)退化成结构化乘法,是 § 4 让反向传播在 O(n)O(n) 而非 O(n2)O(n^2) 内通过激活层的关键

2. 矩阵微积分与高维链式法则

2.1 多元复合函数的链式法则

把两层复合摆出来:z=f(y)z = f(y)y=g(x)y = g(x),其中 xRnx \in \mathbb{R}^nyRmy \in \mathbb{R}^mzRkz \in \mathbb{R}^k。标量情形下链式法则是 dz/dx=(dz/dy)(dy/dx)\mathrm{d}z/\mathrm{d}x = (\mathrm{d}z/\mathrm{d}y)(\mathrm{d}y/\mathrm{d}x);向量情形把它原样搬上来:

zx=zyyx=JfJgRk×n.\frac{\partial z}{\partial x} = \frac{\partial z}{\partial y} \cdot \frac{\partial y}{\partial x} = J_f \cdot J_g \in \mathbb{R}^{k \times n}.

链式法则在向量化形式下就是雅可比矩阵的连乘——形状校验是 (k×m)(m×n)=(k×n)(k \times m)(m \times n) = (k \times n)。一条长度为 LL 的复合链 xy1y2yLx \to y_1 \to y_2 \to \cdots \to y_L 的端到端雅可比就是 JLJL1J1J_L \cdot J_{L-1} \cdots J_1 ;如果中间维数都是 nn,全显式相乘要做 L1L-1n×nn \times n 矩阵乘法、共 O(Ln3)O(L \cdot n^3) 乘加(朴素算法;Strassen 等可降到 O(n2.372.81)O(n^{2.37 \sim 2.81}),但深度学习实践仍用 O(n3)O(n^3) 的 Basic Linear Algebra Subprograms,即 BLAS)。真正的浪费不在内存(流式计算每步只需 O(n2)O(n^2) 暂存一个稠密 Jacobian),而在两点:一是这些都是矩阵-矩阵乘法、每层 O(n3)O(n^3);二是它算出了完整的 n×nn \times n 雅可比矩阵,可反向传播里损失是标量,我们要的只是损失对各层的梯度——一个长度 nn 的向量 vJv Jv=L/yv = \partial L / \partial y 是上游梯度,按分子布局是 1×m1 \times m 行向量);而 vJv J 并不需要先把 JJ 造出来再相乘:每个基本算子都有一条直接从上游 vv 算出 vJv J 的「反向规则」(例如逐元素激活层 JJ 是对角阵,vJv J 就是 vσ(x)v \odot \sigma'(x)O(n)O(n) 完成、不建任何矩阵),所以 JJ 这个矩阵从头到尾都不必显式存在。这正是 § 3.2「维度灾难」与 § 4 VJP 改良的直接动因——VJP 让一个梯度向量向后逐层传播,每经过一层只做一次向量 × 矩阵乘法(即矩阵-向量乘法,每层从 O(n3)O(n^3) 降到 O(n2)O(n^2)),全程不构造稠密 JJ

上一段提到的 Basic Linear Algebra Subprograms(BLAS) 是一套标准化的底层线性代数运算接口,数值计算栈里的矩阵运算最终都通过它执行。它按规模分三级——Level 1 向量-向量(O(n)O(n))、Level 2 矩阵-向量(O(n2)O(n^2))、Level 3 矩阵-矩阵(gemmO(n3)O(n^3),深度学习的算力几乎都集中在这一级)。BLAS 只是接口规范,具体实现由厂商按硬件优化:CPU 上有 OpenBLAS、Intel MKL,NVIDIA GPU 上是 cuBLAS。这些实现把朴素 O(n3)O(n^3) 算法的常数因子(缓存分块、SIMD、Tensor Core 调度)优化到极限,而不换用 Strassen——后者常数大、数值稳定性差,在实际矩阵尺寸下反而更慢

2.2 实战演练:线性层 Y=WXY = WX 的梯度对齐

我们以深度学习中最常见的线性层 Y=WXY = WX 为例。假设损失 LL 是标量,反向传播已经得到本层输出的上游梯度 L/Y\partial L / \partial Y,现在要算 L/W\partial L / \partial WL/X\partial L / \partial X

各矩阵的维度(这里的 Y=WXY = WX 就是全连接层的运算,也是 Attention 里 Q/K/V 投影的运算):

  • 参数矩阵 WRm×nW \in \mathbb{R}^{m \times n}
  • 输入矩阵 XRn×kX \in \mathbb{R}^{n \times k}kk 可以是 batch size 或序列长度)。
  • 输出矩阵 Y=WXRm×kY = WX \in \mathbb{R}^{m \times k}
  • 上游梯度 L/YRm×k\partial L / \partial Y \in \mathbb{R}^{m \times k}(与 YY 同形状)。

目标:得到 L/WRm×n\partial L / \partial W \in \mathbb{R}^{m \times n}L/XRn×k\partial L / \partial X \in \mathbb{R}^{n \times k}。在动手前先确立深度学习工程实现里的一个铁律——维度相容原则(shape-matching principle)标量 LL 对某个参数矩阵求导,其梯度的形状必须与该参数矩阵严格一致。任何推导结果都用这条规则反向校验,转置位置一目了然。

下面分别用最朴素的「标量偏导法」和最优雅的「全微分 + 迹技巧」推一遍。两种方法殊途同归,但相互印证:标量法让读者踏实地看到每个数怎么乘加,消除黑盒恐惧;全微分法展示高维张量求导的工业级写法,是后续推导 Attention、softmax 等复杂层的常用工具。

方法一:标量偏导与求和法(Index Notation)

把矩阵拆成标量元素,用初等微积分链式法则求导,再还原回矩阵形式。Y=WXY = WX 中任意一个元素:

Yij=p=1nWipXpj.Y_{ij} = \sum_{p=1}^{n} W_{ip} X_{pj}.

L/W\partial L / \partial W。考察 WW 中位置 WabW_{ab} 对最终标量 LL 的影响。在 Yij=pWipXpjY_{ij} = \sum_p W_{ip} X_{pj} 里,WabW_{ab} 只出现在 i=ai = ap=bp = b 的项中,即它只影响输出矩阵第 aa 行的元素 YajY_{aj}jj 任意):

YijWab={Xbji=a,0otherwise.\frac{\partial Y_{ij}}{\partial W_{ab}} = \begin{cases} X_{bj} & i = a, \\ 0 & \text{otherwise}. \end{cases}

按链式法则把所有路径求和:

LWab=i,jLYijYijWab=j=1kLYajXbj.\frac{\partial L}{\partial W_{ab}} = \sum_{i, j} \frac{\partial L}{\partial Y_{ij}} \frac{\partial Y_{ij}}{\partial W_{ab}} = \sum_{j=1}^{k} \frac{\partial L}{\partial Y_{aj}} X_{bj}.

最右侧是 L/Y\partial L / \partial Y 的第 aa 行与 XX 的第 bb 行的内积。要凑成矩阵乘法,需要把 XX 转置使其第 bb 行变成第 bb 列:

  LW=LYXT.  \boxed{\;\frac{\partial L}{\partial W} = \frac{\partial L}{\partial Y} \, X^T.\;}

形状校验:(m×k)(k×n)=(m×n)(m \times k)(k \times n) = (m \times n),与 WW 一致。

L/X\partial L / \partial X。同理,考察 XabX_{ab}。在 Yij=pWipXpjY_{ij} = \sum_p W_{ip} X_{pj} 里,XabX_{ab} 只在 p=ap = aj=bj = b 时出现,影响第 bb 列的元素 YibY_{ib}

YijXab={Wiaj=b,0otherwise.\frac{\partial Y_{ij}}{\partial X_{ab}} = \begin{cases} W_{ia} & j = b, \\ 0 & \text{otherwise}. \end{cases}

代入链式法则:

LXab=i,jLYijYijXab=i=1mWiaLYib.\frac{\partial L}{\partial X_{ab}} = \sum_{i, j} \frac{\partial L}{\partial Y_{ij}} \frac{\partial Y_{ij}}{\partial X_{ab}} = \sum_{i=1}^{m} W_{ia} \, \frac{\partial L}{\partial Y_{ib}}.

这是 WW 的第 aa 列与 L/Y\partial L / \partial Y 的第 bb 列的内积。把 WW 转置,凑成矩阵乘法:

  LX=WTLY.  \boxed{\;\frac{\partial L}{\partial X} = W^T \, \frac{\partial L}{\partial Y}.\;}

形状校验:(n×m)(m×k)=(n×k)(n \times m)(m \times k) = (n \times k),与 XX 一致。

方法二:全微分与迹技巧(Matrix Differential & Trace Trick)

标量法不会出错,但面对 Self-Attention 里 Y=softmax(QKT)VY = \mathrm{softmax}(QK^T)V 这种连乘公式时,满屏下标求和让人窒息。矩阵微积分提供了一套不展开下标的工具:全微分(differential)+ 迹(trace)

在矩阵微积分里,标量 LL 关于矩阵 XX 的微分 dL\mathrm{d}L 与其梯度 L/X\partial L / \partial X 之间,通过迹建立一个内积关系:

dL=Tr ⁣((LX)TdX).\mathrm{d}L = \mathrm{Tr}\!\left( \left(\frac{\partial L}{\partial X}\right)^T \mathrm{d}X \right).

加上迹的循环置换性 Tr(ABC)=Tr(CAB)=Tr(BCA)\mathrm{Tr}(ABC) = \mathrm{Tr}(CAB) = \mathrm{Tr}(BCA),就足以从 dL\mathrm{d}L 的表达式里反向剥离出梯度。

L/W\partial L / \partial W。保持 XX 不变,对 WW 求微分:dY=dWX\mathrm{d}Y = \mathrm{d}W \cdot X。代入 dL\mathrm{d}L 的标准形式:

dL=Tr ⁣((LY)TdY)=Tr ⁣((LY)TdWX)=Tr ⁣(X(LY)TdW)=Tr ⁣((LYXT)TdW).\begin{aligned} \mathrm{d}L &= \mathrm{Tr}\!\left( \left(\frac{\partial L}{\partial Y}\right)^T \mathrm{d}Y \right) = \mathrm{Tr}\!\left( \left(\frac{\partial L}{\partial Y}\right)^T \mathrm{d}W \, X \right) \\ &= \mathrm{Tr}\!\left( X \left(\frac{\partial L}{\partial Y}\right)^T \mathrm{d}W \right) = \mathrm{Tr}\!\left( \left(\frac{\partial L}{\partial Y} \, X^T\right)^T \mathrm{d}W \right). \end{aligned}

第二行第一步用了迹的循环置换 Tr(AB)=Tr(BA)\mathrm{Tr}(AB) = \mathrm{Tr}(BA)XX 移到最前;第二步用了 XAT=(AXT)TX A^T = (A X^T)^T。对比标准形式 dL=Tr((L/W)TdW)\mathrm{d}L = \mathrm{Tr}\big((\partial L / \partial W)^T \mathrm{d}W\big),剥掉外壳:

LW=LYXT.\frac{\partial L}{\partial W} = \frac{\partial L}{\partial Y} \, X^T.

L/X\partial L / \partial X。保持 WW 不变,对 XX 求微分:dY=WdX\mathrm{d}Y = W \, \mathrm{d}X。代入:

dL=Tr ⁣((LY)TWdX)=Tr ⁣((WTLY)TdX).\mathrm{d}L = \mathrm{Tr}\!\left( \left(\frac{\partial L}{\partial Y}\right)^T W \, \mathrm{d}X \right) = \mathrm{Tr}\!\left( \left(W^T \frac{\partial L}{\partial Y}\right)^T \mathrm{d}X \right).

最后一步用 (ATB)T=BTA(A^T B)^T = B^T AWTW^T 提到外面。对比标准形式:

LX=WTLY.\frac{\partial L}{\partial X} = W^T \, \frac{\partial L}{\partial Y}.

两种方法结论完全一致。对参数求梯度 = 上游梯度 ×\times 本层输入的转置;对输入求梯度 = 本层参数转置 ×\times 上游梯度——维度只能这么对齐,转置位置由形状唯一确定。这条直觉是后续推导 Attention、softmax、归一化等复杂层梯度时反复用到的通用 framework。

3. 从数学到算法:计算图与反向传播

3.1 拆解前向传播:构建有向无环图

现代深度学习框架(PyTorch / TensorFlow / JAX)的底层抽象都是 计算图(computational graph)——把一次前向计算拆成一张有向无环图(DAG),节点是所有变量(输入、参数 WWbb、中间量、输出),边携带从上游节点到下游节点的局部偏导(即一个雅可比)。注意这跟主文章 § 1.1 / § 2.1 的「网络结构图」不一样:那里 权重画在边上(神经元之间的加权连接),而 计算图里权重是节点,连接它和下游中间量的边携带的是局部偏导 z/W\partial z / \partial W,不是权重本身。

以最简单的回归损失 L=(σ(Wx+b)y)2L = (\sigma(Wx + b) - y)^2 为例,把它拆成中间变量 z=Wx+bz = Wx + ba=σ(z)a = \sigma(z)L=(ay)2L = (a - y)^2,对应的计算图如下。

graph LR
  x((x)) --> z[z]
  W((W)) --> z
  b((b)) --> z
  z --> a[a]
  a --> diff["a - y"]
  y((y)) --> diff
  diff --> L[L]

每条边对应一个 局部偏导z/W=xT\partial z / \partial W = x^T(按 § 2.2 的推导,分子布局下线性层的参数偏导就是这个形式)、a/z=diag(σ(z))\partial a / \partial z = \mathrm{diag}(\sigma'(z))L/a=2(ay)\partial L / \partial a = 2(a - y),等等。求 L/W\partial L / \partial W 只需把 LazWL \to a \to z \to W 这条路径上的所有局部偏导按链式法则乘起来:

LW=LaazzW=2(ay)σ(z)xT.\frac{\partial L}{\partial W} = \frac{\partial L}{\partial a} \cdot \frac{\partial a}{\partial z} \cdot \frac{\partial z}{\partial W} = 2(a - y) \cdot \sigma'(z) \cdot x^T.

反向传播不是新的求导规则——它是按 DAG 的反向拓扑顺序,用每条边的局部偏导执行链式法则。任何可写成 DAG 的函数都能这样求导;框架要做的只是记住每条边上的局部偏导算子。

3.2 维度灾难:为什么不能直接相乘

把 § 2.1 的开销代入深度网络:一个 L=10L = 10 层、每层宽度 n=1000n = 1000 的网络,把 10 个 1000×10001000 \times 1000 的层雅可比全显式连乘起来,约 101010^{10} 次乘加,且每个稠密层 Jacobian 占 n2=106n^2 = 10^6 个浮点数。这只是单条样本的反向;放进 batch 后,矩阵-矩阵乘法的算力开销迅速变得不可承受。

破局点在于雅可比矩阵的结构。在 § 1.2 已经看到:element-wise 激活层的雅可比 y/x=diag(σ(x))\partial y / \partial x = \mathrm{diag}(\sigma'(x)) 是对角阵,独立非零元只有 nn 个而不是 n2n^2 个;线性层的雅可比 y/x=W\partial y / \partial x = W 虽然稠密,但它本来就在内存里、不需要新存一份。element-wise 激活层的雅可比是对角阵——这一观察是反向传播能在 O(n)O(n) 而非 O(n2)O(n^2) 内传一层激活梯度的关键。如果激活是「全连接非对角」(如 softmax 的输入-输出耦合),雅可比就不再稀疏,存储与计算成本都会飙升。

§ 4 把这条结构性观察推到底——根本不去显式存或构造 JJ,而是只问「给定上游 vv,输出 vJv \cdot J」。

4. 现代自动微分的核心:向量-雅可比积

4.1 前向模式(JVP)vs 反向模式(VJP)

把 § 2.1 的链式法则连乘 JLJL1J1J_L \cdot J_{L-1} \cdots J_1 看成一串矩阵乘法,求解顺序有两种选择。

前向模式(forward mode;JVP,Jacobian-vector product,雅可比-向量积):选一个输入方向 vRnv \in \mathbb{R}^n(典型情况是 vv 取标准基向量 eie_i),从最左侧开始累乘 J1vJ2(J1v)J_1 v \to J_2 (J_1 v) \to \cdots,最终得到 y/xvRm\partial y / \partial x \cdot v \in \mathbb{R}^m——JVP 这个名字的顺序(先 JJvv)就是它算的 JvJ \cdot v。每一步是矩阵-向量乘法,代价随 输入维度 增长——要拿到整张 Jacobian,需要 nn 次 JVP(每次选一个 eie_i)。

反向模式(reverse mode;VJP,vector-Jacobian product,向量-雅可比积):取一个上游梯度行向量 vR1×mv \in \mathbb{R}^{1 \times m}(在 backprop 中典型取 v=L/yv = \partial L / \partial y),从最右侧开始累乘 vJL(vJL)JL1v J_L \to (v J_L) J_{L-1} \to \cdots,最终得到 vy/xR1×nv \cdot \partial y / \partial x \in \mathbb{R}^{1 \times n}——VJP 名字的顺序(先 vvJJ)正是它算的 vJv \cdot J。每一步也是矩阵-向量乘法,代价随 输出维度 增长——要拿到整张 Jacobian,需要 mm 次 VJP。

深度学习里损失是标量(m=1m = 1)、参数是 10810^8 量级(nn 极大),反向模式只需一次 VJP 就能拿到完整梯度;前向模式则需要 10810^8 次 JVP——这是 PyTorch、JAX 默认用反向模式的根本原因。GPT-3 量级(175 B 参数、标量损失):反向模式一次 backward 拿到全部梯度,前向模式则要跑 1.75×10111.75 \times 10^{11} 次 JVP,完全不可承受。反过来,如果要算 Jacobian-vector 二阶量(Hessian-vector product 之类),前向模式才有用武之地。

4.2 VJP 的工作流抽象

VJP 的形式定义:给定上游行向量 vR1×mv \in \mathbb{R}^{1 \times m} 与函数 y=f(x)y = f(x) 的雅可比 Jf=f/xRm×nJ_f = \partial f / \partial x \in \mathbb{R}^{m \times n},计算 vJfR1×nv \cdot J_f \in \mathbb{R}^{1 \times n}。链式法则下,从损失 LL 反向传到任意中间变量 xl1x_{l-1} 的梯度就是 VJP 的递推:

Lxl1=Lxlxlxl1=LxlvlJfl.\frac{\partial L}{\partial x_{l-1}} = \frac{\partial L}{\partial x_l} \cdot \frac{\partial x_l}{\partial x_{l-1}} = \underbrace{\frac{\partial L}{\partial x_l}}_{v_l} \cdot J_{f_l}.

关键在于这一步 不需要显式构造 JflJ_{f_l}——只要能算「给定 vv,输出 vJflv \cdot J_{f_l}」就够了。对激活层 Jfl=diag(σ(xl1))J_{f_l} = \mathrm{diag}(\sigma'(x_{l-1})),VJP 退化成 element-wise 乘法 vσ(xl1)v \odot \sigma'(x_{l-1})O(n)O(n) 完成;对线性层 Jfl=WJ_{f_l} = W,VJP 是 vWv \cdot W,复用前向已有的 WW

反向模式自动微分的算法骨架(伪代码,用 § 2.2 的列向量约定:xlRnlx_l \in \mathbb{R}^{n_l}gg 是与 xlx_l 同形状的列向量梯度):

# Forward pass: cache activations needed by backward
for l in range(1, L + 1):
    z_l = W_l @ x_prev + b_l
    x_l = sigma(z_l)
    cache(x_prev, z_l)
L_value = loss(x_L, y_true)

# Backward pass: propagate VJP
g = dL_dx_L                       # initial upstream gradient, shape (m,)
for l in range(L, 0, -1):
    g = g * sigma_prime(z_l)      # VJP through activation, O(n)
    dW_l = g[:, None] @ x_prev[None, :]   # outer product, shape (m, n)
    db_l = g
    g = W_l.T @ g                 # VJP through linear layer, O(m * n)
return {dW_l, db_l for l in 1..L}

dW_l = g[:, None] @ x_prev[None, :] 是 § 2.2 的 L/W=(L/Y)XT\partial L / \partial W = (\partial L / \partial Y) X^T 在 batch size = 1 时的标量化写法(gg 是列向量梯度、xprevx_{\text{prev}} 是列向量输入,外积得到 m×nm \times n 的参数梯度)。W_l.T @ g 对应 L/X=WT(L/Y)\partial L / \partial X = W^T (\partial L / \partial Y)

每一层反向只做两件事:用 VJP 把上游梯度过一遍当前层、把过程中拿到的 gg 与缓存的 xprevx_{\text{prev}} 拼成参数梯度 dWl\mathrm{d}W_l反向模式 AD 的总计算量与前向一致,但内存代价等于缓存所有正向 activations——这是下一节工程取舍的起点。

5. 现代框架底层的工程取舍

5.1 计算与内存的博弈

反向模式 AD 的计算账是好看的(与前向同阶),内存账是另一回事。要让 backward 能算出每层的参数梯度,forward 必须缓存每层的输入 xl1x_{l-1}(和有时缓存 zlz_l 以避免 σ\sigma' 二次计算)。把这件事写成账本:

activation memoryLBdbytes per scalar,\text{activation memory} \approx L \cdot B \cdot d \cdot \text{bytes per scalar},

其中 LL 是层数、BB 是 batch size、dd 是每层特征宽度。一个 12 层 transformer block、d=768d = 768、batch 32、序列长度 512、fp32:单条 activation 张量约 12×32×512×768×4B600MB12 \times 32 \times 512 \times 768 \times 4\text{B} \approx 600\,\mathrm{MB};transformer block 里实际有 attention 中间量、MLP 中间量、residual 等约 5-10 个独立缓存,乘起来轻松上 GB。深度网络训练的瓶颈通常不是参数本身的内存,而是反向所需的 activation 缓存——它随网络深度、batch、序列长度线性增长

Gradient checkpointing(陈天奇 2016 那篇)给的方案是 以时间换空间:forward 时只在 L\sqrt{L} 个「检查点」位置保留 activation,其余位置丢弃;backward 走到没缓存的位置时,从最近的上游检查点起局部重跑 forward 重建 activations。代价是 forward 总算力多约 33%33\%,activation 内存从 O(L)O(L) 降到 O(L)O(\sqrt{L})在大模型训练里,gradient checkpointing 是最常用的内存压缩手段;选择「哪些层做检查点」本身是个工程优化问题

5.2 动态图 vs 静态图的 Trace 模式

反向模式 AD 还需要知道「图本身」长什么样。两条主流路线在「图什么时候构出来」上做了不同选择。

动态图(eager + autograd,PyTorch 路线):前向跑用户写的普通 Python,每个 tensor op 在执行时把自己记录到一张随后被释放的临时图里;调用 loss.backward() 时框架沿这张图反向跑 VJP。优点是控制流(if、while、Python 异常)完全自由,调试器、断点、print 都能直接用;缺点是每次 forward 都要重建图,缺少跨算子全局优化的机会,编译器拿不到完整算子序列。

静态图(trace + jit,JAX / TensorFlow 2 路线):先用一遍 trace 把整个函数走成一张静态图(输入用 placeholder tensor 跑过去),再交给编译器(XLA)做全局算子融合、内存预分配、并行调度;反向沿编译后的图算。优点是算子融合、kernel 合并、推理延迟可控;缺点是 Python 控制流不能直接 trace,需要写 lax.scanlax.cond 等显式控制流原语,调试体验更接近编译器栈。

eager 把 backprop 的图延迟到运行时,trace 把它提前到编译时——选择决定了控制流自由度、编译器优化空间、推理延迟在哪边付出代价。PyTorch 2.x 通过 torch.compile 加上了「先 eager 跑、识别热路径后再 trace 编译」的混合方案;JAX 通过 jit 把整个被装饰的函数搬到静态图侧;两边都在向中间靠拢。

6. 结语

反向传播是把链式法则按 DAG 反向拓扑顺序系统化执行的工程化算法。本文从「为什么需要 AD」(§ 0)出发,依次铺垫多维求导(§ 1)、矩阵链式法则(§ 2,含 Y=WXY = WX 的两种推导)、计算图与反向传播(§ 3)、VJP 与反向模式 AD(§ 4),最后展开内存账本与图模式的工程取舍(§ 5)。圈定了边界,几条没在正文里直接强调、但用之前最好心里有数的事:

  • 布局约定必须从头到尾贯穿。分子布局(本文)让链式法则写成右乘,分母布局让链式法则写成左乘——两种约定都对,但任何混用都会在梯度公式里留下系统性转置错误。代码、推导、注释里只用一种。
  • 维度相容是隐式校验,也是推导的捷径。任何梯度推导(哪怕复杂如 Attention)的结果都能用「梯度 shape = 参数 shape」反向校验;遇到不确定要不要转置时,先写 shape 再决定形式,比硬推下标快得多。
  • VJP 替代显式 Jacobian 才是 AD 的核心。框架不会在内存里实例化完整的 JfJ_f;缓存的是 forward activations,反向时按需算 vJfv \cdot J_f。这是 PyTorch / JAX 与「手推梯度」的本质区别——也是为什么深度网络的雅可比即便维度爆炸、训练依然可行。
  • 内存账本随深度、batch、序列长度线性增长O(LBd)O(L \cdot B \cdot d) 的 activation 缓存常先于参数本身、先于算力成为深层网络训练的瓶颈;gradient checkpointing 是首选缓解手段,它用约 33%33\% 多余算力换 O(L)O(L)O(L) \to O(\sqrt{L}) 内存。
  • 反向模式不是唯一选择。当输出维度远大于输入维度(例如 Jacobian-vector / Hessian-vector 这类二阶量),前向模式 JVP 反而更省。深度学习只是「标量损失 + 巨量参数」这一极端情形让反向模式明显占优。