Abstract:Transformers, driven by attention mechanisms, form the foundation of large language models (LLMs). As these models scale up, efficient GPU attention kernels become essential for high-throughput and low-latency inference. Diverse LLM applications demand flexible and high-performance attention solutions. We present FlashInfer: a customizable and efficient attention engine for LLM serving. FlashInfer tackles KV-cache storage heterogeneity using block-sparse format and composable formats to optimize memory access and reduce redundancy. It also offers a customizable attention template, enabling adaptation to various settings through Just-In-Time (JIT) compilation. Additionally, FlashInfer's load-balanced scheduling algorithm adjusts to dynamism of user requests while maintaining compatibility with CUDAGraph which requires static configuration. FlashInfer have been integrated into leading LLM serving frameworks like SGLang, vLLM and MLC-Engine. Comprehensive kernel-level and end-to-end evaluations demonstrate FlashInfer's ability to significantly boost kernel performance across diverse inference scenarios: compared to state-of-the-art LLM serving solutions, FlashInfer achieve 29-69% inter-token-latency reduction compared to compiler backends for LLM serving benchmark, 28-30% latency reduction for long-context inference, and 13-17% speedup for LLM serving with parallel generation.
Abstract:We present JaxPP, a system for efficiently scaling the training of large deep learning models with flexible pipeline parallelism. We introduce a seamless programming model that allows implementing user-defined pipeline schedules for gradient accumulation. JaxPP automatically distributes tasks, corresponding to pipeline stages, over a cluster of nodes and automatically infers the communication among them. We implement a MPMD runtime for asynchronous execution of SPMD tasks. The pipeline parallelism implementation of JaxPP improves hardware utilization by up to $1.11\times$ with respect to the best performing SPMD configuration.
Abstract:PyPM is a Python-based domain specific language (DSL) for building rewrite-based optimization passes on machine learning computation graphs. Users define individual optimizations by writing (a) patterns that match subgraphs of a computation graph and (b) corresponding rules which replace a matched subgraph with an optimized kernel. PyPM is distinguished from the many other DSLs for defining rewriting passes by its complex and novel pattern language which borrows concepts from logic programming. PyPM patterns can be recursive, nondeterminstic, and can require checking domain-specific constraints such as the shapes of tensors. The PyPM implementation is thus similarly complicated, consisting of thousands of lines of C++ code. In this paper, we present our work on building PyPM, as well as formalizing and distilling and this complexity to an understandable mathematical core. We have developed a formal core calculus expressing the main operations of the PyPM pattern language. We define both a declarative semantics - describing which patterns match which terms - and an algorithmic semantics - an idealized version of the PyPM pattern interpreter - and prove their equivalence. The development is fully mechanized in the Coq proof assistant.