STEM: SCALING TRANSFORMERS WITH EMBEDDING MODULES
Ranajoy Sadhukhan · Sheng Cao · Harry Dong · Changsheng Zhao · Attiano Purpura-Pontoniere · Yuandong Tian · Zechun Liu · Beidi Chen
Abstract
Fine-grained sparsity promises higher parametric capacity without proportional per-token compute, but often suffers from training instability, load balancing, and communication overhead. We introduce \textbf{STEM} (\emph{Scaling Transformers with Embedding Modules}), a static, token-indexed approach that replaces the FFN up-projection with a layer-local embedding lookup while keeping the gate and down-projection dense. This removes runtime routing, enables CPU offload with asynchronous prefetch, and decouples capacity from both per-token FLOPs and cross-device communication. Empirically, STEM trains stably despite extreme sparsity. It improves downstream performance over dense baselines while reducing per-token FLOPs and parameter accesses (eliminating roughly one-third of FFN parameters). STEM learns embedding spaces with large angular spread which enhances it knowledge storage capacity. In addition, STEM strengthens long-context performance: as sequence length grows, more distinct parameters are activated, yielding practical test-time capacity scaling. Across 350M and 1B model scales, STEM delivers up to $\sim$3--4\% improvements in average downstream performance, with notable gains on knowledge and reasoning-heavy benchmarks (ARC-Challenge, OpenBookQA, GSM8K, MMLU). Overall, STEM is an effective way of scaling parametric memory while remaining simpler to train and deploy than existing fine-grained sparse models.
Successful Page Load