谷歌超硬核教科書來了!JeffDean帶貨揭Gemini訓練秘籍:在TPU上scaling


新智元報道  

編輯:KingHZ
【新智元導讀】谷歌團隊釋出LLM硬核技術教科書,從「系統檢視」揭秘LLM Scaling的神秘面紗。Jeff Dean強調書中藏著谷歌最強AI模型Gemini訓練的更多資訊。
由於深度學習的「黑箱」本性,從業者自我調侃道:
如果說深度神經網路是現代版的「鍊金術」,我們在古代就是「鍊金術士」。
2018年5月3日, Science發表新聞,標題直指「鍊金術」,強調加強AI的科學基礎
但這次的谷歌的團隊,卻有不一樣的看法:
在許多方面,深度學習仍然帶有一定的「鍊金術」色彩,但理解和最佳化模型效能並不一定如此——即使是在大規模環境下!
近日,谷歌DeepMind科學家Jacob Austint在X上, 釋出了基於JAX和TPU的大模型Scaling教科書《How to Sacle Your Model》。
Jeff Dean轉發原帖,並打起了廣告:
谷歌最強的Gemini模型的訓練,重度依賴JAX軟體棧+TPU硬體平臺。
如果你想了解更多詳情,來看看這本超棒的書:「How to Sacle Your Model」。
進入教科書網站,可以看到大寫的標題:「如何擴大模型規模(How to  Sacle Your Model)」。
正如小標題所言,這本書關於在張量處理單元(TPU)上大語言模型的的系統觀點。
這是一本關於LLM底層硬核技術的教科書,簡介如下:
訓練大語言模型(LLMs)常常讓人感覺就像鍊金術,但理解和最佳化模型的效能其實並不複雜。
本書的目標是揭開在TPU上擴充套件語言模型的科學謎團:TPU是如何工作的,它們如何相互通訊,LLM在實際硬體上是如何執行的,以及在訓練和推理過程中如何對模型進行並行化,以便在大規模執行時實現高效性。
如果你想知道「訓練這個LLM需要多貴的成本」、「要自己部署這個模型需要多少記憶體」或者「什麼是AllGather」這些問題的答案,希望本書能對你有所幫助。
教科書連結:https://jax-ml.github.io/scaling-book/
模型Scaling,無需恐懼
三四年前,大多數機器學習研究人員,可能並不需要了解模型擴充套件(model scaling)。
但如今,即便是「較小」的模型,也已經逼近硬體極限,因此研究要有真正的創新性,就必須考慮如何在大規模環境下提高效率。
作者詳細解釋了為什麼要模型擴充套件及其目標:
如果某種方法能在基準測試中提升20%的效能,但同時使Roofline效率下降20%,那麼這樣的最佳化是沒有意義的。
許多有前景的模型架構最終失敗,並不是因為它們在理論上不可行,而是因為它們無法高效擴充套件,或者沒有人投入足夠的精力去最佳化它們的計算效率。
模型擴充套件的目標是在增加用於訓練或推理的晶片數量時,實現吞吐量的線性增長,這被稱為 「強擴充套件」(Strong Scaling)。
通常,增加晶片數量(即「平行計算」)可以減少計算時間,但同時也會帶來額外的晶片間通訊開銷。如果通訊時間超過了計算時間,就會遇到 「通訊瓶頸」,導致無法實現理想的擴充套件效能。如果對硬體足夠了解,能夠預測這些瓶頸的出現位置,就可以透過調整模型設計或重新配置系統來規避它們。
看不懂這些,也沒關係,畢竟這是谷歌最強模型Gemini同款的技術棧
但作者誠意十足,表示:如果認真看完後,有晦澀的地方,請及時反饋,保證一定改。
作者保證

:必有所得

從處理單個加速器到處理數萬個加速器,相對簡單的原則無處不在,瞭解這些原則可以讓你做很多有用的事情:
  • 粗略評估模型的各個部分與理論最優效能的接近程度。
  • 在不同規模下,合理選擇平行計算方案(如何在多個裝置間分配計算任務)。
  • 估算訓練和執行大型Transformer模型所需的成本和時間。
  • 設計能夠充分利用特定硬體特性的演算法。
  • 設計硬體時,基於對當前演算法效能瓶頸的明確理解來進行最佳化。
此書的目標是解釋TPU(以及 GPU)的工作原理,以及為了當前硬體上實現高效計算,Transformer架構如何不斷演化。
希望這些內容既能幫助研究人員設計新的模型架構,也能為工程師提供指導,以最佳化當前一代的大語言模型(LLM)計算效能
作者保證,讀完此書一定有所收穫:
在閱讀完本書後,應該能自信地為特定硬體平臺上的Transformer模型選擇最佳並行方案,並大致估算訓練和推理的耗時。
如果你仍然感到困惑,請告訴我們!我們希望知道如何讓這些內容更加清晰易懂。

