原来Scaling Law还能被优化?Meta这招省token又提效


原来Scaling Law还能被优化?Meta这招省token又提效

仅用于站内搜索,没有排版格式,具体信息请跳转上方微信公众号内链接

机器之心报道
编辑:Panda
2017年,一篇《AttentionIsAllYouNeed》论文成为AI发展的一个重要分水岭,其中提出的Transformer依然是现今主流语言模型的基础范式。尤其是在基于Transformer的语言模型的ScalingLaw得到实验验证后,AI领域的发展更是进入了快车道。

随着AI的不断发展,现如今的一个重要挑战是如何获得足够多高质量的token。又或者,该如何更高效地利用这些token?为此,还必须对Transformer进行进一步的升级改造。
近日,Meta的一篇论文公布了他们在这方面取得的一个新进展,提出了一种旋转不变型三线性注意力机制,并证明其表示能力与2-simplicialTransformer相当。更重要的是,它的表现甚至足以改变ScalingLaw中的系数。Meta也用Triton实现了这种注意力机制。

论文标题:FastandSimplex:2-SimplicialAttentioninTriton
论文地址:https ://arxiv. org/pdf/2507. 02754.pdf
他们进一步证明,在有限的token预算下,2-simplicialTransformer的扩展性优于Transformer。

研究结果表明,在token约束下运行时,与点积注意力机制Transformer相比,2-simplicialTransformer可以更有效地逼近自然语言的不可约熵。
神经ScalingLaw概述
要理解这项研究的意义,首先需要了解一下ScalingLaw。
简单来说,就是损失L会随模型参数总数N和token数量D呈幂律衰减:
其中,第一项E通常被描述为不可约损失,对应于自然文本的熵。第二项描述了这样一个事实:具有N个参数的模型的表现达不到理想的生成过程。第三项则对应于这样一个事实:我们仅使用有限的数据样本进行训练,并且没有将模型训练到收敛。
理论上,当N→∞且D→∞时,大型语言模型应该接近底层文本分布的不可约损失E。
对于给定的计算预算C,其中FLOPs(N,D)=C,可以将最佳参数数量表示为Nopt∝Ca,将最佳数据集大小表示为Dopt∝Cb。Hoffmann等人(2022)的作者进行了多项实验,并将参数函数拟合到损失函数中,以估计指数a和b:多种不同的方法证实,a大约为0. 49,b大约为0. 5。这引出了Hoffmann等人(2022)的核心论点:必须根据模型大小按比例缩放token数量。
对于给定的计算预算C,其中FLOPs(N,D)=C,可以将最佳参数数量表示为N_opt∝C^a,将最佳数据集大小表示为D_opt∝C^b。Hoffmannetal.(2022)进行了多次实验,并根据损失拟合了参数函数,以估计指数a和b。
结果,通过多种不同方法发现:a约为0. 49,b约为0. 5。
如此,便引出了Hoffmannetal.(2022)的一个核心论点:必须根据模型大小按比例扩展token数量。
但是,正如前面讨论的那样,足够高质量且足够数量的token是预训练扩展的新瓶颈,因此需要探索替代的训练算法和架构。另一方面,最近的研究表明,之前文献中提出的大多数建模和优化技术仅仅改变了误差(偏移了E),并没有从根本上改变幂律中的指数。谷歌DeepMind的研究者KatieEverett对此进行过精彩的讨论:
https ://x. com/_katieeverett/status/1925665335727808651
2-simplicialTransformer
2-simplicialTransformer由Cliftetal.(2019)提出,他们将点积注意力机制从双线性扩展为三线性形式,也就是从1-simplex扩展成了2-simplex。
先来看看标准的注意力机制:
其中,每一项都是点积。
然后,通过逐行softmax运算将注意力分数(logit)转换为概率权重:
注意力层的最终输出是根据这些注意力分数对这些值进行线性组合得到的。

