Every training step of a neural network requires the gradient of the loss with respect to every parameter. The underlying principle is the chain rule from elementary calculus, but a deep network easily has 10810^8 parameters: deriving the gradient by hand is both error-prone and impossible to scale, while numerical differentiation (finite differences) must perturb each parameter individually, with cost growing linearly in parameter count and no good choice of step size — too large gives approximation error, too small gives floating-point precision loss.

This post introduces backpropagation + automatic differentiation (AD) — backpropagation applies the chain rule in reverse topological order over a computational graph, and AD abstracts that recursion into an operator that does not need to materialize Jacobians explicitly. Together they bring the cost of one gradient computation down to the same order as one forward pass (O(params)O(\text{params}) rather than O(params2)O(\text{params}^2)), and form the algorithmic core of differentiable-programming frameworks like PyTorch and JAX.

This post is the deep-dive companion to § 3.4 Backpropagation: Computational Graph, Chain Rule, Jacobian of Deep Learning Foundations: From Perceptrons to Backpropagation to Training Deep Networks — the main post gives the high-level why; this post walks §§ 0-5 of how the gradient is actually computed: first the requirement and the choice of method (§ 0); then multidim calculus and the Jacobian (§ 1) and the matrix chain rule with the worked Y=WXY = WX derivation (§ 2); then computational graphs and backprop (§ 3), VJP and reverse-mode AD (§ 4), and the memory accounting and dynamic-vs-static-graph tradeoffs (§ 5); § 6 wraps up with takeaways.

0. Why We Need Automatic Differentiation

Think of a neural network as a scalar function L(θ)L(\theta) where θRP\theta \in \mathbb{R}^P stacks every parameter. Each training step needs θL\nabla_\theta L, and PP is routinely 107101110^7 \sim 10^{11}. There are three common approaches.

Symbolic / by-hand derivation: write the partials for every layer and every parameter by hand. The chain rule guarantees correctness, but in a network beyond ten layers, a single misplaced transpose propagates a systematic bias through the entire gradient; any architectural change forces a full re-derivation. Readable, teachable, not scalable.

Numerical differentiation (finite differences): perturb each θi\theta_i by ε\varepsilon and estimate L/θi(L(θ+εei)L(θ))/ε\partial L / \partial \theta_i \approx (L(\theta + \varepsilon e_i) - L(\theta)) / \varepsilon (where eie_i is the ii-th standard basis vector, 1 in position ii and 0 elsewhere, so θ+εei\theta + \varepsilon e_i adds ε\varepsilon to only the ii-th parameter and leaves the rest fixed). It is model-agnostic and trivial to implement. The first drawback is cost — each parameter needs one extra forward pass, so 10810^8 parameters means 10810^8 forwards, which alone rules it out for training.

The second drawback is subtler: no choice of step size ε\varepsilon works. Take the large-ε\varepsilon end first. Taylor-expanding L(θ+εei)L(\theta + \varepsilon e_i) along the eie_i direction and substituting back into the difference formula,

L(θ+εei)L(θ)ε=Lθi+ε22Lθi2+O(ε2),\frac{L(\theta + \varepsilon e_i) - L(\theta)}{\varepsilon} = \frac{\partial L}{\partial \theta_i} + \frac{\varepsilon}{2} \frac{\partial^2 L}{\partial \theta_i^2} + O(\varepsilon^2),

leaves the true gradient plus a remainder term proportional to ε\varepsilon. This is truncation error — the difference quotient is a first-order approximation, and the larger ε\varepsilon is, the more it deviates.

Intuitively, then, smaller ε\varepsilon means smaller truncation error — but too small an ε\varepsilon triggers an error from the other end of floating-point arithmetic. Floating-point numbers have finite precision (fp32 ≈ 7 significant decimal digits, i.e. 23-bit explicit + 1-bit implicit = 24-bit precision), so the moment any value is written to memory, everything past digit 7 is rounded off, carrying a relative error of order εmachine1.19×107\varepsilon_{\text{machine}} \approx 1.19 \times 10^{-7} — those low-order bits no longer correspond to the true value and are noise in themselves. Write the true values as a=L(θ)a = L(\theta) and b=L(θ+εei)b = L(\theta + \varepsilon e_i); stored, they become a^=a+na\hat a = a + n_a and b^=b+nb\hat b = b + n_b (with na,nb107n_a, n_b \sim 10^{-7} the respective rounding noise). When ε\varepsilon is small, aba \approx b and their high-order digits are identical, so subtraction cancels the reliable high-order part exactly: b^a^=(ba)+(nbna)\hat b - \hat a = (b - a) + (n_b - n_a). The signal baεb - a \sim \varepsilon is cancelled down to almost nothing while the noise nbna107n_b - n_a \sim 10^{-7} stays put — the absolute error is unchanged, but the relative error blows up to 107/ε10^{-7} / \varepsilon.

