Invited Talk
in
Workshop: Bridging the Gap Between Practice and Theory in Deep Learning
Invited Talk 2 : Transformers learn in-context by implementing gradient descent
Suvrit Sra
Abstract:
We study the theory of context learning, for which we investigate how Transformers can implement learning algorithms in their forward pass. We show that a linear attention Transformer naturally learns to implement gradient descent, which enables it to learn linear functions in-context. More generally, we show that a (non-linear attention based) Transformer can implement functional gradient descent with respect to some RKHS metric, which allows it to learn a broad class of nonlinear functions in-context. We show that the RKHS metric is determined by the choice of attention activation, and that the optimal choice of attention activation depends in a natural way on the class of functions that need to be learned.
Chat is not available.