
來源 | 機器之心
普林斯頓大學計算機科學系助理教授陳丹琦團隊又有了新論文了。
近期,諸如「長思維鏈」等技術的興起,帶來了需要模型生成數萬個 token 的全新工作負載。
大多數語言模型都基於 Transformer 架構,其在進行自迴歸解碼(即逐字生成文字)時,需要將所有先前 token 的注意力狀態儲存在一個名為 KV 快取的記憶體區域中。
KV 快取是模型進行快速推理的基石,但它的大小會隨著輸入文字的長度線性增長。例如,使用 Llama-3-70B 模型處理一個長度為 128K token 的提示(這大約相當於 Llama 3 技術報告本身的長度),就需要分配高達 42GB 的記憶體專門用於儲存 KV 快取。
許多先前的工作意識到了這個問題,並提出了從記憶體中丟棄(驅逐)部分鍵值對的方法,以實現所謂的「稀疏注意力」。然而,在一個公平的環境下對它們進行橫向比較卻異常困難。
生成過程 = 預填充(對輸入進行前向傳播並儲存鍵值對)+ 後填充(一次解碼一個輸出詞元)。
有些論文旨在加速預填充階段;另一些則忽略該階段,轉而致力於最小化後填充階段的記憶體開銷。同樣,有的研究側重於吞吐量,而另一些則著力於最佳化記憶體使用。
陳丹琦團隊提出了「KV 足跡」作為一種統一的度量標準,它是在所有時間步中,未被逐出的鍵值快取條目所佔比例的聚合值。這一個指標就同時涵蓋了預填充和解碼兩個階段的全部開銷,使得在同等基礎上比較不同方法成為可能。
-
論文標題:Cache Me If You Can: How ManyKVsDoYouNeed for Effective Long-Context LMs?
-
論文地址:https://arxiv.org/pdf/2506.17121v1
-
程式碼地址: https://github.com/princeton-pli/PruLong
為了確保比較的實用價值,團隊定義了「關鍵 KV 足跡」:即在模型效能相對於完整的全注意力機制不低於 90% 的前提下,一個方法所能達到的最小 KV 足跡。這個「90% 效能」的硬性標準,確保了我們比較的是真正有用的、未嚴重犧牲模型能力的最佳化方法。
該度量標準揭示了先前 KV 驅逐方法存在的高峰值記憶體問題。其中後填充驅逐由於與預填充階段的驅逐不相容,導致其 KV 足跡非常高。團隊對這類方法進行了改進,使其能夠在預填充期間驅逐 KV,從而顯著降低了 KV 足跡。
接著,團隊轉向「新近度驅逐」方法,並在此基礎上提出了 PruLong,這是一種端到端的最佳化方法,用於學習哪些注意力頭需要保留完整的 KV 快取,而哪些則不需要。PruLong 在節省記憶體的同時保持了長上下文效能,其 KV 足跡比先前的方法小 12%,並且在具有挑戰性的召回任務中保持了原有的效能。
KV 快取驅逐的統一框架
測量關鍵的 KV 佔用空間
給定一個包含
,基於 Transformer 的語言模型通常分兩個階段來生成一個響應
:

個 token 的提示語


-
預填充
整個輸入序列
在一次前向傳播過程中被處理。每個注意力頭 的鍵值狀態



-
解碼
逐個解碼生成

