让大模型“边学边改错”:On-Policy Distillation 原理与实战全解

核心问题:如何在只利用学生模型自己生成的文本、不依赖人工标注或昂贵 RL 的前提下,把大模型在数学、私域知识、对话格式等任务上的能力稳定提升到教师水平?

本文用一份完整实验报告告诉你:On-Policy Distillation 把「学生自己采样的轨迹」与「教师对每个 Token 的实时点评」拼在一起,兼顾「在线纠错」和「密集奖励」,用 1/10 算力就能得到 RL 级别的效果,还能持续更新模型而不遗忘。


速览摘要

关键词 一句话释义
On-Policy Distillation 学生采样 → 教师逐 Token 给分 → 反向 KL 损失更新,全程在线。
反向 KL 让学生概率逼近教师概率,Mode-seeking,避免“平均化”带来的模糊输出。
密集奖励 每个 Token 都有梯度信号,比 RL「整题 1 比特」效率高 50-100×。
持续学习 学完私域知识后,用旧版本自己当教师蒸馏一次,即可把遗忘的对话能力“一键找回”。

目录

  1. 为什么后训练阶段需要第三条路
  2. 算法拆解:四行伪代码就能复现
  3. 数学推理实战:从 60% → 70% 的三种成本对比
  4. 私域助手案例:新知识进、旧能力退,再用蒸馏“补锅”
  5. 作者反思:我们踩过的 4 个坑
  6. 结论与落地清单
  7. 一页速览(One-page Summary)
  8. FAQ

1. 为什么后训练阶段需要第三条路

核心问题:已有监督微调(SFT)和强化学习(RL),为什么还要折腾“在线蒸馏”?

方案 数据来自 奖励密度 主要痛点
监督微调(Off-policy) 教师轨迹 高(逐 Token) 学生状态一旦偏离教师,误差会累积,长链推理崩掉。
强化学习(On-policy) 学生轨迹 极低(整题 1 比特) 采样贵、信用分配难,10 万题可能只学到“答案对/错”。
On-Policy Distillation 学生轨迹 高(教师逐 Token 给分) 需要教师实时算 logprob,但 GPU 并行后成本≈1/10 RL。

一句话:在线蒸馏同时解决“数据失配”和“奖励稀疏”两大顽疾。


2. 算法拆解:四行伪代码就能复现

核心问题:系统落地到底要改几行代码?

2.1 损失函数

反向 KL 逐 Token 计算:

reverse_kl[t] = logP_student(x_t | x_<t) - logP_teacher(x_t | x_<t)
loss = mean(reverse_kl)
  • 不需要等整题生成完,中途截断也能训练,省显存。
  • 教师只前向一次,学生负责采样,计算量可并行。

2.2 四步流程

  1. 采样:学生模型生成 4 条轨迹
  2. 打分:教师模型对“学生已生成的每个 Token”算 logprob
  3. 求差:得到 reverse_kl 作为优势值
  4. 更新:用重要性采样策略梯度做一次参数步

伪代码(已跑通 Tinker 框架):

teacher_logprobs = teacher.compute_logprobs(trajectories)
reverse_kl = student_logprobs - teacher_logprobs
trajectories["advantages"] = -reverse_kl
training_client.forward_backward(trajectories, loss_fn="importance_sampling")

反思 / 踩坑

我们曾尝试把折扣因子 γ 调到 0.9,想让模型“看远点”,结结果 AIME 分数反而掉 2%。事后看,数学题每步对错即时可见,未来奖励几乎无信息量,γ=0 最干净。


3. 数学推理实战:从 60% → 70% 的三种成本对比

核心问题:把 8B 模型 AIME’24 成绩从 60 提到 70,到底要花多少卡时?

3.1 实验设定

  • 学生:Qwen3-8B-Base
  • 教师:Qwen3-32B(已能在 AIME 拿到 85+)
  • 起点:40 万题 SFT → 60% 正确率
路线 预估样本量 / 步数 总算力(FLOP) 最终 AIME 相对成本
继续 SFT(Off-policy) ~200 万题 4.9×10²¹ 70%
纯 RL(PPO) 17 920 GPU h 与左列近似 68% ≈1×
On-Policy Distillation 150 步×77 k 题 1.66×10²⁰ 70% 0.11×

注:FLOP 只计训练阶段,不含前期数据生成;若把教师采样成本也算上,蒸馏节省可达 30×。

3.2 学习曲线

  • SFT 呈对数线性,前期便宜后期贵。
  • 蒸馏在 150 步内即饱和,反向 KL 掉到 0.002 附近,步数少 7-10×。

学习曲线对比
图片来源:Unsplash


4. 私域助手案例:新知识进、旧能力退,再用蒸馏“补锅”

核心问题:先灌公司内网文档,再想把对话格式捡回来,只能重跑昂贵 RL 吗?

4.1 任务定义

  • Knowledge:内部 QA,43 题,考察文档事实回忆。
  • Behavior:IF-eval,541 条指令格式题,考察“是否按指定格式回答”。

4.2 Mid-Training 结果

把 Qwen3-8B 在“70% 文档 + 30% 闲聊”上继续 SFT 后:

  • Knowledge ↑ 18% → 36%
  • IF-eval ↓ 85% → 79%(继续训还会掉)

