Transformer 没那么难:我用“查表“理解注意力机制

张开发
2026/6/5 1:57:55 15 分钟阅读
Transformer 没那么难:我用“查表“理解注意力机制
Transformer 没那么难我用查表理解注意力机制抛开数学公式聊聊注意力机制到底在干嘛一个问题面试时我问“注意力机制是啥”候选人背公式“Q 乘 K 转置除以根号 d然后 softmax…”我打断他“等等Q 和 K 到底是啥”他懵了。说实话我一开始也这样。学了半年 Transformer只记得一堆矩阵乘法但心里没底。直到有人跟我说了四个字“注意力就是查表”。一下子通透了。一句话理解 Attention核心思想每个词都要看看其他词决定关注谁。怎么做到查表。想象一个图书馆Query你要找什么书“我想找编程相关的”Key书脊上的标签“这本书是编程类”Value书的内容Query 去匹配 Key找到相关的把 Value 拿出来加权平均。完事。图书馆查表模型 ┌─────────────┐ │ Query │ 我想找编程书 │ (查询意图) │ └──────┬──────┘ │ 匹配 ▼ ┌─────────────────────────────────────┐ │ Key Value │ │ (标签) (内容) │ │ ───────────────────────── │ │ 编程 → 《Python入门》 │ │ 小说 → 《三体》 │ │ 编程 → 《算法导论》 │ └─────────────────────────────────────┘ │ ▼ 加权平均 ┌─────────────┐ │ 结果 │ 《Python入门》× 0.6 │ (相关书籍) │ 《算法导论》 × 0.4 └─────────────┘# 核心就三步defattention(Q,K,V):# 1. 算相似度Query 和 Key 有多像scoresQ K.T# 2. 归一化变成权重加起来1weightssoftmax(scores)# 3. 按权重取 Valuereturnweights V工程直觉不是玄学就是「查相似度 加权取值」。Q、K、V 从哪来同一个词通过三个不同的投影矩阵变成三种身份输入: 写 │ │ 词向量: [0.3, -0.2, 0.5, ...] ▼ ┌─────────────────────────────────────┐ │ 三个投影矩阵 │ │ ┌─────────┐ ┌─────────┐ ┌────────┐ │ │ │ W_q │ │ W_k │ │ W_v │ │ │ └────┬────┘ └────┬────┘ └───┬────┘ │ │ │ │ │ │ │ ▼ ▼ ▼ │ │ Query Key Value │ │ (我要查啥) (我是啥标签) (我有啥内容)│ │ │ │ │ │ │ └───────────┴──────────┘ │ │ │ │ │ ▼ │ │ Attention │ │ (查表加权) │ └─────────────────────────────────────┘# 输入写 这个词的向量word_vec[0.3,-0.2,...]# 512 维# 三种身份Qword_vec W_q# 我要查什么Kword_vec W_k# 我是什么标签Vword_vec W_v# 我携带什么信息关键Query 是主动的我要查Key 是被动的我被查Value 是真正要取的内容。为什么要多头类比CNN 有多个卷积核有的抓边缘有的抓纹理。Transformer 也一样8 个头就是 8 个视角┌────────────────────────────────────────┐ │ Multi-Head Attention │ │ │ │ 输入向量 (512维) │ │ │ │ │ ▼ │ │ ┌────────────────┐ │ │ │ 切分成8份 │ │ │ │ 每份64维 │ │ │ └───────┬────────┘ │ │ │ │ │ ┌─────┼─────┬──────────────┐ │ │ │ │ │ │ │ │ ▼ ▼ ▼ ▼ │ │ 头1 头2 头3 ... 头8 │ │ │ │ │ │ │ │ ▼ ▼ ▼ ▼ │ │ 语法 语义 指代 ... 其他关系 │ │ 关系 关系 关系 │ │ │ │ │ │ │ │ └─────┴─────┴──────────────┘ │ │ │ │ │ ▼ │ │ 拼接起来 (8×64512维) │ │ │ │ │ ▼ │ │ 线性投影输出 │ └────────────────────────────────────────┘头 A看语法关系主谓宾头 B看语义关系“苹果和水果”头 C看指代关系它指谁把向量切成 8 份各算各的最后拼起来。# 伪代码把 512 维切成 8 个 64 维并行计算headssplit(QKV,num_heads8)outputs[attention(h)forhinheads]# 并行resultconcat(outputs)# 拼回来位置信息从哪来Attention 本身不关心顺序我喜欢猫和猫喜欢我算出来一样。解决方案给每个位置一个唯一编码加到词向量上。输入句子: 我喜欢猫 词1:我 词2:喜欢 词3:猫 │ │ │ ▼ ▼ ▼ 词向量 词向量 词向量 │ │ │ │ │ │ │ │ │ 位置1编码 位置2编码 位置3编码 (第1个词) (第2个词) (第3个词) │ │ │ ▼ ▼ ▼ 最终向量 最终向量 最终向量 │ │ │ └───────────┴───────────┘ │ ▼ 带位置信息的表示就像给每个词贴了个序号标签模型就知道这是第几个词了。# 第1个词 第1个位置编码# 第2个词 第2个位置编码# ...outputword_vecpos_encoding[pos]完整流程一句话输入句子 │ ▼ ┌─────────────────┐ │ 词嵌入层 │ 查表得到词向量 └────────┬────────┘ │ ▼ ┌─────────────────┐ │ 位置编码 │ 加位置标签 └────────┬────────┘ │ ▼ ┌──────────────────────────────────────┐ │ Transformer Layer (×N) │ │ │ │ ┌────────────────────────────────┐ │ │ │ Multi-Head Attention │ │ │ │ (词之间互相看看查表) │ │ │ └─────────────┬──────────────────┘ │ │ │ │ │ ▼ │ │ 残差连接 LayerNorm │ │ │ │ │ ▼ │ │ ┌────────────────────────────────┐ │ │ │ Feed Forward │ │ │ │ (每个词自己加工) │ │ │ └─────────────┬──────────────────┘ │ │ │ │ │ ▼ │ │ 残差连接 LayerNorm │ │ │ └─────────────────┬────────────────────┘ │ ▼ 下一层 / 输出Transformer Layer 就干两件事Attention词之间互相看看查表FFN每个词自己加工一下堆 6 层或 12 层提取层次化特征。三个工程坑1. 推理慢用 KV Cache生成文本时别重复算历史词的 K、V缓存下来复用。第一次生成: 输入: 我喜欢 → 计算 K1,V1 K2,V1 → 输出 猫 ↓ 缓存下来 第二次生成: 输入: 猫 → 只用算 K3,V3 K1,K2 直接用缓存 ↓ 输出 # 第一次算完整句子的 K、V# 后面只算新词的老的直接用缓存next_token,K_cache,V_cachemodel.generate(token,cache(K_cache,V_cache))2. 显存炸注意 O(n²)Attention 计算量是「长度平方」。512 长度变 4096计算量翻 64 倍。长度 计算量 显存占用 ───────────────────────── 512 1× 正常 1024 4× 注意 2048 16× 危险 4096 64× 爆炸解决方案稀疏 Attention、滑动窗口、分层处理。3. 长度限制位置编码只预计算了固定长度比如 512超了会报错。解决方案用 RoPE、ALiBi 这些可外推的位置编码。总结概念工程理解图示Attention查表Query 查 Key加权取 Value图书馆模型Q/K/V同一个词的三种身份三分支投影图Multi-Head8 个视角并行看切分-并行-拼接位置编码给词贴序号标签向量相加残差连接抄近路防止梯度消失跳跃连接一句话Transformer 就是「查表 加工」重复 N 次。别被公式吓到本质挺简单的。写在最后你是咋理解 Attention 的欢迎评论区交流觉得有用给个 下篇见

更多文章