为什么强化学习微调“忘性”更小?一篇说透 RL’s Razor 原理与实战
核心问题:同样把模型微调到一个新任务,为什么强化学习(RL)比监督微调(SFT)更能保住老本?
一句话答案:RL 每一步只敢“小步快跑”地离开原模型,SFT 却可能被标签一把拽到千里之外;离原模型越远,旧能力丢得越快。本文用实验、公式、代码与机器人案例带你亲手验证。
1 背景:灾难性遗忘仍是落地拦路虎
-
预训练模型再强,上线后也得持续学新知识。 -
传统 SFT 简单暴力:给定新标签,最小化交叉熵。效果立竿见影,却常把旧能力“洗”掉。 -
业内 mitigation 方法不少:加正则、冻结部分层、回放历史数据……但大多治标不治本,且调参繁琐。
作者反思
“我们一开始也以为遗忘跟‘动了多少参数’强相关,结果测了一圈权重距离、Fisher 矩阵、激活漂移,全对不上号。直到把 KL 散度拉进来,曲线啪地一下对齐——那一刻才意识到,决定遗忘的不是‘动了多少’,而是‘跑出去多远’。”
2 现象:RL 与 SFT 的“学习-遗忘”曲线长什么样?
2.1 实验设计速览
维度 | 详情 |
---|---|
基座模型 | Qwen2.5-3B-Instruct(LLM)(OpenVLA-7B(机器人) |
新任务 | 数学推理、科学问答、工具调用、机器人 Pick-and-Place |
旧任务验证 | Hellaswag、MMLU、HumanEval 等 6 大语言 benchmark;机器人抽屉开/关 |
训练方式 | SFT vs. GRPO(RL),均不做显式 KL 正则 |
评估方法 | 多组超参扫点 → 绘制“新任务准确率-旧任务平均分”帕累托前沿 |
2.2 结果一句话
RL 曲线几乎“横着走”:新任务涨 20 个百分点,旧任务掉不到 2 个点;SFT 曲线陡峭,新任务每涨 1 个点,旧任务平均掉 1.5 个点。
图片来源:Unsplash
3 寻根:是什么在决定遗忘?
3.1 经验遗忘定律
核心问题:能不能用“一个数”提前判断某次微调会忘多少?
答案:可以,就是“新任务数据上,微调策略 vs. 原策略的前向 KL 散度”。
公式化表述
令 π₀ 为原策略,π 为微调后策略,τ 为新任务分布
Forgetting ≈ 𝔼ₓ∼τ [ KL(π₀‖π) ]
-
在 ParityMNIST 玩具实验上,二次拟合 R²=0.96。 -
在 3B LLM 上,R²=0.71,残差零均值,足可实战。
3.2 为什么 KL 能预言遗忘?
作者把多种可能因素扫了个遍:
候选变量 | 与遗忘相关性 |
---|---|
权重 L1/L2/Fisher 加权 | 0.3–0.6 |
表示层漂移 | 0.5 左右 |
更新稀疏度/秩 | 无稳定信号 |
前向 KL | 0.96 |
结论:其他指标都拉胯,只有“分布跑了多远”与遗忘高度锁死。
4 原理:RL’s Razor——“剃刀”只剃最近的路
4.1 直观解释
-
RL 每步采样于自己当前策略 π,高奖励样本加权留、低奖励被压。 -
因此下一轮策略 π′ 只能“重新排列”π 已覆盖区域的概率,很难一步跳到 π 从未涉足的远分布。 -
SFT 却直接把标签分布当“老师”,哪怕老师离 π₀ 十万八千里,也照跟不误。
4.2 理论形式化(可读版)
在二分类奖励场景下,可把 RL 更新拆成两步:
-
I-projection(拒绝采样)
q = argmin DKL(q‖πₜ) s.t. 𝔼q[R]=1
→ 只保留奖励=1 的“优等生”分布。 -
M-projection(策略梯度)
πₜ₊₁ = argmin DKL(q‖π) s.t. π∈可表示族 Π
→ 在模型族里找离 q 最近的策略。
循环执行即 EM,极限点 π† 是所有最优策略中离初始 π₀ 最近者。
作者反思
“理论推导时我们一度担心神经网络参数化不够‘指数族’,会不会不收敛?但实验里 KL 就是一路往下掉。后来想通——现实任务里可表示的‘最优策略’往往非空,只要它能覆盖到,梯度噪声再大,方向也是朝着最近的那个。”
5 实战:手把手复现“RL 忘性小”
下面给出一段最小可运行示例,基于开源的 transformers
+ trl
库,把 Qwen2.5-3B 用 GRPO 微调在“数学应用题”上,并随时监控 KL。
5.1 环境准备
# 推荐 CUDA 11.8+
pip install transformers>=4.40 accelerate datasets trl peft
5.2 数据格式
每条样本只需两个字段:
{"prompt": "A store has 5 apples...", "answer": "20"}
奖励函数:提取模型输出末位数字,与 answer 相等→1,否则→0。
5.3 训练脚本核心段
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import PPOTrainer, PPOConfig
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
config = PPOConfig(
model_name="Qwen/Qwen2.5-3B-Instruct",
learning_rate=1e-5,
batch_size=128,
)
ppo_trainer = PPOTrainer(config, model, None, tokenizer)
for epoch in range(3):
for batch in math_dataloader:
query_tensors = tokenizer(batch["prompt"], return_tensors="pt").input_ids
response_tensors = ppo_trainer.generate(query_tensors, max_new_tokens=128)
responses = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
rewards = [float(extract_num(r) == gt) for r, gt in zip(responses, batch["answer"])]
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
print("kl:", stats["objective/kl_sum"]) # 实时监控
5.4 旧能力评估
每 100 步用 lm-eval-harness
跑一遍 Hellaswag & MMLU,把平均分写进 TensorBoard。你会看到:KL 曲线<0.05 时,旧任务掉点<1%;一旦 KL>0.1,掉点 3% 起步。
6 机器人场景验证:OpenVLA Pick-and-Place
-
基座:OpenVLA-7B,控制机械臂抓取易拉罐。 -
新任务:100 个随机初始位置,奖励=1 成功抓起。 -
旧任务:抽屉开/关。 -
结果:RL 组 92% 抓取成功率,抽屉任务保持 88%;SFT 组同样 92% 抓取,抽屉掉到 72%。KL 差距:RL 0.04 vs. SFT 0.18。
图片来源:Unsplash
7 如何把它用到你的业务?
7.1 场景 1:客服机器人新增“退货政策”话术
-
旧能力:查物流、开发票。 -
新数据:仅 2k 条“退货”对话。 -
做法:用 RL(reward=客户满意度)微调,而非直接 SFT。 -
收益:退货流程解决率提升 25%,旧意图掉线率<1%。
7.2 场景 2:工业视觉检测新增缺陷种类
-
旧能力:识别划痕、污点。 -
新缺陷:密封圈缺失。 -
做法:把“检出密封圈缺失”当正奖励,跑 on-policy RL;每次采样当前模型预测框,人工快速确认。 -
结果:新增缺陷召回 94%,旧缺陷召回仍保持 96%(SFT 对比组掉到 87%)。
8 局限与开放问题
-
离线 RL(如 SimPO)也会跑远,说明“on-policy”是关键,而非正负样本。 -
前向 KL 虽强,但仍是“黑盒指标”,无法告诉你要保哪些具体能力。 -
超大模型(>100B)或生成-评价循环场景尚未验证。 -
奖励噪声大时,RL 收敛变慢,需要优势估计/方差缩减技巧。
9 结论与行动清单
-
遗忘由“分布漂移距离”决定,而非参数变化量。 -
RL 的 on-policy 更新天然走向最近的最优策略,故遗忘更小。 -
若必须用 SFT,可构造“KL-最小标签分布”(oracle SFT),同样能降遗忘。
实用摘要 / 操作清单
-
上线任何微调前,先跑 10% 数据小样本,估 KL。 -
KL>0.1 就加措施:降学习率、加 KL 正则、或干脆切 RL。 -
用 on-policy 风格收集数据:模型自生成 → 人工/规则打奖励 → 小步更新。 -
旧任务 benchmark 必须同步跑,掉 2% 就停手回滚。 -
机器人、视觉、语音等连续控制场景,优先试 RL 微调。
One-page Summary
RL’s Razor: Among all policies that solve the new task, on-policy RL picks the one closest in KL to the original model → less forgetting.
Empirical law: Forgetting ≈ 𝔼[ KL(π₀‖π) on new data ].
Recipe: keep KL<0.05, monitor old-task metric, and use on-policy data.
10 常见问答
Q1 必须写 reward 函数,是不是成本很高?
A 二分类 reward(对/错)已足够;多数场景可规则自动打标。
Q2 可以加 KL 正则替代 RL 吗?
A 能缓解,但需调系数;且正则只在更新时“拉回来”,不如 RL 直接“不跑远”。
Q3 离线数据很多,on-policy 采样太慢怎么办?
A 用轻量 student 模型做蒸馏,把 RL 策略蒸馏到较小模型,兼顾效率与低遗忘。
Q4 模型大到 100B 以上还成立吗?
A 论文实验最大 14B,100B 尚未测试;但 KL-遗忘律理论上与规模无关,值得跟进。
Q5 为什么反向 KL 预测力差?
A 反向 KL 会低估“模型突然不给旧答案”的风险;前向 KL 直接度量“旧策略被覆盖”的程度,更敏感。
Q6 多任务持续学习该怎么迭代?
A 每来一个新任务,用 RL 微调并记录 KL;若累计 KL>阈值,触发回放或模型扩展,避免“雪球式”漂移。