這些 token,每次生成時都會讀取並更新 KV 快取。
KV 快取的儲存消耗會隨著提示長度和生成長度的增加而線性增長,研究人員提出了許多方法來解決這一開銷問題。總體而言,這些方法透過稀疏化注意力模式,從而允許某些 KV 條目被驅逐。
然而,這些方法針對推理流程的不同階段進行了定製:有些方法在預填充階段之後丟棄 KV 條目,而另一些方法則在預填充階段也對 KV 快取進行修剪。這使得對不同方法進行公平且全面的比較變得困難。首先探討為何常用的 KV 快取大小指標無法衡量模型在實際應用中的實用性。
在實際應用中,對長上下文進行單次前向傳播的預填充操作成本高昂。對於長輸入序列,將輸入序列分割成多個塊,並在多次前向傳播中處理這些塊的分塊預填充方法正日益成為標準實踐。這種方法通常能夠減少與長輸入相關的峰值 GPU 記憶體佔用,並使得較短提示的解碼過程能夠與較長提示的額外塊同時進行。
此外,像多輪對話或交錯工具呼叫等場景,還需要多個解碼和預填充階段,這就需要一種全面的方法來衡量 KV 佔用空間。而推測性解碼進一步模糊了預填充階段和解碼階段之間的界限,因為解碼過程變得更加依賴計算資源。
在考慮預填充和解碼過程中都進行多次前向傳播的推理情況時,「KV 佔用空間」應考慮隨時間變化的記憶體使用情況。例如,它應反映出在分塊預填充過程中,是否在預填充完成之前就已經驅逐了 KV 條目。
具體的推理過程由輸入長度、輸出長度以及因方法而異的實現細節來表徵。由於缺乏能夠捕捉所有這些細微差別的指標,本研究提出了一種理想化的指標,該指標能夠:(1)跟蹤整個預填充和解碼過程中的 KV 快取記憶體使用情況;(2)考慮每個 KV 條目的生命週期,從而實現對不同方法的公平且全面的比較。
本研究檢查這些方法的注意力模式(圖 1),並將每個鍵值(KV)條目分類為:活躍的(在當前步驟中使用)、非活躍的(在當前步驟中儲存但未使用)或被驅逐的(在任何未來的步驟中都未使用,並從記憶體中移除)。本研究將 KV 佔用空間定義為所有時間步中未被驅逐的注意力條目的數量。該數值被歸一化為完全因果注意力。
例如,在圖 1 中,KV 佔用空間為 。一種理想的方法會盡早驅逐 KV,以儘量減少佔用空間。本研究考慮了另一種指標,該指標跟蹤注意力矩陣中的峰值 KV 佔用率。在實驗中,這兩種指標得出的結論相似。
本研究還討論了方法與實際效能指標(如總令牌吞吐量和 GPU 記憶體利用率)之間的關係。研究發現,在許多情況下,KV 佔用空間與吞吐量密切相關,但具體的排名取決於 KV 驅逐之外的實現細節——不同方法在不同實現框架下的實際效率差異很大。
關鍵 KV 佔用空間:以往的研究通常在固定的稀疏度水平下報告任務效能,但本研究認為,一個更有意義的指標是在保留大部分原始效能的情況下所能達到的稀疏度。本研究將關鍵 KV 佔用空間定義為一種方法在長上下文任務中保留完整注意力效能的一部分(本文中
)時所需的最小佔用空間。低於此閾值,效能下降可能會過於嚴重,導致該方法無法繼續使用。

