TensorRT-LLM Goes Open Source!

With TensorRT-LLM now open source, we can finally take a deep dive into the secret sauce behind its impressive performance.
Huijong Jeong's avatar
Mar 25, 2025
TensorRT-LLM Goes Open Source!

Introduction

A couple of weeks ago, TensorRT-LLM finally open-sourced most of its code. The released code includes the following:
  • codes to manage and schedule requests
  • codes to manage KV caches
  • implementations of advanced features like chunked prefill, speculative decoding, guided json and more
However, the code for internal kernels, such as the core implementations of the attention plugin, is not included, as TensorRT itself remains largely closed-source.
When we were writing our previous blog posts, the code wasn’t open yet. So, we had to rely on superficial results and speculation rather than providing a thorough explanation. For instance, in our blog post about request scheduling algorithms, we had to rely on token output timestamps to understand how requests are scheduled. But now, with the relevant code fully open, we have a clear view of what’s happening under the hood.
In this post, we revisit the batching and scheduling algorithms in TensorRT-LLM, but this time with the actual code. For those who haven’t read the previous posts yet, we recommend reading them first for more details.

How TensorRT-LLM Schedules Requests: Revisited

In each iteration, TensorRT-LLM executes the forwardAsync function. I won’t include the full code since the function is quite long. This function can be roughly divided into three parts.
  1. Schedule requests for current iteration
  1. Execute current iteration
  1. Post-process current iteration execution results