4.3 蒸馏救场

训前原版 Qwen3-8B 当教师,对自己“走样”后的新版做 1 轮 On-Policy Distillation(仅闲聊数据):

  • IF-eval 回到 83%,几乎无损。
  • Knowledge 仍保持 41%,反而略升(正迁移)。
阶段 Internal QA IF-eval 备注
原始 Qwen3-8B 18% 85% 起点高,但不懂私域。
+SFT(70% 文档) 36% 79% 知识涨,格式丢。
再 +Distill 41% 83% 格式找回,知识还在。

反思 / 独特见解

把“自己过去的优秀版本”当教师,相当于给模型做“版本热修复”。这招对持续部署太友好:先快速用 SFT 灌新数据,再用蒸馏找回行为,无需重写奖励模型,也无需人类再标一轮。


5. 作者反思:我们踩过的 4 个坑

  1. LoRA rank 并非越小越好
    数学任务用 rank=32 蒸馏后只比全参微调低 6%,可接受;rank=8 会掉 15%,别省过头。

  2. 教师模型也可“离线”缓存 logprob
    一开始我们实时调教师 API,QPS 成瓶颈。后来把轨迹+教师 logprob 先落盘,再开多卡训练,整体提速 3×。

  3. 单题多 epoch 训不会过拟合
    反向 KL 目标是分布对齐,不是背答案。我们用 1 条极限题连训 20 epoch,AIME 分数仍涨到教师水平,说明“分布级”损失天然抗记忆。

  4. 温度 0 采样反而梯度噪声大
    温度=0 导致同 prompt 每次轨迹完全一致,优势值无方差,更新方向僵硬。把温度提到 0.3-0.5 后,收敛更平滑。


6. 结论与落地清单

  • On-Policy Distillation =「学生轨迹」+「教师逐 Token 点评」+「反向 KL」,四行代码即可在任意 RL 框架里替换奖励模型。
  • 数学推理:150 步、77 k 题、1 800 GPUh 就能完成 60→70% 的跃升,成本≈SFT 的 1/9,RL 的 1/10。
  • 私域助手:先 SFT 灌知识 → 用旧自己当教师蒸馏找回格式,两轮搞定持续学习,IF-eval 遗忘<3%。
  • 梯度效率:逐 Token 监督带来 50-100× 信息密度,单题 20 epoch 也不爆炸。

实用摘要 / 操作清单

  1. 准备阶段

    • 选学生 SFT checkpoint(至少能生成合法格式)。
    • 备教师(更大或同尺寸旧版本),确保 compute_logprobs 接口可用。
  2. 采样脚本

    • temperature=0.3-0.5,每 prompt 2-8 条轨迹,保存 token-level logprob。
  3. 训练脚本

    • 用重要性采样,优势值=-reverse_kl,学习率 5e-7 起步,LoRA rank≥32。
  4. 监控指标

    • 反向 KL ↓ 到 0.002 附近基本饱和;下游任务早停。
  5. 持续学习

    • SFT 新数据后 IF-eval 掉点?直接拿旧模型自蒸馏 1 轮,80% 场景可恢复。

7. 一页速览(One-page Summary)

关键词
方法本质 学生采样,教师给每 token 概率,反向 KL 更新
最大卖点 兼具 On-policy 纠错 + Dense reward,成本 1/10 RL
适用场景 数学推理、私域知识、格式/风格/工具使用等后训练
不适用 预训练(参数空间探索太大),需教师略优于学生
代码改动 4 行伪代码,可在任何 RL 框架把奖励模型换成教师 logprob
硬件 Tips 教师 logprob 可并行,多卡 prefetch,磁盘缓存提速 3×
持续学习 SFT→掉行为→自蒸馏 1 轮,IF-eval 可回弹至 98% 原性能

8. FAQ

  1. 反向 KL 与正向 KL 有何区别?
    反向 KL 让学生逼近教师的最优模式,避免“平均化”;正向 KL 会强迫学生覆盖教师所有小概率区域,易模糊。

  2. 教师必须远大于学生吗?
    不必须。文中私域案例用同尺寸旧版即可,只要教师行为比“走样”后的学生更好。

  3. 可以脱离 RL 框架纯用 SFT 实现吗?
    可以,把 reverse_kl 当损失直接 backward 就行;但 RL 框架已集成重要性采样、日志、早停,更省事。

  4. 为什么不用 BLEU/ROUGE 当损失?
    这些指标只匹配表层 n-gram,无法对齐教师分布,长链推理仍崩。

  5. 训练时温度要固定吗?
    建议固定 0.3-0.5,让轨迹有轻微差异,梯度方差更稳定。

  6. 会过拟合教师缺点吗?
    若教师有系统性偏见,学生也会学。解决方法是多教师集成或后期 RL 微调。

  7. 能否把环境奖励和 KL 混合?
    理论上可以,用 total_advantage = env_reward - β·reverse_kl;留作未来工作。

  8. 单卡 24G 能跑吗?
    8B 学生 + 32B 教师同时加载需 60G+,建议教师端用 CPU offload 或离线落盘 logprob 后再单卡训练。