初探最大更新引數化muP:超引數的跨模型尺度遷移規律

©PaperWeekly 原創 · 作者 | 蘇劍林
單位 | 科學空間
研究方向 | NLP、神經網路
眾所周知,完整訓練一次大型 LLM 的成本是昂貴的,這就決定了我們不可能直接在大型 LLM 上反覆測試超引數。一個很自然的想法是希望可以在同結構的小模型上仔細搜尋超引數,找到最優組合後直接遷移到大模型上。
儘管這個想法很樸素,但要實現它並不平凡,它需要我們瞭解常見的超引數與模型尺度之間的縮放規律,而 muP 正是這個想法的一個實踐。
muP,有時也寫 ,全名是 Maximal Update Parametrization,出自論文《Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer》[1],隨著 LLM 訓練的普及,它逐漸已經成為了科學煉丹的事實標配之一。
方法大意
在接入主題之前,必須先吐槽一下 muP 原論文寫得實在太過晦澀,並且結論的表達也不夠清晰,平白增加了不少理解難度,所以接下來筆者儘量以一種(自認為)簡明扼要的方式來複現 muP 的結論。
先說結論,muP 主要研究超引數跨模型尺度的遷移規律。這裡有幾個關鍵詞:
1. 超引數,目前主要指學習率
2. 模型尺度,目前主要是模型寬度
3. 這的核心是“遷移”
請注意,muP 不研究什麼是最優的超引數,只研究最優超引數隨著模型尺度的變化規律,所以我們需要在某個小模型上搜索最優的超引數組合,然後遷移到大模型上,這就是 muP 的使用場景和使用方法。
推導 muP 的原理是讓模型的前向傳播、反向傳播、損失增量和特徵變化都不隨模型尺度的變化而發生明顯變化:
1. 具體做法是分析初始化的數量級,然後認為結論可以代表後續最佳化的規律;
2. 說白了就是假設做好初始化,後面就會自動沿著正確的軌跡走(好的開始是成功的一大半?);
3. 當然也可以給這個假設講大數定律中心極限定理的故事,但個人認為非必須。
前向傳播
我們從前向傳播開始討論,因為這是相對簡單且成熟的部分。首先,考慮線性層 ,其中 。我們用RMS(Root Mean Square)來作為矩陣尺度的指標,例如
我們知道,要讓初始化階段  的 RMS 跟  的 RMS 大致相等(簡稱“穩定”),那麼  要用:
LeCun 初始化:“均值為 0、方差為 ”的隨機初始化。
這已經算是深度學習的基礎結論之一,所以不再展開推導,還不大瞭解的讀者可以參考以往的《從幾何視角來理解模型引數的初始化策略》[2]、《淺談Transformer的初始化、引數化與標準化》[3] 等文章。
接著,我們考慮非線性層 ,其中  是 Element-wise 的啟用函式。如果還是要維持  的 RMS 跟  的 RMS 近似相等,那麼結果會稍有不同,比如  啟用時我們得到
Kaiming 初始化:“均值為 0、方差為 ”的隨機初始化。
容易看出,Kaiming 初始化跟 LeCun 初始化相比,只是方差相差一個(跟模型尺度無關的)常數 2,可以證明其他啟用函式的結果也類似。所以我們可以下一個結論:
fan_in 初始化:要保證前向傳播的穩定性,那麼應該要用“均值為 0、方差正比於 ”的隨機初始化。
這個結論也可以理解為“啟用函式的影響是模型尺度無關的”,所以如果我們只想分析模型尺度的效應,那麼可以忽略(Element-wise 的)啟用函式的存在,由 LeCun 初始化直接得到縮放規律 
反向傳播
現在我們繼續分析反向傳播(梯度),注意這裡約定變數及其梯度具有相同的 shape,那麼可以算得
第一個公式是當前層內參數的梯度,第二個公式則是該層往前傳播的梯度, 是 Hadamard 積, 是  的導函式。
注意到一個事實:我們常用的啟用函式,其導數都可以被一個(尺度無關的)常數給 Bound 住,所以至少在數量級上我們可以寫出
我們先來看第二個公式,跟  相比,它右端乘的矩陣變成了 ,那麼按照上一節的結論,如果要保持反向傳播的 RMS 穩定性,那麼  的初始化就應該是:
fan_out 初始化:“均值為 0、方差為 ”的隨機初始化。
當  時,前向傳播和反向傳播的要求就出現衝突,這時候有人提了一個折中策略:
Xavier 初始化:“均值為 0、方差為 ”的隨機初始化。
這也叫“fan_avg 初始化”,因為就是將  和  簡單代數平均了一下,其他平均方式也可以考慮,參考《初始化方法中非方陣的維度平均策略思考》[4]。Xavier 初始化看上去同時兼顧了前向和反向,但也可以說兩者都沒兼顧,更好的辦法是設計模型讓大部分引數都是方陣,如後面討論的模型簇(8)。
損失增量
有了前向傳播和反向傳播的鋪墊,我們就可以嘗試分析損失函式的增量了。考慮  時損失函式的變化量
這裡的  是 Frobenius 內積,即把矩陣展平成向量後算向量內積。考慮梯度下降 ,這裡  自然是學習率,結合式(4),我們有
事實上,這個式子已經告訴了我們同一個學習率  不能跨模型尺度使用的原因:
1.  是一個  的矩陣;
2.  是  個數的平方和;
3.  正好是前向和反向的乘積;
4. 如果前向和反向都穩定,那麼  每個元素都是 
5. 所以  就是 
第 4 點可能要多加評述一下。 是一個  矩陣, 是一個  矩陣,兩者相乘就是  個  維向量對做內積,內積是  項求和,而損失  通常是對樣本求平均(即包含了除以  操作),所以如果  和  都是尺度無關的,那麼它們乘起來基本也是尺度無關的【即 RMS 都是 】。
最後的結論表明,如果我們直接將小模型的學習率用於大模型,那麼對於足夠大的模型,它的每一步損失增量就會隨著引數尺度(即 )的變大而爆炸,這意味著沒法複製小模型的收斂過程,甚至可能因為步子邁得太大導致無法收斂。
此時大家可能想到的一個做法是讓  來縮放 ,事實上這個想法已經跟上了 muP 的思路,但實際場景中由於前面說的前向和反向的不相容性,導致第 4 點“如果前向和反向都穩定,那麼  每個元素就是 ”不能總是成立,所以實際情況更為複雜一些。
模型假設
現在讓我們考慮一個更接近實踐的場景。我們的任務是訓練一個  的模型,其中  是資料決定的,不可改變。開頭我們就說了,muP 旨在研究超引數隨著模型尺度的縮放規律,所以一切固定不變的量,都相當於是常數或者說 ,比如初始化方差為 ,等價於說初始化方差為 
我們可以改變的是模型的架構、引數量等部分,但 muP 主要考慮寬度的規律,所以我們把模型的架構定一下。這裡主要考慮的模型簇是:
其中:
1. (帶上了 batch size);
2. 
3.  是任意  的神經網路;
4. 這裡  其實就是我們常說的 hidden size;
5. 我們可以隨意調大 ,來提升模型的引數量和潛力;
6. muP 就是想研究超引數關於  的變化規律。
更具體一點,這裡我們考慮的  是 K 層 MLP:
這裡 ,即都是  的方陣,全都用 fan_in 初始化(等價地,也是 fan_out 初始化)。
補充一下,這裡約定所有引數矩陣都是  方陣,純粹是為了簡化分析,並不是強制要求。因為這裡真正的目的是假設  的引數裡沒有尺度無關的形狀,比如不允許  這樣的形狀,因為 64 是一個常數,但  這樣的形狀是允許的,因為你不管 fan_in、fan_out 或 fan_avg 初始化,方差都是正比於 
組裝起來
確立後具體模型後,我們就可以把前面的結論都組裝起來了。要更新的引數分為  三部分,分別求梯度:
這裡的  運算需要稍微解釋一下: 都是一個矩陣,所以  原則上是一個四階張量,鏈式法則  實際是高階張量的乘法,但這裡不打算展開介紹了,所以簡單用一個  代替,讀者只需要知道它是矩陣乘法的一般推廣就行。
現在來觀察規律:
1. 三個式子都有 
2. 後兩式都有 
3.  裡都是方陣, 和  都是穩定的【RMS 是 】;
4. 如果  也用 fan_in 初始化,那麼  也是穩定的;
5. 要想  穩定,那麼初始化方差是 ,但  是尺度無關的,相當於常數。
這樣一來:
1.  的 RMS 是  是  個數平方和,所以大小是 ,別忘了  是常數,所以實際上就是 ,於是為了得到  的 ,它的學習率要滿足 
2.  是  個數求和, 和  的 RMS 都是 ,我們直接將  的初始化方差設為 ,那麼  的 RMS 就是 ,平方求和後就正好是 ,因此學習率不用變化;
3. 此時  的 RMS 也是 ,但  只是  個數平方和,所以結果是  的,為了得到  的 ,學習率反而需要放大  倍來抵消這個影響,即 
特徵變化
以上結果是沒有問題的,但仔細思考我們會發現推導過程的一個問題:上面的第 2、3 點,都建立在“我們直接將  的初始化方差設為 ”這個設定上,然而這個設定目前來說並沒有直接的依據。如果不對此進一步解釋,那麼推導過程還是不夠完備的。
事實上,單看  這個要求的話,確實是無法排除其他選擇的可能性的,比如  的初始化方差設為 ,此時  的 RMS 是 ,平方求和後是 ,那麼只要學習率  同樣可以實現 。因此,為了解釋 “ 的初始化方差設為 ”的必要性,那麼就需要引入新的條件。
損失函式  是模型的一個宏觀指標,或者說外部指標,單看它的變化已經不足以解釋全部結果了,那麼就需要細化到模型內部了。具體來說,我們希望模型每一層的輸出(通常也稱為特徵,有時也稱啟用值)變化量也具有尺度不變性。比如線性層 ,引數  帶來的輸出變化是
注意 ,所以  就是  個  維向量對的內積。
注意這裡  是精心設計的更新量,它不大可能跟初始化那樣跟  是獨立的,所以“ 維向量對的內積”更有可能是  維內積共有  項求和),因此如果  的 RMS 是 ,那麼可以認為  的 RMS 將是 
於是,為了讓  的 RMS 是 ,我們得到了對  的一個額外要求:
結合  和 ,我們就可以得到 “ 的初始化方差設為 ”的結果。
(注:這一節依賴於 @Chenyu Zheng 的指點,非常感謝!)
Adam 版本
以上就是 SGD 的 muP,對於 Adam,我們通常用 SignSGD 近似做數量級分析:
1. 
2. 
3. 這裡的  指每個元素取絕對值然後求和。
關於 SignSGD 近似本身,讀者還可以參考《當Batch Size增大時,學習率該如何隨之變化?》《Adam的epsilon如何影響學習率的Scaling Law?》等文章,這裡也不展開討論了。總而言之,SignSGD 是分析 Adam 相關縮放規律時一個常用的近似方式。
現在可以模仿 SGD 的過程進行分析:
1.  的 RMS 是  是  個數求和,大小是 ,所以它的學習率要滿足  來抵消尺度影響;
2.  是  個數求和, 和  的 RMS 都是 ,我們將  的初始方差設為 ,那麼  的 RMS 就是  個數求和後是 ,所以學習率按照  變換來抵消尺度影響;
3. 此時  的 RMS 也是 ,但  只是  個數求和,所以它已經是 ,從而學習率不用隨尺度改變。
(注:讀者可以自行檢查一下式(14)是滿足的。)
Muon 版本
接下來自然少不了 Muon 的分析。對於 Muon 本身,我們已經在《Muon最佳化器賞析:從向量到矩陣的本質跨越》、《Muon續集:為什麼我們選擇嘗試Muon?》[5] 做了詳細介紹,這裡不再重複。跟 Adam 用 SignSGD 類似,我們用 MSignSGD 來近似 Muon:
1. 
3. 這裡的  指 Nuclear 範數 [6],是矩陣的所有奇異值之和
4. Nuclear 範數並不好算,但 F 範數好算,它等於矩陣的所有奇異值的平方和的平方根
5. 我們用  範數作為 Nuclear 範數近似,因此 
6.  範數又等於矩陣的所有元素的平方和的平方根
那麼可以開始分析過程:
1.  的 RMS 是 ,所以  大小是 ,要消除尺度的影響,那麼它的學習率要滿足 
2.  是  個數的平方和的平方根, 和  的 RMS 都是 ,我們將  的初始方差設為 ,那麼  的 RMS 就是 ,平方和後再平方根,結果是 ,所以學習率不用變;
3. 此時  的 RMS 也是 ,但  只是  個數的平方和平方根,所以它是  的,學習率反而需要放大  倍來抵消這個影響,即 
(注:這裡 Muon 的結論是對的,但它不滿足條件(14),因為式(14)要細說的話還依賴於一個更新量是 Element-wise 的假設,而 Muon 不符合這個假設,所以實際上不可用。這裡沒有仔細展開相關討論,而是直接沿用了“ 的初始化方差設為 ”的結論,迴避了式(14)。)
結論彙總
將上述結論彙總在一起是:
這裡的  指的是除  外的所有引數,還有要強調的是,這裡的關係都是“正比於”而不是“等於”。另外實踐中可以根據具體需求稍作變化,比如實際我們用 Muon 時, 和  的最佳化通常不用 Muon 而是用 Adam,這將導致兩個變化:
1. 
2.  不變。
如果結合我們在《Muon is Scalable for LLM Training》[7] 所提的 Adujst LR 的話,那麼學習率要多乘一個  是引數矩陣的形狀,我們已經假設了  部分的引數總等比例縮放,所以 。因此,如果要抵消 Adujst LR 帶來的尺度影響,那麼就需要
3.  。
文章小結
本文以儘可能簡明清晰的方式介紹了  muP(Maximal Update Parametrization),這是旨在研究超引數跨模型尺度的遷移規律的工作。基於 muP,我們可以在小模型上以相對較小的成本仔細搜尋超引數(這裡主要是學習率和初始化),然後遷移到大模型上,降低大模型的煉丹成本。
客觀來講,這裡的介紹和分析還比較初步,比如沒有考慮 Bias 項、沒有評估結論在 MLP 以外架構的通用性、也沒有仔細考慮 Normalization 和殘差的作用等。
沒有考慮 Bias 項這個單純是偷懶,權當留給讀者的習題了;至於不同架構下的 muP,一般分析起來比較麻煩,但由於神經網路的相似性,結論大致上是相同的,我們可以不加證明地用著。
個人認為比較關鍵的改進點是 Normalization 和殘差的影響,尤其是 Normalization,它使得不依賴特殊的初始化就可以穩定前向傳播,帶來了更大的自由度和可能性。
當然,這些都留給後續分析了。
參考文獻
[1] https://papers.cool/arxiv/2203.03466
[2] https://kexue.fm/archives/7180
[3] https://kexue.fm/archives/8620
[4] https://kexue.fm/archives/8725
[5] https://kexue.fm/archives/10739
[6] https://en.wikipedia.org/wiki/Nuclear_norm
[7] https://papers.cool/arxiv/2502.16982
更多閱讀

#投 稿 通 道#
 讓你的文字被更多人看到 
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋樑,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。 
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學術熱點剖析科研心得競賽經驗講解等。我們的目的只有一個,讓知識真正流動起來。
📝 稿件基本要求:
• 文章確係個人原創作品,未曾在公開渠道發表,如為其他平臺已發表或待發表的文章,請明確標註 
• 稿件建議以 markdown 格式撰寫,文中配圖以附件形式傳送,要求圖片清晰,無版權問題
• PaperWeekly 尊重原作者署名權,並將為每篇被採納的原創首發稿件,提供業內具有競爭力稿酬,具體依據文章閱讀量和文章質量階梯制結算
📬 投稿通道:
• 投稿郵箱:[email protected] 
• 來稿請備註即時聯絡方式(微信),以便我們在稿件選用的第一時間聯絡作者
• 您也可以直接新增小編微信(pwbot02)快速投稿,備註:姓名-投稿
△長按新增PaperWeekly小編
🔍
現在,在「知乎」也能找到我們了
進入知乎首頁搜尋「PaperWeekly」
點選「關注」訂閱我們的專欄吧
·


相關文章