Kvax: Fast and easy-to-use Flash Attention implementation for JAX
Expo Talk Panel
Garnet 212&219
Sergei Skvortsov
Kvax is a custom FlashAttention implementation for JAX, optimised for long-context training with efficient document mask computation and context parallelism. This talk explores the key ideas behind its implementation, focusing on document mask performance optimisations and context parallelism.
Chat is not available.
Successful Page Load