The Missing Piece of TensorRT-LLM
This article is about an open-source library for direct conversion of PyTorch models to TensorRT-LLM.
Feb 10, 2025
Contents
IntroductionThe TensorRT-LLM Conversion ProcedureChallenge 1: Selecting the Proper Conversion ScriptChallenge 2: Inconsistent Command-Line InterfacesChallenge 3: Exclusive Requirements for Some ModelsFundamental Limitations of the Conversion ProcedureTorch-TensorRT to the RescueDitto: Direct Torch to TensorRT-LLM OptimizerHow Does Ditto Work?BenchmarksConclusionIntroduction
PyTorch has made significant strides in its compilation infrastructure with the release of PyTorch 2.0, introducing a comprehensive compilation stack including fascinating features such as
torch.compile
and torch.export
. As observed in our previous blog post — The Rise and Fall of ONNX (feat. PyTorch 2.0) — these developments have opened new possibilities for deploying deep learning models through various backends. Torch-TensorRT, for example, is an exemplary project that leverages these advancements, providing a direct path to deploy PyTorch models on TensorRT.Meanwhile, as the era of Large Language Model (LLM) has dawned, TensorRT-LLM has emerged as a powerful toolkit built upon TensorRT for optimizing LLM inference on NVIDIA GPUs, offering impressive performance gains through state-of-the-art optimizations such as custom attention kernels, inflight batching, and paged KV caching. These advantages have been witnessed through our series of blog posts — vLLM vs TensorRT-LLM. However, the path to leveraging these benefits isn't always smooth. The model conversion procedure remains a significant challenge, often requiring hand-picking model-specific checkpoint conversion scripts, dealing with inconsistent command-line interfaces, and being limited to pre-defined model architectures. This creates a bottleneck for teams wanting to deploy new or custom models using TensorRT-LLM quickly.
In this post, we share our approach — Direct Torch to TensorRT-LLM Optimizer (Ditto) — to streamlining the TensorRT-LLM's model conversion by leveraging PyTorch's modern compilation stack, specifically via
torch.export
and Torch-TensorRT. We aim to eliminate the manual overhead and inconsistencies in the current workflow, making TensorRT-LLM's powerful optimizations more accessible to the broader AI community.The TensorRT-LLM Conversion Procedure
Note: TensorRT-LLM version 0.16.0 is used throughout this blog post.
Unlike vLLM, TensorRT-LLM cannot directly deploy a model implemented in PyTorch. Instead, the model needs to be baked into the TensorRT engine format to leverage powerful performance optimizations provided by TensorRT — see the TensorRT's programming model for more details. The current TensorRT-LLM conversion procedure mainly consists of two steps:
- Checkpoint Conversion: converting a PyTorch checkpoint using one of the convert_checkpoint.py scripts provided in the TensorRT-LLM examples.
- This step essentially remaps the PyTorch model's
state_dict
to make it compatible with one of the pre-defined TensorRT network definitions provided by TensorRT-LLM. - It also generates a build configuration file, which is required for the next step.
- TensorRT Engine Build: building the TensorRT engine out of the pair of the converted checkpoint and configuration using
trtllm-build
. - This step loads the converted checkpoint into the pre-defined TensorRT network definitions specified in the build configuration file to build the TensorRT engine using the TensorRT APIs.
- It also generates a configuration file for TensorRT-LLM's runtime based on the pre-defined architecture and some command-line arguments the user provides.
- Depending on the optional features enabled by the user, it might generate a few more files, such as LoRA checkpoints.
Though it might sound simple and elegant on the surface, each step heavily relies on a large number of manually written model-specific codes.
The checkpoint conversion is built upon around 30 independent specialized checkpoint conversion scripts for different models — to name a few:
- and the list goes on …
Similarly,
trtllm-build
is established on a large number of pre-defined model architectures implemented with TensorRT-LLM's own "PyTorch-like Python API" — for example, with its own Module
class mimicking torch.nn.Module
and PretrainedModel
class imitating transformers.PreTrainedModel
.You might already be concerned that these hand-written conversion infrastructures can cause practical issues. Indeed, as we extensively worked with TensorRT-LLM, we've encountered several challenges that highlight this approach's limitations.
Challenge 1: Selecting the Proper Conversion Script
First, selecting the proper conversion script can be tricky. If your model's name is on the list, then it's straightforward. But if not, you need to hand-pick one by examining the architectural similarity. For instance, upstage/SOLAR-10.7B-Instruct-v1.0 needs llama/convert_checkpoint.py, whereas CohereForAI/aya-expanse-8b needs commandr/convert_checkpoint.py. This might seem intuitive for people who are familiar with LLM architectures. However, this can be a barrier for those who are unfamiliar. Furthermore, if none of the existing conversion scripts work, you need to write a custom conversion script — and even the pre-defined network definition in the worst case — by yourself.
Challenge 2: Inconsistent Command-Line Interfaces
Moreover, there may be occasions dealing with inconsistent command-line interfaces. For example, there are subtle differences between commands for converting llama and gemma checkpoints.
- Llama2-7B checkpoint conversion with 2-way tensor parallelism
(Other scripts have similar CLI to this script)
python examples/llama/convert_checkpoint.py \ --model_dir /path/to/meta-llama/Llama-2-7b-chat-hf \ --output_dir /path/to/meta-llama/Llama-2-7b-chat-hf/trtllm-ckpt \ --tp_size 2
- Gemma2-9B checkpoint conversion with 2-way tensor parallelism
(This one that stands out from the others)
python examples/gemma/convert_checkpoint.py \ --model-dir /path/to/google/gemma-2-9b-it \ --output-model-dir /path/to/google/gemma-2-9b-it/trtllm-ckpt \ --world-size 2 \ --ckpt-type hf
Converting Gemma models requires careful attention to different flag formats (note the difference between
--model_dir
and --model-dir
) and model-specific options. It might seem like a minor issue, but it sometimes becomes a real headache. For example, it was a real pain in the neck when we tried to automate this conversion procedure just from a given HuggingFace model ID. Wonder why we wanted to automate the conversion procedure? Subscribe the SqueezeBits newsletter for the next blog post 👉
Challenge 3: Exclusive Requirements for Some Models
Some models, such as EXAONE, have exclusive requirements. The following command fails to convert the checkpoint for EXAONE.
python examples/llama/convert_checkpoint.py \ --model_dir /path/to/LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct \ --output_dir /path/to/LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct/trtllm-ckpt \ --dtype float16
Let's look at the output. Can you guess why it fails?
[TensorRT-LLM] TensorRT-LLM version: 0.16.0 [02/05/2025-14:31:55] [TRT-LLM] [W] Implicitly setting LLaMAConfig.tie_word_embeddings = False 5it [00:11, 2.33s/it] Traceback (most recent call last): ... (omitted) ... File "/path/to/examples/llama/convert_checkpoint.py", line 502, in convert_and_save_rank llama = LLaMAForCausalLM.from_hugging_face( File "/another/path/to/lib/python3.10/site-packages/tensorrt_llm/models/llama/model.py", line 434, in from_hugging_face loader.generate_tllm_weights(model, arg_dict) ... (omitted) ... File "/another/path/to/lib/python3.10/site-packages/tensorrt_llm/layers/linear.py", line 417, in postprocess weights = torch.cat(weights) TypeError: expected Tensor as element 0 in argument 0, but got NoneType
It turns out that TensorRT-LLM is expecting the model directory name to contain lowercase "exaone" — see the line 410 in tensorrt_llm/models/llama/model.py — which was missing in our case, causing a problem in the subsequent flow as the error message indicates. A workaround that we found is to simply rename the model directory to
/path/to/LGAI-EXAONE/exaone
(note the lowercase "exaone"), manually resolving inconsistency with the original uppercase model name LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct. While the workaround is simple, this is more than just a typo or minor issue — the checkpoint conversion logic even depends on the names of models! This implicit dependency on model naming might result in unpredictable behavior as more and more models emerge.Fundamental Limitations of the Conversion Procedure
Besides the practical issues described, TensorRT-LLM's approach to model conversion inherently restricts its flexibility. If you look closely at the implementation details of the checkpoint conversion, you will find that it heavily relies on hard-coded keyword matching internally — for example, see tensorrt_llm/models/gemma/convert.py and tensorrt_llm/models/phi3/convert.py. Considering that the PyTorch model's
state_dict
keys originated from the variable names used in the model implementation, the checkpoint conversion exhibits a potential fragility.More fundamentally,
trtllm-build
relies on prebuilt network definitions of some popular models as specified in TensorRT-LLM Overview as follows:TensorRT-LLM comes with several popular models pre-defined. They can easily be modified and extended to fit custom needs via a PyTorch-like Python API.
Although the user can manually modify these pre-defined models via PyTorch-like Python API, new models that don't match one of the prebuilt architectures aren't supported out of the box. Users must either wait for official support from the TensorRT-LLM team or manually implement the model with the PyTorch-like Python API—neither is ideal for the rapid deployment of new or custom models.
Torch-TensorRT to the Rescue
With that said, how can we streamline the TensorRT-LLM conversion procedure? Let's first look at how Torch-TensorRT works to understand the direct conversion from PyTorch to TensorRT—more details are available in Torch-TensorRT Explained.
- The Dynamo frontend captures parts of the user's Python program that can be represented as "PyTorch-native static computation graphs", namely FX graphs consisting only of ATen operator nodes. It leaves dynamic or non-native parts — such as dynamic control flows, loops, or third-party library codes — as they are.
torch.compile
then delegates these FX graphs to Torch-TensorRT backend.
- Torch-TensorRT partitions each FX graph into subgraphs with supported and unsupported operators.
- Torch-TensorRT lowers each ATen operator node to a TensorRT layer for each supported subgraph, accumulating the resulting layers to a TensorRT network behind the scene to build a TensorRT engine out of the subgraph. The unsupported subgraphs, on the other hand, are left as they are.
- As a result, the model compiled by Torch-TensorRT is a hybrid program that consists of TensorRT engines (from the supported subgraphs), PyTorch ATen operators (from the unsupported subgraphs), and possibly some parts of the original Python program (from dynamic or non-native parts).
As you can see, Torch-TensorRT provides a powerful compilation stack that can boost up PyTorch models with TensorRT. Moreover, it supports a wide range of ATen operators, including those commonly used in modern LLM architectures. However, Torch-TensorRT alone cannot fully leverage the benefits of TensorRT-LLM's optimizations. For instance, naive conversion of ATen operators into TensorRT native layers cannot leverage TensorRT-LLM's custom kernels provided in the form of TensorRT plugins — e.g. GPT attention plugin or LoRA plugin. Moreover, the existing auto-regressive generation pipeline must be delegated to TensorRT-LLM's runtime workflow in order to fully leverage TensorRT-LLM's optimized scheduler capabilities such as inflight batching and paged attention.
Ditto: Direct Torch to TensorRT-LLM Optimizer
To extend the capabilities of Torch-TensorRT and enable the direct conversion of PyTorch models into TensorRT-LLM engines, we have developed Direct Torch to TensorRT-LLM Optimizer (Ditto).
With Ditto, you can convert any language model from HuggingFace Hub into a TensorRT-LLM engine with a single command. For example, the following command converts meta-llama/Llama-3.1-8B-Instruct into a TensorRT-LLM engine.
ditto build meta-llama/Llama-3.1-8B-Instruct
The following is what you will see when you actually run the command.
(ditto) jiwoongchoi@kiwi:~/works$ ditto build meta-llama/Llama-3.1-8B-Instruct ditto:0:00:00.740580 [INFO] Applied custom patch for transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask. To disable this patch, set the environment variable DISABLE_TRANSFORMERS_ATTENTION_MASK_CONVERTER_PATCH=1 [TensorRT-LLM] TensorRT-LLM version: 0.16.0 ditto:0:00:02.421850 [INFO] Using default output directory: ./engines/meta-llama/Llama-3.1-8B-Instruct ditto:0:00:02.422080 [INFO] Loading model meta-llama/Llama-3.1-8B-Instruct Loading checkpoint shards: 100%|███████████████████| 4/4 [00:02<00:00, 1.94it/s] ditto:0:00:05.268058 [INFO] device: cuda:0 | dtype: torch.bfloat16 ditto:0:00:05.268713 [INFO] Running torch.export ditto:0:00:22.018891 [INFO] Inlining the exported program into a graph module ditto:0:00:38.136494 [INFO] Optimizing the graph module ditto:0:01:08.742303 [INFO] Writing engine config at ./engines/meta-llama/Llama-3.1-8B-Instruct/config.json ditto:0:01:08.742763 [INFO] Building TensorRT engine Unused Input: position_ids [RemoveDeadLayers] Input Tensor position_ids is unused or used only at compile-time, but is not being removed. ditto:0:01:32.529489 [INFO] Writing serialized engine at ./engines/meta-llama/Llama-3.1-8B-Instruct/rank0.engine ditto:0:01:39.517111 [INFO] Build completed in 01:34
That's it! Now, your model is transformed into a TensorRT-LLM engine with that single command. What's more, you can seamlessly integrate the engine into your existing TensorRT-LLM workflow since what Ditto gives you is just a standard TensorRT-LLM engine and its configuration file.
(ditto) jiwoongchoi@kiwi:~/works$ tree -h engines [4.0K] engines └── [4.0K] meta-llama └── [4.0K] Llama-3.1-8B-Instruct ├── [1.4K] config.json └── [ 15G] rank0.engine
For example, the following is the output from the example script provided by TensorRT-LLM executed with the engine built by Ditto.
(ditto) jiwoongchoi@kiwi:~/works$ python tensorrt-llm/examples/run.py --engine_dir engines/meta-llama/Llama-3.1-8B-Instruct --tokenizer_dir meta-llama/Llama-3.1-8B-Instruct --input_text "What is heavier, a ton of bricks or a ton of feathers?" --max_output_len 100 [TensorRT-LLM] TensorRT-LLM version: 0.16.0 [TensorRT-LLM][INFO] Engine version 0.16.0 found in the config file, assuming engine(s) built by new builder API. [02/10/2025-21:10:44] [TRT-LLM] [I] Using C++ session ... (omitted) ... [02/10/2025-21:10:51] [TRT-LLM] [I] Load engine takes: 6.780348539352417 sec Input [Text 0]: "<|begin_of_text|>What is heavier, a ton of bricks or a ton of feathers?" Output [Text 0 Beam 0]: " The answer is neither. Both are the same weight. The difference is in their density. A ton of bricks is much more dense than a ton of feathers. Density is mass per unit volume. The more dense an object is, the heavier it will feel for a given volume. The density of an object is calculated by dividing its mass by its volume. The formula for density is: Density = Mass / Volume The units for density are typically grams per cubic centimeter (g/cm3) or" [TensorRT-LLM][INFO] Refreshed the MPI local session
How Does Ditto Work?
Ditto applies a series of graph-level optimizations, integrating them with Dynamo frontend and Torch-TensorRT backend to provide end-to-end conversion from PyTorch to TensorRT-LLM without any model-specific logic. The following is the detailed workflow of Ditto.
1. torch.export
Since TensorRT-LLM requires a single TensorRT engine encapsulating the entire model, Ditto utilizes
torch.export
to capture the full graph representation of the model. (See the Partial vs. Full Graph Capture for more details.) After this step, the model, namely transformers.PreTrainedModel
, is converted into an exported program, namely torch.export.ExportedProgram
.2. Graph-level Optimization
The exported program mainly consists of the following components:
- The parameter-free graph of the model represented as
torch.fx.Graph
, all of whose parameters are lifted as inputs to the graph
- A
state_dict
containing tensor values of all lifted parameters and buffers
- Various metadata, including dynamic range constraints
Ditto first inlines the exported program to obtain a self-contained graph module, namely
torch.fx.GraphModule
, using the standard torch.export.ExportedProgram.module()
API to unlift the parameters back into the graph. Ditto optimizes this graph module by applying a series of graph-level optimizations.Ditto's graph-level optimizations fall into roughly two categories:
- ATen-level optimizations: These optimizations tidy up the graph without changing the semantics of the model by
- eliminating redundant operations with no effect;
- fusing a group of operations into their equivalents to reduce the number of operations; and
- folding constant operations into their consumer nodes.
- TensorRT-LLM-specific optimizations: These optimizations modify the graph semantically to make it compatible with TensorRT-LLM's runtime by
- inserting new input tensors required for TensorRT-LLM plugins;
- pattern-matching subgraphs that can be replaced by TensorRT-LLM plugins and replacing them with a single node indicating plugin function call; and
- modifying the graph for tensor parallelism when the user specifies the
--tp-size
flag with a value greater than 1.
There are many more optimizations that Ditto applies that we haven't described here - see the ditto.fx.optimize module for the implementation details. Through these optimizations, the graph module is transformed into a form that is compatible with TensorRT-LLM's runtime.
3. TensorRT Network Construction and Engine Build
This step is where Torch-TensorRT comes into play. Ditto internally utilizes the conversion rules provided by Torch-TensorRT for each ATen operator. In this procedure, each ATen operator node in the graph is converted into one or more TensorRT native layers, for example:
aten.reshape.default
intotensorrt.IShuffleLayer
;
aten.add.Tensor
intotensorrt.IElementWiseLayer
, possibly with extratensorrt.ICastLayer
attached to either of the operands to replicate PyTorch's type promotion;
On the other hand, the plugin nodes inserted by Ditto's graph-level optimizations are converted into
tensorrt.IPluginV2Layer
s based on custom conversion rules provided by Ditto, reproducing what's happening under the hood in trtllm-build
.Finally, all the nodes in the graph are converted into TensorRT layers, and the TensorRT network is constructed. The TensorRT engine is then built using the standard
tensorrt.Builder.build_serialized_network()
API.4. Engine Configuration Generation via Graph-Level Analysis
Ditto also generates the engine configuration file by analyzing the graph. To list a few examples:
- The
hidden_size
andvocab_size
are inferred by looking for the firstaten.embedding.default
node and reading the weight tensor's shape.
- The
num_layers
is inferred by counting the number of GPT attention plugin nodes in the graph.
- Some configuration fields that cannot be inferred from the graph are acquired from the original Hugging Face
PretrainedConfig
object.
The generated configuration file is then saved along with the engine.
Benchmarks
As you have seen, Ditto builds a TensorRT-LLM engine in a way that is inherently different from the existing TensorRT-LLM conversion procedure. Accordingly, you might wonder if the engine produced by Ditto is reliable. To address this concern, we have measured several commonly used metrics to compare the quality and performance of Ditto's engines against ones built by the existing approach.
We evaluated the quality using tensorrt-llm/llmapi integrated with lm-evaluation-harness, and the performance with gptManagerBenchmark from TensorRT-LLM. Both GEMM plugin and the GPT attention plugin are enabled during all benchmarks.
The engines built by Ditto exhibit almost indistinguishable metrics from the ones built by the existing approach. In addition, thanks to Ditto's flexibility, we could convert the kyutai/helium-1-preview-2b model into a TensorRT-LLM engine without the official conversion scripts from TensorRT-LLM!
As of the publication date of this document (February 10, 2025), the Helium model is only available on transformers nightly build, and TensorRT-LLM doesn't yet provide a checkpoint conversion script for it.
Conclusion
In this post, we introduced Ditto, a novel approach to facilitate the conversion of PyTorch models to TensorRT-LLM engines. By leveraging PyTorch's modern compilation stack and applying sophisticated graph-level optimizations, Ditto eliminates the need for model-specific conversion scripts and pre-defined network definitions. This makes TensorRT-LLM's powerful optimizations more accessible while maintaining comparable performance to the native conversion approach.
The benefits of Ditto are clear:
- Ease-of-use: Convert any PyTorch model to TensorRT-LLM with a single command
- Flexibility: Support for new architectures without waiting for official TensorRT-LLM support
- Reliability: Comparable quality and performance metrics to native TensorRT-LLM conversion
As the AI landscape continues to evolve, with new model architectures emerging regularly, tools like Ditto become increasingly valuable for teams looking to deploy and optimize their models quickly. We believe Ditto represents a significant step in making TensorRT-LLM’s high-performance inference more accessible to the broader AI community.
Try Ditto today at SqueezeBits/Torch-TRTLLM and let us know your experience!
Share article
Join the SqueezeBits newsletter today!