原來ScalingLaw還能被最佳化?Meta這招省token又提效

機器之心報道
編輯:Panda
2017 年,一篇《Attention Is All You Need》論文成為 AI 發展的一個重要分水嶺,其中提出的 Transformer 依然是現今主流語言模型的基礎正規化。尤其是在基於 Transformer 的語言模型的 Scaling Law 得到實驗驗證後,AI 領域的發展更是進入了快車道。
現如今,這篇論文的引用量正向 19 萬衝刺,而 Transformer 和注意力機制本身也已經歷了很多改進和創新,比如我們前段時間報道過的「Multi-Token Attention」和「Multi-matrix Factorization Attention」等。
隨著 AI 的不斷發展,現如今的一個重要挑戰是如何獲得足夠多高質量的 token。又或者,該如何更高效地利用這些 token?為此,還必須對 Transformer 進行進一步的升級改造。
近日,Meta 的一篇論文公佈了他們在這方面取得的一個新進展,提出了一種旋轉不變型三線性注意力機制,並證明其表示能力與 2-simplicial Transformer 相當。更重要的是,它的表現甚至足以改變 Scaling Law 中的係數。Meta 也用 Triton 實現了這種注意力機制。
該研究基於 RoPE 向三線性函式的泛化;而 2-simplicial Transformer 則源自 2019 年 Clift et al. 的研究《Logic and the 2-Simplicial Transformer》,其中將 Transformer 的點積注意力機制泛化到了三線性形式。
  • 論文標題:Fast and Simplex: 2-Simplicial Attention in Triton
  • 論文地址:https://arxiv.org/pdf/2507.02754.pdf
他們進一步證明,在有限的 token 預算下,2-simplicial Transformer 的擴充套件性優於 Transformer。
此外,他們的實驗還表明,2-simplicial Transformer 相對於 Transformer 具有更有利的引數數量 scaling 指數。這表明,與 Chinchilla scaling 不同,有可能以比 2-simplicial Transformer 的引數增長更慢的速度增加 token 數量。
研究結果表明,在 token 約束下執行時,與點積注意力機制 Transformer 相比,2-simplicial Transformer 可以更有效地逼近自然語言的不可約熵。
神經 Scaling Law 概述
要理解這項研究的意義,首先需要了解一下 Scaling Law。
簡單來說,就是損失 L 會隨模型引數總數 N 和 token 數量 D 呈冪律衰減:
其中,第一項 E 通常被描述為不可約損失,對應於自然文字的熵。第二項描述了這樣一個事實:具有 N 個引數的模型的表現達不到理想的生成過程。第三項則對應於這樣一個事實:我們僅使用有限的資料樣本進行訓練,並且沒有將模型訓練到收斂。
理論上,當 N → ∞ 且 D → ∞ 時,大型語言模型應該接近底層文字分佈的不可約損失 E。
對於給定的計算預算 C,其中 F LOP s (N, D) = C,可以將最佳引數數量表示為 Nopt ∝ C a,將最佳資料集大小表示為 Dopt ∝ C b。Hoffmann 等人 (2022) 的作者進行了多項實驗,並將引數函式擬合到損失函式中,以估計指數 a 和 b:多種不同的方法證實,a 大約為 0.49,b 大約為 0.5。這引出了 Hoffmann 等人 (2022) 的核心論點:必須根據模型大小按比例縮放 token 數量。
對於給定的計算預算 C,其中 FLOPs (N, D) = C,可以將最佳引數數量表示為 N_opt ∝ C^a,將最佳資料集大小表示為 D_opt ∝ C^b。Hoffmann et al. (2022) 進行了多次實驗,並根據損失擬合了引數函式,以估計指數 a 和 b。
結果,透過多種不同方法發現:a 約為 0.49,b 約為 0.5。
如此,便引出了 Hoffmann et al. (2022) 的一個核心論點:必須根據模型大小按比例擴充套件 token 數量。
但是,正如前面討論的那樣,足夠高質量且足夠數量的 token 是預訓練擴充套件的新瓶頸,因此需要探索替代的訓練演算法和架構。另一方面,最近的研究表明,之前文獻中提出的大多數建模和最佳化技術僅僅改變了誤差(偏移了 E),並沒有從根本上改變冪律中的指數。谷歌 DeepMind 的研究者 Katie Everett 對此進行過精彩的討論:
https://x.com/_katieeverett/status/1925665335727808651
2-simplicial Transformer
2-simplicial Transformer 由 Clift et al. (2019) 提出,他們將點積注意力機制從雙線性擴充套件為三線性形式,也就是從 1-simplex 擴充套件成了 2-simplex。
先來看看標準的注意力機制:
其中,每一項都是點積 

然後,透過逐行 softmax 運算將注意力分數(logit)轉換為機率權重:
注意力層的最終輸出是根據這些注意力分數對這些值進行線性組合得到的

Clift et al. (2019) 的 2-simplicial Transformer 論文將其推廣到三線性積,其中有兩個額外的鍵和值投射矩陣 W_K′ 和 W_V′,從而得到 K′ = XW_K′ 和 V′ = XW_V′。然後,2-simplicial Transformer 的注意力 logit 由 Q、K 和 K′ 的三線性積給出,從而得到以下三階張量:
從而注意力張量變為:
注意力運算的最終輸出定義為:
其中 

表示兩個向量的元素級 Hadamard 積。2-simplicial Transformer 的虛擬碼如演算法 1 所示。注意,公式 5 不包含 RoPE 等任何位置編碼。

