DNT: a Deeply Normalized Transformer that can be trained by Momentum SGD
Abstract
Transformers have become the de facto backbone of modern deep learning, yet their training typically demands an advanced optimizer with adaptive learning rate like AdamW, rather than a momentum SGDW (mSGDW). Previous works show that it is mainly due to a heavy-tailed distribution of the gradients. In this paper, we introduce a Deeply Normalized Transformer (DNT), that is meticulously engineered to overcome the heavy-tailed gradients issue, enabling seamless training with vanilla mSGDW while yielding comparable performance to the Transformers trained via AdamW. Specifically, in DNT, we strategically integrate normalization techniques at proper positions in the Transformers to effectively modulate the Jacobian matrices of each layer, balance the influence of weights, activations, and their interactions, and thus enable the distributions of gradients concentrated. We provide both theoretical justifications of the normalization technique used in our DNT and extensive empirical evaluation on two popular Transformer architectures (\ie, ViT and GPT), validating that: a) DNT can be effectively trained with a vanilla mSGDW; and b) DNT outperforms its counterparts.