Skip to yearly menu bar Skip to main content


Poster

Inference-Aware Fine-Tuning for Best-of-N Sampling in Large Language Models

Yinlam Chow · Guy Tennenholtz · Izzeddin Gur · Vincent Zhuang · Bo Dai · Aviral Kumar · Rishabh Agarwal · Sridhar Thiagarajan · Craig Boutilier · Aleksandra Faust

Hall 3 + Hall 2B #204
[ ]
Sat 26 Apr midnight PDT — 2:30 a.m. PDT

Abstract:

Recent studies indicate that effectively utilizing inference-time compute is crucial for attaining good performance from large language models (LLMs). Specifically, the Best-of-N (BoN) inference strategy, where an LLM generates multiple responses and a verifier selects the best, has shown strong empirical performance. Motivated by this, we develop a novel inference-aware fine-tuning paradigm, which encompasses the BoN-aware inference framework as a special case. We devise the first imitation learning and reinforcement learning (RL) methods for fine-tuning LLMs using BoN, overcoming the challenging, non-differentiable argmax operator in BoN. We empirically demonstrate that our BoN-aware models implicitly learn a per-example "meta-strategy", which interleaves best responses with more diverse responses that might be better suited to a test-time input—a process reminiscent of the exploration-exploitation trade-off in RL. Our experiments demonstrate the effectiveness of BoN-aware fine-tuning in terms of improved performance and inference-time compute. In particular, we show that our methods improve the BoN performance of Gemma 2B on Hendrycks MATH from 26.8% to 30.8%, and Pass@K from 60% to 67%.

Live content is unavailable. Log in and register to view live content