[vLLM vs TensorRT-LLM] #8. KV Cache Quantization
This article provides a comparative analysis of the effects of KV cache quantization on vLLM and TensorRT-LLM frameworks.
Nov 18, 2024
Why KV Cache Quantization?
As the demand for long-context support of LLM continues to grow, managing the memory footprint of the KV cache has become increasingly challenging. In many cases, the memory required for storing the KV cache exceeds the memory needed for the model parameters themselves. This heavy memory consumption constrains the ability to increase batch size, which is critical for improving throughput.
KV cache compression has thus emerged as an important technique, with related approaches generally falling into two categories: pruning and quantization. Pruning removes the KV vectors of tokens deemed less important, leading to complete removal of memory requirement or computation for those pruned tokens. On the other hand, quantization converts the KV cache into a lower-precision format, so that memory usage for KV cache storing can be reduced.
Similar to weight-only quantization or weight-activation quantization, KV cache quantization also involves a trade-off between throughput improvement and accuracy. Several research works have explored quantizing KV cache to 4-bit or even 2-bit precisions, but these often result in noticeable accuracy degradation, such as degraded MMLU scores. While quantizing 16-bit KV caches to 4-bit or 2-bit can reduce their size by 4 to 8 times, quality issues still remain, making it challenging for practical deployment in real service environments. Therefore, both vLLM and TensorRT-LLM currently support 8-bit KV cache quantization, and in this post, we will focus on exploring 8-bit KV cache quantization in detail.
KV Cache Quantization Choices in vLLM and TensorRT-LLM
Although we focus on 8-bit KV cache quantization only in this post, there are still several options to consider when quantizing KV caches in vLLM and TensorRT-LLM. The first is to select the target datatype for quantization. vLLM supports two FP8 datatypes: E4M3 and E5M2, but does not support INT8 KV cache. TensorRT-LLM, on the other hand, supports both FP8 (E4M3) and INT8 KV cache. Depending on which format to choose, there may or may not be a throughput improvement.
Another decision to be made is whether to perform the attention computation in low precision also. While quantizing KV cache involves converting the KV vectors into a low-precision format for storage, this does not imply that all operations within the attention mechanism are performed in low precision. In practice, attention operations are usually executed in high-precision floating-point formats (BF16, FP16, FP32). Thus, when KV caches are stored in quantized formats, they are dequantized back to a higher precision format before being used.
However, TensorRT-LLM provides an option to perform attention operations using FP8 precision if the model and KV cache are all in FP8 format. This feature can be enabled by setting the flag
--use_fp8_context_fmha=True
when building the TensorRT-LLM engine. By enabling this option, KV caches are not only stored in quantized format but also used for attention computation in low-precision format, leading to potential speedups in throughput.vLLM, in contrast, does not support low-precision attention computations but offers multiple options for selecting the attention backend. It is worth noting that the highly optimized FlashAttention-2 backend, which is the default backend in vLLM, does not support quantized KV cache. Therefore, users have to choose the XFormers or FlashInfer backends alternatively.
Experiment Setup
Benchmark Dataset
For all experiments, we used two standardized datasets, each with predetermined input and output token lengths to ensure consistent token processing across different frameworks.
Two distinct datasets are:
- Prefill-heavy: Input length of 2,048 tokens and output length of 128 tokens
- Decode-heavy: Input length of 128 tokens and output length of 2,048 tokens
All of the experiments were conducted with 1024 requests and model max batch size was set to 256.
Framework Version
We selected the recent versions of both frameworks that successfully completed the benchmarking process.
- vLLM: v0.6.3
- TensorRT-LLM: v0.13.0 release / Triton Server: v2.50.0
Model and Hardware
- Model: LLaMA-3.1-8B-Instruct (BF16)
- Hardware: NVIDIA H100-PCIe 80G GPU, AMD EPYC 9554 64-Core Processor CPU
- KV Cache Quantization variants
- BF16 (Baseline)
- vLLM: FP8 (E4M3), FP8 (E5M2)
- TensorRT-LLM: FP8 (E4M3), INT8
Unless otherwise specified, vLLM results with FP8 KV cache refer to results from FP8 (E4M3).
Results
In this article, we are going to compare vLLM and TensorRT-LLM across available KV cache quantization options. To clearly demonstrate the impact of KV cache quantization on the inference speed, we conducted all the experiments with max batch size of 256, where self-attention accounts for a large portion of the overall inference latency.
BF16 vs. Quantized KV Cache
The effectiveness of KV cache quantization varied significantly between vLLM and TensorRT-LLM. For vLLM, FP8 KV cache did not improve throughput; in fact, it slightly degraded throughput in prefill-heavy scenario. In contrast, TensorRT-LLM’s FP8 and INT8 KV cache showed notable throughput improvements. Specifically, KV cache quantization provided up to 1.09x and 1.45x throughput improvement at prefill-heavy and decode-heavy scenarios, respectively. The benefits of KV cache quantization were more pronounced in decode-heavy setting: FP8 quantization outperformed INT8 quantization in terms of throughput improvement, likely due to the lower dequantization overhead of FP8 compared to INT8.
While quantizing the KV cache to FP8 or INT8 reduces memory consumption by half, it does not accelerate the attention computation itself as the cached KV vectors are dequantized before computation. Since all experiments were conducted with max batch size of 256, the scenario is more likely compute-bound rather than memory-bound, thus memory saving may not directly affect throughput. Then, how did KV cache quantization achieve inference speedup?
This can be explained by the fact that, similar to how weight-only quantization achieved throughput improvement in compute-bound scenario with large batch sizes, KV cache quantization can increase the effective batch size. As the memory allocated to KV cache is reduced with low-precision KV cache, more requests can be batched in a single iteration, leading to higher throughput and allowing all request to be processed with fewer iterations.
vLLM: Different Attention Backends
KV cache quantization can increase the effective batch size in both vLLM and TensorRT-LLM. However, no significant throughput gain was observed for vLLM. vLLM offers several options to choose among various attention backends. By default, FlashAttention-2 is used. However, FlashAttention-2 does not support FP8 KV cache, so when using FP8 KV cache, the XFormers backend is used as an alternative. For better performance with FP8 KV cache, vLLM recommends using FlashInfer backend instead, which must be installed separately.
Figure 4 shows the throughput comparison between difference attention backends with FP8 formats. For the XFormers attention backend, the impact of FP8 KV cache quantization was marginal, resulting in throughput similar to the BF16 baseline. However, the throughput using BF16 itself was significantly lower compared to using FlashAttention-2, making XFormers a less appealing option. On the other hand, the FlashInfer backend showed some improvement in throughput with FP8 KV cache quantization, although the BF16 throughput was still slightly lower compared to FlashAttention-2. Even with FP8 KV cache quantization, the throughput achieved with FlashInfer remained at a similar level to the BF16 case with FlashAttention-2.
Although the FlashAttention-2 backend shows better performance compared to FlashInfer, its incompatibility with FP8 KV cache prevents the use of KV cache quantization for inference speedup in vLLM. A new version of FlashAttention, FlashAttention-3, has been recently released and efforts are underway to integrate it into vLLM. FlashAttention-3 offers several enhancements over previous versions, including support for FP8 attention and specific optimizations for it. Once integrated, it is expected to significantly improve performance in vLLM. For the following sections, we chose to use FlashInfer backend for FP8 KV cache results as it is currently the only viable option.
TensorRT-LLM: GUARANTEED_NO_EVICT
vs. MAX_UTILIZATION
In our previous post, we discussed two scheduling policies of TensorRT-LLM:
GUARANTEED_NO_EVICT
and MAX_UTILIZATION
. To understand the impact of KV cache quantization on actual scheduling, we measured the effective batch size over iterations, as in our previous post. In this experiment, we increased the max batch size to 512 and the number of requests to 2,048.With the
GUARANTEED_NO_EVICT
policy, none of the BF16 and FP8 KV cache schemes reached the maximum batch size (512) during inference, which means the effective batch size was limited by the memory usage of KV caches. However, with the quantized KV caches, a higher effective batch size was maintained throughout the inference process. In contrast, under the MAX_UTILIZATION
policy, both quantized and BF16 KV cache schemes reached the batch size limit, although preemption occurred frequently. Notably, using FP8 KV cache allowed for better memory efficiency, resulting in fewer preemptions and consequently longer periods of maintaining the max batch size compared to BF16.The difference in average effective batch size between two scheduler policies results in differences in TPOT, even when the throughput appears similar. Across all KV cache quantization options, average TPOT for
MAX_UTILIZATION
was notably longer compared to GUARANTEED_NO_EVICT
due to higher effective batch size.Accuracy Impact
Quantization always comes with a trade-off between efficiency and accuracy. To examine the impact of KV cache quantization on model accuracy, we measured MMLU (5-shot) accuracy.
The results show that MMLU scores are nearly identical regardless of the framework and the datatype of the KV cache used. As mentioned earlier, lowering the KV cache precision to 4-bit or 2-bit significantly impacts accuracy, but 8-bit quantization has minimal effect. This is why both vLLM and TensorRT-LLM currently support only 8-bit KV cache quantization. However, as lower-precision KV cache quantization techniques continue to be actively researched, we can expect more efficient KV cache management in the future.
Synergy with Weight-Activation Quantization
KV cache quantization can be combined with weight-activation quantization to achieve further throughput improvements. We evaluated token throughput when both model and KV cache quantization were applied together. Only FP8 KV cache quantization was tested, as it consistently showed better or comparable throughput to INT8 KV cache under the same conditions.
For vLLM, FP8 KV cache slightly slowed down inference, even when combined with weight-activation quantization. In contrast, for TensorRT-LLM, combining FP8 KV cache with an FP8 or INT8 quantized model resulted in significant speed improvements compared to using a BF16 KV cache. This effect was more pronounced under decode-heavy settings, similar to the performance trends observed with BF16 model evaluation.
For TensorRT-LLM, selecting the optimal combination of KV cache precision and weight-activation quantization was essential. The INT8 quantized model delivered higher throughput than the BF16 model without KV cache quantization, but pairing it with an FP8 KV cache reduced its performance below that of the BF16 model. On the other hand, the FP8 quantized model showed improved throughput over BF16 regardless of whether it was paired with an FP8 KV cache. Since FP8 model quantization significantly outperformed INT8 even without KV cache quantization, using FP8 format for both the model and KV cache is the best practice for maximizing throughput on TensorRT-LLM.
FP8 Attention of TensorRT-LLM
We tested the quantization of all KV-cache, model weights, and activations. However, one more potential quantization option remains. As previously discussed, TensorRT-LLM offers the option to enable FP8 attention when constructing the inference engine. This option is only available when the model is quantized to FP8. Although using FP8 attention might intuitively enhance throughput, performing attention operations at lower precision can potentially compromise model accuracy. To determine whether throughput is genuinely improved and if accuracy is negatively impacted, we evaluated both throughput and performance differences with FP8 attention enabled.
When the model and KV caches are quantized into FP8 precision, the MMLU score was 0.6790. Further quantizing the attention computation to FP8 precision resulted in a score of 0.6750. This indicates that the degradation in model accuracy due to FP8 attention was minimal in terms of MMLU.
Regarding the throughput gain, FP8 attention surprisingly did not deliver any improvement in both prefill-heavy and decode-heavy scenarios. Throughput increased incrementally when FP8 KV cache and model quantization were applied in sequence, but not when FP8 attention was added. While there might be cases where FP8 attention could be beneficial, it generally does not appear to be an attractive option at present. However, with further optimizations, overall throughput gains might improve. If FP8 attention can eventually contribute to throughput enhancement, it could become an option worth considering.
Final Thoughts
KV cache quantization offers promising performance improvements for LLM serving, especially in decode-heavy scenarios, by increasing effective batch size at each iteration. Our experiments demonstrated that KV cache quantization can bring substantial throughput gains without compromising model accuracy. KV cache quantization can be applied orthogonally with model quantization schemes. Especially combined with weight-activation quantization, we could find the best practice to achieve superior throughput.
However, since the throughput gain from KV cache quantization without FP8 attention implementation is primarily due to the increase in effective batch size, its impact is only noticeable in serving scenarios with large maximum batch sizes. At smaller maximum batch sizes, KV cache quantization may only introduce dequantization overhead. Additionally, the performance improvement from KV cache quantization can vary significantly depending on its implementation, the extent of kernel optimizations applied, and the specific user choices. For example, vLLM does not achieve a speedup with quantized KV cache due to its incompatibility with high-performance attention kernels. In contrast, TensorRT-LLM does show some throughput gain; however, the results are also highly variable depending on whether FP8 or INT8 is chosen for KV cache.
Although we demonstrated instances of improved throughput with KV cache quantization, there is clearly more room for enhancement. As the demand for long-context inference and large batch serving continues to grow, the importance of KV cache quantization will only increase. Therefore, we expect to see further optimizations applied to vLLM and TensorRT-LLM to better manage low-precision KV caches or the introduction of new quantization schemes.
Stay tuned for more insights in the vLLM vs TensorRT-LLM series!
Share article
Join the SqueezeBits newsletter today!