Attention 的核心公式只有一行,但这行公式同时压缩了序列建模、矩阵运算、概率归一化和 GPU 执行模型。直接从 QKTQK^Tdk\sqrt{d_k}、Softmax 开始推导,很容易把它看成一组需要背诵的符号,而不是一组被问题逼出来的设计。Attention 的前置知识不是概念清单,而是一条从「序列如何表示」到「关系如何并行计算」的约束链

本文介绍 Attention 正文之前需要补齐的背景:传统序列模型为什么受限,文本如何变成张量,内积为什么能表示相关,Softmax 如何把打分变成可微选择,深层网络为什么需要 Embedding、Residual Connection 与 LayerNorm,最后再说明 GPU 的内存层级如何影响 Attention 的实现。这篇文章不推导完整 Attention,而是给下一篇的 Scaled Dot-Product Attention 公式建立数学和工程语境

本文写在 深度学习基础:从感知机到深层训练反向传播与自动微分:从雅可比到计算图到 VJP 之后。前两篇分别回答了「神经网络为什么能表达复杂函数」和「梯度如何高效计算」,本文把视角换到序列数据:输入长度不固定、元素之间存在顺序关系、任意两个位置之间都可能存在依赖。当输入从定长向量变成序列,模型的核心问题就从函数拟合扩展为信息如何保存、传递和并行读取

1. 序列建模的固定维度困境

MLP 最自然的输入是固定长度向量,例如一条样本的数值特征、一个展平后的图像块,或者一个已经聚合好的统计指标。文本、代码、日志和用户行为序列不是这种形态:它们长度不同,顺序有意义,后一个元素经常依赖前面很远的位置。固定维度网络遇到序列数据时,第一个问题不是模型够不够深,而是输入如何进入同一套张量接口

最直接的处理方式是截断、补齐或聚合。截断把超长输入裁掉,代价是丢失尾部信息;padding 把短序列补到同一长度,代价是引入大量无效计算;平均池化把所有 token 合成一个向量,代价是顺序和局部结构被抹平。这些方案都把「变长序列」改造成「定长向量」,但同时也把序列里最有价值的位置信息和依赖结构弱化了

循环神经网络(RNN)给出了一个更自然的抽象:模型按时间步读取输入,每一步维护一个隐藏状态 hth_t。最简形式可以写成:

ht=f(ht1,xt)h_t = f(h_{t-1}, x_t)

这里 xtx_t 是第 tt 个输入,ht1h_{t-1} 是前一时刻留下的历史摘要,hth_t 是读入当前 token 后的新摘要。一个 vanilla RNN 通常把 ff 写成:

ht=ϕ(Wxxt+Whht1+bh)h_t = \phi(W_x x_t + W_h h_{t-1} + b_h)

其中 WxW_x 作用在当前输入上,WhW_h 作用在上一隐藏状态上,bhb_h 是偏置,ϕ\phi 可以是 tanh\tanh 或 ReLU 这类非线性函数。RNN 的关键设计是把任意长度的历史压缩进一个不断更新的状态向量,并让同一组参数 Wx,Wh,bhW_x, W_h, b_h 在所有时间步复用

设输入维度为 DxD_x、隐藏状态维度为 DhD_h,常见列向量记法下各项维度是:

符号维度含义
xtx_tDx×1D_x \times 1tt 个输入向量
ht1h_{t-1}Dh×1D_h \times 1上一时间步隐藏状态
WxW_xDh×DxD_h \times D_x把输入映射到隐藏状态空间
WhW_hDh×DhD_h \times D_h把上一隐藏状态映射到新的隐藏状态空间
bhb_hDh×1D_h \times 1隐藏状态偏置
hth_tDh×1D_h \times 1当前隐藏状态

因此 WxxtW_x x_tWhht1W_h h_{t-1} 都是 Dh×1D_h \times 1,可以相加后再过非线性函数。DxD_x 由输入表示决定,DhD_h 是模型选择的隐藏状态容量;RNN 把每个时间步的输入都写入同一个 DhD_h 维状态空间

如果任务需要每个时间步都有输出,可以再从隐藏状态读出:

yt=g(Wyht+by)y_t = g(W_y h_t + b_y)

这里的 Wy,byW_y, b_y 也通常在时间步之间共享;如果只做整段序列分类,也可以只读取最后的 hTh_T时间展开图里的多个 ff 不是多层各自独立的权重,而是同一个 RNN 单元在不同时间步上的重复调用