高效長上下文推理的現有方法
本研究調研了高效的長上下文方法,並討論了它們如何契合本研究的 KV 佔用空間框架。表 1 概述了主要方法,展示了這些方法如何進行不同的權衡以及使用不同的稀疏性概念。
動態和預填充稀疏性方面:Native Sparse Attention、MoBA、QUEST 和 TokenButler 將 KV 快取視為兩級層次結構,僅將相關的注意力塊從高頻寬記憶體(HBM)載入到片上 SRAM 進行處理。像 MInference 和 FTP 這類技術,在預填充階段使用動態稀疏注意力來近似全注意力。動態稀疏性方法會產生更多非活躍的 KV,能夠提升吞吐量,但它們並未減少 KV 記憶體,因此這些方法與本研究的關注點正交。
近期性驅逐:先前的研究確定了流式注意力頭,這些注意力頭僅關注區域性滑動視窗和一組初始的「匯聚令牌」。驅逐遠距離的鍵值(KV)條目會大幅減少 KV 佔用空間(圖 2),因為在上下文長度增加時,KV 快取的大小保持固定,並且這種方法可在預填充和解碼過程中應用。然而,近期性驅逐可能會「遺忘」相關的遠距離上下文,這促使 DuoAttention 和 MoA 僅將一部分注意力頭轉換為流式頭。作為 KV 快取壓縮的有前景的候選方法,後續將更詳細地討論這些方法。
後填充驅逐:我們使用「後填充驅逐」這一術語來指代在預填充階段結束後從鍵值(KV)快取中刪除令牌的方法。這些方法依賴於通常基於注意力分數的啟發式規則來識別上下文中最重要鍵值對。這些方法可以在預填充後大量修剪鍵值對,並在解碼過程中減少 KV 記憶體。然而,在具有長提示和短生成的推理場景中,由於所有 KV 條目在預填充期間都儲存在記憶體中,這也會在驅逐前導致相當大的峰值記憶體,後填充驅逐只能實現有限的 KV 佔用空間減少。
正交技術:量化透過降低 KV 快取的精度而非基數來節省記憶體,並且可以與本文考慮的任何方法結合使用。另一個方向是在預訓練新語言模型之前設計記憶體高效的架構。這可能涉及在查詢或層之間重用 KV 狀態,降低鍵值維度,或者交錯全域性和區域性注意力層。其他方法是用迴圈層、線性注意力或狀態空間層替換 softmax 注意力。這些方法與 KV 驅逐正交。
PruLong:一種用於注意力頭專業化的端到端方法
本研究探討過:驅逐「陳舊」鍵值對(KVs)雖能顯著降低記憶體佔用,但可能導致重要歷史資訊的丟失。這一發現推動了後續研究工作,旨在識別哪些注意力頭關注全域性上下文、哪些聚焦區域性上下文,從而僅對區域性注意力頭中的 KVs 執行驅逐操作。
DuoAttention 將注意力頭分為兩類:檢索頭,從整個上下文中召回相關資訊;流式頭,僅關注最近的 token 和輸入序列開頭的少量「匯聚」token。DuoAttention 透過將注意力機制表示為流式注意力和全注意力的疊加,並透過引數化來學習注意力頭的型別。
其中,
雖然 DuoAttention 在實證中表現出色,但團隊發現了幾種進一步降低其關鍵 KV 佔用空間的方法。團隊結合這些見解,設計出 PruLong(長程精簡注意力機制),一種用於 KV 驅逐的端到端方法。PruLong 像 DuoAttention 一樣將注意力頭分為兩類,但在訓練目標、引數化和訓練資料方面進行了創新。接下來將依次介紹這些內容。
-
下一個 token 預測損失
PruLong(長程精簡注意力機制)直接最小化混合注意力模型的下一個 token 預測損失,而非最後一個隱藏狀態的重建誤差,這與這些模型在文字生成中的使用方式更為契合。
-
針對注意力型別最佳化離散掩碼
DuoAttention 學習一個連續的門控變數

,該變數易於最佳化,但沒有反映出在推理過程中

會被四捨五入為 0 或 1,因此引入了訓練-測試差距。PruLong(長程精簡注意力機制)將

視為從由

引數化的伯努利分佈中抽取的二進位制掩碼,並透過來自剪枝文獻的既定方法——將伯努利分佈重新引數化為硬實體隨機變數,實現端到端最佳化。最終目標如下

其中,
(正則化損失)透過約束掩碼整體稀疏度
(稀疏度函式)逼近目標值



(目標稀疏度),該過程透過 min-max 最佳化實現——
-
利用自然長上下文資料
PruLong 利用自然長上下文資料。DuoAttention 的合成訓練資料僅需要簡單的長程回憶能力,而實際應用場景可能需要更復雜的能力。PruLong 由高天宇等人在自然長上下文預訓練資料上進行訓練,這些資料包含程式碼倉庫和書籍等,具有多樣的長程依賴關係。
PruLong 論文地址:https://arxiv.org/abs/2410.02660
技術交流群邀請函
△長按新增小助手
掃描二維碼新增小助手微信
請備註:姓名-學校/公司-研究方向
(如:小張-哈工大-對話系統)
即可申請加入自然語言處理/Pytorch等技術交流群
關於我們
MLNLP 社群是由國內外機器學習與自然語言處理學者聯合構建的民間學術社群,目前已經發展為國內外知名的機器學習與自然語言處理社群,旨在促進機器學習,自然語言處理學術界、產業界和廣大愛好者之間的進步。
社群可以為相關從業者的深造、就業及研究等方面提供開放交流平臺。歡迎大家關注和加入我們。

掃描二維碼新增小助手微信
關於我們
