Approximating Function Space Distance for Continual Learning in Transformers
Abstract
Measuring how neural network functions evolve during training, finetuning, or editing is critical for several applications. Such shifts can be formalized via a function space distance (FSD) — the expected squared difference in network outputs under a data distribution — but computing the true FSD requires dataset access that is often infeasible. The previously proposed Linearized Activation Function TRick (LAFTR) circumvents this challenge via specific approximations for linear networks with ReLU activations. We extend this to a more general LInearized Function TRick (LIFTR) to enable data-free FSD estimation for arbitrary architectures, with particular focus on transformers. Our approach decomposes FSD estimation into moment propagation using only pre-computed activation statistics of the data, resulting in a modular implementation that easily generalizes to arbitrary functions. On a modular arithmetic continual learning task, we show that a stochastic variant of LIFTR approaches oracle performance while outperforming parameter-space linearization baselines. LIFTR estimates correlate strongly with oracle FSD and produce better-aligned gradients than competing methods. We further demonstrate that LIFTR degrades more gracefully with network depth than global parameter-space linearization.