基於行列式的三線性形式
Su et al., 2024 提出 RoPE 時,是想將其作為一種用於 Transformer 語言模型的序列位置資訊捕獲方法。RoPE 對查詢 q_i 和鍵 k_j 應用位置相關的旋轉,使得點積 <q_i, K_j> 是相對距離 i-j 的函式。特別需要注意的是,點積對於正交變換 R 具有不變性:
這對於 RoPE 至關重要,因為對於同一位置 i 相同的查詢 q_i 和鍵 k_i,我們期望其點積不會因基於位置的旋轉而發生變化。請注意,(5) 式中定義的三線性形式並非是旋轉不變,並且對 q_i 、k_i 和 k′_i 進行相同的旋轉不再保留內積。因此,為了將 RoPE 泛化到 2-simplicial 注意力模型,探索其他具有旋轉不變性的雙線性和三線性形式至關重要。
而 Meta 的這個團隊注意到,以下函式也具有旋轉不變性:
可以使用帶符號的行列式運算 

來計算 A^(det) ∈ ℝ^n×n×n。對於任意向量 q,令 q^(l) = q = q [3 (l – 1) : 3l] 為其第 l 個大小為 3 的塊。其 logit 定義為:

由於公式 8 根據 Sarrus 規則包含 2 個點積項,因此需要修改演算法 1,使用 2 個 einsum 而不是第 2 行中的 1 個。最終的注意力權重 S 是透過對上述 logit 應用 softmax 函式來計算的,類似於公式 6。然後,token i 的輸出是值向量的加權和,如公式 7 所示。
定理:對於任意輸入大小 n 和輸入範圍 m = n^{O (1)},存在一個具有單個注意力頭的 Transformer 架構,其 logit 計算方式如公式 (9) 所示,注意力頭維度為 d = 7,使得對於所有 X ∈ [M]^N,如果

則 Transformer 對元素 x_i 的輸出為 1,否則為 0。

對該定理的證明請見原論文附錄。
模型設計
由於 2-simplicial 注意力在序列長度 n 上的擴充套件複雜度為 O (n^3),因此將其應用於整個序列是不切實際的。該團隊的做法是將其引數化為 O (n× w_1 × w_2),其中 w_1 和 w_2 定義的是序列上滑動視窗的維度。每個查詢向量 Q_i 會關注 w_1 個 K 鍵和 w_2 個 K′ 鍵的區域性區域,從而減輕計算負擔。該團隊系統地評估了 w_1 和 w_2 的各種配置,以確定計算效率和模型效能之間的最佳平衡點(見表 1)。
對於因果點積注意力機制,長度為 n 的序列的複雜度由下式給出:
其中 n 是序列長度。這涉及兩次矩陣乘法:一次用於 Q@K,一次用於 P@V,每次乘法每個元素都需要兩次浮點運算。因果掩碼使其能夠跳過 1/2 的計算。
相比之下,以 w_1 和 w_2 為引數的 2-simplicial 注意力機制的複雜度表示為:
其複雜度的增長來源是三線性 einsum 運算,與標準點積注意力機制相比,它需要進行一次額外的乘法運算。
該團隊選擇視窗大小為 (512, 32),以平衡延遲和質量。在此配置下,2-simplicial 注意力機制的計算複雜度與 48k 上下文長度的點積注意力機制相當。
圖 2 給出了一個實現。因此,像在 Flash 注意力機制中那樣平鋪式查詢 Q 會導致計算吞吐量較低。受 Native Sparse Attention 的啟發,Meta 該團隊採用的模型架構利用了較高 (64) 的分組查詢注意力 (GQA) 比率。這種方法能夠沿著查詢頭高效地平鋪,確保密集計算,並消除昂貴的逐元素掩碼。
該團隊還引入了一系列針對 2-simplicial 注意力的核最佳化,這些最佳化基於使用線上 softmax 的 Flash Attention。詳見原論文。下面來重點看看實驗表現。
實驗與結果
這個團隊訓練了一系列 MoE 模型,其引數範圍從 1B 活動引數和 57B 總引數到 3.5B 活動引數和 176B 總引數。具體配置見原論文。
該團隊發現,從 1B (活動)引數模型到 3.5B (活動)引數模型,負對數似然的擴充套件(∆)出現了下降。
此外,在小於 2B (活動)引數的模型中,使用 2-simplicial 注意力機制沒有任何好處。
基於此,該團隊估算了 2-simplicial 注意力機制與點積注意力機制的冪律係數有何不同。基於前述方法,其損失可以表示為:
由於訓練這兩個模型使用的 token 數量相同,因此可以忽略第三項,將損失簡化為:
其中 β = – log E′′ – logA ,由於 E′ 較小,E′′ 是 E′ 的近似值。注意,這裡使用了 log (a + b) = log (1 + a/b) + log (b) 來分離這兩個項,並將 1 + a/b 項隱藏在 E′′ 中。
因此,可以根據表 2 中的損失估算兩組模型的 α 和 β,其中 N 代表每個模型中的有效引數。
該團隊在表 3 中估計了 Transformer 和 2-simplicial Transformer 的斜率 α 和截距 β。
可以看到,與點積注意力 Transformer 相比,2-simplicial 注意力具有更陡的斜率 α,即其 Scaling Law 的指數更高。
© THE END 
轉載請聯絡本公眾號獲得授權
投稿或尋求報道:[email protected]


相關文章