Skip to yearly menu bar Skip to main content


Poster
in
Workshop: Workshop on Sparsity in LLMs (SLLM): Deep Dive into Mixture of Experts, Quantization, Hardware, and Inference

RLMedusa: Reinforcement Learning for Multiple Decoding Heads to Accelerate LLM Inference

Aadit Juneja · Parsa Idehpour


Abstract:

Traditional transformer inference requires step-by-step generation of tokens in which each step is dependent on the previous one, presenting a bottleneck in inference speed. The Medusa technique used LoRA fine-tuning to train multiple decoding heads, each predicting a different number of tokens in advance in order to generate multiple tokens in parallel as part of a draft model that the base model can verify. In this paper, we propose a reinforcement learning based approach to training multiple decoding heads. Our method proposes a reward model scheme that leverages feed-forward networks to estimate token probabilities based on context hidden states and candidate token embeddings. We provide commentary comparing our interpretation of reinforcement learning in language modeling research and how this contrasts with traditional, RLHF-centric interpretations, as well as discuss our experiments with RLMedusa.

Chat is not available.