「古董」GPU也能跑DeepSeek同款GRPO!視訊記憶體只需1/10,上下文爆漲10倍


新智元報道  

編輯:KingHZ
【新智元導讀】開源微調神器Unsloth帶著黑科技又來了:短短兩週後,再次最佳化DeepSeek-R1同款GRPO訓練演算法,上下文變長10倍,而視訊記憶體只需原來的1/10!
開源微調神器Unsloth帶著黑科技又來了:上次更新把GRPO需要的記憶體見到了7GB,這次只需要5GB的VRAM,就能訓練自己的推理模型Qwen2.5(1.5B),比上次要少2GB。
這次徹底把推理模型訓練視訊記憶體打下來了!
這次把GRPO訓練推理模型的上下文變長10倍,同時需要的視訊記憶體少了90%。
使用最新的Unsloth,只要5GB視訊記憶體就能訓練自己的推理模型,而且Qwen2.5-1.5B不會損失準確率。
5GB視訊記憶體什麼概念呢?
16年開始發售的GPU比如GTX 1060的視訊記憶體都有8GB。16年GTX 1060放到現在,堪稱電子古董!
目前,實現更長的上下文是GRPO面臨的最大挑戰之一。
與其他GRPO LoRA/QLoRA實現相比,即使是基於Flash Attention 2(FA2)的實現,Unsloth新推出的高效GRPO演算法上下文長度增加了10倍,同時使用的VRAM只要10%。
在配備TRL+FA2的GRPO設定中,Llama 3.1(8B)在20K上下文長度下,訓練需要510.8GB的VRAM。
而Unsloth將VRAM減少了90%,降至僅54.3GB。
減少長上下文90%VRAM
和使用Flash Attention 2的標準實現相比,Unsloth使用多種技巧,巧妙地把GRPO的VRAM使用量減少了90%多!
在20K的上下文長度下,每個提示生成8次,Unsloth在Llama-3.1-8B模型上僅使用54.3GB的VRAM,而標準實現需要510.8GB(Unsloth減少了90%)。這一切得益於下列3項突破:
  1. 全新設計的記憶體高效線性演算法:將GRPO的記憶體使用量削減了8倍以上,節省了68.5GB的記憶體。藉助torch.compile,在num_generations=8和20K上下文長度下,實際上還更快。
  2. 利用了Unsloth已釋出的智慧梯度checkpoint演算法:將中間啟用值非同步解除安裝到系統RAM中,速度僅慢了1%。由於需要num_generations=8,這節省了高達372GB的VRAM。透過中間梯度累積,甚至可以進一步減少記憶體使用。
  3. 與底層推理引擎(vLLM)共享相同的GPU/CUDA記憶體空間,不像其他包中的實現那樣。這又節省了16GB的VRAM。
Unsloth和基於Flash Attention 2(FA2)的標準實現記憶體比較
在典型的GRPO標準實現中,需要建立兩個大小為(8,20K)的logits來計算GRPO損失。這需要2*2位元組*8(生成次數)*20K(上下文長度)*128256(詞彙表大小)=78.3GB的VRAM。
Unsloth將長上下文GRPO的記憶體使用量削減了8倍,因此對於20K的上下文長度,只需要額外的9.8GBVRAM!
還需要以16位格式儲存KV快取。Llama3.18B有32層,K和V的大小均為1024。因此,對於20K的上下文長度,記憶體使用量=2*2位元組*32層*20K上下文長度*1024=每個批次2.5GB。
可以將vLLM的批次大小設定為8,但為了節省VRAM,在計算中將其保持為1。否則,需要20GB來儲存KV快取。
數學原理
分組相對策略最佳化(Group Relative Policy Optimization,GRPO),出自DeepSeek去年發表的論文。
如果一生只能讀一篇DeepSeek的論文,網友建議選擇首次提出GRPO的DeepSeekMath論文。
論文連結:https://arxiv.org/abs/2402.03300
隨後在DeepSeek的論文中,利用GRPO演算法建立了DeepSeek-R1。

