神经网络Transformer架构中的多查询注意力
字数 963 2025-12-01 07:44:56

神经网络Transformer架构中的多查询注意力

步骤1:基础注意力机制回顾
在标准Transformer中,多头注意力(MHA)通过多个独立的“头”并行计算注意力,每个头拥有自己的查询(Q)、键(K)、值(V)投影矩阵。这种设计能捕获不同方向的语义信息,但计算和内存开销随头数线性增长。

步骤2:多查询注意力(MQA)的核心思想
多查询注意力是MHA的变体,其核心改进在于:所有注意力头共享同一组键和值投影,仅保留查询向量独立投影。即:

  • 查询(Q):每个头仍保留独立的投影矩阵,生成不同的查询序列。
  • 键和值(K、V):所有头共享同一组投影矩阵,生成单一的键和值序列。

步骤3:MQA的数学形式化对比

  • 传统MHA:每个头 \(h\) 计算 \(\text{Attention}(Q_h W_h^Q, K_h W_h^K, V_h W_h^V)\),需存储 \(H \times (d_q + d_k + d_v)\) 个投影参数。
  • MQA:每个头计算 \(\text{Attention}(Q_h W_h^Q, K W^K, V W^V)\),参数量减少为 \(H \cdot d_q + d_k + d_v\)。其中 \(H\) 为头数,\(d_q, d_k, d_v\) 为投影维度。

步骤4:MQA的优势与代价

  • 优势
    1. 内存效率:键值缓存(KV-Cache)大小降低为MHA的 \(1/H\),显著减少推理时显存占用。
    2. 推理加速:解码阶段只需重复计算查询投影,键值只需计算一次并复用,提升生成速度。
  • 代价
    由于键值共享,模型捕获多样上下文的能力可能弱于MHA,需在效率和表达力间权衡。

步骤5:MQA的实际应用场景
MQA常用于大规模语言模型(如PaLM、Falcon)的推理优化,尤其在长序列生成任务中:

  • 在自回归解码时,键值缓存减少使得相同硬件可处理更长的上下文窗口。
  • 适合对延迟敏感的应用(如实时对话系统),但训练阶段仍可能使用MHA保证模型容量。

步骤6:扩展变体——分组查询注意力(GQA)
GQA是MQA与MHA的折中方案:将头分组,组内共享键值投影,组间保持独立。平衡了效率与表达能力,被Llama 2等模型采用。

神经网络Transformer架构中的多查询注意力 步骤1:基础注意力机制回顾 在标准Transformer中,多头注意力(MHA)通过多个独立的“头”并行计算注意力,每个头拥有自己的查询(Q)、键(K)、值(V)投影矩阵。这种设计能捕获不同方向的语义信息,但计算和内存开销随头数线性增长。 步骤2:多查询注意力(MQA)的核心思想 多查询注意力是MHA的变体,其核心改进在于: 所有注意力头共享同一组键和值投影 ,仅保留查询向量独立投影。即: 查询(Q):每个头仍保留独立的投影矩阵,生成不同的查询序列。 键和值(K、V):所有头共享同一组投影矩阵,生成单一的键和值序列。 步骤3:MQA的数学形式化对比 传统MHA:每个头 \( h \) 计算 \( \text{Attention}(Q_ h W_ h^Q, K_ h W_ h^K, V_ h W_ h^V) \),需存储 \( H \times (d_ q + d_ k + d_ v) \) 个投影参数。 MQA:每个头计算 \( \text{Attention}(Q_ h W_ h^Q, K W^K, V W^V) \),参数量减少为 \( H \cdot d_ q + d_ k + d_ v \)。其中 \( H \) 为头数,\( d_ q, d_ k, d_ v \) 为投影维度。 步骤4:MQA的优势与代价 优势 : 内存效率 :键值缓存(KV-Cache)大小降低为MHA的 \( 1/H \),显著减少推理时显存占用。 推理加速 :解码阶段只需重复计算查询投影,键值只需计算一次并复用,提升生成速度。 代价 : 由于键值共享,模型捕获多样上下文的能力可能弱于MHA,需在效率和表达力间权衡。 步骤5:MQA的实际应用场景 MQA常用于大规模语言模型(如PaLM、Falcon)的推理优化,尤其在长序列生成任务中: 在自回归解码时,键值缓存减少使得相同硬件可处理更长的上下文窗口。 适合对延迟敏感的应用(如实时对话系统),但训练阶段仍可能使用MHA保证模型容量。 步骤6:扩展变体——分组查询注意力(GQA) GQA是MQA与MHA的折中方案:将头分组,组内共享键值投影,组间保持独立。平衡了效率与表达能力,被Llama 2等模型采用。