Site icon Efficient Coder

On-Policy Distillation: The Cheap Way to Supercharge Small Language Models

Teaching Models to Correct Themselves: A Complete Guide to On-Policy Distillation

What is the cheapest way to make a small language model as good as a big one at narrow tasks?
Let the small model generate its own answers, then let the big model grade every single token in real time. On-policy distillation does exactly this—online, dense, and 5-30× cheaper than RL.


Table of Contents

  1. Why Post-Training Needs a Third Way
  2. Algorithm in One Breath
  3. Math Reasoning: 60 % → 70 % with 1/10 the GPU Hours
  4. Company Assistant: Add Private Knowledge, Then Get Chat Skills Back for Free
  5. Author’s Reflection: Four Traps We Fell Into
  6. Action Checklist & One-Page Overview
  7. FAQ

1. Why Post-Training Needs a Third Way

Core question: If supervised fine-tuning (SFT) and reinforcement learning (RL) already exist, why invent on-policy distillation?

Method Data Source Reward Density Key Pain Point
SFT (off-policy) teacher trajectories high per-token compounding error when student leaves teacher manifold
RL (on-policy) student roll-outs 1 bit per answer sample-inefficient, credit-assignment nightmare
On-policy distillation student roll-outs high per-token needs teacher log-probs, but GPU-parallelisable

SFT teaches the student to copy the teacher, but only inside states the teacher has seen.
RL lets the student learn from its own mistakes, but only tells it “right/wrong” after the whole episode.
On-policy distillation samples trajectories from the student and immediately supplies a dense teacher signal for every token—combining the best of both worlds.

Application vignette:
Imagine teaching a 4-billion-parameter model to solve AMC math problems. With RL, you generate 256 roll-outs, check the final answer, and back-propagate a single reward. The model knows “21” is wrong but has no idea which of the 200 tokens was the first mistake. With on-policy distillation, the 32B teacher gives a negative score to the exact token where the student first mis-applied the quadratic formula—like a tutor circling the wrong line in red pen.


2. Algorithm in One Breath

Core question: What exactly do I need to change in my training script?

2.1 Loss = reverse KL per token

reverse_kl[t] = log π_θ(x_t | x_<t) − log π_teacher(x_t | x_<t)
loss = mean(reverse_kl)
  • No discounting (γ=0) worked best in our runs.
  • Teacher runs once forward; student does the sampling.

2.2 Four lines of pseudocode (Tinker API shown)

teacher_logp = teacher.compute_logprobs(student_rollout)
student_logp = student_rollout.logprobs
advantage = -(student_logp - teacher_logp)
trainer.update(advantage, loss_fn="importance_sampling")

Application vignette:
We started from an RL script that already computed student log-probs for KL-penalty against a frozen reference. We literally swapped the reference model path to the teacher checkpoint and replaced KL-penalty with reverse KL against teacher. The rest of the hyper-parameters (batch 64, 4 roll-outs per prompt, LoRA rank 128) stayed identical. Training loss touched zero in 150 steps.


3. Math Reasoning: 60 % → 70 % with 1/10 the GPU Hours

Core question: How much compute does it really take to push an 8B model from 60 % to 70 % on AIME’24?

3.1 Setup

  • Student: Qwen3-8B-Base
  • Teacher: Qwen3-32B
  • Starting checkpoint: 400 k prompts SFT → 60 % AIME
Route Estimated Samples Teacher FLOP Student FLOP AIME’24 Relative Cost
More SFT (2 M) ~2 M 3.4×10²¹ 1.5×10²¹ 70 %
RL (Qwen3 report) 17 920 GPU h 68 % ~1×
On-policy distill 77 k ×4 8.4×10¹⁹ 8.2×10¹⁹ 70 % 0.11×

3.2 Learning curve snapshot

  • SFT follows log-linear scaling: cheap at first, brutally expensive later.
  • Distillation saturates in <150 steps; reverse KL drops below 0.002.

Application vignette:
Our cluster queue was full, so we ran the teacher on 8×A100-40G and the student update on 4×A100-80G. Because teacher log-probs are embarrassingly parallel, we finished the 150-step experiment overnight instead of the usual week-long RL window.


