神经网络Transformer架构中的序列到序列模型蒸馏
字数 1693 2025-12-05 15:16:18
神经网络Transformer架构中的序列到序列模型蒸馏
-
基本概念与动机
- 序列到序列模型:是一类神经网络模型,其核心任务是接收一个输入序列(如一句英文句子),并生成一个对应的输出序列(如该句的法语翻译)。典型架构由编码器和解码器组成,编码器将输入序列压缩为上下文表示,解码器基于此表示自回归地生成输出序列。
- 模型蒸馏:是一种模型压缩技术,其核心思想是将一个庞大、复杂但性能优异的“教师模型”中的知识,迁移到一个更小、更高效的“学生模型”中,使学生模型在保持近似性能的同时,大幅减少计算和存储开销。
- 序列到序列模型蒸馏的动机:Transformer架构的序列到序列模型(如用于翻译、摘要的模型)通常参数量极大、推理延迟高。直接将其部署到资源受限环境(如移动设备、边缘计算或需要低延迟响应的在线服务)中面临挑战。蒸馏技术旨在获得一个轻量级的“学生”模型,以解决部署难题。
-
知识形式与蒸馏目标
- 在序列到序列任务中,需要蒸馏的知识主要体现在两个层面:
- 输出分布知识:这是最核心的知识。对于解码器的每一个生成步骤,教师模型会为整个词汇表输出一个概率分布(即下一个词是什么的软目标),这个分布包含了模型对各类选项的“信心”程度,比硬标签(仅正确词为1,其余为0)蕴含更丰富的信息。学生模型的学习目标是模仿教师模型在每个生成步骤上的输出概率分布。
- 隐藏状态知识:除了最终输出,教师模型中间层的隐藏状态(编码器的输出、解码器的中间表示等)也被认为包含了有价值的语义和结构信息。学生模型可以通过对齐其隐藏状态与教师模型相应层的隐藏状态来学习。
- 在序列到序列任务中,需要蒸馏的知识主要体现在两个层面:
-
主要蒸馏方法
- 序列级蒸馏:
- 原理:不再要求学生对齐每个生成步骤的分布,而是利用教师模型先为整个训练集生成一组“软序列”或“伪目标序列”。例如,在机器翻译中,教师模型为每个源语句生成一个翻译(可能通过束搜索获得),这个生成的翻译序列就作为学生模型训练时的目标标签。
- 过程:学生模型学习的目标是最大化它生成这个“教师提供的伪序列”的概率。这种方法简化了训练目标,使学生更专注于复制教师的最终输出行为。
- 词级蒸馏:
- 原理:要求学生在解码的每一个时间步上都对齐教师的输出分布。这是最直接的蒸馏形式。
- 损失函数:通常使用Kullback-Leibler散度来衡量学生输出概率分布与教师输出概率分布之间的差异,并将其作为额外的损失项,与传统的交叉熵损失(针对真实标签)结合。
- 隐藏状态蒸馏:
- 原理:强制学生模型的某些中间层表示(如编码器最后层的输出、解码器各层的隐藏状态)与教师模型的对应层表示相似。
- 方法:通过计算均方误差或余弦相似度等损失函数,最小化学生与教师对应隐藏状态之间的差距。这有助于学生模型内部学习到与教师相似的表示空间。
- 结合方法:在实际应用中,常将上述方法结合使用,例如同时使用词级蒸馏损失和隐藏状态蒸馏损失,以更全面、更强力地指导学生模型的学习。
- 序列级蒸馏:
-
技术挑战与策略
- 学生模型架构选择:学生模型可以是一个小型的Transformer,也可以是其他更高效的架构(如使用深度可分离卷积、或更少的层和头数)。关键在于设计容量足够但结构更精简的网络。
- 蒸馏温度:在词级蒸馏中常引入“温度”参数。通过提高温度来“软化”教师模型的输出分布,使概率分布更平滑,从而揭示不同类别间更丰富的相似性关系,便于学生模仿。
- 序列长度不匹配:在序列级蒸馏中,教师生成的伪序列长度可能与真实标签不同。训练时需要使用动态的序列处理方法来应对。
- 多教师蒸馏:有时会集成多个教师模型(可能是不同架构或在不同数据上训练)的知识,共同指导一个学生模型,以期获得更稳健、泛化能力更强的学生。
-
应用与影响
- 应用场景:这项技术广泛应用于需要部署大型序列生成模型的领域,包括但不限于神经机器翻译、文本摘要、对话生成、语音识别等。
- 核心价值:它成功地在模型大小、推理速度与生成质量之间取得了显著更好的权衡。使得原本只能在强大服务器上运行的复杂模型,得以在手机、嵌入式设备或高并发在线服务中实用化,极大地推动了Transformer模型在真实世界的落地应用。