Poster
FlashMask: Efficient and Rich Mask Extension of FlashAttention
Guoxia Wang · Jinle Zeng · Xiyuan Xiao · Siming Wu · Jiabin Yang · Lujing Zheng · Zeyu Chen · Jiang Bian · Dianhai Yu · Haifeng Wang
Hall 3 + Hall 2B #137
[
Abstract
]
Fri 25 Apr midnight PDT
— 2:30 a.m. PDT
Abstract:
The computational and memory demands of vanilla attention scale quadratically with the sequence length , posing significant challenges for processing long sequences in Transformer models. FlashAttention alleviates these challenges by eliminating the memory dependency and reducing attention latency through IO-aware memory optimizations. However, its native support for certain attention mask types is limited, and it does not inherently accommodate more complex masking requirements. Previous approaches resort to using dense masks with memory complexity, leading to inefficiencies. In this paper, we propose \ours{}, an extension of FlashAttention that introduces a column-wise sparse representation of attention masks. This approach efficiently represents a wide range of mask types and facilitates the development of optimized kernel implementations. By adopting this novel representation, \ours{} achieves linear memory complexity , making it suitable for modeling long-context sequences. Moreover, this representation enables kernel optimizations that eliminate unnecessary computations by leveraging sparsity in the attention mask, without sacrificing computational accuracy, resulting in higher computational efficiency. We evaluate \ours{}'s performance in fine-tuning and alignment training of LLMs such as SFT, LoRA, DPO, and RM. \ours{} achieves significant throughput improvements, with end-to-end speedups ranging from 1.65x to 3.22x compared to existing FlashAttention dense method. Additionally, our kernel-level comparisons demonstrate that \ours{} surpasses the latest counterpart, FlexAttention, by 12.1\% to 60.7\% in terms of kernel TFLOPs/s, achieving 37.8\% to 62.3\% of the theoretical maximum FLOPs/s on the A100 GPU. The code is open-sourced on PaddlePaddle\footnote{\url{https://github.com/PaddlePaddle/Paddle}} and integrated into PaddleNLP\footnote{\url{https://github.com/PaddlePaddle/PaddleNLP}}, supporting models with over 100 billion parameters for contexts extending up to 128K tokens.
Live content is unavailable. Log in and register to view live content