Poster
DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference
Jinwei Yao · Kaiqi Chen · Kexun Zhang · Jiaxuan You · Binhang Yuan · Zeke Wang · Tao Lin
Hall 3 + Hall 2B #197
Abstract:
Large language models (LLMs) are increasingly employed for complex tasks that process multiple generation calls in a tree structure with shared prefixes of tokens, including few-shot prompting, multi-step reasoning, speculative decoding, etc. However, existing inference systems for tree-based applications are inefficient due to improper partitioning of queries and KV cache during attention calculation.This leads to two main issues: (1) a lack of memory access (IO) reuse for KV cache of shared prefixes, and (2) poor load balancing.As a result, there is redundant KV cache IO between GPU global memory and shared memory, along with low GPU utilization. To address these challenges, we propose DeFT(Decoding with Flash Tree-Attention), a hardware-efficient attention algorithm with prefix-aware and load-balanced KV cache partitions. DeFT reduces the number of read/write operations of KV cache during attention calculation through **KV-Guided Grouping**, a method that avoids repeatedly loading KV cache of shared prefixes in attention computation. Additionally, we propose **Flattened Tree KV Splitting**, a mechanism that ensures even distribution of the KV cache across partitions with little computation redundancy, enhancing GPU utilization during attention computations. By reducing 73-99% KV cache IO and nearly 100% IO for partial results during attention calculation, DeFT achieves up to 2.23/3.59× speedup in the end-to-end/attention latency across three practical tree-based workloads compared to state-of-the-art attention algorithms. Our code is available at https://github.com/LINs-lab/DeFT.
Live content is unavailable. Log in and register to view live content