使用A10單卡24G復現DeepSeekR1強化學習過程

阿里妹導讀
本文描述DeepSeek的三個模型的學習過程,其中DeepSeek-R1-Zero模型所涉及的強化學習演算法,是DeepSeek最核心的部分之一會重點展示。
一、背景
隨著DeepSeek的火爆使用,其背後的訓練技術也值得深入學習,整體DeepSeek相關的訓練過程如下圖所示。
其中主要涉及以下三個模型,其中DeepSeek-R1-Zero模型所涉及的強化學習演算法,是DeepSeek最核心的部分之一,本次我們主要重現的也是這個部分。
1. DeepSeek-R1-Zero
是在基礎模型DeepSeek-V3上進行強化學習(RL)後得到了DeepSeek-R1-Zero模型。該模型學會了如何推理、建立思維鏈序列,並具備自我驗證和反思等能力。儘管DeepSeek-R1-Zero的學習能力令人驚歎,但它存在語言混合、可讀性差等嚴重問題。
2. DeepSeek-R1
首先使用數千個思維鏈(CoT)序列示例形式的冷啟動資料,在DeepSeek-V3上進行監督微調(SFT),目的是為強化學習建立一個更穩定的起點,解決DeepSeek-R1-Zero存在的問題。接著進行強化學習,並設定獎勵機制,以促進語言一致性,增強在科學、編碼和數學等任務上的推理能力。然後,再次進行監督微調,這次加入了非推理重點的訓練示例,幫助模型保留寫作、角色扮演等更多通用能力。最後,再次進行強化學習,以更好地符合人類偏好。最終得到了一個擁有6710億引數的高效能模型。
3. DeepSeek-R1-Distill*
他們基於Qwen和Llama架構,對引數在15億 – 700億之間的較小模型進行微調,得到了一組更輕量、更高效且推理能力更強的模型。這極大地提高了開發人員的可及性,因為許多提煉後的模型可以在他們的裝置上快速執行。
二、方案
1. 環境資訊
強化學習(TRL):主要採用了huggingface提供的grpo_trainer方案(參考連結:https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb)
資料集:主要透過資料集gsm8k進行訓練
GPU: 單張A10,視訊記憶體24G
模型:Qwen2.5-0.5B-Instruct
2. 依賴安裝
# 基於目前最新的vllm 0.7.2進行驗證pip install vllm -U# 基於目前最新的trl 0.15.1進行驗證pip install trl -U
3. 訓練
import reimport torchfrom modelscope import AutoTokenizer, AutoModelForCausalLMfrom modelscope.msdatasets import MsDatasetfrom trl import GRPOConfig, GRPOTrainerSYSTEM_PROMPT = """You need to answer in XML format, include <reasoning> and <answer>, respond in the following format:<reasoning>...</reasoning><answer>...</answer>"""XML_COT_FORMAT = """\<reasoning>{reasoning}</reasoning><answer>{answer}</answer>"""def extract_xml_answer(text: str) -> str:    answer = text.split("<answer>")[-1]    answer = answer.split("</answer>")[0]return answer.strip()def extract_hash_answer(text: str) -> str | None:if"####"not in text:return Nonereturn text.split("####")[1].strip()def get_gsm8k_questions(split="train") -> MsDataset:    data = MsDataset.load('modelscope/gsm8k', subset_name='main', split=split)    data = data.map(lambda x: {'prompt': [            {'role': 'system', 'content': SYSTEM_PROMPT},            {'role': 'user', 'content': x['question']}        ],'answer': extract_hash_answer(x['answer'])    })return datadataset = get_gsm8k_questions()# Reward functionsdef correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:    responses = [completion[0]['content'] for completion in completions]    q = prompts[0][-1]['content']    extracted_responses = [extract_xml_answer(r) for r in responses]    print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}",          f"\nExtracted:\n{extracted_responses[0]}")return [2.0if r == a else0.0for r, a in zip(extracted_responses, answer)]def int_reward_func(completions, **kwargs) -> list[float]:    responses = [completion[0]['content'] for completion in completions]    extracted_responses = [extract_xml_answer(r) for r in responses]return [0.5if r.isdigit() else0.0for r in extracted_responses]# def strict_format_reward_func(completions, **kwargs) -> list[float]:#     pattern = r"\n<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"#     responses = [completion[0]["content"] for completion in completions]#     matches = [re.fullmatch(pattern, r, re.DOTALL) for r in responses]#     return [0.5 if match else 0.0 for match in matches]def strict_format_reward_func(completions, **kwargs) -> list[float]:    pattern = r"<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>"    responses = [completion[0]["content"] for completion in completions]    # 新增除錯日誌    matches = []for idx, r in enumerate(responses):        print(f"\n--- Processing response {idx} ---")        print("Raw content:", repr(r))  # 使用 repr() 顯示跳脫字元        match = re.fullmatch(pattern, r, re.DOTALL)        print("Match result:", bool(match))        matches.append(match)return [0.5if match else0.0for match in matches]def soft_format_reward_func(completions, **kwargs) -> list[float]:    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"    responses = [completion[0]["content"] for completion in completions]    matches = [re.fullmatch(pattern, r, re.DOTALL) for r in responses]return [0.5if match else0.0for match in matches]def count_xml(text) -> float:    count = 0.0if text.count("<reasoning>\n") == 1:        count += 0.125if text.count("\n</reasoning>\n") == 1:        count += 0.125if text.count("\n<answer>\n") == 1:        count += 0.125        count -= len(text.split("\n</answer>\n")[-1]) * 0.001if text.count("\n</answer>") == 1:        count += 0.125        count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001return countdef xmlcount_reward_func(completions, **kwargs) -> list[float]:    contents = [completion[0]["content"] for completion in completions]return [count_xml(c) for c in contents]model_name = "Qwen/Qwen2.5-0.5B-Instruct"output_dir = "outputs/Qwen-0.5B-GRPO"run_name = "Qwen-0.5B-GRPO-gsm8k"training_args = GRPOConfig(    output_dir=output_dir,    run_name=run_name,    learning_rate=5e-6,    adam_beta1=0.9,    adam_beta2=0.99,    weight_decay=0.1,    warmup_ratio=0.1,    lr_scheduler_type='cosine',    logging_steps=1,    bf16=True,    per_device_train_batch_size=8,    gradient_accumulation_steps=4,    num_generations=8,    max_prompt_length=256,    max_completion_length=200,    num_train_epochs=1,    save_steps=100,    max_grad_norm=0.1,    log_on_each_node=False,    use_vllm=True,    vllm_gpu_memory_utilization=.2,    vllm_device="cuda:0",    report_to="none")model = AutoModelForCausalLM.from_pretrained(    model_name,    torch_dtype=torch.bfloat16,    device_map=None).to("cuda")tokenizer = AutoTokenizer.from_pretrained(model_name)tokenizer.pad_token = tokenizer.eos_tokentrainer = GRPOTrainer(    model=model,    processing_class=tokenizer,    reward_funcs=[        xmlcount_reward_func,        soft_format_reward_func,        strict_format_reward_func,        int_reward_func,        correctness_reward_func],    args=training_args,    train_dataset=dataset,)trainer.train()
4. reward_funcs(獎勵函式)
如上面程式碼所示,主要涉及以下5個獎勵函式

