神经网络Transformer架构中的多查询注意力
字数 963 2025-12-01 07:44:56
神经网络Transformer架构中的多查询注意力
步骤1:基础注意力机制回顾
在标准Transformer中,多头注意力(MHA)通过多个独立的“头”并行计算注意力,每个头拥有自己的查询(Q)、键(K)、值(V)投影矩阵。这种设计能捕获不同方向的语义信息,但计算和内存开销随头数线性增长。
步骤2:多查询注意力(MQA)的核心思想
多查询注意力是MHA的变体,其核心改进在于:所有注意力头共享同一组键和值投影,仅保留查询向量独立投影。即:
- 查询(Q):每个头仍保留独立的投影矩阵,生成不同的查询序列。
- 键和值(K、V):所有头共享同一组投影矩阵,生成单一的键和值序列。
步骤3:MQA的数学形式化对比
- 传统MHA:每个头 \(h\) 计算 \(\text{Attention}(Q_h W_h^Q, K_h W_h^K, V_h W_h^V)\),需存储 \(H \times (d_q + d_k + d_v)\) 个投影参数。
- MQA:每个头计算 \(\text{Attention}(Q_h W_h^Q, K W^K, V W^V)\),参数量减少为 \(H \cdot d_q + d_k + d_v\)。其中 \(H\) 为头数,\(d_q, d_k, d_v\) 为投影维度。
步骤4:MQA的优势与代价
- 优势:
- 内存效率:键值缓存(KV-Cache)大小降低为MHA的 \(1/H\),显著减少推理时显存占用。
- 推理加速:解码阶段只需重复计算查询投影,键值只需计算一次并复用,提升生成速度。
- 代价:
由于键值共享,模型捕获多样上下文的能力可能弱于MHA,需在效率和表达力间权衡。
步骤5:MQA的实际应用场景
MQA常用于大规模语言模型(如PaLM、Falcon)的推理优化,尤其在长序列生成任务中:
- 在自回归解码时,键值缓存减少使得相同硬件可处理更长的上下文窗口。
- 适合对延迟敏感的应用(如实时对话系统),但训练阶段仍可能使用MHA保证模型容量。
步骤6:扩展变体——分组查询注意力(GQA)
GQA是MQA与MHA的折中方案:将头分组,组内共享键值投影,组间保持独立。平衡了效率与表达能力,被Llama 2等模型采用。