从而注意力张量变为:
注意力运算的最终输出定义为:
其中表示两个向量的元素级Hadamard积。2-simplicialTransformer的伪代码如算法1所示。注意,公式5不包含RoPE等任何位置编码。
基于行列式的三线性形式
Suetal.,2024提出RoPE时,是想将其作为一种用于Transformer语言模型的序列位置信息捕获方法。RoPE对查询q_i和键k_j应用位置相关的旋转,使得点积是相对距离i-j的函数。特别需要注意的是,点积对于正交变换R具有不变性:
这对于RoPE至关重要,因为对于同一位置i相同的查询q_i和键k_i,我们期望其点积不会因基于位置的旋转而发生变化。请注意,(5)式中定义的三线性形式并非是旋转不变,并且对q_i、k_i和k′_i进行相同的旋转不再保留内积。因此,为了将RoPE泛化到2-simplicial注意力模型,探索其他具有旋转不变性的双线性和三线性形式至关重要。
而Meta的这个团队注意到,以下函数也具有旋转不变性:
可以使用带符号的行列式运算来计算A^(det)∈ℝ^n×n×n。对于任意向量q,令q^(l)=q=q[3(l-1):3l]为其第l个大小为3的块。其logit定义为:
由于公式8根据Sarrus规则包含2个点积项,因此需要修改算法1,使用2个einsum而不是第2行中的1个。最终的注意力权重S是通过对上述logit应用softmax函数来计算的,类似于公式6。然后,tokeni的输出是值向量的加权和,如公式7所示。
定理:对于任意输入大小n和输入范围m=n^{O(1)},存在一个具有单个注意力头的Transformer架构,其logit计算方式如公式(9)所示,注意力头维度为d=7,使得对于所有X∈[M]^N,如果,则Transformer对元素x_i的输出为1,否则为0。
对该定理的证明请见原论文附录。
模型设计
由于2-simplicial注意力在序列长度n上的扩展复杂度为O(n^3),因此将其应用于整个序列是不切实际的。该团队的做法是将其参数化为O(n×w_1×w_2),其中w_1和w_2定义的是序列上滑动窗口的维度。每个查询向量Q_i会关注w_1个K键和w_2个K′键的局部区域,从而减轻计算负担。该团队系统地评估了w_1和w_2的各种配置,以确定计算效率和模型性能之间的最佳平衡点(见表1)。
对于因果点积注意力机制,长度为n的序列的复杂度由下式给出:
其中n是序列长度。这涉及两次矩阵乘法:一次用于Q@K,一次用于P@V,每次乘法每个元素都需要两次浮点运算。因果掩码使其能够跳过1/2的计算。
相比之下,以w_1和w_2为参数的2-simplicial注意力机制的复杂度表示为:
其复杂度的增长来源是三线性einsum运算,与标准点积注意力机制相比,它需要进行一次额外的乘法运算。
该团队选择窗口大小为(512,32),以平衡延迟和质量。在此配置下,2-simplicial注意力机制的计算复杂度与48k上下文长度的点积注意力机制相当。
图2给出了一个实现。因此,像在Flash注意力机制中那样平铺式查询Q会导致计算吞吐量较低。受NativeSparseAttention的启发,Meta该团队采用的模型架构利用了较高(64)的分组查询注意力(GQA)比率。这种方法能够沿着查询头高效地平铺,确保密集计算,并消除昂贵的逐元素掩码。
该团队还引入了一系列针对2-simplicial注意力的核优化,这些优化基于使用在线softmax的FlashAttention。详见原论文。下面来重点看看实验表现。
实验与结果
这个团队训练了一系列MoE模型,其参数范围从1B活动参数和57B总参数到3. 5B活动参数和176B总参数。具体配置见原论文。
该团队发现,从1B(活动)参数模型到3. 5B(活动)参数模型,负对数似然的扩展(∆)出现了下降。
此外,在小于2B(活动)参数的模型中,使用2-simplicial注意力机制没有任何好处。
基于此,该团队估算了2-simplicial注意力机制与点积注意力机制的幂律系数有何不同。基于前述方法,其损失可以表示为:
由于训练这两个模型使用的token数量相同,因此可以忽略第三项,将损失简化为:
其中β=-logE′′-logA,由于E′较小,E′′是E′的近似值。注意,这里使用了log(a+b)=log(1+a/b)+log(b)来分离这两个项,并将1+a/b项隐藏在E′′中。
因此,可以根据表2中的损失估算两组模型的α和β,其中N代表每个模型中的有效参数。
该团队在表3中估计了Transformer和2-simplicialTransformer的斜率α和截距β。
可以看到,与点积注意力Transformer相比,2-simplicial注意力具有更陡的斜率α,即其ScalingLaw的指数更高。
©THEEND
转载请联系本公众号获得授权
投稿或寻求报道:liyazhou@jiqizhixin. com


文章作者: ZejunCao
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 ZejunCao !
  目录