
1. Attention 对于矩阵计算的优化
1.1 \(Q\)、\(K\)、\(V\) 矩阵融合成单个矩阵
对应的原始操作为:
# ...
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
# ...
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
常规的优化操作为:
# ...
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# ...
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
1.2 多头矩阵融合成单个矩阵
对应的原始操作为:
assert config.n_embd % config.num_heads == 0
# --- Head 1 计算 ---
# x: (B, L, D) -> q1: (B, L, head_dim)
q1, k1, v1 = self.q1(x), self.k1(x), self.v1(x)
# --- 每个头内部进行 Attention 计算 ---
# --- Head 2 计算 ---
q2, k2, v2 = self.q2(x), self.k2(x), self.v2(x)
# --- 每个头内部进行 Attention 计算 ---
# --- 合并 (Concat) ---
# 将两个 (B, L, 64) 拼成 (B, L, 128)
combined = torch.cat([head1_out, head2_out], dim=2)
output = self.w_o(combined)
常规的优化操作为:
assert config.n_embd % config.num_heads == 0
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
# 变换过程: (B, L, D) -> (B, L, Heads, Head_Dim)
q = q.view(B, L, self.num_heads, self.head_dim)
k = k.view(B, L, self.num_heads, self.head_dim)
v = v.view(B, L, self.num_heads, self.head_dim)
# ...
# 变换: (B, L, Heads, Head_Dim) -> (B, Heads, L, Head_Dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# --- 融合的 Attention 计算 ---
context = context.transpose(1, 2)
context = context.contiguous().view(B, L, self.d_model)
output = self.w_o(context)
2. 底层数学逻辑
2.1 矩阵算子融合
矩阵融合其实非常简单,按照传统的做法,\(Q\)、\(K\)、\(V\) 分别对应三个独立的权重矩阵,且由于得到三个矩阵所用到的输入数据 \(x\) 都是同一个,因此如同 1.1 中操作的,每个输入都要与矩阵相乘:
- \(q = x W_q\)
- \(k = x W_k\)
- \(v = x W_v\)
此时,GPU 需要进行三次小规模的矩阵乘法,而融合算子的做法是将这三个独立矩阵横向(dim = 1)拼成一个大矩阵\(W_{qkv}\),此时列数变为原始的 3 倍,而 GPU 只需要进行一次乘法:
当 \(x\) 经过这个 Linear 层(如 1.1 所示的 nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias))的时候,实际进行的运算为:
为了进一步得到独立的 \(q, k, v\) 矩阵,所以需要在最后的维度上进行切分:q, k, v = self.c_attn(x).split(self.n_embd, dim=2)。由于 \(x\) 的实际 shape 为 [B, L, D],所以列是下标为 2 的维度,而不是下标为 1 的维度。
算子通常对应计算图的一个节点,粒度非常细。算子融合对应的操作就是将多个小节点“坍缩”成一个大节点,目的不是减少工作量而是减少对显存的访问次数。现代 GPU 的计算能力极强,但是访存带宽/数据传输速度往往是瓶颈。
2.2 多头注意力的优化
无论是哪种实现方式,多头注意力开始的数据维度均为 [B, L, D],多头注意力计算结束的数据维度均为 [B, L, D]。换句话说,每个头只能且必定分到均匀的一部分 embedding 空间,在计算结束之后再拼接和融合。
注意,为了更纯粹地集中在线性变换的矩阵融合上,下文忽略了 Scaled Attention Score 的计算,不影响推导的正确性。
2.2.1 Type 1 - 独立小矩阵
对于独立小矩阵实现的多头注意力,如 1.2 所示,我们一开始就声明了 \(H=2\) 个独立的小矩阵 [D, d](\(d = D // head\_num\)),并独立计算完了 Attention,此时,我们手里有两个独立的输出张量,shape 都是 [L, d]。我们以每个头中 \(Q\) 的权重矩阵 \(W_q\) 为例,回顾 Attention 前的处理:
- Head 1 输出 (\(O_1\)),\(XH_1\):
第 1 行是 Token 1 (\(x_1\))的结果,第 2 行是 Token 2 (\(x_2\))的结果;
- Head 2 输出 (\(O_2\)),\(XH_2\):
接着,我们会使用拼接操作:torch.cat([O_1, O_2], dim=-1),由于行数完全一样,在数值分布上,这相当于把 \(O_2\) 的每一行,硬生生地衔接到 \(O_1\) 对应行的后面。拼接后的结果是一个 shape 为 [L, D] (\(D = 2d\)) 的矩阵:
不难看出,\(O_{cat}\) 其实就等于:
2.2.2 Type 2 - 融合的矩阵
既然目标是得到 [L, D] 的 \(O_{cat}\),我们进一步看看融合的矩阵是如何在不拆分小矩阵的情况下实现的。
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
阶段一:Linear 计算
大矩阵 \(W_q\)(\(W_k\),\(W_v\) 同理)的 shape 为 [D, D],而不是小矩阵的 [D, d],因为不需要切割,此时:
让我们依然沿用 1.2 的设定,假设 \(L=2, H=2\)。在刚完成计算时,维度 \(L\) 的每个元素对应的结果依然是完整且水平展开的 embedding 向量(长度为 \(D\)):
阶段二:View 升维
紧接着,执行q = q.view(B, L, self.num_heads, self.head_dim)。按照 head_num 在最后一个维度上将上一步得到的结果进行拆分(每个头都均匀地分享一部分 embedding 空间),此时形状的变化为:由 [L, D] 变为 [L, H, d](\(d = D // head\_num\))。在上面的矩阵乘法中,列的方向(维度 \(D\))由 \(Q\) 控制,因此相当于将 \(Q\) 进行了切分。
此时,维度 \(L\) 的每个元素依然有完整的 embedding 信息(\(D\)),只是内部从 \(H\) 的维度进行了进一步细分:
这里有非常关键的一个细节,也是我踩坑了很久的点,\(L[0]\) 此时的矩阵结构中,\(x_1W_{q1}\) 与 \(x_1W_{q2}\) 不再是水平展开在同一“行”的相邻元素,而是变成了沿新增的 Head 维度垂直堆叠的不同的“行”。
升维时,很容易关注到的是轴数的增加,非常容易忽视随着特征维度的拆解,原本沿着单一维度水平展开的 \(D\) 个元素,会被折叠成由新增的 \(H\) 和 \(d\) 共同描述的一个二维矩阵切面。这导致单独一行所能承载的特征大小从 \(D\) 缩小成了 \(d\) 个元素。
我们从 [H, d] 的角度再来观察 \(L[0]\):
\(L[1]\)同理:
阶段三:Transpose 换轴
执行 q = q.transpose(1, 2)进行转置,shape 将从 [L, H, d] 变化为 [H, L, d](省略 \(B\) 维度的变化,此时 \(H\) 即 dim = 1)。
这一步的含义是,不再按照 Token 的顺序(即维度 \(L\))进行第一优先级的分组,而是将 Head 的顺序作为第一优先级。所以 \(H[0]\) 此时就有了完整的 \(x_1\) 与 \(x_2\) 的信息,相当于:第一顺位 head,第二顺位 Token,第三顺位 Embedding 值。\(H[0]\) 将“抽调”所有的 \(\text{Head 1}\) 的数据组成新的矩阵(\(H[1]\)同理):
阶段四:多头独立计算(Batch Matmul)
至此,\(Q, K, V\) 三个矩阵都通过升维实现了多头的融合,shape 为 [H, L, d]。
在实际进行 Attention 计算时,需要用到一个基础知识:对于 PyTorch,当我们对两个多维张量进行矩阵乘法时,无论维数多大,永远是最后两维进行乘法,前面的维度均会被当成 Batch。换句话说,矩阵 \(Q\) 的 \(\text{Head 1}\) 只能看到 \(K\) 和 \(V\) 的 \(\text{Head 1}\),根本不会看到 \(\text{Head 2}\),前置的维度会被隔离,所以不用担心会混合。这里的计算等价于:
for h in range(H): # 取出第 h 个头的 Q 和 K^T (形状都是二维的矩阵)
Q_head = Q[h, :, :] # 形状 [L, d]
K_head_T = K_T[h, :, :] # 形状 [d, L] # 只在当前这个头内部进行标准的二维矩阵乘法
Output[h, :, :] = torch.matmul(Q_head, K_head_T)
Attention 的计算不影响我们对维度的分析,大矩阵并行算 Attention 时,我们的输出张量形状会发生 [H, L, d] → [H, L, L] → [H, L, d] 的变化,最终依然是 [H, L, d]。虽然我们前面讲 \(H\) 这个维度拆解成了 \(H[0]\) 和 \(H[1]\),但数据依然在同一个矩阵中,只是更高维的数据我们不好直观用文本描述。
事实上,在内存里这个大矩阵是这样排的:先把 Head 1 的所有 Token 排完,再排 Head 2 的所有 Token(竖线分隔两个 head)。
阶段五:逆向重排与还原
最后实施转置与融合 context = context.transpose(1, 2).contiguous().view(B, L, self.d_model)。这是前两步的逆过程,相当于重新以维度 \(L\) 优先于维度 \(B\) 的优先级进行数据排列,接着再将最后两个维度的数据“捋直”(从 [H, d] 变为 [D]):
-
第一步:Transpose
- 执行
.transpose(1, 2)后,形状变回[L, H, d](忽略维度 B)。 -
在内存中,它被强制重排(更准确地说,transpose 只改变逻辑视图,contiguous 强制内存重新分配),成了以 Token 作为第一优先级(竖线分隔两个 token):
\[[ \underbrace{o_{11}, o_{12}}_{H1}, \underbrace{o_{21}, o_{22}}_{H2} \quad | \quad \underbrace{o_{31}, o_{32}}_{H1}, \underbrace{o_{41}, o_{42}}_{H2} ]\] -
此时如果我们在一张二维的纸面上描述,按照维度 L 的顺序它的逻辑视图为:
- 执行
-
第二步:View 捋直 (取消换行)
- View 只是改变最后一个维度的“换行规则”。原本的最后一个维度是 \(d=2\),也就是每读 2 个数就换行。所以同一个 Token 内部,Head 1 和 Head 2 是分两行垂直排列的。
-
现在的最后一个维度变成了 \(D=4\)(即 self.d_model,原始 embedding 的空间大小),于是,垂直折叠的 Head 1 和 Head 2,就像被一只手“捋直”成了一根水平的面条:
\[\begin{bmatrix} o_{11} & o_{12} & o_{21} & o_{22} \\ o_{31} & o_{32} & o_{41} & o_{42} \end{bmatrix}\] -
这个“捋直”后的矩阵,和上面 2.2.1 我们用
torch.cat硬拼接出来的矩阵 \(O_{cat}\) 结构完全等价。
3. 重要知识点汇总
- 矩阵算子融合其实有一个基础且直白的原理。横向拼接矩阵时,如:\(W_{qkv} = \begin{bmatrix} W_q & W_k & W_v \end{bmatrix}\) ,矩阵乘法一直是用 左侧的“行”与右侧的“列” 相乘,因此无论拼接多少个新矩阵,后续拼接的矩阵都不会对前置的运算结果造成影响;
- 两个高维张量的矩阵乘法只看最后两个维度,即:前面的维度完全相同(或满足广播机制)时才会进行最后两个维度的二维计算;
- View 之后不仅维度增加了一个,最后一个维度的空间也缩小了。我们可以从右到左去观察 shape,然后模拟 torch 的操作,按照张量最后一个维度的数值,每次读取对应数值的元素后进行“换行”;
0 条评论