Skip to yearly menu bar Skip to main content


Poster

Transformers Learn to Implement Multi-step Gradient Descent with Chain of Thought

Jianhao Huang · Zixuan Wang · Jason Lee

Hall 3 + Hall 2B #446
[ ]
Thu 24 Apr 7 p.m. PDT — 9:30 p.m. PDT

Abstract:

Chain of Thought (CoT) prompting has been shown to significantly improve the performance of large language models (LLMs), particularly in arithmetic and reasoning tasks, by instructing the model to produce intermediate reasoning steps. Despite the remarkable empirical success of CoT and its theoretical advantages in enhancing expressivity, the mechanisms underlying CoT training remain largely unexplored. In this paper, we study the training dynamics of transformers over a CoT objective on a in-context weight prediction task for linear regression. We prove that while a one-layer linear transformer without CoT can only implement a single step of gradient descent (GD) and fails to recover the ground-truth weight vector, a transformer with CoT prompting can learn to perform multi-step GD autoregressively, achieving near-exact recovery. Furthermore, we show that the trained transformer effectively generalizes on the unseen data. Empirically, we demonstrate that CoT prompting yields substantial performance improvements.

Live content is unavailable. Log in and register to view live content