把同一个递推式沿时间展开,就能直接看到这条依赖链。输出 yty_t 可以在每个时间步读出,也可以只在末尾读出,具体取决于任务;这里讨论的瓶颈主要在隐藏状态路径上。每个时间步复用同一个转移函数 ff,隐藏状态则沿序列方向把信息继续传给下一步

Compact and unrolled recurrent neural network A compact recurrent neural network cell with a hidden-state feedback loop on the left, and the same cell unfolded over time on the right. ht-1 ht yt f xt h0 h1 h2 h3 hT y1 y2 y3 yT f f f f x1 x2 x3 xT ...
RNN 单元及其沿时间展开后的形式:同一个转移函数会在每个时间步复用。

这个设计解决了变长输入接口,却引入了新的表示瓶颈。一个长度为 NN 的序列会被逐步折叠进 h1,h2,,hNh_1, h_2, \ldots, h_N,越早出现的信息需要经过越多次状态更新才能影响最后输出;如果中间每一步都要把新信息写入同一个容量有限的状态向量,早期信息就容易被覆盖或稀释。RNN 的隐藏状态既承担记忆职责,又承担当前表示职责,这种双重角色让长程依赖天然困难

梯度传播也面对同一条长链。损失从后往前传回第 tt 个位置时,需要连续穿过 hNhN1hth_N \to h_{N-1} \to \cdots \to h_t 的递推路径;每一步都会乘上局部雅可比,反复相乘后容易出现梯度消失或梯度爆炸。长程依赖在前向里表现为信息被压缩,在反向里表现为梯度要穿过过长的乘法链

工程上还有一个更直接的限制:hth_t 依赖 ht1h_{t-1},因此时间步之间存在严格数据依赖。即使一个 batch 里有很多样本,单条序列内部也很难把所有位置一次性展开成同一轮矩阵乘法。RNN 的递推结构让序列维度接近串行队列,GPU 的大规模并行能力无法被充分使用

Attention 想解决的正是这组问题:每个位置不再只依赖上一个隐藏状态,而是可以直接读取其他位置的信息;这种读取不是 query == key 这类硬匹配规则,而是先给所有位置算相关性分数,再用 Softmax 变成权重,最后按权重把其他位置的信息加权求和。从 RNN 到 Attention 的转变,本质上是从「逐步传递历史」改为「一次性构造位置之间的关系矩阵」

2. 从 token 到张量的表示路径

神经网络不能直接处理字符串,它处理的是数值张量。文本进入模型前通常先经过 tokenization,被切成 token;每个 token 再按词表映射成一个整数 token id,用这个 id 到 Embedding 表中取出对应的连续向量。Embedding 的作用是把离散符号接入连续空间,使后续层可以用矩阵乘法和梯度下降处理语言单位

可以把 Embedding 看成一个可训练的查表矩阵。设词表大小为 VV,向量维度为 DD,则 Embedding 矩阵 ERV×DE \in \mathbb{R}^{V \times D};token id ii 对应矩阵第 iiEiE_i。训练过程中,经常出现在类似上下文里的 token 会在参数更新中获得相近表示。Embedding 不是词典释义,而是一组会随任务目标一起学习的连续坐标

从实现角度看,这一步是按行 gather;从线性代数角度看,它等价于用 one-hot 向量乘 EE。如果 oiRVo_i \in \mathbb{R}^{V} 只在 token id ii 的位置为 1,那么 oiTE=Eio_i^T E = E_i。框架实际执行时不会物化 one-hot 向量,而是直接按 token id 取出 EE 的对应行;但 EE 仍然是模型参数,训练时会通过反向传播更新被查到的那些行。Embedding 表是一个普通的可训练权重,只是前向路径用索引代替了稠密矩阵乘法

进入 Attention 之前,我们通常把输入写成三维张量:

XRB×N×DX \in \mathbb{R}^{B \times N \times D}

其中 BB 是 batch size,NN 是 sequence length,DD 是 hidden size 或 embedding dimension。为了推导清晰,后续经常先忽略 batch 维度,把单条序列写成:

XRN×DX \in \mathbb{R}^{N \times D}

这个形状约定很重要:NN 表示有多少个位置,DD 表示每个位置用多少个特征描述