4.1.  correctness_reward_func(正確性獎勵函式)

檢查模型的輸出是否與參考答案 (answer) 完全匹配,匹配則獎勵 2.0,否則 0.0。

4.2. int_reward_func(整數檢測獎勵函式)

檢查模型輸出是否是純數字(整數),是則獎勵 0.5,否則 0.0。

4.3. strict_format_reward_func(嚴格格式獎勵函式)

嚴格格式獎勵,必須完全匹配 <reasoning>…</reasoning><answer>…</answer>,包括其中的換行符,都必須滿足格式,如果符合格式的獎勵 0.5,否則 0.0

4.4. soft_format_reward_func(寬鬆格式獎勵函式)

允許更靈活的格式,只要包含 <reasoning>…</reasoning><answer>…</answer>,即獎勵 0.5,對比嚴格模式更加寬鬆

4.5. count_xml,xmlcount_reward_func(XML 結構評分函式)

計算模型輸出 XML 結構的完整度,並給予相應獎勵。獎勵規則:
檢查 XML 結構完整度:
每個正確的標籤匹配增加 0.125 獎勵:
<reasoning>\\n:+0.125
</reasoning>\\n:+0.125
<answer>\\n:+0.125
</answer>:+0.125
考慮額外文字的懲罰:
如果 </answer> 後面有多餘的內容,則減少獎勵 0.001 × 額外字元數
5. 訓練引數
核心引數說明如下:
1.gradient_accumulation_steps=4:每進行4次的前向傳播和反向傳播後,才會執行一次權重更新;
2.max_completion_length=200: 表示限制模型返回最大長度200;
3.save_steps=100:表示每執行100步才儲存一次checkpoint;
gsm8k資料集一共接近8000條資料,每4次會更新一次,則需要更新2000次,每100步儲存一次,則需要生成20個checkpoint。
三、過程日誌分析
1. 日誌分析
透過python train.py > train.log執行程式碼,透過tail -f train.log進行即時日誌檢視,最後整體效果如下圖所示,最後有效資料1868個,執行時間是2:25:25。
2. 訓練資料分析
GRPO Trainer會記錄很多訓練過程中的指標,主要包括在:
  • completion_length:完成時長;
  • reward/{reward_func_name}:每個 reward 函式計算的獎勵;
  • reward:平均獎勵;
  • reward_std :獎勵組內的平均標準差;
  • kl : 根據完成次數計算的模型和參考模型之間的平均 KL 散度。
