
阿里妹導讀
本文描述DeepSeek的三個模型的學習過程,其中DeepSeek-R1-Zero模型所涉及的強化學習演算法,是DeepSeek最核心的部分之一會重點展示。
一、背景
隨著DeepSeek的火爆使用,其背後的訓練技術也值得深入學習,整體DeepSeek相關的訓練過程如下圖所示。

其中主要涉及以下三個模型,其中DeepSeek-R1-Zero模型所涉及的強化學習演算法,是DeepSeek最核心的部分之一,本次我們主要重現的也是這個部分。
1. DeepSeek-R1-Zero
是在基礎模型DeepSeek-V3上進行強化學習(RL)後得到了DeepSeek-R1-Zero模型。該模型學會了如何推理、建立思維鏈序列,並具備自我驗證和反思等能力。儘管DeepSeek-R1-Zero的學習能力令人驚歎,但它存在語言混合、可讀性差等嚴重問題。
2. DeepSeek-R1
首先使用數千個思維鏈(CoT)序列示例形式的冷啟動資料,在DeepSeek-V3上進行監督微調(SFT),目的是為強化學習建立一個更穩定的起點,解決DeepSeek-R1-Zero存在的問題。接著進行強化學習,並設定獎勵機制,以促進語言一致性,增強在科學、編碼和數學等任務上的推理能力。然後,再次進行監督微調,這次加入了非推理重點的訓練示例,幫助模型保留寫作、角色扮演等更多通用能力。最後,再次進行強化學習,以更好地符合人類偏好。最終得到了一個擁有6710億引數的高效能模型。
3. DeepSeek-R1-Distill*
他們基於Qwen和Llama架構,對引數在15億 – 700億之間的較小模型進行微調,得到了一組更輕量、更高效且推理能力更強的模型。這極大地提高了開發人員的可及性,因為許多提煉後的模型可以在他們的裝置上快速執行。
二、方案
1. 環境資訊
強化學習(TRL):主要採用了huggingface提供的grpo_trainer方案(參考連結:https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb)
資料集:主要透過資料集gsm8k進行訓練
GPU: 單張A10,視訊記憶體24G
模型:Qwen2.5-0.5B-Instruct
2. 依賴安裝
# 基於目前最新的vllm 0.7.2進行驗證
pip install vllm -U
# 基於目前最新的trl 0.15.1進行驗證
pip install trl -U
3. 訓練
import re
import torch
from modelscope import AutoTokenizer, AutoModelForCausalLM
from modelscope.msdatasets import MsDataset
from trl import GRPOConfig, GRPOTrainer
SYSTEM_PROMPT = """
You need to answer in XML format, include <reasoning> and <answer>, respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if"####"not in text:
return None
return text.split("####")[1].strip()
def get_gsm8k_questions(split="train") -> MsDataset:
data = MsDataset.load('modelscope/gsm8k', subset_name='main', split=split)
data = data.map(lambda x: {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
})
return data
dataset = get_gsm8k_questions()
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}",
f"\nExtracted:\n{extracted_responses[0]}")
return [2.0if r == a else0.0for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5if r.isdigit() else0.0for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>"
responses = [completion[0]["content"] for completion in completions]
# 新增除錯日誌
matches = []
for idx, r in enumerate(responses):
print(f"\n--- Processing response {idx} ---")
print("Raw content:", repr(r)) # 使用 repr() 顯示跳脫字元
match = re.fullmatch(pattern, r, re.DOTALL)
print("Match result:", bool(match))
matches.append(match)
return [0.5if match else0.0for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.fullmatch(pattern, r, re.DOTALL) for r in responses]
return [0.5if match else0.0for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1]) * 0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
output_dir = "outputs/Qwen-0.5B-GRPO"
run_name = "Qwen-0.5B-GRPO-gsm8k"
training_args = GRPOConfig(
output_dir=output_dir,
run_name=run_name,
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
logging_steps=1,
bf16=True,
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
num_generations=8,
max_prompt_length=256,
max_completion_length=200,
num_train_epochs=1,
save_steps=100,
max_grad_norm=0.1,
log_on_each_node=False,
use_vllm=True,
vllm_gpu_memory_utilization=.2,
vllm_device="cuda:0",
report_to="none"
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=None
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func],
args=training_args,
train_dataset=dataset,
)
trainer.train()
4. reward_funcs(獎勵函式)
如上面程式碼所示,主要涉及以下5個獎勵函式
4.1. correctness_reward_func(正確性獎勵函式)
檢查模型的輸出是否與參考答案 (answer) 完全匹配,匹配則獎勵 2.0,否則 0.0。
4.2. int_reward_func(整數檢測獎勵函式)
檢查模型輸出是否是純數字(整數),是則獎勵 0.5,否則 0.0。
4.3. strict_format_reward_func(嚴格格式獎勵函式)
嚴格格式獎勵,必須完全匹配 <reasoning>…</reasoning><answer>…</answer>,包括其中的換行符,都必須滿足格式,如果符合格式的獎勵 0.5,否則 0.0。
4.4. soft_format_reward_func(寬鬆格式獎勵函式)
允許更靈活的格式,只要包含 <reasoning>…</reasoning> 和 <answer>…</answer>,即獎勵 0.5,對比嚴格模式更加寬鬆
4.5. count_xml,xmlcount_reward_func(XML 結構評分函式)
計算模型輸出 XML 結構的完整度,並給予相應獎勵。獎勵規則:
檢查 XML 結構完整度:
每個正確的標籤匹配增加 0.125 獎勵:
<reasoning>\\n:+0.125
</reasoning>\\n:+0.125
<answer>\\n:+0.125
</answer>:+0.125
考慮額外文字的懲罰:
如果 </answer> 後面有多餘的內容,則減少獎勵 0.001 × 額外字元數
5. 訓練引數
核心引數說明如下:
1.gradient_accumulation_steps=4:每進行4次的前向傳播和反向傳播後,才會執行一次權重更新;
2.max_completion_length=200: 表示限制模型返回最大長度200;
3.save_steps=100:表示每執行100步才儲存一次checkpoint;
gsm8k資料集一共接近8000條資料,每4次會更新一次,則需要更新2000次,每100步儲存一次,則需要生成20個checkpoint。
三、過程日誌分析
1. 日誌分析
透過python train.py > train.log執行程式碼,透過tail -f train.log進行即時日誌檢視,最後整體效果如下圖所示,最後有效資料1868個,執行時間是2:25:25。

2. 訓練資料分析
GRPO Trainer會記錄很多訓練過程中的指標,主要包括在:
-
completion_length:完成時長; -
reward/{reward_func_name}:每個 reward 函式計算的獎勵; -
reward:平均獎勵; -
reward_std :獎勵組內的平均標準差; -
kl : 根據完成次數計算的模型和參考模型之間的平均 KL 散度。
其中我們主要關注以下兩個獎勵指標:
-
準確性獎勵:基於響應的正確性(對應correctness_reward_func) -
格式獎勵:確保響應符合結構指南(對應strict_format_reward_func和soft_format_reward_func)
2.1. 準確性獎勵

2.2. 格式獎勵


四、推理驗證
1. 微調前的模型
格式和答案都不對,而且不穩定:

2. 微調後的模型
格式和答案都滿足要求:

五、思考
透過對比微調前後的模型,雖然我們這次使用的是一個0.5B的小模型,資料量也不大,但是還是可以透過這個流程,體驗強化學習的整個流程,對我們理解強化學習還是很有好處的。並且從整個實驗中,也理解了DeepSeek整個方案設計的原因,其中以下幾個點印象深刻。
1. 訓練資料分析
透過對訓練後的獎勵函式資料進行分析發現,其中模型的格式獎勵函式strict_format_reward_func和soft_format_reward_func,都是在訓練到固定步數左右的時候,得分開始突然上升,然後後續就逐漸穩定,如下圖所示。可以看到,寬鬆校驗在500步的時候已經基本穩定到0.5的分數,而由於嚴格模式對格式更加嚴格,所以嚴格模式在1000步的時候才到穩定。透過這樣的資料,可以指導我們下一步進行實驗資料調整,從而獲取最佳的checkponit模型進行匯出。


2. 冷啟動的問題
我們可以看到模型在早期訓練的時候,效果很差,模型基本都是在瞎試。所以為了加快訓練,deepseek加入了SFT的資料解決冷啟動的問題,如下面的截圖所示。透過R1-Zero生成SFT的資料,解決了R1的冷啟動問題。

Lindorm泛時序資料一站式解決方案
隨著業務增長帶來的資料量激增,如何高效地獲取和分析這些資料成為業務洞察和決策的關鍵挑戰,Lindorm作為阿里雲自研的雲原生多模資料庫,具備低成本儲存、彈性高可用的能力,提供一站式的分析與洞察。
點選閱讀原文檢視詳情。