Concretely (fp32, L(θ)=1.0L(\theta) = 1.0, true gradient 1): at ε=108\varepsilon = 10^{-8}, L(θ+εei)L(\theta + \varepsilon e_i) is mathematically 1.000000011.00000001, but fp32 stores only 7 significant digits and rounds it back to 1.01.0, so the difference is 0 and dividing by ε\varepsilon still gives 0 (true gradient 1 estimated as 0); at ε=104\varepsilon = 10^{-4}, the literal (1.00011.0)/104=1(1.0001 - 1.0)/10^{-4} = 1 looks exact, but 1.00011.0001 stored in fp32 already carries a rounding error of about 6×1086 \times 10^{-8} (Lεmachine\approx |L|\,\varepsilon_{\text{machine}}), and dividing by ε\varepsilon inflates it into a relative error of about 6×1046 \times 10^{-4} — the center value is indeed close to 1, but only about 3-4 significant digits are trustworthy. This is catastrophic cancellation — the smaller ε\varepsilon is, the weaker the signal and the larger the noise fraction, so the fewer significant digits survive the subtraction.

Writing both error types as functions of the step ε\varepsilon, the total error is a sum of one rising and one falling term:

E(ε)12Lε+Lεmachineε.E(\varepsilon) \approx \frac{1}{2}\,|L''|\,\varepsilon + \frac{|L|\,\varepsilon_{\text{machine}}}{\varepsilon}.

The first term is truncation error (the Taylor remainder from above; L|L''| is the magnitude of the second derivative 2L/θi2\partial^2 L / \partial \theta_i^2, measuring curvature), growing linearly with ε\varepsilon; the second is roundoff error — εmachine\varepsilon_{\text{machine}} is a relative error, and multiplying by the value’s magnitude L|L| gives the absolute rounding noise Lεmachine\sim |L|\,\varepsilon_{\text{machine}} from storing one value, which survives the subtraction and is then amplified by dividing by ε\varepsilon (the 107/ε10^{-7}/\varepsilon from the previous paragraph), worsening as ε\varepsilon shrinks. A sum of the form aε+b/εa\varepsilon + b/\varepsilon, differentiated and set to zero, is minimized at ε=b/a\varepsilon = \sqrt{b/a} with value 2ab2\sqrt{ab}; here b/a=2LLεmachineb/a = \tfrac{2|L|}{|L''|}\,\varepsilon_{\text{machine}}. Since LL and LL'' are typically of order O(1)O(1) in neural networks, dropping the constant yields the optimal step εεmachine\varepsilon \approx \sqrt{\varepsilon_{\text{machine}}} and the minimal total error is also εmachine\sim \sqrt{\varepsilon_{\text{machine}}} — the two effects offset, and total error cannot be pushed below that floor.

Plugging in numbers: fp64 (machine epsilon εmachine1016\varepsilon_{\text{machine}} \approx 10^{-16}) has optimal step about 10810^{-8} and minimal error 108\sim 10^{-8} (~8 significant digits), while fp32 (εmachine1.19×107\varepsilon_{\text{machine}} \approx 1.19 \times 10^{-7}) has optimal step about 3.4×1043.4 \times 10^{-4} and keeps only 3-4 digits. Numerical differentiation is therefore good only for gradient checking, not for training.

Analytic automatic differentiation (analytic AD): split the network into primitive ops (matmul, add, σ\sigma, loss, …) and write each primitive’s analytic local partial once; the framework then stitches the local partials back into θL\nabla_\theta L by walking the computational graph in reverse. Each primitive’s local derivative runs in that primitive’s own asymptotic class, and the total cost is the same order as one forward pass.

Automatic differentiation is not about “computing more accurately” — it is about using analytic Jacobians plus reverse topological order to bring the per-step cost from ”PP forwards” down to “one forward”. The rest of this post unpacks that mechanism from math to algorithm to engineering.

1. Multidimensional Calculus: From Scalars to the Jacobian

1.1 Scalars, Vectors, and the Gradient

For a scalar function y=f(x)y = f(x) the derivative is a single number dy/dx\mathrm{d}y/\mathrm{d}x — geometrically, the slope of the curve at xx. Lift the input to nn dimensions, f:RnRf: \mathbb{R}^n \to \mathbb{R}: the scalar admits one partial f/xi\partial f / \partial x_i per coordinate, and arranging those nn partials into a vector gives the gradient vector fRn\nabla f \in \mathbb{R}^n.

