Poster
Transformers Provably Solve Parity Efficiently with Chain of Thought
Juno Kim · Taiji Suzuki
Hall 3 + Hall 2B #445
[
Abstract
]
Oral
presentation:
Oral Session 6C
Sat 26 Apr 12:30 a.m. PDT — 2 a.m. PDT
Fri 25 Apr 7 p.m. PDT
— 9:30 p.m. PDT
Sat 26 Apr 12:30 a.m. PDT — 2 a.m. PDT
Abstract:
This work provides the first theoretical analysis of training transformers to solve complex problems by recursively generating intermediate states, analogous to fine-tuning for chain-of-thought (CoT) reasoning. We consider training a one-layer transformer to solve the fundamental kk-parity problem, extending the work on RNNs by \citet{Wies23}. We establish three key results: (1) any finite-precision gradient-based algorithm, without intermediate supervision, requires substantial iterations to solve parity with finite samples. (2) In contrast, when intermediate parities are incorporated into the loss function, our model can learn parity in one gradient update when aided by \emph{teacher forcing}, where ground-truth labels of the reasoning chain are provided at each generation step. (3) Even without teacher forcing, where the model must generate CoT chains end-to-end, parity can be learned efficiently if augmented data is employed to internally verify the soundness of intermediate steps. Our findings, supported by numerical experiments, show that task decomposition and stepwise reasoning naturally arise from optimizing transformers with CoT; moreover, self-consistency checking can improve multi-step reasoning ability, aligning with empirical studies of CoT.
Live content is unavailable. Log in and register to view live content