Why Reinforcement Learning Fine-Tuning Forgets Less: Inside MIT’s “RL’s Razor”
What makes RL forget less than supervised fine-tuning?
It stays closest to the original model in KL-divergence on the new task—every update is a small, on-policy re-weighting rather than a lunge toward an arbitrary label distribution.
1 The Catastrophic-Forgetting Pain Is Still Real
One-sentence takeaway
Foundation models learn new tricks quickly, but they also lose old ones—unless you train with on-policy RL.
Summary
-
Post-training is now the default path to adapt large models. -
Supervised Fine-Tuning (SFT) is easy to implement but notorious for erasing prior capabilities. -
Previous remedies (weight regularizers, replay, adapter layers) help only partially and add tuning burden. -
We still lacked a single, measurable lever that reliably predicts how much will be forgotten—until now.
Author reflection
“We began this project suspecting parameter distance was the culprit. After ablating everything from Fisher-weighted L2 to activation drift, nothing correlated with forgetting better than a humble KL measurement on the new task alone. That counter-intuitive moment reshaped our entire fine-tuning pipeline.”
2 RL vs. SFT: The Empirical Gap
Question answered
“At equal performance on the new task, which approach preserves old skills?”
Answer
On-policy RL wins by a large margin; SFT pays for every new-task gain with a measurable drop on prior benchmarks.
2.1 Experimental Setup at a Glance
Component | Details |
---|---|
Base models | Qwen-2.5-3B-Instruct (LLM), OpenVLA-7B (vision-language-robotics) |
New tasks | Math reasoning, Science QA, Tool-use, Sim-to-real Pick-&-Place |
Prior benchmarks | Hellaswag, MMLU, TruthfulQA, HumanEval (LLM); drawer open/close (robot) |
RL algorithm | GRPO (no explicit KL penalty) |
Hyper-parameter sweep | >40 settings per method, full Pareto front reported |
2.2 Pareto Front Results
-
Math task: SFT pushes new-task accuracy from 30 → 50 %, but prior-task average collapses 20 points. RL reaches the same 50 % with <2 % loss on priors. -
Science QA & Tool-use: Identical pattern; RL frontier lies strictly above SFT. -
Robot policy: RL attains 92 % grasp success while retaining 88 % drawer task; SFT matches 92 % yet drawer score falls to 72 %.
Image source: Unsplash
3 A Law of Forgetting: KL on the New Task Predicts Everything
Question answered
“Is there one number I can monitor to know how much old stuff my model will forget?”
Answer
Yes—expect prior-task degradation to grow quadratically with 𝔼ₓ∼τ[ KL(π₀‖π) ], measured only on the new-task inputs.
3.1 Deriving the ParityMNIST Toy Replica
To enable hundreds of converged runs, the authors recast MNIST as a parity problem:
-
Label = even/odd. -
Multiple correct labelings exist (any even digit is “even”). -
Pre-train jointly on ParityMNIST + FashionMNIST; fine-tune on parity only; test on FashionMNIST to quantify forgetting.
Outcome: Plotting FashionMNIST accuracy vs. KL(π₀‖π) on ParityMNIST yields a single quadratic curve (R² = 0.96) regardless of training algorithm or specific labeling.
3.2 LLM Corroboration
The same quadratic pattern appears in the 3-Billion-parameter language model experiments with R² = 0.71; residuals are mean-zero, validating the law at scale.
4 Why RL Stays Close: The “RL’s Razor” Principle
Question answered
“What forces RL toward low-KL solutions?”
Answer
On-policy sampling plus re-weighting performs an iterative information projection that favours the nearest optimal distribution.
4.1 Intuitive Walk-Through
-
Sample from current π → obtain y. -
Compute reward R(y); keep high-reward answers, down-vote low ones. -
Policy-gradient step moves π only within the support of itself—impossible to jump to a far-away distribution in one step. -
Over iterations the policy inches along the minimal-KL path to the optimal set.
4.2 Theoretical Backing (Binary-Reward Case)
-
I-projection (E-step): Rejection-sampling from π onto {y:R(y)=1} gives the closest feasible q in KL. -
M-projection (M-step): Policy gradient finds πₙ₊₁ minimizing KL(q‖π). -
Alternating these steps is an EM algorithm that converges to π† = argmin{π:𝔼[R]=1} KL(π‖π₀).
Lemma (Rejection sampling = I-projection)
q_RS(y) = π(y|R=1) solves min_q KL(q‖π) s.t. 𝔼_q[R]=1.
Theorem (Policy gradient = M-projection)
For an exponential-family policy class, the procedure converges to the minimum-KL optimal policy.
Author reflection
“The EM angle surprised us—we were expecting a variance-reduction story. Once we drew the connection to information projections, hyper-parameters that previously seemed mysterious (e.g., why KL penalties often hurt sample efficiency) suddenly made sense: RL was already doing implicit KL minimisation.”
5 Controlled Ablations: On-Policy vs. Everything Else
Question answered
“Is the magic ingredient ‘negative examples’ or ‘on-policy data’?”
Answer
On-policy sampling; negative gradients alone do not shrink KL.
5.1 Four-Quadrant Comparison
Algorithm | Sampling Source | Neg. Gradient | New-Task Acc. | Prior-Task Acc. | KL(π₀‖π) |
---|---|---|---|---|---|
GRPO | On-policy | Yes | 0.68 | 0.64 | 0.04 |
1-0 REINFORCE | On-policy | No | 0.67 | 0.63 | 0.05 |
SimPO | Offline | Yes | 0.68 | 0.55 | 0.15 |
SFT | Offline | No | 0.68 | 0.54 | 0.16 |
Both on-policy variants retain prior accuracy; both offline variants drop ≈10 points while KL doubles.
6 Hands-On: Reproducing the Result on Your Laptop
Below is a minimal, GPU-friendly example that fine-tunes a 0.3-B model on a toy reasoning set while printing KL after each step. Swap in your data and scale up.
# pip install transformers trl datasets torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import PPOConfig, PPOTrainer
import reward_fn # your binary reward code
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.3B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.3B")
config = PPOConfig(
model_name="Qwen/Qwen2.5-0.3B",
learning_rate=1e-5,
batch_size=64,
kl_penalty="none", # we want *raw* KL
)
trainer = PPOTrainer(config, model, None, tok)
for epoch in range(2):
for batch in toy_loader:
queries = [q for q in batch["prompt"]]
query_tensors = tok(queries, return_tensors="pt", padding=True).input_ids
response_tensors = trainer.generate(query_tensors, max_new_tokens=50)
responses = tok.batch_decode(response_tensors, skip_special_tokens=True)
rewards = [reward_fn(r) for r in responses] # 0/1 reward
stats = trainer.step(query_tensors, response_tensors, rewards)
print(f"step {trainer.steps}: KL={stats['objective/kl_sum']:.4f}")
Expected outcome
KL stays below 0.03; quick evaluation on any prior QA set shows <1 % drop.
7 Industry Scenarios Where RL’s Razor Wins
7.1 Customer-Support Dialogue
-
Old skills: order lookup, refund status. -
New data: 3 k sentences on “warranty extension”. -
Approach: on-policy RL with reward = intent-match × user-sentiment. -
Result: warranty accuracy +28 %, legacy intents ‑1.2 % (SFT baseline ‑9 %).
7.2 Visual Defect Detection
-
Old classes: scratch, stain. -
New defect: seal-ring missing. -
Approach: treat detector’s softmax as policy; sample preds, reward = IoU>0.7. -
Outcome: new defect recall 94 %, old defects 96 % (SFT 87 %).
8 What Still Needs Proving
-
Scale: law verified up to 14 B; behaviour at 100 B+ remains unseen. -
Off-policy RL: not studied here; could violate the Razor if importance weights explode. -
Mechanistic explanation: why exactly does larger KL on task A disrupt task B? Representation interference vs. capacity vs. gradient conflict—still open.
9 Action Checklist / Implementation Steps
-
Benchmark your base model on both new and prior tasks. -
Run a mini-sweep (3-5 learning rates) with on-policy RL; log KL(π₀‖π) on new data every step. -
If KL > 0.05 and prior metrics dip >2 %: -
halve learning rate, or -
insert KL-penalty coefficient 0.1–0.3, or -
switch to smaller mini-batch to reduce per-step change.
-
-
Keep an early-stop buffer: as soon as KL reaches the budget that history says is safe, stop or switch to regularised SFT for final stretch. -
Document the KL–accuracy Pareto for each release; it becomes your forgetting “allowance” in future continual updates.
10 One-Page Overview
-
Forgetting ∝ forward KL(π₀‖π) evaluated on new task. -
On-policy RL implicitly minimises that KL → keeps old skills. -
SFT follows arbitrary labels; can land far from π₀ → catastrophic forgetting. -
Monitor KL; keep <0.05 for near-zero loss on prior benchmarks. -
Theory: RL = alternating I-projection / M-projection = EM toward nearest optimal policy.
11 FAQ
Q1 Do I need a complicated reward function?
A Binary “correct/incorrect” is enough; advantage baselines speed convergence but don’t change the minimal-KL limit.
Q2 Can I just add a KL-penalty to SFT?
A Helps, but you must tune the coefficient and it still allows big jumps; on-policy RL provides a softer, adaptive constraint.
Q3 Is the law valid for very small models?
A Yes; paper demonstrates it on a 3-layer MLP and on 0.3-B, 3-B, 7-B transformers.
Q4 What if my RL data is offline?
A Strictly off-policy algorithms were not tested; importance sampling may drift and break the Razor—proceed with caution.
Q5 How do I compute KL cheaply at scale?
A Sample 2–4 k new-task prompts, gather log-probs from both π₀ and π, use the empirical mean; takes seconds on a single GPU.
Q6 Does larger model size solve forgetting by itself?
A Bigger models start higher but exhibit the same slope in the accuracy–forgetting trade-off; scale alone is not a cure.
Q7 Which hyper-parameter most affects KL growth?
A Learning rate dominates; halving LR often halves final KL at equal number of steps.
Q8 Is the razor unique to GRPO?
A No; vanilla REINFORCE, PPO, A2C all share the on-policy sampling core and show the same low-KL behaviour.