The geometric picture survives: f\nabla f points in the direction of fastest increase, with magnitude equal to the instantaneous rate of increase along that direction. By convention f\nabla f is a column vector (same shape as xx) — the default starting point before § 1.2 settles the numerator-vs-denominator question.

1.2 Vector-Valued Functions and the Jacobian

When the output is also a vector, f:RnRmf: \mathbb{R}^n \to \mathbb{R}^m, each output component fif_i has a partial fi/xj\partial f_i / \partial x_j with respect to each input component xjx_jm×nm \times n of them; arranging these partials into a matrix gives the first-order information at a point, the Jacobian matrix. But “arranging them into a matrix” admits two conventions, transposes of each other.

  • Numerator layout: rows = output dimension mm, columns = input dimension nn. (f/x)ij=fi/xj(\partial f / \partial x)_{ij} = \partial f_i / \partial x_j.
  • Denominator layout: the transpose of numerator layout. Rows = input dimension, columns = output dimension.

Both conventions are mathematically valid, but the chain rule looks different under each: under numerator layout it reads “upstream on the left, current layer on the right”, L/x=L/yy/x\partial L / \partial x = \partial L / \partial y \cdot \partial y / \partial x (a right multiply); under denominator layout you must transpose first to line up the product. Mix the two anywhere in an implementation and the transpose bug propagates systematically through every gradient formula. The rest of this post uses numerator layout throughout, with the chain rule arranged as a right multiply.

