Expo Talk Panel
Kvax: Fast and easy-to-use Flash Attention implementation for JAX
Sergei Skvortsov
Abstract:
Kvax is a custom FlashAttention implementation for JAX, optimised for long-context training with efficient document mask computation and context parallelism. This talk explores the key ideas behind its implementation, focusing on document mask performance optimisations and context parallelism.
Chat is not available.
Successful Page Load