训练神经网络的每一步都要算「损失对所有参数的梯度」。原理是初等微积分里就有的链式法则,但深度网络的参数动辄 108 起步,手算推导既容易出错也无法扩展;数值微分(差分法)则要为每个参数单独扰动一遍,代价随参数量线性增长,并且扰动步长两头不讨好——大有近似误差、小有浮点精度损失。
本文介绍 反向传播(backpropagation)+ 自动微分(automatic differentiation, AD)——前者把链式法则按计算图的反向拓扑顺序套用一遍,后者把这套递推抽象成一个不需要显式雅可比的算子。两者合起来让一次梯度计算的代价与一次前向同阶(O(params) 而不是 O(params2)),并构成 PyTorch、JAX 这类可微编程框架的算法核心。
这篇文章是 从感知机到反向传播再到深层训练 § 3.4 反向传播:计算图、链式法则、雅可比 的深度展开——主文章给「为什么这么做」的高层视角,本文按 § 0 ~ § 5 顺序展开「梯度具体怎么算」:先讲清楚需求与方案选择(§ 0),铺垫多维求导与雅可比矩阵(§ 1)、矩阵链式法则与 Y=WX 的实战推导(§ 2);再过渡到计算图与反向传播(§ 3)、VJP 与反向模式 AD(§ 4)、内存账本与动态/静态图的工程取舍(§ 5);§ 6 收束 takeaways。
0. 引言:为什么需要自动微分?
把神经网络看成一个标量函数 L(θ),θ∈RP 是所有参数的拼接。训练每一步都要拿到 ∇θL,量级 P 经常是 107∼1011。一般有如下三种方案。
手算解析导数:对每一层、每一个参数手工写出偏导。链式法则保证结果正确,但在 10 层以上的网络里,写错一个转置就会让整个梯度系统性偏移;模型一改架构,公式重推一遍。可读、可教学、不可扩展。
数值微分(finite differences):把每个参数 θi 单独扰动 ε,估计 ∂L/∂θi≈(L(θ+εei)−L(θ))/ε(ei 是第 i 个标准基向量,只在第 i 位为 1,因此 θ+εei 只给第 i 个参数加 ε、其余不动)。优点是与模型解耦、写两行就能用。第一个缺点是计算量——每个参数都要一次额外前向,108 参数就是 108 次前向,单这一条就排除了用于训练的可能。
第二个缺点更隐蔽,在于步长 ε 怎么取都不对。先看 ε 偏大的一端:把 L(θ+εei) 沿 ei 方向做 Taylor 展开代回差分公式,
εL(θ+εei)−L(θ)=∂θi∂L+2ε∂θi2∂2L+O(ε2),
真梯度之外多出一个正比于 ε 的余项。这就是截断误差(truncation error)——差分本质是一阶近似,ε 越大、近似越偏。
直觉上 ε 取越小、截断误差越小,但 ε 过小会触发浮点运算另一端的误差。浮点数精度有限(fp32 约 7 位有效十进制位、即 24 位尾数),任何一个数从写进内存起,第 8 位以后就被舍入、带上量级约 εmachine≈10−7 的相对误差——这些低位不再对应真实值,本身就是噪声。把真值记为 a=L(θ)、b=L(θ+εei),存进内存是 a^=a+na、b^=b+nb(na,nb∼10−7 为各自的舍入噪声)。ε 很小时 a≈b、两者高位完全相同,相减把可靠的高位精确抵消:b^−a^=(b−a)+(nb−na)。信号 b−a∼ε 被抵消到只剩这么点,噪声 nb−na∼10−7 原地不动——绝对误差没变,相对误差却放大成 10−7/ε。
举具体数(fp32、L(θ)=1.0、真梯度为 1):ε=10−8 时 L(θ+εei) 数学上是 1.00000001,但 fp32 只存 7 位有效数字、被舍入回 1.0,相减得 0、除以 ε 仍是 0(真梯度 1 估成了 0);ε=10−4 时表面上 (1.0001−1.0)/10−4=1 像是算对了,但 1.0001 存进 fp32 本身带着约 6×10−8(≈∣L∣εmachine)的舍入误差,除以 ε 后放大成约 6×10−4 的相对误差——中心值确实接近 1,可信的只有约 3-4 位有效数字。这就是浮点相减灾难(catastrophic cancellation)——ε 越小、信号越弱、噪声占比越高,相减后剩的有效位越少。
把两类误差写成步长 ε 的函数,总误差是一增一减两项之和:
E(ε)≈21∣L′′∣ε+ε∣L∣εmachine.
右边第一项是截断误差(上面 Taylor 展开的余项,∣L′′∣ 是二阶导 ∂2L/∂θi2 的量级、衡量函数弯曲程度),随 ε 线性增大;第二项是舍入误差:εmachine 是相对误差,乘上数值量级 ∣L∣ 才是存一个数时的绝对噪声 ∼∣L∣εmachine,它在相减后留存、再除以 ε 被放大(即上一段的 10−7/ε),随 ε 缩小越发严重。形如 aε+b/ε 的和对 ε 求导置零,在 ε=b/a 处取最小、最小值 2ab;这里 b/a=∣L′′∣2∣L∣εmachine。在神经网络中,L 与 L′′ 通常都在 O(1) 量级,略去常数可得最优步长 ε≈εmachine、对应的最小总误差也 ∼εmachine——两端相抵,总误差降不到这个下限以下。
代入数字:fp64(机器精度 εmachine≈10−16)最优步长约 10−8、最小误差 ∼10−8(约 8 位有效数字),fp32(εmachine≈1.19×10−7)最优步长约 3.4×10−4、只剩 3-4 位。数值微分因此只够做梯度检查(gradient checking),不能用来训练。
解析自动微分(analytic AD):把网络拆成基本算子(matmul、add、σ、loss 等),每个算子手写解析的局部偏导一次,框架按计算图反向将局部偏导组合还原成 ∇θL。每个算子的局部导数在该算子的 O-级代价内完成,总代价与一次前向同阶。
自动微分不是「算得更准」,而是用解析雅可比 + 反向拓扑顺序,让每次梯度计算的代价从「参数次前向」降到「一次前向」。本文剩下的篇幅就是把这套机制从数学到算法到工程拆开。
1. 多维求导基础:从标量到雅可比矩阵
1.1 标量、向量与梯度的直觉回顾
标量函数 y=f(x) 的导数是一个数 dy/dx,几何上是曲线在 x 点的斜率。把输入升到 n 维 f:Rn→R,标量对每个分量都有偏导 ∂f/∂xi,将这 n 个偏导按分量排成一个向量,即 梯度向量(gradient) ∇f∈Rn。
梯度的几何意义不变:∇f 指向 f 在当前点增长最快的方向,大小是该方向上的瞬时增长率。约定上 ∇f 是列向量(与 x 同形状);这是后续讨论「分子 vs 分母布局」前的默认起点。
1.2 向量值函数与雅可比矩阵
输出也升到向量 f:Rn→Rm 时,每个输出分量 fi 对每个输入分量 xj 都有一个偏导 ∂fi/∂xj,共 m×n 个;把这些偏导排成一个矩阵,就是 f 在某点的一阶信息——雅可比矩阵(Jacobian matrix)。不过「排成矩阵」有两种约定,互为转置。
- 分子布局(numerator layout):行数 = 输出维数 m,列数 = 输入维数 n。(∂f/∂x)ij=∂fi/∂xj。
- 分母布局(denominator layout):把分子布局转置。行 = 输入维数,列 = 输出维数。
两种约定数学上都对,但 链式法则在它们下面写法不同:分子布局下,链式法则是「左乘上游、右乘本层」∂L/∂x=∂L/∂y⋅∂y/∂x(矩阵右乘);分母布局下要先转置才能凑出乘法。一旦在同一份实现里前后混用,转置 bug 会沿着梯度公式系统性地传染。本文剩下的所有推导统一用分子布局,链式法则按右乘排列。
按分子布局,神经网络中常见三类层的雅可比:
- 线性层 y=Wx+b(y∈Rm,x∈Rn,W∈Rm×n):∂y/∂x=W(m×n);对参数 W 的偏导按元素写为 ∂yi/∂Wij=xj。
- 激活层 y=σ(x)(element-wise):∂y/∂x=diag(σ′(x)),是一个对角矩阵。
- 损失层 L=ℓ(y,ytrue)(标量输出):∂L/∂y∈R1×m 是行向量。
线性层的雅可比就是权重矩阵本身、element-wise 激活层的雅可比是对角阵——这两个观察让大多数层的 VJP(向量-雅可比积)退化成结构化乘法,是 § 4 让反向传播在 O(n) 而非 O(n2) 内通过激活层的关键。
2. 矩阵微积分与高维链式法则
2.1 多元复合函数的链式法则
把两层复合摆出来:z=f(y),y=g(x),其中 x∈Rn、y∈Rm、z∈Rk。标量情形下链式法则是 dz/dx=(dz/dy)(dy/dx);向量情形把它原样搬上来:
∂x∂z=∂y∂z⋅∂x∂y=Jf⋅Jg∈Rk×n.
链式法则在向量化形式下就是雅可比矩阵的连乘——形状校验是 (k×m)(m×n)=(k×n)。一条长度为 L 的复合链 x→y1→y2→⋯→yL 的端到端雅可比就是 JL⋅JL−1⋯J1 ;如果中间维数都是 n,全显式相乘要做 L−1 次 n×n 矩阵乘法、共 O(L⋅n3) 乘加(朴素算法;Strassen 等可降到 O(n2.37∼2.81),但深度学习实践仍用 O(n3) 的 Basic Linear Algebra Subprograms,即 BLAS)。真正的浪费不在内存(流式计算每步只需 O(n2) 暂存一个稠密 Jacobian),而在两点:一是这些都是矩阵-矩阵乘法、每层 O(n3);二是它算出了完整的 n×n 雅可比矩阵,可反向传播里损失是标量,我们要的只是损失对各层的梯度——一个长度 n 的向量 vJ(v=∂L/∂y 是上游梯度,按分子布局是 1×m 行向量);而 vJ 并不需要先把 J 造出来再相乘:每个基本算子都有一条直接从上游 v 算出 vJ 的「反向规则」(例如逐元素激活层 J 是对角阵,vJ 就是 v⊙σ′(x),O(n) 完成、不建任何矩阵),所以 J 这个矩阵从头到尾都不必显式存在。这正是 § 3.2「维度灾难」与 § 4 VJP 改良的直接动因——VJP 让一个梯度向量向后逐层传播,每经过一层只做一次向量 × 矩阵乘法(即矩阵-向量乘法,每层从 O(n3) 降到 O(n2)),全程不构造稠密 J。
上一段提到的 Basic Linear Algebra Subprograms(BLAS) 是一套标准化的底层线性代数运算接口,数值计算栈里的矩阵运算最终都通过它执行。它按规模分三级——Level 1 向量-向量(O(n))、Level 2 矩阵-向量(O(n2))、Level 3 矩阵-矩阵(gemm,O(n3),深度学习的算力几乎都集中在这一级)。BLAS 只是接口规范,具体实现由厂商按硬件优化:CPU 上有 OpenBLAS、Intel MKL,NVIDIA GPU 上是 cuBLAS。这些实现把朴素 O(n3) 算法的常数因子(缓存分块、SIMD、Tensor Core 调度)优化到极限,而不换用 Strassen——后者常数大、数值稳定性差,在实际矩阵尺寸下反而更慢。
2.2 实战演练:线性层 Y=WX 的梯度对齐
我们以深度学习中最常见的线性层 Y=WX 为例。假设损失 L 是标量,反向传播已经得到本层输出的上游梯度 ∂L/∂Y,现在要算 ∂L/∂W 与 ∂L/∂X。
各矩阵的维度(这里的 Y=WX 就是全连接层的运算,也是 Attention 里 Q/K/V 投影的运算):
- 参数矩阵 W∈Rm×n。
- 输入矩阵 X∈Rn×k(k 可以是 batch size 或序列长度)。
- 输出矩阵 Y=WX∈Rm×k。
- 上游梯度 ∂L/∂Y∈Rm×k(与 Y 同形状)。
目标:得到 ∂L/∂W∈Rm×n 与 ∂L/∂X∈Rn×k。在动手前先确立深度学习工程实现里的一个铁律——维度相容原则(shape-matching principle):标量 L 对某个参数矩阵求导,其梯度的形状必须与该参数矩阵严格一致。任何推导结果都用这条规则反向校验,转置位置一目了然。
下面分别用最朴素的「标量偏导法」和最优雅的「全微分 + 迹技巧」推一遍。两种方法殊途同归,但相互印证:标量法让读者踏实地看到每个数怎么乘加,消除黑盒恐惧;全微分法展示高维张量求导的工业级写法,是后续推导 Attention、softmax 等复杂层的常用工具。
方法一:标量偏导与求和法(Index Notation)
把矩阵拆成标量元素,用初等微积分链式法则求导,再还原回矩阵形式。Y=WX 中任意一个元素:
Yij=p=1∑nWipXpj.
求 ∂L/∂W。考察 W 中位置 Wab 对最终标量 L 的影响。在 Yij=∑pWipXpj 里,Wab 只出现在 i=a、p=b 的项中,即它只影响输出矩阵第 a 行的元素 Yaj(j 任意):
∂Wab∂Yij={Xbj0i=a,otherwise.
按链式法则把所有路径求和:
∂Wab∂L=i,j∑∂Yij∂L∂Wab∂Yij=j=1∑k∂Yaj∂LXbj.
最右侧是 ∂L/∂Y 的第 a 行与 X 的第 b 行的内积。要凑成矩阵乘法,需要把 X 转置使其第 b 行变成第 b 列:
∂W∂L=∂Y∂LXT.
形状校验:(m×k)(k×n)=(m×n),与 W 一致。
求 ∂L/∂X。同理,考察 Xab。在 Yij=∑pWipXpj 里,Xab 只在 p=a、j=b 时出现,影响第 b 列的元素 Yib:
∂Xab∂Yij={Wia0j=b,otherwise.
代入链式法则:
∂Xab∂L=i,j∑∂Yij∂L∂Xab∂Yij=i=1∑mWia∂Yib∂L.
这是 W 的第 a 列与 ∂L/∂Y 的第 b 列的内积。把 W 转置,凑成矩阵乘法:
∂X∂L=WT∂Y∂L.
形状校验:(n×m)(m×k)=(n×k),与 X 一致。
方法二:全微分与迹技巧(Matrix Differential & Trace Trick)
标量法不会出错,但面对 Self-Attention 里 Y=softmax(QKT)V 这种连乘公式时,满屏下标求和让人窒息。矩阵微积分提供了一套不展开下标的工具:全微分(differential)+ 迹(trace)。
在矩阵微积分里,标量 L 关于矩阵 X 的微分 dL 与其梯度 ∂L/∂X 之间,通过迹建立一个内积关系:
dL=Tr((∂X∂L)TdX).
加上迹的循环置换性 Tr(ABC)=Tr(CAB)=Tr(BCA),就足以从 dL 的表达式里反向剥离出梯度。
求 ∂L/∂W。保持 X 不变,对 W 求微分:dY=dW⋅X。代入 dL 的标准形式:
dL=Tr((∂Y∂L)TdY)=Tr((∂Y∂L)TdWX)=Tr(X(∂Y∂L)TdW)=Tr((∂Y∂LXT)TdW).
第二行第一步用了迹的循环置换 Tr(AB)=Tr(BA) 把 X 移到最前;第二步用了 XAT=(AXT)T。对比标准形式 dL=Tr((∂L/∂W)TdW),剥掉外壳:
∂W∂L=∂Y∂LXT.
求 ∂L/∂X。保持 W 不变,对 X 求微分:dY=WdX。代入:
dL=Tr((∂Y∂L)TWdX)=Tr((WT∂Y∂L)TdX).
最后一步用 (ATB)T=BTA 把 WT 提到外面。对比标准形式:
∂X∂L=WT∂Y∂L.
两种方法结论完全一致。对参数求梯度 = 上游梯度 × 本层输入的转置;对输入求梯度 = 本层参数转置 × 上游梯度——维度只能这么对齐,转置位置由形状唯一确定。这条直觉是后续推导 Attention、softmax、归一化等复杂层梯度时反复用到的通用 framework。
3. 从数学到算法:计算图与反向传播
3.1 拆解前向传播:构建有向无环图
现代深度学习框架(PyTorch / TensorFlow / JAX)的底层抽象都是 计算图(computational graph)——把一次前向计算拆成一张有向无环图(DAG),节点是所有变量(输入、参数 W 和 b、中间量、输出),边携带从上游节点到下游节点的局部偏导(即一个雅可比)。注意这跟主文章 § 1.1 / § 2.1 的「网络结构图」不一样:那里 权重画在边上(神经元之间的加权连接),而 计算图里权重是节点,连接它和下游中间量的边携带的是局部偏导 ∂z/∂W,不是权重本身。
以最简单的回归损失 L=(σ(Wx+b)−y)2 为例,把它拆成中间变量 z=Wx+b、a=σ(z)、L=(a−y)2,对应的计算图如下。
graph LR
x((x)) --> z[z]
W((W)) --> z
b((b)) --> z
z --> a[a]
a --> diff["a - y"]
y((y)) --> diff
diff --> L[L]
每条边对应一个 局部偏导:∂z/∂W=xT(按 § 2.2 的推导,分子布局下线性层的参数偏导就是这个形式)、∂a/∂z=diag(σ′(z))、∂L/∂a=2(a−y),等等。求 ∂L/∂W 只需把 L→a→z→W 这条路径上的所有局部偏导按链式法则乘起来:
∂W∂L=∂a∂L⋅∂z∂a⋅∂W∂z=2(a−y)⋅σ′(z)⋅xT.
反向传播不是新的求导规则——它是按 DAG 的反向拓扑顺序,用每条边的局部偏导执行链式法则。任何可写成 DAG 的函数都能这样求导;框架要做的只是记住每条边上的局部偏导算子。
3.2 维度灾难:为什么不能直接相乘
把 § 2.1 的开销代入深度网络:一个 L=10 层、每层宽度 n=1000 的网络,把 10 个 1000×1000 的层雅可比全显式连乘起来,约 1010 次乘加,且每个稠密层 Jacobian 占 n2=106 个浮点数。这只是单条样本的反向;放进 batch 后,矩阵-矩阵乘法的算力开销迅速变得不可承受。
破局点在于雅可比矩阵的结构。在 § 1.2 已经看到:element-wise 激活层的雅可比 ∂y/∂x=diag(σ′(x)) 是对角阵,独立非零元只有 n 个而不是 n2 个;线性层的雅可比 ∂y/∂x=W 虽然稠密,但它本来就在内存里、不需要新存一份。element-wise 激活层的雅可比是对角阵——这一观察是反向传播能在 O(n) 而非 O(n2) 内传一层激活梯度的关键。如果激活是「全连接非对角」(如 softmax 的输入-输出耦合),雅可比就不再稀疏,存储与计算成本都会飙升。
§ 4 把这条结构性观察推到底——根本不去显式存或构造 J,而是只问「给定上游 v,输出 v⋅J」。
4. 现代自动微分的核心:向量-雅可比积
4.1 前向模式(JVP)vs 反向模式(VJP)
把 § 2.1 的链式法则连乘 JL⋅JL−1⋯J1 看成一串矩阵乘法,求解顺序有两种选择。
前向模式(forward mode;JVP,Jacobian-vector product,雅可比-向量积):选一个输入方向 v∈Rn(典型情况是 v 取标准基向量 ei),从最左侧开始累乘 J1v→J2(J1v)→⋯,最终得到 ∂y/∂x⋅v∈Rm——JVP 这个名字的顺序(先 J 后 v)就是它算的 J⋅v。每一步是矩阵-向量乘法,代价随 输入维度 增长——要拿到整张 Jacobian,需要 n 次 JVP(每次选一个 ei)。
反向模式(reverse mode;VJP,vector-Jacobian product,向量-雅可比积):取一个上游梯度行向量 v∈R1×m(在 backprop 中典型取 v=∂L/∂y),从最右侧开始累乘 vJL→(vJL)JL−1→⋯,最终得到 v⋅∂y/∂x∈R1×n——VJP 名字的顺序(先 v 后 J)正是它算的 v⋅J。每一步也是矩阵-向量乘法,代价随 输出维度 增长——要拿到整张 Jacobian,需要 m 次 VJP。
深度学习里损失是标量(m=1)、参数是 108 量级(n 极大),反向模式只需一次 VJP 就能拿到完整梯度;前向模式则需要 108 次 JVP——这是 PyTorch、JAX 默认用反向模式的根本原因。GPT-3 量级(175 B 参数、标量损失):反向模式一次 backward 拿到全部梯度,前向模式则要跑 1.75×1011 次 JVP,完全不可承受。反过来,如果要算 Jacobian-vector 二阶量(Hessian-vector product 之类),前向模式才有用武之地。
4.2 VJP 的工作流抽象
VJP 的形式定义:给定上游行向量 v∈R1×m 与函数 y=f(x) 的雅可比 Jf=∂f/∂x∈Rm×n,计算 v⋅Jf∈R1×n。链式法则下,从损失 L 反向传到任意中间变量 xl−1 的梯度就是 VJP 的递推:
∂xl−1∂L=∂xl∂L⋅∂xl−1∂xl=vl∂xl∂L⋅Jfl.
关键在于这一步 不需要显式构造 Jfl——只要能算「给定 v,输出 v⋅Jfl」就够了。对激活层 Jfl=diag(σ′(xl−1)),VJP 退化成 element-wise 乘法 v⊙σ′(xl−1),O(n) 完成;对线性层 Jfl=W,VJP 是 v⋅W,复用前向已有的 W。
反向模式自动微分的算法骨架(伪代码,用 § 2.2 的列向量约定:xl∈Rnl、g 是与 xl 同形状的列向量梯度):
# Forward pass: cache activations needed by backward
for l in range(1, L + 1):
z_l = W_l @ x_prev + b_l
x_l = sigma(z_l)
cache(x_prev, z_l)
L_value = loss(x_L, y_true)
# Backward pass: propagate VJP
g = dL_dx_L # initial upstream gradient, shape (m,)
for l in range(L, 0, -1):
g = g * sigma_prime(z_l) # VJP through activation, O(n)
dW_l = g[:, None] @ x_prev[None, :] # outer product, shape (m, n)
db_l = g
g = W_l.T @ g # VJP through linear layer, O(m * n)
return {dW_l, db_l for l in 1..L}
dW_l = g[:, None] @ x_prev[None, :] 是 § 2.2 的 ∂L/∂W=(∂L/∂Y)XT 在 batch size = 1 时的标量化写法(g 是列向量梯度、xprev 是列向量输入,外积得到 m×n 的参数梯度)。W_l.T @ g 对应 ∂L/∂X=WT(∂L/∂Y)。
每一层反向只做两件事:用 VJP 把上游梯度过一遍当前层、把过程中拿到的 g 与缓存的 xprev 拼成参数梯度 dWl。反向模式 AD 的总计算量与前向一致,但内存代价等于缓存所有正向 activations——这是下一节工程取舍的起点。
5. 现代框架底层的工程取舍
5.1 计算与内存的博弈
反向模式 AD 的计算账是好看的(与前向同阶),内存账是另一回事。要让 backward 能算出每层的参数梯度,forward 必须缓存每层的输入 xl−1(和有时缓存 zl 以避免 σ′ 二次计算)。把这件事写成账本:
activation memory≈L⋅B⋅d⋅bytes per scalar,
其中 L 是层数、B 是 batch size、d 是每层特征宽度。一个 12 层 transformer block、d=768、batch 32、序列长度 512、fp32:单条 activation 张量约 12×32×512×768×4B≈600MB;transformer block 里实际有 attention 中间量、MLP 中间量、residual 等约 5-10 个独立缓存,乘起来轻松上 GB。深度网络训练的瓶颈通常不是参数本身的内存,而是反向所需的 activation 缓存——它随网络深度、batch、序列长度线性增长。
Gradient checkpointing(陈天奇 2016 那篇)给的方案是 以时间换空间:forward 时只在 L 个「检查点」位置保留 activation,其余位置丢弃;backward 走到没缓存的位置时,从最近的上游检查点起局部重跑 forward 重建 activations。代价是 forward 总算力多约 33%,activation 内存从 O(L) 降到 O(L)。在大模型训练里,gradient checkpointing 是最常用的内存压缩手段;选择「哪些层做检查点」本身是个工程优化问题。
5.2 动态图 vs 静态图的 Trace 模式
反向模式 AD 还需要知道「图本身」长什么样。两条主流路线在「图什么时候构出来」上做了不同选择。
动态图(eager + autograd,PyTorch 路线):前向跑用户写的普通 Python,每个 tensor op 在执行时把自己记录到一张随后被释放的临时图里;调用 loss.backward() 时框架沿这张图反向跑 VJP。优点是控制流(if、while、Python 异常)完全自由,调试器、断点、print 都能直接用;缺点是每次 forward 都要重建图,缺少跨算子全局优化的机会,编译器拿不到完整算子序列。
静态图(trace + jit,JAX / TensorFlow 2 路线):先用一遍 trace 把整个函数走成一张静态图(输入用 placeholder tensor 跑过去),再交给编译器(XLA)做全局算子融合、内存预分配、并行调度;反向沿编译后的图算。优点是算子融合、kernel 合并、推理延迟可控;缺点是 Python 控制流不能直接 trace,需要写 lax.scan、lax.cond 等显式控制流原语,调试体验更接近编译器栈。
eager 把 backprop 的图延迟到运行时,trace 把它提前到编译时——选择决定了控制流自由度、编译器优化空间、推理延迟在哪边付出代价。PyTorch 2.x 通过 torch.compile 加上了「先 eager 跑、识别热路径后再 trace 编译」的混合方案;JAX 通过 jit 把整个被装饰的函数搬到静态图侧;两边都在向中间靠拢。
6. 结语
反向传播是把链式法则按 DAG 反向拓扑顺序系统化执行的工程化算法。本文从「为什么需要 AD」(§ 0)出发,依次铺垫多维求导(§ 1)、矩阵链式法则(§ 2,含 Y=WX 的两种推导)、计算图与反向传播(§ 3)、VJP 与反向模式 AD(§ 4),最后展开内存账本与图模式的工程取舍(§ 5)。圈定了边界,几条没在正文里直接强调、但用之前最好心里有数的事:
- 布局约定必须从头到尾贯穿。分子布局(本文)让链式法则写成右乘,分母布局让链式法则写成左乘——两种约定都对,但任何混用都会在梯度公式里留下系统性转置错误。代码、推导、注释里只用一种。
- 维度相容是隐式校验,也是推导的捷径。任何梯度推导(哪怕复杂如 Attention)的结果都能用「梯度 shape = 参数 shape」反向校验;遇到不确定要不要转置时,先写 shape 再决定形式,比硬推下标快得多。
- VJP 替代显式 Jacobian 才是 AD 的核心。框架不会在内存里实例化完整的 Jf;缓存的是 forward activations,反向时按需算 v⋅Jf。这是 PyTorch / JAX 与「手推梯度」的本质区别——也是为什么深度网络的雅可比即便维度爆炸、训练依然可行。
- 内存账本随深度、batch、序列长度线性增长。O(L⋅B⋅d) 的 activation 缓存常先于参数本身、先于算力成为深层网络训练的瓶颈;gradient checkpointing 是首选缓解手段,它用约 33% 多余算力换 O(L)→O(L) 内存。
- 反向模式不是唯一选择。当输出维度远大于输入维度(例如 Jacobian-vector / Hessian-vector 这类二阶量),前向模式 JVP 反而更省。深度学习只是「标量损失 + 巨量参数」这一极端情形让反向模式明显占优。