Deepseek的RL演算法GRPO解讀

MLNLP

社群是國內外知名的機器學習與自然語言處理社群,受眾覆蓋國內外NLP碩博生、高校老師以及企業研究人員。


社群的願景是促進國內外自然語言處理,機器學習學術界、產業界和廣大愛好者之間的交流和進步,特別是初學者同學們的進步。
來源 | 知乎
作者|AIQL
在本文中,我們將深入探討Deepseek採用的策略最佳化方法GRPO,並順帶介紹一些強化學習(Reinforcement Learning, RL)的基礎知識,包括PPO等關鍵概念。

策略函式(policy)

在強化學習中, 表示在狀態 下采取動作 的條件機率。具體來說,它是由策略函式 決定的。

詳細說明

  • 表示在時間步 時的狀態(state)。
  • 狀態是環境對智慧體的當前描述,例如在遊戲中可能是角色的位置、速度等資訊。
  • 表示在時間步 時智慧體採取的動作(action)。
  • 動作是智慧體在給定狀態下可以執行的操作,例如在遊戲中可能是“向左移動”或“跳躍”。
  • 是策略函式(policy),表示在狀態 下選擇動作 的機率。
  • 如果是確定性策略, 會直接輸出一個確定的動作;如果是隨機策略,它會輸出一個動作的機率分佈。
  • 在 PPO 中, 是新策略 和舊策略 在狀態 下選擇動作 的機率比。
  • 這個比值用於衡量策略更新的幅度,並透過裁剪機制限制其變化範圍,確保訓練的穩定性。

舉例說明

假設我們有一個簡單的遊戲環境:
  • 狀態 :角色的位置。
  • 動作 :可以執行的動作是“向左”或“向右”。
  • 策略 :在某個位置 下,策略可能以 70% 的機率選擇“向左”,以 30% 的機率選擇“向右”。
在 PPO 中,我們會比較新舊策略在相同狀態 下選擇相同動作 的機率,從而計算機率比 ,並用於最佳化目標函式。

總結

表示在狀態 下選擇動作 的條件機率,由策略函式 決定。在 PPO 中,這一機率用於計算新舊策略的比值,從而控制策略更新的幅度。

近端策略最佳化(PPO)

PPO(Proximal Policy Optimization)是一種用於強化學習的策略最佳化演算法,由 OpenAI 提出。它透過限制策略更新的幅度,確保訓練過程的穩定性。

核心思想

PPO 的核心在於限制策略更新的幅度,避免因更新過大導致效能下降。它透過引入“裁剪”機制,控制新舊策略之間的差異。

公式

PPO 的替代目標函式 用於最佳化策略 ,公式如下:
其中:
期望符號 表示對查詢 和輸出 的期望:
  • : 查詢 從分佈 中取樣。
  • : 輸出 由舊策略 生成。
對輸出 的每個時間步 求平均:
  • 是輸出序列的長度。
其核心目標函式為:
其中:
  • 是新舊策略的機率比。
  • 是優勢函式,衡量動作的相對好壞。
  • 是裁剪引數,通常為 0.1 或 0.2。

步驟

  1. 取樣:使用當前策略與環境互動,收集資料,在語言模型中,可以類比為生成補全(generating completions)。
  2. 計算優勢值:基於收集的資料計算優勢值函式 。
  3. 最佳化目標函式:透過梯度上升最佳化目標函式 。
  4. 更新策略:重複上述步驟,直到策略收斂。

優點

  • 穩定性:透過裁剪機制,避免策略更新過大。
  • 高效性:比 TRPO,PPO 實現更簡單,計算效率更高。

補充

在強化學習中,策略的目標是最大化期望回報,而不是最小化損失。所以,在PPO中使用的是梯度上升,原因在於它的最佳化目標是最大化目標函式(如強化學習中的期望回報),而不是最小化損失函式(如分類或迴歸問題)。

Advantage(優勢函式)

定義
Advantage函式用於衡量在某個狀態(State)下,採取某個動作(Action)相對於平均表現的優劣程度。它的數學定義為:, 其中:
  • 動作值函式,表示在狀態 下采取動作 後,未來累積回報的期望。
  • 狀態值函式,表示在狀態 下,按照當前策略採取動作後,未來累積回報的期望。
  • 優勢函式,表示在狀態 下采取動作 比平均表現好多少(或差多少)。

