[vLLM vs TensorRT-LLM] #7. Weight-Activation Quantization
This article provides a comparative analysis of the effects of weight-activation quantization on vLLM and TensorRT-LLM frameworks.
Nov 11, 2024
Introduction
Modern NVIDIA GPUs have low-precision units for fast INT8 and FP8 Multiply-accumulate (MAC) operations. Recent NVIDIA GPUs deliver twice the performance in TOPS when using INT8 and FP8 compared to FP16 for tensor core operations. For example, the NVIDIA H100 PCI-e can achieve 756 TFLOPS with FP16, whereas it reaches 1513 TOPS or TFLOPS with INT8 and FP8, effectively doubling its peak performance (NVIDIA Hopper Architecture Whitepaper).
In our previous post, we explored weight-only quantization, which reduces memory usage by converting weights to low-precision formats for more efficient LLM serving. However, this approach leads to a mismatch in precision between the weights and activations. As a result, weights need to be dequantized back to high precision before multiplication, limiting the performance gains in compute-bound scenarios (e.g., large batch sizes, prefill-heavy).
Since compute-bound scenarios can occur in real-world serving situations, as discussed in our previous posts, there has been growing interest in finding quantization schemes that perform well in such cases. This post explores weight-activation quantization, which quantizes both weights and activations, enabling the use of low-precision arithmetic units for faster computation. Weight-activation quantization not only reduces the amount of memory transfer needed during inference but also improves performance in compute-bound scenarios. In this post, we will examine how weight-activation quantization performs in vLLM and TensorRT-LLM serving environments.
How to Quantize Both Weights and Activations?
Quantization Scheme
Quantization maps the values of a tensor to discrete buckets, approximating the original continuous distribution. When Quantization-Aware Training (QAT) is feasible, bucket size can be learned through back-propagation. However, in Post-Training Quantization (PTQ) which we will focus on throughout this article, finding the optimal bucket size is the most important task of quantization. When quantizing both weights and activations, one of the simplest methods is the min-max quantization. This method determines the size of buckets by identifying the minimum and maximum values of weights and activations from a calibration dataset. While easy to implement, this approach is prone to relatively large quantization errors, which can result in significant degradation of model output quality.
This degradation highlights the challenge of preserving model accuracy during quantization, especially for LLMs. LLMs often exhibit outliers in their activations, complicating the quantization process and requiring more sophisticated strategies. To address the limitations of min-max quantization, researchers have developed several advanced quantization schemes.
Notable examples include:
- LLM.int8(): An early study which utilizes mixed-precision decomposition to preserve model output quality by excluding outliers from the quantization process. This scheme is supported in vLLM through the bitsandbytes but is unavailable in TensorRT-LLM.
- SmoothQuant: A prominent weight-activation quantization method that leverages the fact that LLM weights are easier to quantize than activations. This approach smooths activation outliers, making them easier to quantize, while the weights become more variable yet still manageable for quantization. This scheme is supported through the llm-compressor in vLLM and the ModelOpt in TensorRT-LLM.
In this post, we will use min-max and SmoothQuant approaches for weight-activation quantization, as both are supported by both vLLM and TensorRT-LLM.
Where to Quantize
A transformer block includes multiple weights and activations, and selecting which ones to quantize can greatly influence both efficiency and model output quality. In the experiments conducted for this post, we chose to quantize only the linear layers, while keeping the attention and normalization layers in FP16, as shown in Figure 1. This is the default behavior of both vLLM and TensorRT-LLM and is based on findings from numerous studies indicating that quantizing attention layers can significantly degrade the model output quality. vLLM does not support quantized attention currently (v0.6.2) and the option
use_fp8_context_fmha
in TensorRT-LLM is disabled by default.Even with weight-activation quantization, not all computations are performed in a quantized format, which means format conversion (quantization/dequantization) must occur continuously throughout the inference. Additionally, since the attention layer is still executed in 16-bit operations, weight-activation quantization may be less effective in service scenarios where the attention layer becomes a bottleneck. To better understand the effect of weight-activation quantization in various scenarios, we will evaluate the approach through a series of experiments.
Experiment Setup
Benchmark Dataset
For all experiments, we used two standardized datasets, each with predetermined input and output token lengths to ensure consistent token processing across different frameworks.
Two distinct datasets are:
- Prefill-heavy: Input length of 2,048 tokens and output length of 128 tokens
- Decode-heavy: Input length of 128 tokens and output length of 2,048 tokens
Most experiments were conducted with 256 requests, with the exception of scenarios using the max batch size of 256, where the sample size was increased to 1,024 to ensure more reliable measurements.
Framework Version
We selected the recent versions of both frameworks that successfully completed the benchmarking process.
- vLLM: v0.6.2 (commit 7193774)
- TensorRT-LLM: 0.13.0 with C++ API.
Model and Hardware
- Model: LLaMA-3.1-8B-Instruct with various quantization settings
- Quantization variants
- FP16 (baseline)
- INT8: SmoothQuant, per-channel weight, per-token dynamic activation
- FP8: Min-max, per-channel weight, per-token dynamic activation
- Hardware: NVIDIA H100-PCIe GPU, AMD EPYC 9554 64-Core Processor CPU
Results
Throughput with Infinite Request Rate
First, we set the request rate to infinity to measure the maximum achievable throughput for each quantization method.
Prefill-Heavy
Figure 2 describes the throughput of the models with various precision on prefill-heavy workloads in both frameworks. For prefill-heavy workloads, 8-bit weight and activation quantization (W8A8) achieved approximately a 40% throughput improvement across all max batch sizes when using the optimal choice between INT8 and FP8. This is lower than the 80-100% improvement reported with 4-bit weight-only quantization (W4A16) at max batch size of 1 in the previous post. The primary reason is that W8A8 demands more memory access than W4A16.
However, at max batch size of 256, while W4A16 showed only about 10% throughput improvement, W8A8 achieved a 40% improvement. This is because workloads with larger batch sizes become more compute-bound, and W8A8 can leverage lower precision processing units.
We expected that both INT8 and FP8 quantization would produce similar throughput because of their same granularity and computational unit performance. However, in TensorRT-LLM, INT8 performed better at smaller batch sizes and FP8 excelled at larger batch sizes. We speculate this difference is due to variations in kernel optimization.
Decode-Heavy
For decode-heavy workloads, the throughput improvement at max batch size of 1 was comparable to the 40% observed in the prefill-heavy case (see Figure 3). However, at the max batch size of 256, the improvement dropped to 15-20%. This may seem counter-intuitive since larger batch sizes are typically compute-bound, where W8A8 would be expected to boost throughput. To understand this phenomenon further, we conducted a detailed profiling of the inference latency.
There are two main reasons why the impact of W8A8 quantization is limited for decode-heavy workloads at large max batch sizes. Figure 4 illustrates the first reason, showing the change in the proportion of attention layer time as batch size increases. As seen in the figure 4, the larger the batch size, the greater the time proportion required for the attention layer. This is because the linear layer (dense layer in the figure 4) can reuse weights as the batch size grows, whereas the attention layer cannot. Therefore, even if the latency of the linear layer is improved through W8A8 quantization, its effect appears limited in scenarios where attention layer latency is the bottleneck. On the other hand, for prefill-heavy workloads, W8A8 quantization still achieved around a 40% improvement even at larger batch sizes, primarily because the prefill batch size is limited by the max number of batched tokens rather than the max batch size. Additionally, as the decode phase is much shorter, the impact of max batch size is less significant, allowing improvements in linear layer latency to result in meaningful end-to-end performance gain.
The other reason why the impact of W8A8 quantization is limited for decode-heavy workloads at large max batch sizes is about sequence length. Figure 5 illustrates the latency breakdown of a single decode step with various precision in vLLM at batch size 64. Context length means that the number of context tokens required for the decode step. Note that latencies for QKV projection layers are included in the attention block latency in this figure. While W8A8 quantization improved the latency of MLP blocks by 1.6x across all context lengths, the latency of attention blocks showed marginal improvement. As context length increases, the proportion of attention block latency in the overall latency rises, limiting the throughput improvement achievable with W8A8 quantization.
FP8: Dynamic vs. Static
Both vLLM and TensorRT-LLM offer various quantization granularities. Here, we evaluated the difference between FP8-static quantization and FP8-dynamic quantization in more detail. In the FP8 quantization we tested earlier, activation quantization scales are determined dynamically at runtime based on token values. In contrast, FP8-static scheme uses predetermined quantization scales per tensor for both weights and activations. Generally, static quantization scheme can be faster than dynamic quantization schemes since it does not require the overhead of dynamically determining quantization scales, but it's known to result in poorer model output quality.
Figure 6 shows the throughput comparison between dynamic and static FP8 quantization. For vLLM, FP8-static quantization showed slightly better throughput improvements than dynamic one at all max batch sizes. However, TensorRT-LLM exhibited different behavior. For smaller batch sizes, FP8-dynamic scheme showed higher throughput, whereas for larger batch sizes, FP8-static scheme achieved higher throughput. Thus, careful consideration of the granularity of FP8 quantization is required.
Time-to-First-Token with Low Request Rate
Since all requests are initially queued and wait to be scheduled at the infinite request rate, the Time-to-First-Token (TTFT) measured with infinite request rate diverges. To address this, we also conducted experiments with a fixed request rate of 4.
As shown in Figure 7, W8A8 quantization improved TTFT on prefill-heavy workloads regardless of the max batch size and frameworks (similar trend observed for decode-heavy workloads). In vLLM, we observed an improvement of at least 2x, while in TensorRT-LLM, the improvement was around 25%. Note that even when using W8A8 quantization in vLLM, TTFT was still significantly higher than when using FP16 in TensorRT-LLM. As the max batch size decreases, throughput also drops, eventually causing the processing speed to fall behind the rate of incoming requests, leading to a backlog of requests and very high TTFT measured. With quantization, however, we found that it is possible to maintain a reasonable TTFT at lower batch sizes.
Impact on Output Quality
Previous experiments have confirmed that weight-activation quantization can improve the efficiency of LLM serving. However, quantization can significantly impact model output quality, which may limit its adoption in real-world applications.
To evaluate the model output quality degradation due to the quantization, we measured the score on 5-shot MMLU (Multitask Minimal Benchmark for Language Understanding) with the instruct-eval tool. We used the "HuggingFaceH4/ultrachat_200k" dataset as the calibration dataset for quantization.
Table 1 describes the score of models with various precision on 5-shot MMLU. Overall, the FP8 version showed better score compared to INT8, with FP8-dynamic performing slightly better than FP8-static. These results may vary depending on the benchmark dataset, so it is important to assess quality according to the specific service scenario. Additionally, as discussed in previous sections, it's essential to carefully compare INT8, FP8-dynamic, and FP-static in terms of throughput, TTFT, and TPOT to make the best choice.
Final Thoughts
Weight-activation quantization offers promising performance improvements for LLM serving, particularly in compute-bound scenarios, by enabling low-precision arithmetic. Our experiments demonstrate that adopting INT8 or FP8 quantization schemes can yield significant throughput gains in various service scenarios. However, the impact is less pronounced in decode-heavy scenarios with large batch sizes due to the increased ratio of overhead of attention layers, which remain in higher precision formats. The next post will introduce the techniques of KV-cache quantization, which can effectively accelerate attention overhead and can be combined with weight-activation quantization to enable even more efficient serving.
While weight-activation quantization improves efficiency, it comes with a trade-off in model output quality. FP8 generally preserves quality better than INT8, with FP8-dynamic achieving the best balance between performance and accuracy. These results underscore the importance of selecting the appropriate quantization strategy based on the specific workload, batch size, and quality requirements of the application.
Ultimately, quantization strategies like weight-activation quantization present an effective approach for maximizing inference efficiency in large-scale LLM deployment. With ongoing research and advancements in low-precision hardware support, weight-activation quantization holds the potential to further elevate the performance of LLM serving frameworks like vLLM and TensorRT-LLM in diverse real-world scenarios.
Stay tuned for more insights in the vLLM vs TensorRT-LLM series!
Share article
Join the SqueezeBits newsletter today!