Under numerator layout, the three common neural-network layers have these Jacobians:

  • Linear layer y=Wx+by = Wx + b (yRmy \in \mathbb{R}^m, xRnx \in \mathbb{R}^n, WRm×nW \in \mathbb{R}^{m \times n}): y/x=W\partial y / \partial x = W (shape m×nm \times n); element-wise w.r.t. WW, yi/Wij=xj\partial y_i / \partial W_{ij} = x_j.
  • Activation layer y=σ(x)y = \sigma(x) (element-wise): y/x=diag(σ(x))\partial y / \partial x = \mathrm{diag}(\sigma'(x)), a diagonal matrix.
  • Loss layer L=(y,ytrue)L = \ell(y, y_{\text{true}}) (scalar output): L/yR1×m\partial L / \partial y \in \mathbb{R}^{1 \times m}, a row vector.

The Jacobian of a linear layer is the weight matrix itself, and the Jacobian of an element-wise activation is diagonal — these two observations collapse most layer VJPs (vector-Jacobian products) into structured products, and they are the reason backprop can pass an activation gradient in O(n)O(n) instead of O(n2)O(n^2) in § 4.

2. Matrix Calculus and the High-Dimensional Chain Rule

2.1 The Chain Rule for Composed Multivariate Functions

Lay out a two-level composition: z=f(y)z = f(y), y=g(x)y = g(x), with xRnx \in \mathbb{R}^n, yRmy \in \mathbb{R}^m, zRkz \in \mathbb{R}^k. The scalar chain rule is dz/dx=(dz/dy)(dy/dx)\mathrm{d}z/\mathrm{d}x = (\mathrm{d}z/\mathrm{d}y)(\mathrm{d}y/\mathrm{d}x); the vector form lifts it verbatim:

zx=zyyx=JfJgRk×n.\frac{\partial z}{\partial x} = \frac{\partial z}{\partial y} \cdot \frac{\partial y}{\partial x} = J_f \cdot J_g \in \mathbb{R}^{k \times n}.

Vectorized, the chain rule is just a product of Jacobians — the shape check (k×m)(m×n)=(k×n)(k \times m)(m \times n) = (k \times n) tells you the formula is right. A composition chain of length LL, xy1y2yLx \to y_1 \to y_2 \to \cdots \to y_L, has end-to-end Jacobian JLJL1J1J_L \cdot J_{L-1} \cdots J_1; if every intermediate dimension is nn, the explicit product takes L1L-1 matrix-matrix multiplies, O(Ln3)O(L \cdot n^3) multiply-adds in total (naive algorithm; Strassen and friends drop this to O(n2.372.81)O(n^{2.37 \sim 2.81}), but deep learning still uses the O(n3)O(n^3) Basic Linear Algebra Subprograms, i.e. BLAS, in practice). The real waste is not memory (streaming keeps only O(n2)O(n^2) for one dense Jacobian at a time), but two things: these are matrix-matrix products, O(n3)O(n^3) per layer; and they compute the full n×nn \times n Jacobian, whereas in backprop the loss is scalar and all we ultimately need is the loss’s gradient with respect to each layer — a length-nn vector vJv J (v=L/yv = \partial L / \partial y is the upstream gradient, a 1×m1 \times m row vector under numerator layout); and vJv J need not be obtained by building JJ first and then multiplying — every primitive op has a backward rule that maps the upstream vv directly to vJv J (for an element-wise activation, JJ is diagonal and vJv J is just vσ(x)v \odot \sigma'(x), O(n)O(n), with no matrix built), so the matrix JJ never needs to exist explicitly. This is the direct motivation for the “dimensionality disaster” in § 3.2 and the VJP refactor in § 4 — VJP propagates a single gradient vector backward layer by layer, so each layer is a vector-times-matrix product (matrix-vector, O(n3)O(n2)O(n^3) \to O(n^2) per layer) and never forms a dense JJ.

The Basic Linear Algebra Subprograms (BLAS) mentioned above is a standardized low-level linear-algebra interface that virtually every numerical stack routes its matrix operations through. It splits into three levels by problem size — Level 1 vector-vector (O(n)O(n)), Level 2 matrix-vector (O(n2)O(n^2)), and Level 3 matrix-matrix (gemm, O(n3)O(n^3), where almost all of deep learning’s compute lives). BLAS is only an interface spec; the implementations are hardware-tuned by vendors: OpenBLAS and Intel MKL on CPU, cuBLAS on NVIDIA GPUs. These implementations push the constant factor of the naive O(n3)O(n^3) algorithm to its limit (cache blocking, SIMD, Tensor Core scheduling) rather than switching to Strassen, whose larger constant and weaker numerical stability make it slower at practical matrix sizes.

2.2 Worked Example: Aligning Gradients for the Linear Layer Y=WXY = WX

We take the most common building block in deep learning — the linear layer Y=WXY = WX — as our worked example. Assume the loss LL is scalar and backpropagation has already produced the upstream gradient L/Y\partial L / \partial Y at this layer’s output; the goal is L/W\partial L / \partial W and L/X\partial L / \partial X.

The shapes (this Y=WXY = WX is exactly what a fully-connected layer computes, and what the Q/K/V projections in attention compute):

  • Parameter matrix WRm×nW \in \mathbb{R}^{m \times n}.
  • Input matrix XRn×kX \in \mathbb{R}^{n \times k} (kk is batch size or sequence length).
  • Output matrix Y=WXRm×kY = WX \in \mathbb{R}^{m \times k}.
  • Upstream gradient L/YRm×k\partial L / \partial Y \in \mathbb{R}^{m \times k} (same shape as YY).

Target: L/WRm×n\partial L / \partial W \in \mathbb{R}^{m \times n} and L/XRn×k\partial L / \partial X \in \mathbb{R}^{n \times k}. Before doing any algebra, fix one engineering invariant — the shape-matching principle: the gradient of a scalar LL with respect to a parameter matrix must have exactly the same shape as that matrix. Every derivation result is verified against this rule; transposes become unambiguous.

We derive the gradient two ways: the naive “index notation” method and the elegant “matrix differential + trace trick” method. They reach the same answer, but each carries a different value: index notation makes every multiply-and-add visible to the reader and removes any sense of a black box; matrix differentials show the industrial-strength technique you need when later layers (attention, softmax, layer norm) produce expressions where unrolling indices is no longer practical.

Method 1: Index Notation (Scalar Partials with Summation)

Decompose each matrix into scalar entries, apply the elementary chain rule, then reassemble. For Y=WXY = WX any entry is

Yij=p=1nWipXpj.Y_{ij} = \sum_{p=1}^{n} W_{ip} X_{pj}.

Solving L/W\partial L / \partial W. Consider the influence of a single entry WabW_{ab} on the final scalar LL. In Yij=pWipXpjY_{ij} = \sum_p W_{ip} X_{pj}, WabW_{ab} appears only when i=ai = a and p=bp = b, i.e. it touches only the elements YajY_{aj} of row aa (any jj):

YijWab={Xbji=a,0otherwise.\frac{\partial Y_{ij}}{\partial W_{ab}} = \begin{cases} X_{bj} & i = a, \\ 0 & \text{otherwise}. \end{cases}

Sum over all paths via the chain rule:

LWab=i,jLYijYijWab=j=1kLYajXbj.\frac{\partial L}{\partial W_{ab}} = \sum_{i, j} \frac{\partial L}{\partial Y_{ij}} \frac{\partial Y_{ij}}{\partial W_{ab}} = \sum_{j=1}^{k} \frac{\partial L}{\partial Y_{aj}} X_{bj}.

The right-hand side is the inner product of row aa of L/Y\partial L / \partial Y with row bb of XX. Folding it into a matrix product requires transposing XX so that row bb becomes column bb:

  LW=LYXT.  \boxed{\;\frac{\partial L}{\partial W} = \frac{\partial L}{\partial Y} \, X^T.\;}

Shape check: (m×k)(k×n)=(m×n)(m \times k)(k \times n) = (m \times n), matching WW.

Solving L/X\partial L / \partial X. Same drill for XabX_{ab}. In Yij=pWipXpjY_{ij} = \sum_p W_{ip} X_{pj}, XabX_{ab} appears only when p=ap = a and j=bj = b, affecting column bb of the output, YibY_{ib}:

YijXab={Wiaj=b,0otherwise.\frac{\partial Y_{ij}}{\partial X_{ab}} = \begin{cases} W_{ia} & j = b, \\ 0 & \text{otherwise}. \end{cases}

Substitute into the chain rule:

LXab=i,jLYijYijXab=i=1mWiaLYib.\frac{\partial L}{\partial X_{ab}} = \sum_{i, j} \frac{\partial L}{\partial Y_{ij}} \frac{\partial Y_{ij}}{\partial X_{ab}} = \sum_{i=1}^{m} W_{ia} \, \frac{\partial L}{\partial Y_{ib}}.

This is the inner product of column aa of WW with column bb of L/Y\partial L / \partial Y. Transpose WW to fold it into a matrix product:

  LX=WTLY.  \boxed{\;\frac{\partial L}{\partial X} = W^T \, \frac{\partial L}{\partial Y}.\;}

Shape check: (n×m)(m×k)=(n×k)(n \times m)(m \times k) = (n \times k), matching XX.

Method 2: Matrix Differential and the Trace Trick

Index notation never lies, but for Self-Attention’s Y=softmax(QKT)VY = \mathrm{softmax}(QK^T)V the cascading subscript sums become unmanageable. Matrix calculus offers a no-index alternative: differentials and the trace.

In matrix calculus, the differential dL\mathrm{d}L of a scalar LL with respect to a matrix XX and its gradient L/X\partial L / \partial X are linked by an inner product written as a trace:

dL=Tr ⁣((LX)TdX).\mathrm{d}L = \mathrm{Tr}\!\left( \left(\frac{\partial L}{\partial X}\right)^T \mathrm{d}X \right).

Combined with the cyclic property of the trace, Tr(ABC)=Tr(CAB)=Tr(BCA)\mathrm{Tr}(ABC) = \mathrm{Tr}(CAB) = \mathrm{Tr}(BCA), this is enough to peel the gradient back out of any expression for dL\mathrm{d}L.

Solving L/W\partial L / \partial W. Hold XX fixed and differentiate WW: dY=dWX\mathrm{d}Y = \mathrm{d}W \cdot X. Substitute into the canonical form of dL\mathrm{d}L:

dL=Tr ⁣((LY)TdY)=Tr ⁣((LY)TdWX)=Tr ⁣(X(LY)TdW)=Tr ⁣((LYXT)TdW).\begin{aligned} \mathrm{d}L &= \mathrm{Tr}\!\left( \left(\frac{\partial L}{\partial Y}\right)^T \mathrm{d}Y \right) = \mathrm{Tr}\!\left( \left(\frac{\partial L}{\partial Y}\right)^T \mathrm{d}W \, X \right) \\ &= \mathrm{Tr}\!\left( X \left(\frac{\partial L}{\partial Y}\right)^T \mathrm{d}W \right) = \mathrm{Tr}\!\left( \left(\frac{\partial L}{\partial Y} \, X^T\right)^T \mathrm{d}W \right). \end{aligned}

The first step on the second line uses the cyclic property Tr(AB)=Tr(BA)\mathrm{Tr}(AB) = \mathrm{Tr}(BA) to move XX to the front; the second uses XAT=(AXT)TX A^T = (A X^T)^T. Comparing with the canonical form dL=Tr((L/W)TdW)\mathrm{d}L = \mathrm{Tr}\big((\partial L / \partial W)^T \mathrm{d}W\big), peel off the wrapping:

LW=LYXT.\frac{\partial L}{\partial W} = \frac{\partial L}{\partial Y} \, X^T.

Solving L/X\partial L / \partial X. Hold WW fixed and differentiate XX: dY=WdX\mathrm{d}Y = W \, \mathrm{d}X. Substitute:

dL=Tr ⁣((LY)TWdX)=Tr ⁣((WTLY)TdX).\mathrm{d}L = \mathrm{Tr}\!\left( \left(\frac{\partial L}{\partial Y}\right)^T W \, \mathrm{d}X \right) = \mathrm{Tr}\!\left( \left(W^T \frac{\partial L}{\partial Y}\right)^T \mathrm{d}X \right).

The last step uses (ATB)T=BTA(A^T B)^T = B^T A to move WTW^T outside. Compare with the canonical form:

LX=WTLY.\frac{\partial L}{\partial X} = W^T \, \frac{\partial L}{\partial Y}.

Both methods reach the same conclusion. Parameter gradient = upstream gradient × the transpose of the layer’s input; input gradient = the transpose of the layer’s parameter × upstream gradient — the shape constraint determines the transpose positions uniquely. This intuition is the reusable framework behind gradient derivations for attention, softmax, normalization, and the rest.

3. From Math to Algorithm: The Computational Graph and Backpropagation

3.1 Decomposing the Forward Pass: Building the DAG

The lowest-level abstraction shared by modern deep-learning frameworks (PyTorch / TensorFlow / JAX) is the computational graph — a directed acyclic graph where nodes are all the variables (inputs, parameters WW and bb, intermediates, output) and edges carry the local partial derivative (a Jacobian) from an upstream node to its downstream. Note this differs from the “network architecture diagrams” in § 1.1 / § 2.1 of the main post: those put weights on the edges (the weighted connection between neurons), while a computational graph treats each weight as a node, and the edge between that node and a downstream intermediate carries the local partial z/W\partial z / \partial W, not the weight itself.

Take the simplest regression loss L=(σ(Wx+b)y)2L = (\sigma(Wx + b) - y)^2. Split it into intermediates z=Wx+bz = Wx + b, a=σ(z)a = \sigma(z), L=(ay)2L = (a - y)^2; the corresponding computational graph is below.

graph LR
  x((x)) --> z[z]
  W((W)) --> z
  b((b)) --> z
  z --> a[a]
  a --> diff["a - y"]
  y((y)) --> diff
  diff --> L[L]

Each edge carries a local derivative: z/W=xT\partial z / \partial W = x^T (the linear layer’s parameter partial under numerator layout, derived in § 2.2), a/z=diag(σ(z))\partial a / \partial z = \mathrm{diag}(\sigma'(z)), L/a=2(ay)\partial L / \partial a = 2(a - y), and so on. To compute L/W\partial L / \partial W, multiply the local derivatives along the path LazWL \to a \to z \to W using the chain rule:

LW=LaazzW=2(ay)σ(z)xT.\frac{\partial L}{\partial W} = \frac{\partial L}{\partial a} \cdot \frac{\partial a}{\partial z} \cdot \frac{\partial z}{\partial W} = 2(a - y) \cdot \sigma'(z) \cdot x^T.

Backpropagation isn’t a new differentiation rule — it’s the chain rule executed in reverse topological order over a DAG using local derivatives. Any function expressible as a DAG admits this procedure; all the framework has to do is remember the local-derivative operator on each edge.

3.2 The Dimensionality Disaster: Why You Can’t Just Multiply Jacobians

Apply the § 2.1 cost to a deep network: a 10-layer, 1000-wide network, multiplying ten 1000×10001000 \times 1000 layer Jacobians explicitly, takes roughly 101010^{10} multiply-adds, with each dense layer Jacobian occupying n2=106n^2 = 10^6 floats. That’s one sample; with a batch, the matrix-matrix compute quickly becomes unaffordable.

The escape route lies in the structure of those Jacobians. From § 1.2: an element-wise activation has y/x=diag(σ(x))\partial y / \partial x = \mathrm{diag}(\sigma'(x)), with only nn independent nonzeros instead of n2n^2; a linear layer’s Jacobian y/x=W\partial y / \partial x = W is dense but already in memory, so it requires no extra storage. The Jacobian of an element-wise activation is diagonal — this is the observation that lets backprop pass a layer’s activation gradient in O(n)O(n) instead of O(n2)O(n^2). If the activation is “fully coupled” (such as softmax, which couples its input and output dimensions), the Jacobian is no longer sparse and both storage and compute explode.

§ 4 pushes this structural observation to its limit: never explicitly store or construct JJ; only ask the operator “given upstream vv, return vJv \cdot J“.

4. The Core of Modern Autodiff: Vector-Jacobian Products

4.1 Forward Mode (JVP) vs Reverse Mode (VJP)

View the § 2.1 product JLJL1J1J_L \cdot J_{L-1} \cdots J_1 as a string of matrix multiplications; there are two ways to evaluate it.

Forward mode (JVP, Jacobian-vector product): pick an input direction vRnv \in \mathbb{R}^n (typically vv is a standard basis eie_i) and accumulate from the left: J1vJ2(J1v)J_1 v \to J_2 (J_1 v) \to \cdots, arriving at y/xvRm\partial y / \partial x \cdot v \in \mathbb{R}^m — the name order (Jacobian then vector) is exactly what it computes, JvJ \cdot v. Each step is a matrix-vector product; cost grows with input dimension — recovering the full Jacobian requires nn JVPs (one per basis vector).

Reverse mode (VJP, vector-Jacobian product): take an upstream-gradient row vector vR1×mv \in \mathbb{R}^{1 \times m} (in backprop, v=L/yv = \partial L / \partial y) and accumulate from the right: vJL(vJL)JL1v J_L \to (v J_L) J_{L-1} \to \cdots, arriving at vy/xR1×nv \cdot \partial y / \partial x \in \mathbb{R}^{1 \times n} — the name order (vector then Jacobian) is exactly what it computes, vJv \cdot J. Each step is also a matrix-vector product; cost grows with output dimension — recovering the full Jacobian requires mm VJPs.

In deep learning the loss is scalar (m=1m = 1) while parameters reach 10810^8 (nn huge), so reverse mode recovers the full gradient in a single VJP, whereas forward mode would need 10810^8 JVPs — this is the fundamental reason PyTorch and JAX default to reverse mode. At GPT-3 scale (175 B parameters, scalar loss), reverse mode finishes the gradient in one backward sweep; forward mode would need 1.75×10111.75 \times 10^{11} JVPs, which is simply impractical. The roles flip for Jacobian-vector second-order quantities like Hessian-vector products — that is exactly where forward mode earns its keep.

4.2 The VJP Workflow Abstraction

VJP is defined as: given an upstream row vector vR1×mv \in \mathbb{R}^{1 \times m} and a function y=f(x)y = f(x) with Jacobian Jf=f/xRm×nJ_f = \partial f / \partial x \in \mathbb{R}^{m \times n}, compute vJfR1×nv \cdot J_f \in \mathbb{R}^{1 \times n}. Under the chain rule, propagating gradient from the loss LL backward to any intermediate xl1x_{l-1} is exactly a VJP recursion:

Lxl1=Lxlxlxl1=LxlvlJfl.\frac{\partial L}{\partial x_{l-1}} = \frac{\partial L}{\partial x_l} \cdot \frac{\partial x_l}{\partial x_{l-1}} = \underbrace{\frac{\partial L}{\partial x_l}}_{v_l} \cdot J_{f_l}.

The key point: this step does not require explicitly constructing JflJ_{f_l} — a routine that, given vv, returns vJflv \cdot J_{f_l} is enough. For an activation layer with Jfl=diag(σ(xl1))J_{f_l} = \mathrm{diag}(\sigma'(x_{l-1})), the VJP collapses to an element-wise product vσ(xl1)v \odot \sigma'(x_{l-1}), O(n)O(n); for a linear layer with Jfl=WJ_{f_l} = W, the VJP is vWv \cdot W and reuses the WW already in memory.

The skeleton of reverse-mode automatic differentiation (pseudocode using the § 2.2 column-vector convention: xlRnlx_l \in \mathbb{R}^{n_l} and gg is a column-vector gradient with the same shape as xlx_l):

# Forward pass: cache activations needed by backward
for l in range(1, L + 1):
    z_l = W_l @ x_prev + b_l
    x_l = sigma(z_l)
    cache(x_prev, z_l)
L_value = loss(x_L, y_true)

# Backward pass: propagate VJP
g = dL_dx_L                       # initial upstream gradient, shape (m,)
for l in range(L, 0, -1):
    g = g * sigma_prime(z_l)      # VJP through activation, O(n)
    dW_l = g[:, None] @ x_prev[None, :]   # outer product, shape (m, n)
    db_l = g
    g = W_l.T @ g                 # VJP through linear layer, O(m * n)
return {dW_l, db_l for l in 1..L}

dW_l = g[:, None] @ x_prev[None, :] is the batch-size-1 form of the § 2.2 result L/W=(L/Y)XT\partial L / \partial W = (\partial L / \partial Y) X^T (with column-vector gradient gg and column-vector input xprevx_{\text{prev}}, the outer product produces the m×nm \times n parameter gradient). W_l.T @ g is the corresponding L/X=WT(L/Y)\partial L / \partial X = W^T (\partial L / \partial Y).

Each backward layer does only two things: propagate the upstream gradient through the current layer via a VJP, and combine the propagating gg with the cached xprevx_{\text{prev}} to form the parameter gradient dWl\mathrm{d}W_l. Reverse-mode AD matches the forward pass in compute cost, but pays a memory cost equal to caching every forward activation — and that is exactly where the next section’s engineering tradeoffs start.

5. Engineering Tradeoffs in Modern Frameworks

5.1 The Compute-Memory Tradeoff

Reverse-mode AD’s compute budget is friendly (same order as forward); its memory budget is another matter. To make backward able to produce each layer’s parameter gradient, forward must cache each layer’s input xl1x_{l-1} (and sometimes zlz_l, to avoid recomputing σ\sigma'). Written as accounting:

activation memoryLBdbytes per scalar,\text{activation memory} \approx L \cdot B \cdot d \cdot \text{bytes per scalar},

where LL is the number of layers, BB is the batch size, and dd is the per-layer feature width. A 12-layer transformer block with d=768d = 768, batch 32, sequence 512, fp32: a single activation tensor is about 12×32×512×768×4B600MB12 \times 32 \times 512 \times 768 \times 4\text{B} \approx 600\,\mathrm{MB}; a real transformer block holds ~5-10 independent caches (attention intermediates, MLP intermediates, residuals), easily summing to several GB. The bottleneck for training deep networks is usually not parameter memory but the activation cache reverse-mode AD demands — it grows linearly with depth, batch, and sequence length.

Gradient checkpointing (Chen et al., 2016) offers a time-for-memory trade: keep activations only at L\sqrt{L} checkpoint positions during forward and drop the rest; during backward, when execution lands on a position with no cached activation, re-run forward locally from the nearest upstream checkpoint to rebuild it. The price is roughly 33%33\% extra forward compute; activation memory drops from O(L)O(L) to O(L)O(\sqrt{L}). At large-model scale, gradient checkpointing is the most common activation-memory mitigation; choosing which layers to checkpoint becomes its own engineering question.

5.2 Dynamic vs Static Graphs: The Trace Question

Reverse-mode AD also needs to know “what the graph looks like”. The two mainstream lines differ on when the graph is built.

Dynamic graphs (eager + autograd, the PyTorch line): forward runs ordinary user-written Python, and each tensor op records itself into a transient graph as it executes; on loss.backward() the framework walks that graph in reverse and runs VJPs. Strengths: control flow (if, while, Python exceptions) is fully free; debuggers, breakpoints, and print work as-is. Weaknesses: the graph is rebuilt every forward, so the compiler never sees a full op sequence and has no opportunity for cross-op global optimization.

Static graphs (trace + jit, the JAX / TensorFlow 2 line): trace the function once (run it with placeholder tensors) into a static graph, then hand the graph to a compiler (XLA) for global op fusion, memory pre-allocation, and parallel scheduling; backward then runs against the compiled graph. Strengths: op fusion, kernel merging, controlled inference latency. Weaknesses: native Python control flow cannot be traced and must be rewritten with explicit control-flow primitives such as lax.scan and lax.cond; debugging looks more like working in a compiler stack.

Eager defers the backprop graph to runtime; trace pulls it forward to compile time — the choice determines where you pay the cost: control-flow freedom, compiler optimization headroom, or inference latency. PyTorch 2.x with torch.compile adds a hybrid path (“run eager, then trace and compile the hot region”); JAX with jit moves the decorated function fully to the static-graph side; both ecosystems are converging from their respective extremes.

6. Wrapping Up

Backpropagation is the chain rule turned into an algorithm — applied systematically in reverse topological order over a DAG. Starting from “why do we need AD” (§ 0), this post walked through multidim calculus (§ 1), the matrix chain rule (§ 2, with two derivations of Y=WXY = WX), the computational graph and backpropagation (§ 3), VJP and reverse-mode AD (§ 4), and the memory accounting and graph-mode tradeoffs in real frameworks (§ 5). A few items that were not foregrounded in the body but are worth knowing before you reach for these tools:

  • A layout convention must hold end-to-end. Numerator layout (this post) writes the chain rule as a right multiply; denominator layout writes it as a left multiply — both are valid, but mixing them anywhere leaves a systematic transpose error in the gradient formulas. Pick one for code, derivations, and comments alike.
  • Shape-matching is both a check and a shortcut. Any gradient derivation — even one as elaborate as attention — can be verified by “gradient shape = parameter shape”; when uncertain about a transpose, writing shapes before forms is faster than grinding through indices.
  • VJP replacing explicit Jacobians is the core of AD. Frameworks never materialize the full JfJ_f in memory; what they cache is forward activations, and backward computes vJfv \cdot J_f on demand. This is the defining difference between PyTorch / JAX and hand-derived gradients — and why training remains feasible even when a network’s Jacobian dimensions would otherwise be ruinous.
  • Activation memory grows linearly with depth, batch, and sequence length. The O(LBd)O(L \cdot B \cdot d) activation cache usually hits before parameter memory or compute does in deep-network training; gradient checkpointing is the standard mitigation, trading roughly 33%33\% extra compute for an O(L)O(L)O(L) \to O(\sqrt{L}) memory reduction.
  • Reverse mode is not the only choice. When the output dimension dominates the input dimension — as it does for Jacobian-vector and Hessian-vector products — forward-mode JVP is the more economical operator. Deep learning’s “scalar loss + huge parameter count” is just the extreme corner case where reverse mode wins overwhelmingly.