SNaX: sparse narrow accelerated mixture of experts
Wentao Guo · Mayank Mishra · Xinle Cheng · Ion Stoica · Tri Dao
Abstract
Mixture of Experts (MoE) models have emerged as the de-facto architecture for scaling up language models without significantly increasing the computational cost. Existing MoE methods optimize system efficiency or model architecture independently. We show that as MoE models get more granular and sparser, they become more memory-bound, and jointly optimizing the algorithms and the kernel design leads to a major improvement in MoE training throughput. We first propose a memory-efficient algorithm to compute the forward and backward of MoE with minimal activation saved. We then design GPU kernels that overlap memory IO latency with compute, benefiting all MoE architectures. Finally, we propose a novel "token rounding" method that minimizes the wasted compute brought by tile quantization. As a result, our method SNaX reduces activation memory by 45% and has 1.80x throughput improvement on NVidia H100 GPUs compared to ScatterMoE for a fine-grained 7B MoE. Moreover, SNaX on 64 H100s achieves a training throughput of 213 billion tokens a day comparable to ScatterMoE's 225 billion tokens a day on 96 H100s for a 7B MoE model training with token-choice routing while training with FSDP-2. Under high MoE sparsity settings, our tile-aware token rounding algorithm yields an additional 1.18x speedup on kernel execution time compared to vanilla top-$K$ routing while maintaining similar downstream performance.
Successful Page Load