2 min read

LLM推理解码:单次All-reduce融合核函数突破延迟瓶颈

大规模语言模型(LLM)在生产环境中的推理,尤其关注低延迟。在LLM的解码阶段,时间-到-下一个token是关键优化指标。为降低延迟,通常采用跨GPU进行张量并行化,尤其是在MLP和投影GEMM层。然而,在解码阶段,小消息尺寸和计算需求意味着静态开销(如核函数调用、通信设置)会主导总时间。

为应对此挑战,我们开发了一种创新的单次(single-shot)All-reduce算法,替代传统的环形(ring)All-reduce。环形算法在大消息尺寸时效率高,但在小消息尺寸(如约30 KB/s)时,其多次数据交换和同步开销会显著增加延迟。新算法通过一次性聚合和归约数据,并利用双向NVLink实现并行通信,有效降低了整体通信延迟。此外,通过`cudaDeviceEnablePeerAccess`,我们避免了额外的内存拷贝,直接访问对等GPU的缓冲区,特别适用于单节点多GPU环境。

该单次All-reduce核函数进一步与层归一化和逐点加法操作融合,形成单一的CUDA C++核函数。这种融合不仅最小化了核函数启动开销和HBM内存传输,还通过JAX的FFI(Foreign Function Interface)集成到模型中。实践证明,融合核函数比独立核函数带来约3倍的核函数时间加速,并使解码阶段的端到端延迟降低约27%。结合CUDA Graph,进一步实现了5%的延迟改进。

未来的优化方向包括NCCL 2.27+中的对称内存模型,以及利用NVIDIA OpenSHMEM库的GPU启动通信API,实现计算与通信的交错,以隐藏通信延迟。Mosaic-GPU DSL也支持表达此类融合模式,用于张量并行GEMM或MoE中的专家并行Grouped GEMM。

Optimizing for Low-Latency Communication in Inference Workloads with JAX and XLA | NVIDIA Technical Blog
Running inference with large language models (LLMs) in production requires meeting stringent latency constraints. A critical stage in the process is LLM decode, where time-to-next-token becomes a…
订阅情报