Oral
in
Workshop: How Far Are We From AGI
DEFT: FLASH TREE-ATTENTION WITH IO-AWARENESS FOR EFFICIENT TREE-SEARCH-BASED LLM INFERENCE
Jinwei Yao · Kexun Zhang · Kaiqi Chen · Jiaxuan You · Zeke Wang · Binhang Yuan · Tao Lin
Keywords: [ LLM inference,Tree-based Decoding ] [ Tree Attention ] [ memory efficiency ]
Decoding using tree search can greatly enhance the inference quality for transformer-based Large Language Models (LLMs). Depending on the guidance signal, it searches for the best path from root to leaf in the tree by forming LLM outputs to improve controllability, reasoning ability, alignment, et cetera. However, current tree decoding strategies and their inference systems do not suit each other well due to redundancy in computation, memory footprints, and IOs for a tree of generated sequences, resulting in inefficient inference. To address this issue, we propose DEFT, an IO-aware tree attention algorithm that maintains memory-efficient attention calculation with low memory footprints. DEFT can achieve a speedup of 1.7-2.4 × across two practical reasoning tasks over the SOTA attention algorithms.