作用

  • Advantage函式用於指導策略更新:
  • 如果 ,說明動作 比平均表現更好,策略應該更傾向於選擇這個動作;
  • 如果 ,說明動作 比平均表現更差,策略應該減少選擇這個動作的機率。
  • 在PPO等演算法中,Advantage函式通常透過GAE(Generalized Advantage Estimation)來估計。

直觀理解

Advantage函式就像一個“評分”,告訴模型某個動作在當前狀態下是好還是壞,以及好(或壞)的程度。

KL Penalty(KL散度懲罰)

定義
KL Penalty是基於KL散度(Kullback-Leibler Divergence)的一種正則化手段。KL散度用於衡量兩個機率分佈之間的差異。在強化學習中,KL Penalty通常用於限制當前策略 和參考策略 之間的差異。其數學定義為: 其中:
  • 是當前策略(由模型引數 決定)。
  • 是參考策略(通常是更新前的策略或某個基線策略)。
  • 是KL散度,用於衡量兩個策略之間的差異。

作用

  • KL Penalty用於防止策略更新過大,確保當前策略不會偏離參考策略太遠。這樣可以避免訓練過程中的不穩定現象(如策略崩潰)。
  • 在PPO等演算法中,KL Penalty通常被新增到目標函式中,作為正則化項。

直觀理解

KL Penalty就像一個“約束”,告訴模型在更新策略時不要“步子邁得太大”,以免失去穩定性。

Advantage和KL Penalty的關係

  • Advantage 用於指導策略更新,告訴模型哪些動作更好。
  • KL Penalty 用於約束策略更新,防止策略變化過大。
  • 在PPO等演算法中,Advantage和KL Penalty共同作用,既鼓勵模型選擇更好的動作,又確保更新過程穩定可靠。

舉例說明

假設我們訓練一個機器人走迷宮:
  • Advantage:機器人發現“向右轉”比“向左轉”更容易找到出口,於是Advantage函式會給“向右轉”一個正的值,鼓勵策略更傾向於選擇“向右轉”。
  • KL Penalty:為了防止策略突然變得只選擇“向右轉”而忽略其他可能性,KL Penalty會限制策略的變化幅度,確保策略更新是平滑的。

總結

  • Advantage(優勢函式):衡量某個動作比平均表現好多少,用於指導策略更新。
  • KL Penalty(KL散度懲罰):限制策略更新的幅度,確保訓練過程的穩定性。

群體相對策略最佳化(GRPO)

GRPO 是一種線上學習演算法(online learning algorithm),這意味著它透過使用訓練過程中由訓練模型自身生成的資料來迭代改進。GRPO 的目標直覺是最大化生成補全(completions)的優勢函式(advantage),同時確保模型保持在參考策略(reference policy)附近。
其目標函式為:
為了理解 GRPO 的工作原理,可以將其分解為四個主要步驟:
  1. 生成補全(Generating completions)
  2. 計算優勢值(Computing the advantage)
  3. 估計KL散度(Estimating the KL divergence)
  4. 計算損失(Computing the loss)

1. 生成補全(Generating completions)

在每一個訓練步驟中,我們從提示(prompts)中取樣一個批次(batch),併為每個提示生成一組 個補全(completions)(記為 )。

2. 計算優勢值(Computing the advantage)

對於每一個 序列,使用獎勵模型(reward model)計算其獎勵(reward)。為了與獎勵模型的比較性質保持一致——通常獎勵模型是基於同一問題的輸出之間的比較資料集進行訓練的——優勢的計算反映了這些相對比較。其歸一化公式如下:
這種方法賦予了該方法其名稱:群體相對策略最佳化(Group Relative Policy Optimization, GRPO)
GRPO透過最佳化PPO演算法,解決了計算優勢值時需要同時依賴獎勵模型(reward model)和價值模型(value model)的問題,成功移除了value model(價值模型),顯著降低了推理時的記憶體佔用和時間開銷。Advantage(優勢值)的核心價值在於為模型輸出提供更精準的評估,不僅衡量答案的絕對質量,還透過相對比較(與其他回答的對比)來更全面地定位其優劣。

3. 估計KL散度(Estimating the KL divergence)