發現的問題

在這裡利用了Hugging Face的TRL GRPO實現。
注意到,TRL實現的公式如下:
其中使用的是反向KL散度(而不是正向KL散度)。β是一個設為0.04的縮放因子,A是考慮所有獎勵函式後得到的優勢值。q是新訓練的模型,P是原始參考模型。
然後注意到,該實現將反向KL散度計算為:
但這真的是正確的嗎?
首先嚐試推導並整理類似項:
這意味著什麼?實現中可能缺少一個與q(新分佈項)的乘法嗎?
但這似乎是正確的,和DeepSeek-Math論文第14頁首次引入GRPO時一樣。
DeepSeek-Math論文第14頁:在損失函式中新增KL散度,正則化GRPO演算法
同樣,John Schulman的部落格也提到,反向KL項的無偏估計,實際上並不需要額外的q項。
連結地址:http://joschu.net/blog/kl-approx.html
在部落格中看到:
還發現了一個有趣的現象:
torch.exp(q-q.detach()) * advantages.unsqueeze(1)
這應該等於1,對嗎?
Hugging Face的TRL GRPO實現
實際上,發現這是必要的——似乎自動梯度autograd引擎可能無法正確傳播梯度。
因此,進行了4個實驗:
  1. 使用參考實現的常規GRPO(紅線)
  2. 移除detach程式碼(藍線)
  3. 按照之前討論的完整反向KL,新增額外項(黃線)
  4. 使用正向KL散度代替(綠線)
總體來說,移除detach顯然會破壞訓練,所以必須保留它——這很可能需要進一步調查。其他實現似乎也類似?可能需要執行模型更長時間,以觀察不同的效果。
在所有實現中,還利用了logsumexp技巧:
Unsloth高效GRPO演算法

但沒想到華人工程師Horace He的線性交叉熵實現,帶給unsloth靈感併成功應用於GRPO!

Horace He,在Meta從事PyTorch相關工作
實際上,unsloth發現了一些令人驚訝的要點:
1 GRPO參考實現使用的是反向KL散度,而不是正向KL散度。
2 如果不正確處理,在float16混合精度(以及float8)上直接實現線性交叉熵,並使用自動混合精度縮放機制,會導致崩潰。
3 發現了GRPO損失實現中的其他一些奇怪之處,主要是在反向KL散度的公式表述方面。
線性交叉商連結:https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899
其他功能

GRPO的完整日誌記錄

之前,unsloth只顯示總聚合獎勵函式本身,新版本為所有獎勵函式提供完整的日誌記錄詳情!
也不再需要呼叫函式來給GRPO打補丁了!也就是說,新版本會自動處理,可以刪除下列程式碼:
from unsloth import PatchFastRLPatchFastRL("GRPO", FastLanguageModel)

vLLM推理選項

現在在vLLM中還能使用FP8 KV快取,這可以在較新的GPU(RTX 3090、A100及更新型號)上將KV快取空間使用量減少2倍。
model,tokenizer = FastLanguageModel.from_pretrained(model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",max_seq_length = max_seq_length,load_in_4bit = True, fast_inference = True, max_lora_rank = lora_rank,gpu_memory_utilization = 0.6, float8_kv_cache = True, )
如果想在vLLM中使用min_p=0.1或其他取樣引數,也支援傳遞vLLM的SamplingParams引數中的任何內容!
max_prompt_length = 256fromtrl import GRPOConfig, GRPOTrainerfromunsloth import vLLMSamplingParamsvllm_sampling_params = vLLMSamplingParams(min_p = 0.1,seed = 3407,...)training_args = GRPOConfig(...vllm_sampling_params = vllm_sampling_params,temperature = 1.5,)
參考資料:
https://unsloth.ai/blog/grpo

相關文章