基礎知識

要閱讀此書,作者提醒讀者:
對LLM(大語言模型)和Transformer架構有基本的瞭解,但不一定熟悉它們在大規模計算中的運作方式。
應該瞭解LLM訓練的基礎知識,並且最好對JAX有一定的瞭解。
下面的背景資料,有助於瞭解所需的基礎知識:
部落格連結:https://jalammar.github.io/illustrated-transformer/
JAX講義:https://github.com/rwitten/HighPerfLLMs2024
整體結構
在本書中,將解答以下問題:
  • 矩陣乘法的計算時間如何估算?在多大規模下,它的計算受限於計算能力、記憶體頻寬還是通訊頻寬?
  • TPU是如何連線在一起組成訓練叢集的?系統的各個部分分別具備多少頻寬?
  • 在多個TPU之間進行資料收集(gather)、分發(scatter)或重新分佈(re-distribute)需要多少時間?
  • 如何高效地計算跨裝置分佈的矩陣乘法?
這些內容能幫助讀者,深入理解LLM在現代硬體上的執行機制,並學會如何最佳化訓練和推理的效率。
《第1章》介紹屋頂線分析(Roofline Analysis),並探討限制模型擴充套件的關鍵因素,包括通訊、計算和記憶體。
《第2章》和《第3章》詳細講解TPU和現代GPU的工作原理,既包括作為獨立晶片的執行機制,也涵蓋了更關鍵的內容——它們如何透過晶片間互連(inter-chip links)形成一個計算叢集,並受到頻寬和延遲的限制。
五年前,機器學習領域的架構還十分多樣化包——括卷積神經網路、長短時記憶網路、多層感知機和Transformer等。如今,Transformer架構一家獨大
Transformer結構的每一個細節,都非常值得深入理解,包括:矩陣的具體尺寸、歸一化(Normalization)發生的位置、各部分包含多少引數和FLOPs(浮點運算次數)。
《第4章》將詳細解析Transformer的數學計算,幫助你掌握如何計算訓練和推理過程中的引數量和FLOPs。
這些計算將揭示:
  • 模型的記憶體佔用有多大?
  • 計算和通訊的時間消耗分佈如何?
  • 注意力機制(Attention)和前饋網路(Feed-Forward Blocks)何時成為計算的瓶頸?
透過這些分析,將能夠更精確地最佳化Transformer訓練和推理的效率,並更深入地理解其計算特性。
圖示2:標準Transformer層,每個矩陣乘法(matmul)以圓圈中的點表示。所有引數(不包括歸一化層)以紫色顯示。
《第5章:訓練》和《第7章:推理》是本書的核心內容,在這兩章中將討論一個根本問題:
給定一個大小和一定數量晶片的模型,如何將模型並行化,以保持在「強擴充套件」(strong scaling)範疇內?
這個看似簡單的問題,其實有著令人意外的複雜答案。
從高層次來看,主要有四種並行化技術用於將模型分佈到多個晶片上:資料並行(Data Parallelism)、張量並行(Tensor Parallelism)、流水線並行(Pipeline Parallelism)以及專家並行(Expert Parallelism)。
圖3:純資料並行(前向傳播)示意圖。啟用(Activations)(左側)完全按照批次維度(batch dimension) 進行分片。
這種方法透過將批次分配到多個 TPU 上,實現了資料並行,從而在沒有額外通訊負擔的情況下,加速模型的計算。
此外,還有多種技術可以減少記憶體需求,比如重新計算(Rematerialization)、最佳化器/模型分片(Optimizer/Model Sharding,也稱為ZeRO)、主機解除安裝(Host Offload)、梯度累積(Gradient Accumulation)。
在這兩章中將討論這些技術,並幫助理解如何在新的架構或設定中選擇最適合的並行化策略。
《第6章》和《第8章》是實際操作教程,應用這些概念於LLaMA-3,更直觀地理解如何在實際應用中進行操作。
最後,《第9章》和《第10章》將討論如何在JAX中實現這些想法,並介紹當代碼出現問題時如何進行效能分析和除錯。
在《第11章》中,會給出進一步閱讀清單和更深入的參考文獻。
在整個過程中,會給出一些需要自己動手解決的問題。
作者溫馨提示:
請不要覺得有壓力要按順序閱讀所有章節,也不一定要全部閱讀完。
我們鼓勵你留下反饋意見。
目前這是草稿版本,未來會繼續修訂和改進。
當前的目錄,翻譯如下:
參考資料:
https://x.com/jacobaustin132/status/1886844716446007300
https://x.com/JeffDean/status/1886852442815652188
https://jax-ml.github.io/scaling-book/

相關文章