矩阵乘法在这里可以从两个层面理解。对单个 token 向量 xiRDx_i \in \mathbb{R}^{D} 来说,xiWx_i W 是一次线性映射,把它从原来的 DD 维特征空间映射到新的 DD' 维特征空间。对整条序列 XRN×DX \in \mathbb{R}^{N \times D} 来说,XWXW 则是把同一个线性映射同时应用到 NN 个 token 上。所以矩阵乘法在神经网络里既是批量计算,也是特征空间变换

一个线性层可以写成:

Y=XW,XRN×D,WRD×DY = XW,\qquad X \in \mathbb{R}^{N \times D},\quad W \in \mathbb{R}^{D \times D'}

输出 YRN×DY \in \mathbb{R}^{N \times D'} 保留了 NN 个位置,但每个位置的表示从 DD 维变成 DD' 维。线性层不会改变序列长度,它改变的是每个 token 所处的特征空间

这个观察会直接进入 Attention 正文。后续的 Query、Key、Value 都来自对同一个输入 XX 做不同线性投影;具体角色在下一篇展开。这里先抓住一点:线性投影不会改变 token 个数,只会改变每个 token 的表示空间

3. 内积相似度与可微寻址

两个向量的内积可以写成:

xixj=k=1Dxikxjkx_i \cdot x_j = \sum_{k=1}^{D} x_{ik}x_{jk}

它对应的几何恒等式是:

xixj=xixjcosθx_i \cdot x_j = \lVert x_i\rVert\,\lVert x_j\rVert\cos\theta

其中 θ\theta 是两个向量的夹角。如果两个向量都经过归一化,内积就等价于余弦相似度;在范数固定时,方向越一致,内积越大,正交时为 0,方向相反时为负。在连续表示空间里,内积提供了一种基于方向对齐程度的、简单、可微、硬件友好的相关性打分方式

这件事和传统 Key-Value 查询有明显差别。普通哈希表查询是硬匹配:key 相等就命中,不相等就失败;神经网络需要的是软匹配:某个位置和另一个位置不是绝对相关或绝对无关,而是有一个连续分数。Attention 的寻址方式不是 query == key,而是用内积给所有候选位置同时打分

对单条序列 XRN×DX \in \mathbb{R}^{N \times D} 来说,如果直接计算 XXTXX^T,结果是一个 N×NN \times N 的矩阵:

S=XXT,Sij=xixjS = XX^T,\qquad S_{ij} = x_i \cdot x_j

矩阵中第 ii 行表示第 ii 个 token 与所有 token 的相似度。N×NN \times N 关系矩阵把序列建模从「沿时间传递状态」改成了「显式计算任意两个位置之间的关系」

为什么不用欧氏距离或其他打分函数?从表达能力看,很多相似度都能工作;从工程实现看,dot product 可以批量写成 GEMM(General Matrix Multiply,GPU、BLAS、cuBLAS、深度学习框架都会把 GEMM 优化得非常彻底。只要一个操作能写成大规模矩阵乘法,通常就能高效利用 GPU 的并行计算能力)。Attention 选择内积,不只是数学上合理,也因为它能把全局关系计算压到高度优化的线性代数内核上

直接用 XXTXX^T 可以说明一件事:序列里的任意两个位置都能通过一次矩阵乘法得到相关性分数。但这还不是完整 Attention。真正的 Attention 会先把输入投影到用于打分和用于输出的不同表示空间;这些角色会在下一篇介绍 Q、K、V 时展开。这里需要先抓住的只是:内积可以把所有位置之间的相关性一次性写成矩阵乘法

下一篇再展开投影矩阵和 Multi-Head Attention 的细节。本文在这里先停在前置层面:只要能把 token 表示组织成矩阵,内积打分就能一次性覆盖所有位置对

4. Softmax 与可微选择

内积只能得到一组实数分数,还不能直接说明每个位置应该读取多少信息。Softmax 把一组实数映射成非负且和为 1 的权重:

softmax(zi)=ezijezjsoftmax(z_i) = \frac{e^{z_i}}{\sum_j e^{z_j}}

分数大的位置得到更高权重,分数小的位置仍保留非零权重。Softmax 把硬选择改成了概率分布,使多个 token 可以按比例共同参与当前表示的构造

如下图所示,Softmax 在两种输入尺度下表现不同:分数差距适中时,多个位置都会保留权重;分数差距很大时,输出会接近 one-hot。Softmax 的输出永远归一化为权重分布,但分数尺度决定这个分布是分散还是尖锐

Softmax turns scores into normalized weights Two examples of Softmax: moderate scores become a distributed probability vector, while large score gaps become a nearly one-hot probability vector. scores z softmax(z) weights s exp + normalize sum = 1 exp + normalize sum = 1 moderate scores large score gap 1.2 2.0 0.6 0.27 0.60 0.13 1 6 0 0.01 0.99 0.00
Softmax 把原始分数转换成和为 1 的权重;分数差距越大,权重分布越尖锐。

在这篇前置文章里,先把 Softmax 理解成一个把分数变成权重的函数:输入是一行分数,输出是一行权重;权重越大,后续加权求和时对应项的贡献越大。本节先保留 Softmax 的两个关键点:归一化成权重、分数尺度会影响分布形状

实际实现不会直接计算 ezie^{z_i},因为指数函数很容易溢出。常见写法会先减去这一行的最大值:

softmax(zi)=ezimax(z)jezjmax(z)softmax(z_i) = \frac{e^{z_i - \max(z)}}{\sum_j e^{z_j - \max(z)}}

这个变换不改变 Softmax 结果,因为分子分母同时乘上了同一个常数因子 emax(z)e^{-\max(z)}。这种写法通常叫 stable softmax(数值稳定版 Softmax):它不是改变数学定义,而是让指数运算落在浮点数能稳定表示的范围内。stable softmax 的“stable”指计算过程更稳定,不是输出含义发生变化

Softmax 还有另一个和训练相关的问题:如果一行分数的差距过大,最大分数会拿走接近全部权重,输出接近 one-hot,梯度也会变小。这个原因可以通过求导看出来:

si=eziZ,Z=kezksizj=zj(eziZ)=δijeziZeziezjZ2=eziZ(δijezjZ)=si(δijsj)\begin{aligned} s_i &= \frac{e^{z_i}}{Z}, \qquad Z = \sum_k e^{z_k} \\ \frac{\partial s_i}{\partial z_j} &= \frac{\partial}{\partial z_j}\left(\frac{e^{z_i}}{Z}\right) \\ &= \frac{\delta_{ij} e^{z_i} Z - e^{z_i} e^{z_j}}{Z^2} \\ &= \frac{e^{z_i}}{Z}\left(\delta_{ij} - \frac{e^{z_j}}{Z}\right) \\ &= s_i(\delta_{ij} - s_j) \end{aligned}

其中 s=softmax(z)s = softmax(z)δij\delta_{ij} 是 Kronecker delta:当 i=ji=j 时取 1,否则取 0。关键在导数项本身:对角项是 si(1si)s_i(1 - s_i),非对角项是 sisj-s_i s_j。如果输出已经接近 one-hot,比如某一项 si=0.99s_i=0.99,那么对角项只有 0.99×0.01=0.00990.99 \times 0.01 = 0.0099;其他项接近 0 时,非对角项 sisjs_i s_j 也接近 0。如果分数太接近,Softmax 输出的权重会缺少明确偏向;如果分数差距过大,分布又会进入饱和区并削弱梯度

这正是 Scaled Dot-Product Attention 中 dk\sqrt{d_k} 的前置动机。设两个向量 q,kRdq,k \in \mathbb{R}^{d} 的各维独立、均值为 0、方差为 1,则内积为:

qk=r=1dqrkrq \cdot k = \sum_{r=1}^{d} q_r k_r

在独立同分布的简化假设下,每一项 qrkrq_r k_r 的均值为 0、方差为 1,dd 项相加后的方差约为 dd维度越高,未经缩放的内积分数方差越大,Softmax 越容易被推入饱和区

因此,后续公式中的缩放项不是装饰:

QKTdk\frac{QK^T}{\sqrt{d_k}}

在上述简化假设下,除以 dk\sqrt{d_k} 后,内积分数的方差会从约 dkd_k 拉回到约 1 的量级。真实模型里的 QQKK 不会严格满足独立同分布、均值 0、方差 1,但这个推导说明了缩放项要控制的量是什么。dk\sqrt{d_k} 缩放的核心作用是控制打分尺度,避免分数差距随着维度增大而过大,让 Softmax 在更稳定的数值区间里工作

5. 深层网络的三个稳定组件

Attention 本身只描述「如何从其他位置读取信息」,但 Transformer block 不是单独一个 Attention 算子。真正可训练的深层网络还需要输入表示、残差连接提供的直接传递路径和归一化机制配合。Embedding、Residual Connection 和 LayerNorm 是 Attention 能堆叠成深层模型的基础组件

5.1 Embedding:初始 token 表示

Embedding 前面已经出现过,它负责把 token id 映射成连续向量。这里再强调一次:Embedding 是模型参数,会被训练目标更新;它通常和 tokenizer 以及词表绑定,因为 token id 本质上是在词表中的行号。每个大模型通常都有自己的 Embedding 矩阵,这个矩阵在预训练中从初始化状态逐步学出来;即使两个模型结构相似,只要 tokenizer、训练数据或训练过程不同,最后得到的 Embedding 也通常不同。同一个 token 在不同模型里不一定有相同坐标,同一组坐标也不必对应人类能直接命名的语义维度。Embedding 提供的是可学习表示空间,而不是人工定义的语言知识表

Text becomes token ids and embedding vectors A left-to-right pipeline: raw text is split by a tokenizer into tokens, tokens are mapped to token ids, and token ids look up rows in the embedding table to produce vectors. raw text tokens token ids Embedding table E vectors "I like AI" string "I" "like" "AI" 120 5321 42 row indices select rows by id E[120] = [ ... ] E[5321] = [ ... ] E[42] = [ ... ] continuous vectors
文本先被 tokenizer 切成 token,再映射成 token id;token id 是 Embedding 表的行号,用来取出对应的连续向量。

5.2 Residual Connection:学习修正量

Residual Connection 的形式很短:

y=x+F(x)y = x + F(x)
Residual connection with an add-back path Input x flows through a sublayer F and also bypasses it through a direct path. The two paths are added to produce y equals x plus F of x. x F(x) sublayer + y x + F(x) direct path F path
Residual Connection 保留从输入到输出的直接相加路径;子层只需要学习修正量 F(x)。

数学上,F(x)F(x) 是子层学到的变换,xx 是原输入本身;Residual Connection 不只输出 F(x)F(x),而是输出 x+F(x)x + F(x)。如果某一层暂时学不到有用修正,可以让 F(x)F(x) 接近 0,这一层就近似变成 yxy \approx x,不会强行破坏已经有用的表示。Residual Connection 把每一层要学习的目标从「重新生成完整表示」改成了「学习相对输入的修正量」

直觉上,它像是在已有草稿上批注修改,而不是每一层都把草稿撕掉重写。没有残差连接时,信息必须一层层被重新加工,走到很深之后容易变形或丢失;有了 xx 这条直接相加路径,原表示至少有机会原样进入下一步,子层只负责补充、删除或调整一部分内容。这就是为什么残差连接能让深层模型更容易训练:模型不必每层都从零开始表达同一份信息

5.3 LayerNorm:按 token 稳定特征尺度

先看 BatchNorm 的思路:训练时,对某一层输出里的某个特征维度,从一个 mini-batch 里的多条样本上统计均值和方差,再用这组 batch 统计量做归一化。用简化记法写,设 ab,da_{b,d} 是第 bb 条样本在第 dd 个维度上的数值,也就是这一层输出张量里的一个元素;很多教材会把这个数叫作 activation(激活值)。BatchNorm 会按 batch 维度计算:

μd=1Bb=1Bab,dσd2=1Bb=1B(ab,dμd)2BN(ab,d)=γdab,dμdσd2+ϵ+βd\begin{aligned} \mu_d &= \frac{1}{B}\sum_{b=1}^{B} a_{b,d} \\ \sigma_d^2 &= \frac{1}{B}\sum_{b=1}^{B}(a_{b,d} - \mu_d)^2 \\ BN(a_{b,d}) &= \gamma_d \frac{a_{b,d} - \mu_d}{\sqrt{\sigma_d^2 + \epsilon}} + \beta_d \end{aligned}

在大模型 Transformer 里,一条样本通常是一段 token 序列;把多条样本组成 batch 后,可以粗略看成形状 B×N×DB \times N \times D,其中 BB 是 batch size,NN 是序列长度,DD 是 hidden dimension。不同样本原本的 token 数可以不一样,工程上可以通过 padding、attention mask、packing 等方式把它们组织成可计算的 batch;第 6 节会展开这些做法。上面的公式省略了序列位置维度,只强调 BatchNorm 的关键点:统计量来自 batch 里的多条样本。BatchNorm 的统计范围跨样本,因此它需要 batch 里的样本形态足够稳定

变长序列会让这个前提变麻烦:padding 解决的是“张量形状能不能对齐”,不是“统计量是否稳定”。如果归一化依赖当前 batch 的均值 μ\mu 和方差 σ2\sigma^2,这些统计量就会被当前这批样本的长度组成、padding 比例和 batch size 影响。在 Transformer 里,BatchNorm 的问题不是能不能把序列组成 batch,而是 batch 统计量是否适合作为稳定的归一化依据

Layer Normalization 来自 Layer Normalization 这篇工作。这里的 Layer 不是说跨很多层一起归一化,而是说对同一层里某个样本的输出向量做归一化;在 Transformer 里,就是对某一层中某个 token 的 hidden vector 做归一化。它要解决的问题是:深层网络里,每个 token 的隐藏向量会连续经过 Attention、MLP、残差相加等操作,数值尺度可能一层层漂移;有的维度变得很大,有的维度变得很小,后续子层就更难稳定处理。LayerNorm 的作用是把单个 token 的隐藏向量先拉回到相对稳定的尺度,再交给下一步计算

这里的隐藏向量可以理解成 token 在模型内部的当前表示。Embedding 向量是最开始的表示:token id 查 Embedding 表得到 eie_i,再加上位置相关信息后进入第一个 Transformer block;经过一层 Attention、MLP、Residual Connection 等操作后,同一个位置上的向量会被更新成新的表示。Embedding 向量是第 0 层输入,hidden vector 是模型内部某一层、某个位置上的当前 token 表示

具体地,给定某个位置的向量 xRDx \in \mathbb{R}^{D},LayerNorm 会在这个向量自己的 DD 个 hidden dimension 上计算均值和方差,然后进行归一化、缩放和平移:

LN(x)=γxμσ2+ϵ+βLN(x) = \gamma \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

公式中各项含义如下:

符号含义
xRDx \in \mathbb{R}^{D}单个 token 在当前层、当前子层入口处的隐藏向量
DDhidden dimension 的大小
μ\mu这个向量 DD 个维度的均值
σ2\sigma^2这个向量 DD 个维度的方差
ϵ\epsilon很小的常数,防止分母为 0
γ\gamma可训练的缩放参数,让模型可以重新调整每个维度的尺度
β\beta可训练的平移参数,让模型可以重新调整每个维度的偏移

LayerNorm 对每个 token 独立计算均值和方差,不需要从 batch 里估计稳定统计量;无论当前步处理 1 个 token 还是一批 token,它的定义都一样。在 Transformer block 里,LayerNorm 通常和 Residual Connection 成对出现,放在 Attention 子层和 MLP 子层的前后,常见变体包括 Pre-LN 和 Post-LN。LayerNorm 常用在 Transformer 的每个子层附近,用来稳定 hidden vector 的尺度,而不是改变 token 的长度或位置结构

这三个组件在 Transformer 中会和 Attention 反复组合。一个简化的 block 可以看成:输入先进入 Attention 子层,经过残差相加和归一化,再进入 MLP 子层,继续经过残差相加和归一化。Attention 负责跨位置读信息,MLP 负责逐位置做非线性变换,Residual 与 LayerNorm 负责让这些子层可以稳定堆叠

这里不展开 Pre-LN、Post-LN、激活函数选择和训练稳定性细节。本文只需要建立一个边界:Attention 是核心算子,但不是完整网络;真正的大模型把 Attention 放进一套可训练、可堆叠、可并行执行的结构里。把 Attention 和 Transformer block 区分开,有助于后续分别讨论数学公式和工程实现

6. 并行计算与 GPU 内存墙

6.1 Transformer 中的 batch:训练、prefill 与 decode

在 Transformer 训练中,batch 先出现在 token id 张量上。设一批里有 BB 条序列,padding 或 packing 后的共同长度是 NN,那么几个关键张量的形状可以粗略写成下面这样。训练时的 batch 本质上是把多条序列整理成同一套张量形状

  • input_ids: B × N:每个位置是一个 token id。
  • embedding: B × N × D:经过 Embedding 表查表后,每个 token id 变成一个 DD 维向量。
  • hidden: B × N × D:进入 Transformer block 之后,每一层内部的当前表示也通常叫 hidden states。

如果不同样本原本长度不同,简单做法是把短序列 padding 到同一个 NN,再用 attention mask 标记哪些位置是真实 token、哪些位置是 padding;预训练里也常把长文本流切成固定长度 block,或者把多段文本 packing 到同一段里,减少 padding 浪费。训练 batch 的目标,是把多条序列整理成规则张量,让 GPU 可以一次处理大量 token

Packing 之后并不是把样本边界丢掉。预训练通常会用 <eos> 这类分隔 token 标记文本结束,再做 next-token prediction;SFT 或指令微调里还会配合 attention mask、loss mask 等信息,控制哪些位置可以互相看到、哪些位置参与 loss。Packing 是训练数据组织方式的优化,不是把多个样本无条件混成一段文本

在 Attention 里,Query 可以粗略理解成“我现在要找什么信息”,Key 是“我可以被什么条件匹配到”,Value 是“如果我被匹配到,我贡献什么内容”。推理的第一段通常叫 prefill:模型先处理用户已经输入的 prompt,并在每一层 Attention 中为这些历史 token 计算 Key 和 Value;这些 Key/Value 会被保存在 KV Cache 里,供后续 decode 步骤复用。不同请求的 prompt 长度可能不同,简单实现仍然可以 padding 到同一长度再用 mask;高性能推理框架则会用 variable-length attention、paged attention 或类似机制,用元数据记录每个请求的真实长度,减少无效 padding。Prefill 阶段 batch 的对象是多条 prompt,它的难点是这些 prompt 的长度不一定一样

第二段是 decode:模型开始自回归生成,每一步通常为每个仍在生成的请求新增 1 个 token。假设当前 decode batch 里有 3 个请求,它们在 prefill 和前面若干 decode 步之后,已经积累了不同长度的 KV Cache:

请求 A: 已有 120 个历史 token 的 KV Cache
请求 B: 已有 900 个历史 token 的 KV Cache
请求 C: 已有 35 个历史 token 的 KV Cache

下一步生成时,每个请求都只新增 1 个 token,所以当前输入更像 Bactive×1×DB_{\text{active}} \times 1 \times D;但做 Attention 时,A 的新 token 要读 A 的 120 个历史 token,B 的新 token 要读 B 的 900 个历史 token,C 的新 token 要读 C 的 35 个历史 token。Decode 阶段 batch 的当前输入很短,但每条请求关联的历史缓存长度不同

当前输入: 3 × 1 × D
A 的新 token -> 读取 A 自己的 120 个历史 token
B 的新 token -> 读取 B 自己的 900 个历史 token
C 的新 token -> 读取 C 自己的 35 个历史 token

工程实现会用 position id、sequence length、slot mapping、paged KV Cache 等元数据,把“当前这一步的 token”与“每个请求自己的历史缓存”对应起来。Transformer 推理时也 batch,但这是变长请求的动态 batch,不是把所有请求都强行补成同一个巨大矩形

6.2 矩阵并行与 GPU 内存墙

Attention 相比 RNN 的一个重要优势,是它把序列内部的依赖关系组织成矩阵运算。RNN 的第 tt 步依赖第 t1t-1 步,序列维度难以并行;Attention 可以一次性构造 N×NN \times N 的关系矩阵,再通过矩阵乘法聚合信息。Attention 更适合 GPU,不只是因为模型效果好,也因为它把序列关系改写成了大规模并行线性代数

GPU 的强项是高吞吐矩阵计算。大量乘加可以分配到许多 SM(Streaming Multiprocessor,GPU 上的基本并行计算单元)和 Tensor Core(专门加速矩阵乘法的硬件单元)上执行,只要数据供应及时,硬件就能维持很高利用率。问题在于,计算单元和显存之间并不是同一种速度:HBM(High Bandwidth Memory,高带宽显存)容量大、带宽高,但访问仍然远慢于片上寄存器、shared memory 和缓存。GPU 性能经常不只由 FLOPs(Floating Point Operations,浮点运算次数)决定,还由张量位于哪个存储层级、以及数据在层级之间搬运多少次决定

粗略看,GPU 的存储层级按离计算单元的距离排列:寄存器最快、最小,通常服务于线程内部的临时值;shared memory 和缓存位于片上,容量较小但访问很快;HBM 是 GPU 的主要显存,用来存模型权重、激活、中间矩阵和 KV Cache;CPU 主机内存更远,通常不希望核心计算频繁访问。以 NVIDIA 官方页面中的 H100 SXM 为例,它有 80GB HBM3,GPU memory bandwidth 为 3.35TB/s;这些 HBM 是片外大显存,而寄存器、shared memory 和缓存仍然是更靠近计算单元的小容量工作区。越靠近计算单元,访问越快但容量越小;越远离计算单元,容量越大但搬运代价越高

寄存器 / shared memory / cache   片上,小容量,靠近 SM 和 Tensor Core
HBM3(H100 SXM 示例)             80 GB,3.35 TB/s,片外显存,存权重、激活和 KV Cache
CPU 主机内存                      更远,容量更大,但核心 kernel 不希望频繁访问

这就引出了 Compute Bound 与 I/O Bound 的区分。Compute Bound 表示瓶颈主要在乘加计算,增加算力能直接带来收益;I/O Bound 表示瓶颈主要在读写数据,计算单元可能在等待数据到达。当模型把大量中间矩阵写回 HBM 又读回来时,性能瓶颈可能从计算转移到内存带宽

roofline 模型用 arithmetic intensity 表达同一个判断:一个算子每搬运一个字节能执行多少浮点运算。如果某个算子相对于硬件的算力带宽比来说每个 FLOP 要搬运太多字节,它就会成为访存受限。Attention 的 N×NN \times N 分数矩阵和概率矩阵有足够多的数据搬运,因此朴素 kernel 即使写成矩阵乘法,也可能落在访存受限一侧

标准 Attention 会产生一个 N×NN \times N 的分数矩阵,再经过 Softmax 得到一个 N×NN \times N 的概率矩阵。序列长度翻倍时,这类中间结果按 N2N^2 增长;在长上下文场景下,内存读写开销会迅速变成主要成本。Attention 的数学公式是并行友好的,但朴素实现会制造巨大的中间矩阵 I/O

这也是 FlashAttention 这类工作的出发点。它不改变 Attention 的数学结果,而是通过 tiling 和 online softmax 减少对 HBM 的读写,把更多计算留在更快的片上存储附近完成。推理阶段的 KV Cache 则做了互补取舍:把历史 key 和 value 存在 HBM 里,避免每生成一个新 token 都重新计算它们。FlashAttention 与 KV Cache 的关键都不是重新定义 Attention,而是有意识地使用存储层级

因此,学习 Attention 不能只停在公式层面。公式解释了 QKTQK^T、Softmax 和 VV 的组合关系;系统实现还要回答这些矩阵如何被切分、搬运、缓存,以及长序列下哪些中间结果可以不落到 HBM。大模型里的 Attention 是数学算子,也是受内存层级强约束的系统组件

结语

这篇文章实际回答的是一个前置问题:Attention 公式为什么会长成现在这样?答案不是某个单独技巧,而是多组约束叠加后的结果。序列输入需要保留位置之间的关系,连续向量需要可微相似度,Softmax 需要受控的数值尺度,深层网络需要稳定组件,GPU 实现需要减少无效串行和内存往返。Attention 的公式短,是因为很多动机已经被压缩进每个符号的角色分工里

几条在进入正文前需要保持清楚的边界:

  • Attention 不是完整的 Transformer。Attention 只负责跨位置的信息读取;Transformer block 还包括 MLP、Residual Connection、LayerNorm、位置编码和 mask 等结构。
  • 内积打分不是语义理解本身。它只是连续空间中的相似度计算;语义能力来自数据、目标函数、参数规模和多层组合共同塑造的表示空间。
  • Softmax 的概率权重不是可解释性保证。权重高表示当前层当前 head 的聚合比例高,不等价于人类意义上的因果解释。
  • 矩阵化不等于没有代价。Attention 摆脱了 RNN 的时间步串行依赖,但 N×NN \times N 中间矩阵会把长序列成本转移到显存容量和带宽上。
  • 硬件优化不改变数学语义。FlashAttention 这类优化的重点是减少 HBM 往返;它让同一个 Attention 公式以更适合 GPU 的方式执行。

下一篇可以正式进入 Attention Is All You Need 中的 Scaled Dot-Product Attention:从 QQKKVV 的角色拆分开始,逐项推导 softmax(QKT/dk)Vsoftmax(QK^T / \sqrt{d_k})V,再把 mask、Multi-Head Attention、RoPE、KV Cache 和推理侧内存优化接到同一条线上。