Skip to yearly menu bar Skip to main content


Poster

FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores

Dan Fu · Hermann Kumbong · Eric Nguyen · Christopher Re

Halle B #140

Abstract: Convolution models with long filters have demonstrated state-of-the-art reasoning abilities in many long-sequence tasks but lag behind the most optimized Transformers in wall-clock time.A major bottleneck is the Fast Fourier Transform (FFT)---which allows long convolutions to run in $O(N\log N)$ time in sequence length $N$ but has poor hardware utilization.In this paper, we study how to optimize the FFT convolution.We find two key bottlenecks: the FFT does not effectively use specialized matrix multiply units, and it incurs expensive I/O between layers of the memory hierarchy.In response, we propose FlashFFTConv.FlashFFTConv uses a matrix decomposition that computes the FFT using matrix multiply units and enables kernel fusion for long sequences, reducing I/O.We also present two sparse convolution algorithms---1) partial convolutions and 2) frequency-sparse convolutions---which can be implemented simply by skipping blocks in the matrix decomposition, enabling further opportunities for memory and compute savings.FlashFFTConv speeds up exact FFT convolutions by up to 8.7$\times$ over PyTorch and achieves up to 4.4$\times$ speedup end-to-end.Given the same compute budget, FlashFFTConv allows Hyena-GPT-s to achieve 2.3 points better perplexity and M2-BERT-base to achieve 3.3 points higher GLUE score---matching models with twice the parameter count.FlashFFTConv also achieves 96.1% accuracy on Path-512, a high-resolution vision task where no model had previously achieved better than 50%.Furthermore, partial convolutions enable longer-sequence models---yielding the first DNA model that can process the longest human genes (2.3M base pairs)---and frequency-sparse convolutions speed up pretrained models while maintaining or improving model quality.

Chat is not available.