Efficient Autoregressive Inference for Transformer Probabilistic Models
Conor Hassan ⋅ Nasrulloh Satrio ⋅ Cen-You Li ⋅ Daolang Huang ⋅ Paul Chang ⋅ Yang Yang ⋅ Francesco Silvestrin ⋅ Samuel Kaski ⋅ Luigi Acerbi
Abstract
Set-based transformer models for amortized probabilistic inference and meta-learning, such as neural processes, prior-fitted networks, and tabular foundation models, excel at single-pass _marginal_ prediction. However, many applications require _joint distributions_ over multiple predictions. Purely autoregressive architectures generate these efficiently but sacrifice flexible set-conditioning. Obtaining joint distributions from set-based models requires re-encoding the entire context at each autoregressive step, which scales poorly. We introduce a _causal autoregressive buffer_ that combines the strengths of both paradigms. The model encodes the context once and caches it; a lightweight causal buffer captures dependencies among generated targets, with each new prediction attending to both the cached context and all previously predicted targets added to the buffer. This enables efficient batched autoregressive sampling and joint predictive density evaluation. Training integrates set-based and autoregressive modes through masked attention at minimal overhead. Across synthetic functions, EEG time series, a Bayesian model comparison task, and tabular regression, our method closely matches the performance of full context re-encoding while delivering up to $20\times$ faster joint sampling and density evaluation, and up to $7\times$ lower memory usage.
Successful Page Load