其中我們主要關注以下兩個獎勵指標:
  • 準確性獎勵:基於響應的正確性(對應correctness_reward_func)
  • 格式獎勵:確保響應符合結構指南(對應strict_format_reward_func和soft_format_reward_func)

2.1. 準確性獎勵

2.2. 格式獎勵

四、推理驗證
1. 微調前的模型
格式和答案都不對,而且不穩定:
2. 微調後的模型
格式和答案都滿足要求:
五、思考
透過對比微調前後的模型,雖然我們這次使用的是一個0.5B的小模型,資料量也不大,但是還是可以透過這個流程,體驗強化學習的整個流程,對我們理解強化學習還是很有好處的。並且從整個實驗中,也理解了DeepSeek整個方案設計的原因,其中以下幾個點印象深刻。
1. 訓練資料分析
透過對訓練後的獎勵函式資料進行分析發現,其中模型的格式獎勵函式strict_format_reward_func和soft_format_reward_func,都是在訓練到固定步數左右的時候,得分開始突然上升,然後後續就逐漸穩定,如下圖所示。可以看到,寬鬆校驗在500步的時候已經基本穩定到0.5的分數,而由於嚴格模式對格式更加嚴格,所以嚴格模式在1000步的時候才到穩定。透過這樣的資料,可以指導我們下一步進行實驗資料調整,從而獲取最佳的checkponit模型進行匯出。
2. 冷啟動的問題
我們可以看到模型在早期訓練的時候,效果很差,模型基本都是在瞎試。所以為了加快訓練,deepseek加入了SFT的資料解決冷啟動的問題,如下面的截圖所示。透過R1-Zero生成SFT的資料,解決了R1的冷啟動問題。
Lindorm泛時序資料一站式解決方案
隨著業務增長帶來的資料量激增,如何高效地獲取和分析這些資料成為業務洞察和決策的關鍵挑戰,Lindorm作為阿里雲自研的雲原生多模資料庫,具備低成本儲存、彈性高可用的能力,提供一站式的分析與洞察。    
點選閱讀原文檢視詳情。

相關文章