Poster
Transformers Provably Learn Two-Mixture of Linear Classification via Gradient Flow
Hongru Yang · Zhangyang Wang · Jason Lee · Yingbin Liang
Hall 3 + Hall 2B #440
Understanding how transformers learn and utilize hidden connections between tokens is crucial to understand the behavior of large language models.To understand this mechanism, we consider the task of two-mixture of linear classification which possesses a hidden correspondence structure among tokens, and study the training dynamics of a symmetric two-headed transformer with ReLU neurons.Motivated by the stage-wise learning phenomenon in our experiments, we design and theoretically analyze a three-stage training algorithm, which can effectively characterize the actual gradient descent dynamics when we simultaneously train the neuron weights and the softmax attention.The first stage is a neuron learning stage, where the neurons align with the underlying signals. The second stage is a attention feature learning stage, where we analyze the feature learning process of how the attention learns to utilize the relationship between the tokens to solve certain hard samples.In the meantime, the attention features evolve from a nearly non-separable state (at the initialization) to a well-separated state.The third stage is a convergence stage, where the population loss is driven towards zero.The key technique in our analysis of softmax attention is to identify a critical sub-system inside a large dynamical system and bound the growth of the non-linear sub-system by a linear system. Finally, we discuss the setting with more than two mixtures. We empirically show the difficulty of generalizing our analysis of the gradient flow dynamics to the case even when the number of mixtures equals three, although the transformer can still successfully learn such distribution. On the other hand, we show by construction that there exists a transformer that can solve mixture of linear classification given any arbitrary number of mixtures.
Live content is unavailable. Log in and register to view live content