4. Company Assistant: Add Private Knowledge, Then Get Chat Skills Back for Free

Core question: After I fine-tune on internal documents, IF-eval drops. Must I rerun costly RL?

4.1 Task definition

  • Knowledge eval: 43 internal QA pairs (completely unseen during pre-training).
  • Behaviour eval: IF-eval (541 format instructions).

4.2 Mid-training outcome (70 % docs + 30 % chat SFT)

  • Knowledge ↑ 18 % → 36 %
  • IF-eval ↓ 85 % → 79 % and still falling

4.3 Self-distillation rescue

Use the original weights as teacher, run on-policy distillation on open-domain chat prompts for one more round:

  • IF-eval recovers to 83 % (–2 % vs original).
  • Knowledge stays at 41 %, even improves slightly.
Stage Internal QA IF-eval Notes
Base Qwen3-8B 18 % 85 % zero private knowledge
+SFT (70 % docs) 36 % 79 % knowledge up, format down
+distill (self) 41 % 83 % format back, knowledge kept

Application vignette:
Legal team pushed 3 k new clauses on a Friday. We injected them via SFT over the weekend, saw IF-eval slip Monday morning, and ran self-distillation during lunch break. The assistant passed both the knowledge regression test and the format regression test before the coffee cooled.


5. Author’s Reflection: Four Traps We Fell Into

  1. LoRA rank too small
    rank=8 loses 15 % AIME vs full fine-tune; rank=32 only 6 % and still 3× faster.

  2. Teacher real-time bottleneck
    We now cache teacher log-probs to disk and replay them multi-GPU—3× end-to-end speed-up.

  3. Single-prompt over-fear
    Distillation targets distribution, not exact answers. Training 20 epochs on one prompt copied the teacher’s style without memorising a single solution.

  4. Zero-temperature sampling
    Identical trajectories give zero gradient variance. Bumping temperature to 0.3 stabilised convergence.


6. Action Checklist & One-Page Overview

Implementation Steps

  1. Pick a student checkpoint that at least produces legal JSON/format.
  2. Ensure teacher exposes compute_logprobs(batch_of_tokens)—no backward pass needed.
  3. Sampling loop: temperature 0.3, 2–8 roll-outs per prompt, store token-level log-probs.
  4. Training loop: importance sampling, advantage = −(student_logp − teacher_logp), LR 5e-7, LoRA rank ≥32.
  5. Stop when reverse KL < 0.002 or downstream metric plateaus.
  6. For continual learning: alternate SFT (new knowledge) → self-distil (old behaviour).

One-Page Overview

| Essence | Student rolls out → teacher scores each token → minimise reverse KL. |
| Benefit | On-policy纠错 + 密集奖励; 5-30× cheaper than RL. |
| Speed | 150 steps, 77 k prompts, 1 800 GPUh to gain +10 % AIME. |
| Continual Learning | Self-distil restores IF-eval after SFT; no human labels, no RL. |
| Code Change | 4 lines; swap reward model to teacher logprob inside any RL framework. |


7. FAQ

  1. Reverse vs forward KL—why prefer reverse?
    Reverse KL mode-seeks; student matches teacher’s best mode instead of covering all low-prob regions.

  2. Does the teacher have to be larger?
    No. The original weights can teach the SFT-damaged version even if size is identical.

  3. Can I discard the RL framework entirely?
    Yes—just backward through reverse_kl.mean()—but RL scripts already offer logging, checkpointing, and importance sampling for free.

  4. Will the student copy teacher biases?
    Yes, systematically wrong teacher tokens will transfer. Use ensemble teachers or later RL polishing if critical.

  5. Is on-policy distillation safe for multi-epoch training?
    Empirically yes. Because the objective is distribution-matching, not exact-answer cloning, we observed no memorisation even after 20 epochs on a single prompt.

  6. Minimum GPU memory for 8B student + 32B teacher?
    ~60 GB if loaded together. Off-load teacher to CPU or pre-compute log-probs to cut requirement to 24 GB.

  7. Any special tokenisers needed?
    Both models must share vocabulary; if not, remap logits with a simple alignment matrix before computing KL.

Exit mobile version