为什么强化学习微调“忘性”更小?一篇说透 RL’s Razor 原理与实战

核心问题:同样把模型微调到一个新任务,为什么强化学习(RL)比监督微调(SFT)更能保住老本?
一句话答案:RL 每一步只敢“小步快跑”地离开原模型,SFT 却可能被标签一把拽到千里之外;离原模型越远,旧能力丢得越快。本文用实验、公式、代码与机器人案例带你亲手验证。


1 背景:灾难性遗忘仍是落地拦路虎

  • 预训练模型再强,上线后也得持续学新知识。
  • 传统 SFT 简单暴力:给定新标签,最小化交叉熵。效果立竿见影,却常把旧能力“洗”掉。
  • 业内 mitigation 方法不少:加正则、冻结部分层、回放历史数据……但大多治标不治本,且调参繁琐。

作者反思
“我们一开始也以为遗忘跟‘动了多少参数’强相关,结果测了一圈权重距离、Fisher 矩阵、激活漂移,全对不上号。直到把 KL 散度拉进来,曲线啪地一下对齐——那一刻才意识到,决定遗忘的不是‘动了多少’,而是‘跑出去多远’。”


2 现象:RL 与 SFT 的“学习-遗忘”曲线长什么样?

2.1 实验设计速览

维度 详情
基座模型 Qwen2.5-3B-Instruct(LLM)(OpenVLA-7B(机器人)
新任务 数学推理、科学问答、工具调用、机器人 Pick-and-Place
旧任务验证 Hellaswag、MMLU、HumanEval 等 6 大语言 benchmark;机器人抽屉开/关
训练方式 SFT vs. GRPO(RL),均不做显式 KL 正则
评估方法 多组超参扫点 → 绘制“新任务准确率-旧任务平均分”帕累托前沿

2.2 结果一句话

RL 曲线几乎“横着走”:新任务涨 20 个百分点,旧任务掉不到 2 个点;SFT 曲线陡峭,新任务每涨 1 个点,旧任务平均掉 1.5 个点。

Pareto 前沿对比图
图片来源:Unsplash


3 寻根:是什么在决定遗忘?

3.1 经验遗忘定律

核心问题:能不能用“一个数”提前判断某次微调会忘多少?
答案:可以,就是“新任务数据上,微调策略 vs. 原策略的前向 KL 散度”。

公式化表述
令 π₀ 为原策略,π 为微调后策略,τ 为新任务分布
Forgetting ≈ 𝔼ₓ∼τ [ KL(π₀‖π) ]

  • 在 ParityMNIST 玩具实验上,二次拟合 R²=0.96。
  • 在 3B LLM 上,R²=0.71,残差零均值,足可实战。

3.2 为什么 KL 能预言遗忘?

作者把多种可能因素扫了个遍:

候选变量 与遗忘相关性
权重 L1/L2/Fisher 加权 0.3–0.6
表示层漂移 0.5 左右
更新稀疏度/秩 无稳定信号
前向 KL 0.96

结论:其他指标都拉胯,只有“分布跑了多远”与遗忘高度锁死。


4 原理:RL’s Razor——“剃刀”只剃最近的路

4.1 直观解释

  • RL 每步采样于自己当前策略 π,高奖励样本加权留、低奖励被压。
  • 因此下一轮策略 π′ 只能“重新排列”π 已覆盖区域的概率,很难一步跳到 π 从未涉足的远分布。
  • SFT 却直接把标签分布当“老师”,哪怕老师离 π₀ 十万八千里,也照跟不误。

4.2 理论形式化(可读版)

在二分类奖励场景下,可把 RL 更新拆成两步:

  1. I-projection(拒绝采样)
    q = argmin DKL(q‖πₜ) s.t. 𝔼q[R]=1
    → 只保留奖励=1 的“优等生”分布。
  2. M-projection(策略梯度)
    πₜ₊₁ = argmin DKL(q‖π) s.t. π∈可表示族 Π
    → 在模型族里找离 q 最近的策略。

循环执行即 EM,极限点 π† 是所有最优策略中离初始 π₀ 最近者。

作者反思
“理论推导时我们一度担心神经网络参数化不够‘指数族’,会不会不收敛?但实验里 KL 就是一路往下掉。后来想通——现实任务里可表示的‘最优策略’往往非空,只要它能覆盖到,梯度噪声再大,方向也是朝着最近的那个。”


5 实战:手把手复现“RL 忘性小”

下面给出一段最小可运行示例,基于开源的 transformers + trl 库,把 Qwen2.5-3B 用 GRPO 微调在“数学应用题”上,并随时监控 KL。

5.1 环境准备

# 推荐 CUDA 11.8+
pip install transformers>=4.40 accelerate datasets trl peft

5.2 数据格式

每条样本只需两个字段:

{"prompt": "A store has 5 apples...", "answer": "20"}

奖励函数:提取模型输出末位数字,与 answer 相等→1,否则→0。

5.3 训练脚本核心段

from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import PPOTrainer, PPOConfig

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct")

config = PPOConfig(
    model_name="Qwen/Qwen2.5-3B-Instruct",
    learning_rate=1e-5,
    batch_size=128,
)

ppo_trainer = PPOTrainer(config, model, None, tokenizer)

for epoch in range(3):
    for batch in math_dataloader:
        query_tensors = tokenizer(batch["prompt"], return_tensors="pt").input_ids
        response_tensors = ppo_trainer.generate(query_tensors, max_new_tokens=128)
        responses = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
        rewards = [float(extract_num(r) == gt) for r, gt in zip(responses, batch["answer"])]
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
        print("kl:", stats["objective/kl_sum"])   # 实时监控

5.4 旧能力评估

每 100 步用 lm-eval-harness 跑一遍 Hellaswag & MMLU,把平均分写进 TensorBoard。你会看到:KL 曲线<0.05 时,旧任务掉点<1%;一旦 KL>0.1,掉点 3% 起步。


6 机器人场景验证:OpenVLA Pick-and-Place

  • 基座:OpenVLA-7B,控制机械臂抓取易拉罐。
  • 新任务:100 个随机初始位置,奖励=1 成功抓起。
  • 旧任务:抽屉开/关。
  • 结果:RL 组 92% 抓取成功率,抽屉任务保持 88%;SFT 组同样 92% 抓取,抽屉掉到 72%。KL 差距:RL 0.04 vs. SFT 0.18。

机器人抓取示意
图片来源:Unsplash


7 如何把它用到你的业务?

7.1 场景 1:客服机器人新增“退货政策”话术

  • 旧能力:查物流、开发票。
  • 新数据:仅 2k 条“退货”对话。
  • 做法:用 RL(reward=客户满意度)微调,而非直接 SFT。
  • 收益:退货流程解决率提升 25%,旧意图掉线率<1%。

7.2 场景 2:工业视觉检测新增缺陷种类

  • 旧能力:识别划痕、污点。
  • 新缺陷:密封圈缺失。
  • 做法:把“检出密封圈缺失”当正奖励,跑 on-policy RL;每次采样当前模型预测框,人工快速确认。
  • 结果:新增缺陷召回 94%,旧缺陷召回仍保持 96%(SFT 对比组掉到 87%)。

8 局限与开放问题

  1. 离线 RL(如 SimPO)也会跑远,说明“on-policy”是关键,而非正负样本。
  2. 前向 KL 虽强,但仍是“黑盒指标”,无法告诉你要保哪些具体能力。
  3. 超大模型(>100B)或生成-评价循环场景尚未验证。
  4. 奖励噪声大时,RL 收敛变慢,需要优势估计/方差缩减技巧。

9 结论与行动清单

  • 遗忘由“分布漂移距离”决定,而非参数变化量。
  • RL 的 on-policy 更新天然走向最近的最优策略,故遗忘更小。
  • 若必须用 SFT,可构造“KL-最小标签分布”(oracle SFT),同样能降遗忘。

实用摘要 / 操作清单

  1. 上线任何微调前,先跑 10% 数据小样本,估 KL。
  2. KL>0.1 就加措施:降学习率、加 KL 正则、或干脆切 RL。
  3. 用 on-policy 风格收集数据:模型自生成 → 人工/规则打奖励 → 小步更新。
  4. 旧任务 benchmark 必须同步跑,掉 2% 就停手回滚。
  5. 机器人、视觉、语音等连续控制场景,优先试 RL 微调。

One-page Summary

RL’s Razor: Among all policies that solve the new task, on-policy RL picks the one closest in KL to the original model → less forgetting.
Empirical law: Forgetting ≈ 𝔼[ KL(π₀‖π) on new data ].
Recipe: keep KL<0.05, monitor old-task metric, and use on-policy data.


10 常见问答

Q1 必须写 reward 函数,是不是成本很高?
A 二分类 reward(对/错)已足够;多数场景可规则自动打标。

Q2 可以加 KL 正则替代 RL 吗?
A 能缓解,但需调系数;且正则只在更新时“拉回来”,不如 RL 直接“不跑远”。

Q3 离线数据很多,on-policy 采样太慢怎么办?
A 用轻量 student 模型做蒸馏,把 RL 策略蒸馏到较小模型,兼顾效率与低遗忘。

Q4 模型大到 100B 以上还成立吗?
A 论文实验最大 14B,100B 尚未测试;但 KL-遗忘律理论上与规模无关,值得跟进。

Q5 为什么反向 KL 预测力差?
A 反向 KL 会低估“模型突然不给旧答案”的风险;前向 KL 直接度量“旧策略被覆盖”的程度,更敏感。

Q6 多任务持续学习该怎么迭代?
A 每来一个新任务,用 RL 微调并记录 KL;若累计 KL>阈值,触发回放或模型扩展,避免“雪球式”漂移。