字节Seed团队PHD-Transformer突破预训练长度扩展!破解KV缓存膨胀难题


字节Seed团队PHD-Transformer突破预训练长度扩展!破解KV缓存膨胀难题

仅用于站内搜索,没有排版格式,具体信息请跳转上方微信公众号内链接

机器之心报道
编辑:杜伟
最近,DeepSeek-R1和OpenAIo1/03等推理大模型在后训练阶段探索了长度扩展(lengthscaling),通过强化学习(比如PPO、GPRO)训练模型生成很长的推理链(CoT),并在奥数等高难度推理任务上取得了显著的效果提升。
受此启发,研究人员开始探索预训练阶段的长度扩展,已有方法包括在序列中插入文本、插入潜在向量(如Coconut)、复用中间层隐藏状态(如CoTFormer)以及将中间隐藏状态映射为概念(如COCOMix)。不过,这些方法普遍存在问题,比如需要更大的KV缓存导致推理慢/占内存多。
本文中,来自ByteDanceSeed团队的研究者提出了更简单的方法:直接重复输入tokens(1/2/3/4次),不做中间层处理。他们观察到了训练损失和模型性能随重复倍数扩展的趋势,如下图1a和1b所示。但是,直接重复tokens也带来了新问题,包括KV缓存规模线性增加,内存压力大;预填充时间超线性增加;解码延迟变长。这些都是实现预训练长度扩展需要重点解决的挑战。
论文标题:EfficientPretrainingLengthScaling
arXiv地址:https ://arxiv.org/pdf/2504.14992

具体来讲,研究者将第一个token表示原始token,将重复的token表示为解码token。同时仅保留从原始token生成的KV缓存来用于长距离依赖建模,并在隐藏解码token用于下一个token预测之后丢弃它们的KV缓存。因此,PHD-Transformer提供了与原始transformer相同的KV缓存,同时相较于简单的token重复实现了显著的推理加速(如图1d所示)。
另外,为了更好地保留隐藏解码token的KV缓存的性能优势,研究者引入了一种滑动窗口注意力——PHD-SWA,保持了这些token的局部滑动窗口缓存,在实现显著性能提升的同时,仅需要的额外KV缓存内存。
研究者还注意到,在PHD-SWA中,隐藏解码token的KV缓存表现出了顺序依赖关系,这导致预填充时间呈线性增长。为了解决这个问题,研究者提出了逐块滑动窗口注意力——PHD-CSWA,从而限制了每个块内的顺序依赖关系。
因此,得益于只有最后一个块的预填充时间呈线性增长,PHD-CSWA显著缩短了预填充时间(如图1c所示)。
方法概览
PHD的架构下图2所示,与原始Transformer相比,PHD保留了相同的模型架构,仅在输入序列和注意力矩阵的设计上有所不同。具体而言,他们仅允许原始token生成KV缓存,并且可以被所有token全局关注;同时隐藏状态的KV缓存在并行隐藏解码后会被立即丢弃。注意力矩阵的策略具体如下:
研究者在推理过程中实现了与原始Transformer相同的KV缓存大小和内存访问模式。虽然需要K次FLOP,但这些计算可以并行处理,从而在内存受限的推理场景中最大限度地降低延迟开销。该架构的核心优势在于原始token和隐藏解码token之间的解耦。在预填充期间,只有原始token需要计算。
这种设计确保预填充时间与原始Transformer相同,并且无论扩展因子K如何变化,预填充时间都保持不变。而对于损失计算,研究者仅使用token的最终副本进行下一个token的预测。总之,使用token的第一个副本进行KV缓存生成,使用token的最后一个副本进行下一个token的预测。
内核设计
M^ij_mn的简单实现会导致注意力层计算量增加K^2倍,FFN层计算量也增加K倍。然而,由于注意力是稀疏计算的,的注意力可以大幅降低。因此,研究者将原始token和隐藏解码token分成两组,并将它们连接在一起。
下图3展示了K=3的示例,可以得到一个包含t个原始token的序列和一个包含2t个隐藏解码序列的序列。通过重新排列token的位置,研究者将掩码注意力的位置保留在一个连续块中,从而优化了注意力计算,将注意力计算复杂度降低到。
PHD-SWA和PHD-CSWA
与简单的token重复相比,PHD-Transformer在保持原始KV缓存大小的同时实现了长度扩展。然而通过经验观察到,为隐藏解码token保留一些KV缓存可以带来显著的性能提升。因此,为了在保持效率的同时获得这些优势,研究者引入了PHD-SWA,将滑动窗口注意力限制在W个先前的隐藏解码token上。
如下图4所示,PHD-SWA的注意力模式将对原始token的全局访问与对W个最近隐藏解码token的局部访问相结合。这种改进的注意力机制实现了显著的性能提升,同时仅需要的额外KV缓存内存。
虽然PHD-SWA滑动窗口方法提升了模型性能,但由于隐藏解码token的KV缓存中存在顺序依赖关系,它会产生K倍的预填充开销。为了解决这个问题,研究者引入了PHD-CSWA,它可以在独立的块内处理注意力。
如下图4所示,PHD-CSWA将滑动窗口注意力限制在单个块内运行。这种架构创新将额外的预填充开销减少到最终块内的K次重复,而不是整个序列重复,这使得额外的计算成本几乎可以忽略不计,同时保留了局部注意力模式的优势。
实验结果
在实验中,研究者使用OLMo2作为代码库,并在ARC、HellaSwag、PIQA、Winogrande、MMLU和CommonsenseQA等公开基准测试集上进行了评估。
训练细节:研究者使用1.2B参数规模的模型,它是一个16层的密集模型。每个token的隐藏层维数设置为2048,FFN层的隐藏层大小设置为16384。同时使用组查询注意力(Group-QueryAttention,GQA),它包含32个查询头和8个键/值头,每个头的隐藏层维数设置为64。研究者使用500B个token训练该模型。
对于本文提出的PHD系列设置,研究者预训练了以下两种PHD-CSWA变体:
PHD-CSWA-2-16-32,其中训练token重复两次。保留一个包含16个token的局部窗口,并将块大小设置为32个token。
PHD-CSWA-3-16-32,其中训练token重复三次。局部窗口大小和块大小与PHD-CSWA-2-16-32的设置相同。
PHD-CSWA在各个基准测试中均实现了持续的性能提升。下图5中展示了训练曲线,下表1中展示了主要结果。本文提出的PHD-CSWA-2-16-32在这些基准测试中平均实现了1.5%的准确率提升,训练损失降低了0.025;而PHD-CSWA-3-16-32在这些基准测试中平均实现了2.0%的准确率提升,训练损失降低了0.034。
研究者还分析了PHD和PHD-SWA的扩展性能,以分析扩展解码计算的性能。训练细节:使用相同的550M模型配置,将窗口大小W设置为16,并在{2,3,5}范围内改变扩展因子K。对于局部窗口大小,研究者在所有实验中都将窗口大小设置为16。
PHD-SWA的性能在增加扩展因子时有效扩展。如下图8所示,使用固定窗口大小时,损失曲线和下游性能会随着token重复次数而有效扩展。通过将扩展因子设置为5,可以实现接近0.06的损失降低,同时显著提升下游性能。
下表2中的定量结果表明,当扩展至K=5时,所有基准测试的平均准确率提高了1.8%,这证实了本文的方法在更激进的扩展方面仍然有效。
更多实验结果请参阅原论文。
©THEEND
转载请联系本公众号获得授权
投稿或寻求报道:liyazhou@jiqizhixin.com


文章作者: ZejunCao
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 ZejunCao !
  目录