Below is the code corresponding to the first part of the forwardAsync function body, where it schedules requests for the current iteration.
// https://github.com/NVIDIA/TensorRT-LLM/blob/da0b0e0ee307ae4f97accb75e3e5f6c31c2507c6/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp#L872-L904 ... // schedule requests according to scheduling policy auto [fittingRequests, fittingDisaggGenInitRequests, requestsToPause] = (*mCapacityScheduler)(activeRequests, mKvCacheManager, mPeftCacheManager, mCrossKvCacheManager); ... // finalize the schedule std::tie(currRequests.contextRequests, currRequests.generationRequests) = (*mMicroBatchScheduler)(fittingRequests, mInflightReqIds, mMaxBatchSizeRuntime, mMaxNumTokensRuntime); ...
mCapacityScheduler first schedules active requests based on the scheduling policy, such as GUARANTEED_NO_EVICT or MAX_UTILIZATION, and then mMicroBatchScheduler finalizes the schedule.
When finalizing the schedule, mMicroBatchScheduler considers batching configurations such as max_num_tokens and max_batch_size to ensure that the current schedule does not exceed the configured limits.
// https://github.com/NVIDIA/TensorRT-LLM/blob/da0b0e0ee307ae4f97accb75e3e5f6c31c2507c6/cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp#L189-L273 // Note that codes are slighly modified for ease of explanation int batchNumTokens = 0; int scheduledReqSize = 0; for (auto& llmReq: fittingRequests) { ... // check max_num_tokens if (maxNumTokensRuntime && batchNumTokens + reqNumTokens > maxNumTokensRuntime.value()) { break; } batchNumTokens += reqNumTokens; // check max_batch_size if (++scheduledReqSize >= maxBatchSizeRuntime) { break; } ... }
Due to this logic, the batch size increases gradually rather than all at once, as prefill requests include context tokens and are likely to be capped by max_num_tokens. Additionally, the maximum number of requests is constrained by max_batch_size. This behavior is clearly observed in Figure 2, taken from the previous post.
Figure 2. An illustration of how TensorRT-LLM with MAX_UTILIZATION policy schedules multiple requests. Each request is represented by a different color. Some requests are preempted and resumed.
Figure 2. An illustration of how TensorRT-LLM with MAX_UTILIZATION policy schedules multiple requests. Each request is represented by a different color. Some requests are preempted and resumed.
GUARANTEED_NO_EVICT and MAX_UTILIZATION are specific implementations of the CapacityScheduler. The implementation used depends on the chosen scheduling policy.
// https://github.com/NVIDIA/TensorRT-LLM/blob/da0b0e0ee307ae4f97accb75e3e5f6c31c2507c6/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp#L449-L474 CapacityScheduler::CapacityScheduler(SizeType32 maxNumRequests, executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool hasKvCacheManager, std::optional<bool> manyMicroBatches, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState) { ... else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kMAX_UTILIZATION) { mScheduler = MaxUtilizationScheduler{ maxNumRequests, manyMicroBatches ? *manyMicroBatches : false, noScheduleUntilState, noScheduleAfterState}; } else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT) { mScheduler = GuaranteedNoEvictScheduler{maxNumRequests, noScheduleUntilState, noScheduleAfterState}; } ...
The code below is the implementation of the GUARANTEED_NO_EVICT scheduling policy.
// https://github.com/NVIDIA/TensorRT-LLM/blob/da0b0e0ee307ae4f97accb75e3e5f6c31c2507c6/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp#L193-L325 // Note that codes are slighly modified for ease of explanation template <bool StaticBatchScheduling> std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl( kv_cache_manager::BaseKVCacheManager const& kvCacheManager, OptionalRef<kv_cache_manager::BaseKVCacheManager const> crossKvCacheManager, OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const { RequestVector scheduledRequests; auto const numFreeBlocks = kvCacheManager.getNumFreeBlocks(); ... // If a request is already in progress, include it // If it's been allocated, it had resource to run to completion // Also keep track of blocks needed to drive all in-progress requests to completion SizeType32 reservedBlocks{0}; RequestVector pendingRequests; pendingRequests.reserve(activeRequests.size()); for (auto const& req : activeRequests) { if (scheduledRequests.size() >= static_cast<std::size_t>(mMaxNumRequests)) { break; } else if (req->isGenerationInProgressState()) { scheduledRequests.emplace_back(req); reservedBlocks += kvCacheManager.getRemainingBlocksToCompletion(*req); ... } else { pendingRequests.emplace_back(req); } } ... // Now check if we can add pending requests auto availableBlocks = numFreeBlocks - reservedBlocks; auto availableCrossBlocks = numFreeCrossBlocks - reservedCrossBlocks; auto availablePeftPages = maxPeftCachePages - claimedPeftPages; // Loop over pending requests and add them if they can be scheduled for (auto const& req : pendingRequests}) { ... if (scheduledRequests.size() >= static_cast<std::size_t>(mMaxNumRequests)) { break; } else if (req->isContextInitState()) { auto const neededBlocks = kvCacheManager.getRemainingBlocksToCompletion(*req); ... if (neededBlocks <= availableBlocks) { scheduledRequests.emplace_back(req); availableBlocks -= neededBlocks; ... } else if (neededBlocks > availableBlocks) { // If one requests fails to be scheduled, break break; } } } return {std::move(scheduledRequests), RequestVector{}}; }
In the code above, we can see that the logic first calculates the maximum number of KV cache blocks the request may use by following line:
auto const neededBlocks = kvCacheManager.getRemainingBlocksToCompletion(*req);
And then the calculated value is compared with the currently available blocks to decide whether the request can be scheduled.
As a result, once a request is scheduled, it can run to completion without being preempted due to KV cache exhaustion. This is why requests that have already been scheduled are simply added to the schedule again in the code without checking for resource availability. This can also be observed in the Figure 3 below.
Figure 3. An illustration of how TensorRT-LLM with GUARANTEED_NO_EVICT policy schedules multiple requests. Each request is represented by a different color.
Figure 3. An illustration of how TensorRT-LLM with GUARANTEED_NO_EVICT policy schedules multiple requests. Each request is represented by a different color.
Additionally, the maximum batch size is smaller compared to Figure 2(MAX_UTILIZATION). This is a result of considering KV cache availability at scheduling time.
On the other hand, MAX_UTILIZATION does not have such constraints. Instead of considering the maximum possible KV cache usage, it only looks at the KV cache used so far and aims to maximize the batch size in each iteration.
// https://github.com/NVIDIA/TensorRT-LLM/blob/da0b0e0ee307ae4f97accb75e3e5f6c31c2507c6/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp#L327-L447 // Note that codes are slighly modified for ease of explanation std::tuple<RequestVector, RequestVector> MaxUtilizationScheduler::operator()( kv_cache_manager::BaseKVCacheManager& kvCacheManager, OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const SizeType32 numScheduledBlocks{0}; ... // Function to find last active in case we need to evict auto startedReqLambda = [this](std::shared_ptr<LlmRequest> const& req) { return (req->hasReachedState(getNoScheduleUntilState()) && !req->hasReachedState(getNoScheduleAfterState()) && (req->isContextInitState() || req->isGenerationInProgressState())); }; RequestVector scheduledRequests; RequestVector pausedRequests; auto reqItEnd = std::end(activeRequests); for (auto reqIt = std::begin(activeRequests); reqIt != reqItEnd;) { auto const& req = *reqIt; TLLM_LOG_DEBUG("MaxUtilizationScheduler: scheduling request ID %lu", req->mRequestId); // if request cannot be scheduled yet or request should no longer be scheduled, skip if (!req->hasReachedState(getNoScheduleUntilState()) || req->hasReachedState(getNoScheduleAfterState()) { TLLM_LOG_DEBUG("MaxUtilizationScheduler: request ID %lu cannot / should not be scheduled", req->mRequestId); reqIt++; continue; } ... auto const [fitsKvCache, fitsPeftCache] = trySchedulingRequestMaxUtilization(kvCacheManager, peftCacheManager, req, scheduledRequests, numScheduledBlocks, numScheduledPeftPages, seenTaskIds); if (fitsKvCache) { TLLM_LOG_DEBUG("MaxUtilizationScheduler: request ID %lu -> start", req->mRequestId); reqIt++; } else { auto const rbegin = std::reverse_iterator(reqItEnd); auto const rend = std::reverse_iterator(reqIt); auto const lastStartedReqIt = std::find_if(rbegin, rend, startedReqLambda); if (lastStartedReqIt != rend) { // If we can't allocate a started request, we need to start freeing started requests // from the end of the vector and try again // Here we simulate freeing the kvCache blocks associated with that sequence kvCacheManager.schedulingRemoveSequence((*lastStartedReqIt)->mRequestId); pausedRequests.emplace_back(*lastStartedReqIt); TLLM_LOG_DEBUG("MaxUtilizationScheduler: request ID %lu -> pause", (*lastStartedReqIt)->mRequestId); reqItEnd = std::next(lastStartedReqIt).base(); } else { break; } } } return {std::move(scheduledRequests), std::move(pausedRequests)}; } std::pair<bool, bool> MaxUtilizationScheduler::trySchedulingRequestMaxUtilization( kv_cache_manager::BaseKVCacheManager const& kvCacheManager, OptionalRef<BasePeftCacheManager const> peftCacheManager, std::shared_ptr<LlmRequest> const& req, RequestVector& scheduledRequests, SizeType32& numScheduledBlocks, SizeType32& numScheduledPeftPages, std::unordered_set<uint64_t>& seenTaskIds) const { if (scheduledRequests.size() < static_cast<std::size_t>(mMaxNumRequests)) { SizeType32 numRequiredBlocks = kvCacheManager.getNeededBlocksOneStep(*req, mManyMicroBatches); TLLM_LOG_DEBUG( "MaxUtilizationScheduler: request ID %lu required blocks: %i", req->mRequestId, numRequiredBlocks); ... bool const fitsKvCache = kvCacheManager.getBlockManager().schedulingHasFreeBlocks(numScheduledBlocks + numRequiredBlocks); if (fitsKvCache) { numScheduledBlocks += numRequiredBlocks; TLLM_LOG_DEBUG("MaxUtilizationScheduler: scheduled blocks: %i", numScheduledBlocks); scheduledRequests.emplace_back(req); } return std::make_pair(fitsKvCache, fitsPeft); } return std::make_pair(false, false); }
As we can see from the line:
SizeType32 numRequiredBlocks = kvCacheManager.getNeededBlocksOneStep(*req, mManyMicroBatches);
MAX_UTILIZATION only considers the KV cache required for a single step. Thus, it requires additional checks and eviction logic to ensure that there is enough KV cache available to proceed with each iteration.
Note that the code snippets in this post omit parts related to advanced features such as LoRA, disaggregated prefill, and chunked prefill to focus on explaining the core algorithms. However, the related code is also fully open-source, and the actual scheduling process is more complex as it takes all these factors into account.

Final Remarks

Now, we can transparently see how the options we set actually work and how requests are selected for processing in each iteration. This is expected to enhance TensorRT-LLM users' understanding of the system and greatly help attract new users.
Frankly speaking, despite its outstanding performance, TensorRT-LLM has not been widely adopted in the market due to its lack of accessibility. Contributions and support for new models and features have also been limited. With this code release, these limitations are expected to improve significantly.
We are genuinely excited that the community can now collaborate to understand and enhance TensorRT-LLM, contributing to a richer AI ecosystem. We look forward to positive changes and more active contributions.
Subscribe and stay tuned to latest changes in TensorRT-LLM!
 
Share article
Join the SqueezeBits newsletter today!

SqueezeBits