在實際演算法實現中,直接計算KL散度可能會面臨一些挑戰:
  • 計算複雜度高:KL散度的定義涉及對兩個機率分佈的對數比值的期望計算。對於複雜的策略分佈,直接計算KL散度可能需要大量的計算資源;
  • 數值穩定性:在實際計算中,直接計算KL散度可能會遇到數值不穩定的問題,尤其是當兩個策略的機率分佈非常接近時,對數比值可能會趨近於零或無窮大。近似器可以透過引入一些數值穩定性的技巧(如截斷或平滑)來避免這些問題;
  • 線上學習:在強化學習中,策略通常需要在每一步或每幾步更新一次。如果每次更新都需要精確計算KL散度,可能會導致訓練過程變得非常緩慢。近似器可以快速估計KL散度,從而支援線上學習和即時更新。
Approximating KL Divergence 提出的近似器可以根據當前策略和參考策略的差異動態調整估計的精度,從而在保證計算效率的同時,儘可能減少估計誤差,其定義如下:
這個近似器的核心思想是透過對當前策略和參考策略的機率比值的簡單變換來估計KL散度。具體來說:
  • 第一項: 是參考策略與當前策略的機率比值。
  • 第二項: 是對數機率比值。
  • 第三項: 是一個常數項,用於調整近似器的偏差。
這個近似器的優勢在於它只需要計算當前策略和參考策略的機率比值,而不需要直接計算KL散度的積分或期望。因此,它可以在保證一定精度的同時,顯著降低計算複雜度。
近似器的直觀理解
這個近似器的設計靈感來自於泰勒展開。KL散度可以看作是兩個分佈之間的某種“距離”,而這個近似器透過一階或二階近似來估計這個距離。具體來說:
  • 當 和 非常接近時,,此時 ,近似器的值趨近於零,符合KL散度的性質。
  • 當 和 差異較大時,近似器會給出一個較大的正值,反映出兩個分佈之間的差異。

4. 計算損失(Computing the loss)

這一步的目標是最大化優勢,同時確保模型保持在參考策略附近。因此,損失定義如下:
其中第一項表示縮放後的優勢,第二項透過KL散度懲罰與參考策略的偏離。
在原始論文中,該公式被推廣為在每次生成後透過利用裁剪替代目標(clipped surrogate objective)進行多次更新:
其中 透過將策略比率限制在 和 之間,確保更新不會過度偏離參考策略。
在很多程式碼實現,比如Huggingface的TRL中,與原始論文一樣每次生成只進行一次更新,因此可以將損失簡化為第一種形式。

總結

GRPO透過最佳化PPO演算法,移除了價值模型,降低了計算開銷,同時利用群體相對優勢函式和KL散度懲罰,確保策略更新既高效又穩定。
想象一下,你是個銷售員,這個月業績10萬塊,PPO演算法就像個精明的老會計,拿著算盤噼裡啪啦一頓算,考慮市場行情、產品型別,最後得出結論:“嗯,這10萬還算靠譜,但GAE一算,發現你的優勢值還不夠高,還得再加把勁啊”
而GRPO呢,就像老闆直接搞了個“內卷大賽”,把所有銷售員拉到一個群裡,每天曬業績:“你10萬,他15萬,她20萬……”老闆還時不時發個紅包,刺激大家繼續卷。你的10萬塊在群裡瞬間被淹沒,老闆搖搖頭:“你這水平,還得加把勁啊!”
GRPO這招絕了,它把PPO的“算盤”扔了,省了不少計算功夫,直接搞“內卷PK”,用KL散度懲罰來確保大家別躺平。這樣一來,策略更新既快又穩,老闆再也不用擔心有人摸魚了,畢竟大家都在拼命卷,誰敢鬆懈?

寫在最後

總結一下:PPO是“單打獨鬥看實力”,GRPO是“內卷大賽拼到死”,最後GRPO還省了算盤錢,老闆笑得合不攏嘴,而我們只能默默加班,繼續卷。
技術交流群邀請函
△長按新增小助手
掃描二維碼新增小助手微信
請備註:姓名-學校/公司-研究方向
(如:小張-哈工大-對話系統)
即可申請加入自然語言處理/Pytorch等技術交流群

關於我們

MLNLP 社群是由國內外機器學習與自然語言處理學者聯合構建的民間學術社群,目前已經發展為國內外知名的機器學習與自然語言處理社群,旨在促進機器學習,自然語言處理學術界、產業界和廣大愛好者之間的進步。
社群可以為相關從業者的深造、就業及研究等方面提供開放交流平臺。歡迎大家關注和加入我們。

相關文章