让大模型“边学边改错”:On-Policy Distillation 原理与实战全解
“
核心问题:如何在只利用学生模型自己生成的文本、不依赖人工标注或昂贵 RL 的前提下,把大模型在数学、私域知识、对话格式等任务上的能力稳定提升到教师水平?
本文用一份完整实验报告告诉你:On-Policy Distillation 把「学生自己采样的轨迹」与「教师对每个 Token 的实时点评」拼在一起,兼顾「在线纠错」和「密集奖励」,用 1/10 算力就能得到 RL 级别的效果,还能持续更新模型而不遗忘。
速览摘要
| 关键词 | 一句话释义 |
|---|---|
| On-Policy Distillation | 学生采样 → 教师逐 Token 给分 → 反向 KL 损失更新,全程在线。 |
| 反向 KL | 让学生概率逼近教师概率,Mode-seeking,避免“平均化”带来的模糊输出。 |
| 密集奖励 | 每个 Token 都有梯度信号,比 RL「整题 1 比特」效率高 50-100×。 |
| 持续学习 | 学完私域知识后,用旧版本自己当教师蒸馏一次,即可把遗忘的对话能力“一键找回”。 |
目录
-
为什么后训练阶段需要第三条路 -
算法拆解:四行伪代码就能复现 -
数学推理实战:从 60% → 70% 的三种成本对比 -
私域助手案例:新知识进、旧能力退,再用蒸馏“补锅” -
作者反思:我们踩过的 4 个坑 -
结论与落地清单 -
一页速览(One-page Summary) -
FAQ
1. 为什么后训练阶段需要第三条路
核心问题:已有监督微调(SFT)和强化学习(RL),为什么还要折腾“在线蒸馏”?
| 方案 | 数据来自 | 奖励密度 | 主要痛点 |
|---|---|---|---|
| 监督微调(Off-policy) | 教师轨迹 | 高(逐 Token) | 学生状态一旦偏离教师,误差会累积,长链推理崩掉。 |
| 强化学习(On-policy) | 学生轨迹 | 极低(整题 1 比特) | 采样贵、信用分配难,10 万题可能只学到“答案对/错”。 |
| On-Policy Distillation | 学生轨迹 | 高(教师逐 Token 给分) | 需要教师实时算 logprob,但 GPU 并行后成本≈1/10 RL。 |
一句话:在线蒸馏同时解决“数据失配”和“奖励稀疏”两大顽疾。
2. 算法拆解:四行伪代码就能复现
核心问题:系统落地到底要改几行代码?
2.1 损失函数
反向 KL 逐 Token 计算:
reverse_kl[t] = logP_student(x_t | x_<t) - logP_teacher(x_t | x_<t)
loss = mean(reverse_kl)
-
不需要等整题生成完,中途截断也能训练,省显存。 -
教师只前向一次,学生负责采样,计算量可并行。
2.2 四步流程
-
采样:学生模型生成 4 条轨迹 -
打分:教师模型对“学生已生成的每个 Token”算 logprob -
求差:得到 reverse_kl 作为优势值 -
更新:用重要性采样策略梯度做一次参数步
伪代码(已跑通 Tinker 框架):
teacher_logprobs = teacher.compute_logprobs(trajectories)
reverse_kl = student_logprobs - teacher_logprobs
trajectories["advantages"] = -reverse_kl
training_client.forward_backward(trajectories, loss_fn="importance_sampling")
反思 / 踩坑:
“
我们曾尝试把折扣因子 γ 调到 0.9,想让模型“看远点”,结结果 AIME 分数反而掉 2%。事后看,数学题每步对错即时可见,未来奖励几乎无信息量,γ=0 最干净。
3. 数学推理实战:从 60% → 70% 的三种成本对比
核心问题:把 8B 模型 AIME’24 成绩从 60 提到 70,到底要花多少卡时?
3.1 实验设定
-
学生:Qwen3-8B-Base -
教师:Qwen3-32B(已能在 AIME 拿到 85+) -
起点:40 万题 SFT → 60% 正确率
| 路线 | 预估样本量 / 步数 | 总算力(FLOP) | 最终 AIME | 相对成本 |
|---|---|---|---|---|
| 继续 SFT(Off-policy) | ~200 万题 | 4.9×10²¹ | 70% | 1× |
| 纯 RL(PPO) | 17 920 GPU h | 与左列近似 | 68% | ≈1× |
| On-Policy Distillation | 150 步×77 k 题 | 1.66×10²⁰ | 70% | 0.11× |
“
注:FLOP 只计训练阶段,不含前期数据生成;若把教师采样成本也算上,蒸馏节省可达 30×。
3.2 学习曲线
-
SFT 呈对数线性,前期便宜后期贵。 -
蒸馏在 150 步内即饱和,反向 KL 掉到 0.002 附近,步数少 7-10×。
图片来源:Unsplash
4. 私域助手案例:新知识进、旧能力退,再用蒸馏“补锅”
核心问题:先灌公司内网文档,再想把对话格式捡回来,只能重跑昂贵 RL 吗?
4.1 任务定义
-
Knowledge:内部 QA,43 题,考察文档事实回忆。 -
Behavior:IF-eval,541 条指令格式题,考察“是否按指定格式回答”。
4.2 Mid-Training 结果
把 Qwen3-8B 在“70% 文档 + 30% 闲聊”上继续 SFT 后:
-
Knowledge ↑ 18% → 36% -
IF-eval ↓ 85% → 79%(继续训还会掉)
4.3 蒸馏救场
用训前原版 Qwen3-8B 当教师,对自己“走样”后的新版做 1 轮 On-Policy Distillation(仅闲聊数据):
-
IF-eval 回到 83%,几乎无损。 -
Knowledge 仍保持 41%,反而略升(正迁移)。
| 阶段 | Internal QA | IF-eval | 备注 |
|---|---|---|---|
| 原始 Qwen3-8B | 18% | 85% | 起点高,但不懂私域。 |
| +SFT(70% 文档) | 36% | 79% | 知识涨,格式丢。 |
| 再 +Distill | 41% | 83% | 格式找回,知识还在。 |
反思 / 独特见解:
“
把“自己过去的优秀版本”当教师,相当于给模型做“版本热修复”。这招对持续部署太友好:先快速用 SFT 灌新数据,再用蒸馏找回行为,无需重写奖励模型,也无需人类再标一轮。
5. 作者反思:我们踩过的 4 个坑
-
LoRA rank 并非越小越好
数学任务用 rank=32 蒸馏后只比全参微调低 6%,可接受;rank=8 会掉 15%,别省过头。 -
教师模型也可“离线”缓存 logprob
一开始我们实时调教师 API,QPS 成瓶颈。后来把轨迹+教师 logprob 先落盘,再开多卡训练,整体提速 3×。 -
单题多 epoch 训不会过拟合
反向 KL 目标是分布对齐,不是背答案。我们用 1 条极限题连训 20 epoch,AIME 分数仍涨到教师水平,说明“分布级”损失天然抗记忆。 -
温度 0 采样反而梯度噪声大
温度=0 导致同 prompt 每次轨迹完全一致,优势值无方差,更新方向僵硬。把温度提到 0.3-0.5 后,收敛更平滑。
6. 结论与落地清单
-
On-Policy Distillation =「学生轨迹」+「教师逐 Token 点评」+「反向 KL」,四行代码即可在任意 RL 框架里替换奖励模型。 -
数学推理:150 步、77 k 题、1 800 GPUh 就能完成 60→70% 的跃升,成本≈SFT 的 1/9,RL 的 1/10。 -
私域助手:先 SFT 灌知识 → 用旧自己当教师蒸馏找回格式,两轮搞定持续学习,IF-eval 遗忘<3%。 -
梯度效率:逐 Token 监督带来 50-100× 信息密度,单题 20 epoch 也不爆炸。
实用摘要 / 操作清单
-
准备阶段 -
选学生 SFT checkpoint(至少能生成合法格式)。 -
备教师(更大或同尺寸旧版本),确保 compute_logprobs接口可用。
-
-
采样脚本 -
temperature=0.3-0.5,每 prompt 2-8 条轨迹,保存 token-level logprob。
-
-
训练脚本 -
用重要性采样,优势值=-reverse_kl,学习率 5e-7 起步,LoRA rank≥32。
-
-
监控指标 -
反向 KL ↓ 到 0.002 附近基本饱和;下游任务早停。
-
-
持续学习 -
SFT 新数据后 IF-eval 掉点?直接拿旧模型自蒸馏 1 轮,80% 场景可恢复。
-
7. 一页速览(One-page Summary)
| 关键词 | 值 |
|---|---|
| 方法本质 | 学生采样,教师给每 token 概率,反向 KL 更新 |
| 最大卖点 | 兼具 On-policy 纠错 + Dense reward,成本 1/10 RL |
| 适用场景 | 数学推理、私域知识、格式/风格/工具使用等后训练 |
| 不适用 | 预训练(参数空间探索太大),需教师略优于学生 |
| 代码改动 | 4 行伪代码,可在任何 RL 框架把奖励模型换成教师 logprob |
| 硬件 Tips | 教师 logprob 可并行,多卡 prefetch,磁盘缓存提速 3× |
| 持续学习 | SFT→掉行为→自蒸馏 1 轮,IF-eval 可回弹至 98% 原性能 |
8. FAQ
-
反向 KL 与正向 KL 有何区别?
反向 KL 让学生逼近教师的最优模式,避免“平均化”;正向 KL 会强迫学生覆盖教师所有小概率区域,易模糊。 -
教师必须远大于学生吗?
不必须。文中私域案例用同尺寸旧版即可,只要教师行为比“走样”后的学生更好。 -
可以脱离 RL 框架纯用 SFT 实现吗?
可以,把 reverse_kl 当损失直接 backward 就行;但 RL 框架已集成重要性采样、日志、早停,更省事。 -
为什么不用 BLEU/ROUGE 当损失?
这些指标只匹配表层 n-gram,无法对齐教师分布,长链推理仍崩。 -
训练时温度要固定吗?
建议固定 0.3-0.5,让轨迹有轻微差异,梯度方差更稳定。 -
会过拟合教师缺点吗?
若教师有系统性偏见,学生也会学。解决方法是多教师集成或后期 RL 微调。 -
能否把环境奖励和 KL 混合?
理论上可以,用total_advantage = env_reward - β·reverse_kl;留作未来工作。 -
单卡 24G 能跑吗?
8B 学生 + 32B 教师同时加载需 60G+,建议教师端用 CPU offload 或离线落盘 logprob 后再单卡训练。
