神经网络Transformer架构中的相对位置偏置
字数 2132 2025-12-10 12:43:57
神经网络Transformer架构中的相对位置偏置
相对位置偏置是一种在Transformer的自注意力机制中,替代传统绝对位置编码(如正弦波或可学习绝对位置嵌入)来建模序列中元素间位置关系的方法。其核心思想是:注意力权重不仅应依赖于查询和键的内容(语义),还应依赖于它们之间的相对距离,而非它们在序列中的绝对位置。
第一步:从绝对位置编码到相对位置关系的动机
- 传统绝对位置编码为序列中每个位置分配一个独立的向量(无论是通过正弦函数生成还是可学习的)。当模型在训练时遇到从未见过的序列长度(如长于训练时的序列)时,这些绝对位置编码可能无法泛化,因为新位置没有对应的编码向量。
- 对于许多任务(如理解文本、代码),两个元素之间的相对距离(例如,“单词A在单词B前面3个位置”)往往比它们的绝对索引位置(例如,“单词A在第5个位置”)更重要、更可迁移。
- 相对位置偏置旨在直接建模这种“相对距离”对注意力权重的影响,从而潜在地提升模型对长度变化的泛化能力和对结构关系的理解。
第二步:基本实现原理——在注意力得分中注入偏置
- 在标准的缩放点积注意力中,注意力得分矩阵
A通过查询Q和键K计算:A = QK^T / sqrt(d_k)。这个矩阵A的每个元素A_{ij}表示位置i对位置j的关注程度。 - 相对位置偏置方法不修改
Q和K的内容部分,而是直接向注意力得分矩阵A添加一个偏置矩阵B。即,新的注意力得分A'_{ij} = A_{ij} + B_{ij}。 - 关键点在于,偏置矩阵
B中的值B_{ij}不是一个可学习的针对每个绝对位置对(i, j)的参数,而是一个基于i和j之间的相对距离(i - j)的函数。例如,B_{ij} = r_{i-j},其中r是一个可学习的嵌入向量表,该表以相对距离(或距离桶)为索引。
第三步:相对距离的建模与“距离桶”
- 最简单的方式是直接使用
(i - j)作为索引。但由于序列长度可变,(i-j)的范围可能很大(从-L+1到L-1),需要大量参数且对极远距离的泛化可能不佳。 - 因此,实践中常使用“距离桶”(distance buckets)策略。将所有可能的相对距离
(i-j)映射到少数几个预设的桶中。例如:- 距离为0(自身)一个桶。
- 距离为1(相邻)一个桶。
- 距离为2-4一个桶。
- 距离为5-16一个桶。
- 距离大于16一个桶。
- 也可以区分正向(
i>j)和反向(i<j)距离。
- 模型维护一个小的可学习嵌入表
R,其大小等于桶的数量。计算B_{ij}时,首先根据(i-j)确定其所属的桶索引b,然后取R[b]作为偏置值B_{ij}。这使得模型能用有限的参数处理任意长的相对距离,并鼓励模型对相似距离范围的行为进行泛化。
第四步:在Transformer中的具体集成方式(以T5模型为例)
一种经典且高效的实现是Google T5模型采用的“共享偏置”方式:
- 不再将相对位置信息编码进词嵌入或查询/键向量中,而是将其作为注意力逻辑(logit)的一个加性偏置。
- 定义最大相对距离
k(例如32)。所有大于k或小于-k的距离都映射到同一个桶。 - 模型学习两个小的可学习嵌入表:一个用于“内容-内容”偏置,一个用于“位置-内容”偏置(在某些变体中),但本质都是基于相对距离索引。
- 在计算注意力时,先计算基于内容的点积
A = QK^T,然后根据查询位置i和键位置j的相对距离(i-j),从嵌入表中查出对应的偏置标量b_{i-j},直接加到A_{ij}上。 - 这种方法的计算和存储开销很小,因为它只增加了一个与序列长度相关的偏置矩阵加法,而不是改变
Q/K/V的维度或计算。
第五步:相对位置偏置的优势与影响
- 长度外推性增强:由于偏置仅依赖于相对距离,且距离桶机制能处理未见过的长距离,模型在推理时处理比训练序列更长的输入时,性能下降通常比使用绝对位置编码更平缓。
- 计算效率:相对于一些需要复杂计算的相对位置编码变体,简单的加性偏置实现效率很高。
- 对称性/灵活性:可以灵活地设计距离桶来捕捉不同的先验,例如局部精细建模、远程粗略建模、方向敏感性等。
- 成为现代架构标准:由于其有效性和简洁性,相对位置偏置(或其变种)已被许多先进的Transformer变体(如T5、DeBERTa、ALBERT的部分设计)以及大语言模型(如LLaMA系列)广泛采用,基本取代了原始的绝对正弦位置编码作为默认的位置建模方式。
总结:神经网络Transformer架构中的相对位置偏置是一种通过向注意力得分直接添加基于词元间相对距离的、可学习的偏置项,来建模序列顺序信息的技术。它通过“距离桶”机制高效地参数化相对距离,相比绝对位置编码,在长度外推、计算效率和模型灵活性上展现出优势,已成为现代Transformer模型处理位置信息的主流方法之一。