Task-Specific Knowledge Distillation via Intermediate Probes
Abstract
Standard knowledge distillation from large language models (LLMs) assumes the teacher's output distribution is a high-quality training signal. On reasoning tasks, this assumption is frequently violated as a model's intermediate representations may encode the correct answer, yet this information is lost or distorted through the vocabulary projection, where prompt formatting and answer-token choices create a brittle, noisy interface. We introduce ProbeKD, a distillation framework that bypasses this bottleneck by training lightweight probes on frozen teacher hidden states and using the probe's predictions, rather than output logits, as supervision for student training. This simple change yields consistent improvements across four reasoning benchmarks (AQuA-RAT, ARC Easy/Challenge, and MMLU), with gains most pronounced under limited data. The key mechanism is that probes trained on intermediate representations provide cleaner labels than the teacher's own outputs, effectively denoising the distillation signal. ProbeKD requires no architectural changes to student or teacher, is architecture-agnostic, and adds minimal compute since probe training is cheap and teacher representations can be cached. By tapping into internal representations, ProbeKD enables practitioners to extract more value from large teacher models without additional training data or architectural complexity.