Adaptive Order Policies for Masked Diffusion
Abstract
Masked diffusion models have seen great success in capturing data distributions over discrete sequences in domains such as text and proteins. These models generate data by iteratively unmasking tokens starting from a fully masked sequence, with the unmasking order typically chosen at random or using a heuristic based on denoiser probabilities. In this work, we propose a scheme for learning the unmasking order using an additional lightweight policy network on top of an existing diffusion model. Our proposed loss reweights terms in the masked diffusion loss according to policy probabilities, and results in a policy that prefers positions where the denoiser is more likely to be correct. We demonstrate that our approach outperforms common heuristics on problems that are sensitive to token ordering, such as Sudoku and Boolean satisfiability (3-SAT).