Skip to yearly menu bar Skip to main content


Invited Talk
in
Workshop: Bridging the Gap Between Practice and Theory in Deep Learning

Invited Talk 2 : Transformers learn in-context by implementing gradient descent

Suvrit Sra


Abstract:

We study the theory of context learning, for which we investigate how Transformers can implement learning algorithms in their forward pass. We show that a linear attention Transformer naturally learns to implement gradient descent, which enables it to learn linear functions in-context. More generally, we show that a (non-linear attention based) Transformer can implement functional gradient descent with respect to some RKHS metric, which allows it to learn a broad class of nonlinear functions in-context. We show that the RKHS metric is determined by the choice of attention activation, and that the optimal choice of attention activation depends in a natural way on the class of functions that need to be learned.

Chat is not available.