Introducing rebellions ATOM™-MAX

Introducing ATOM™-Max, rebellions’ next-generation NPU designed for high-performance AI inference. Learn how its runtime, profiling tools, and PyTorch-native integrations enable developers to run and serve models efficiently without sacrificing usability.
Huijong Jeong's avatar
Dec 24, 2025
Introducing rebellions ATOM™-MAX

Introduction

notion image
Modern AI workloads are growing rapidly in both scale and complexity. This growth is pushing the limits of traditional GPU architectures, and developers are increasingly facing challenges related to performance, efficiency, and deployment cost. As models become larger and inference demand rises, the need for new types of compute has become clear.
rebellions is a South Korea–based AI semiconductor company that set out to solve this problem. Since its founding in 2020, the company has focused on building NPUs that are purpose-built for large-scale and high-performance inference. Backed by more than 540 million dollars in funding and valued at roughly 1.4 billion dollars as of late 2025, rebellions has become one of Korea’s most prominent AI chip companies.
ATOM is the result of these efforts. It is a purpose-built NPU designed from the ground up for large-scale inference, high efficiency, and scalable deployment. The architecture reflects rebellions’ central goal: to provide a specialized accelerator that addresses the computational demands of modern AI workloads more effectively than traditional GPUs.
To ensure that this new hardware fits smoothly into real developer workflows, SqueezeBits and rebellions have formed a strategic partnership focused on building a complete software ecosystem around ATOM. The goal is not to treat the NPU as a niche or isolated subsystem, but to make ATOM feel like a natural extension of existing AI development practices. Our work aims to go beyond basic support and provide an environment where developers can write, debug, and optimize models on ATOM with the same level of confidence and comfort they expect when working on GPU.
In this post, we take a closer look at the current state of ATOM’s software ecosystem, how execution works on the device, and the directions in which development and collaboration are continuing to move.

ATOM™-Max

ATOM™-Max is rebellions’ answer to the growing demands of modern AI workloads. It is a purpose-built accelerator designed specifically for large-scale inference, with an architecture focused on both efficiency and scalability. Through this design, ATOM™-Max positions itself as a strategic alternative to general-purpose GPUs for building high-performance AI inference stacks.
Figure 1. Photo of ATOM™-Max (rebellions)
Figure 1. Photo of ATOM™-Max (rebellions)
As shown in the product images, each card integrates five chips: four identical ATOM NPUs and a central PCIe controller. The controller aggregates the four ATOM NPU dies and manages communication with the host system through a single PCIe Gen5x16 interface, delivering 128 TFLOPS, 512 TOPS, and 1 TB/s of bandwidth within a 350W TDP per card.
Figure 2. Photo of ATOM™-Max Server (rebellions)
Figure 2. Photo of ATOM™-Max Server (rebellions)
With this multi-die configuration, a standard server node equipped with eight PCIe slots can host up to 32 ATOM NPUs, enabling exceptionally dense compute configurations without requiring complex interposers or bridge solutions.
Figure 3. ATOM™-Max server rack deployed in the SqueezeBits IDC environment
Figure 3. ATOM™-Max server rack deployed in the SqueezeBits IDC environment
At SqueezeBits, we recently brought an ATOM™-Max servers into our internal development setup, and all experiments and analyses in this post were conducted on that system.

The Core of Execution on ATOM: RBLN Runtime

The rebellions software stack provides a mature and comprehensive solution that covers the entire workflow from compilation and execution to profiling. Beyond that, it also enables the development of practical and meaningful AI applications and solutions built on top of this foundation.
At the heart of rebellions’ software stack lies the RBLN Runtime API, which serves as the main interface between the host system and the NPU. The runtime exposes the necessary primitives for executing computations on ATOM, enabling users to run operations that have been compiled into RBLN binaries directly on the hardware. Through this API, developers can seamlessly offload workloads from the host to the NPU without needing to handle low-level scheduling or memory management details themselves.
Below is a simple example demonstrating how to use the RBLN SDK to compile a PyTorch model, create a runtime from the compiled binary, and perform inference on the ATOM.
First, we compile a standard torch.nn.Module into an RBLN compatible binary using the rebel.compile_from_torch API.
import rebel import torch # Compile the model compiled_model = rebel.compile_from_torch( torch.nn.Linear(16, 16), example_inputs=[torch.randn(16, 16)] ) # Save the compiled binary compiled_model.save("compiled.rbln")
Code 1. Example showing how to compile a torch.nn.Module using the rebel.compile_from_torch API and export it as a binary
Once the model is compiled, we can load the resulting binary to create a runtime and execute inference directly on the ATOM.
import rebel import torch # Load the compiled model and create a runtime runtime = rebel.Runtime("compiled.rbln", tensor_type="pt") # Run inference on the NPU x = torch.randn(16, 16) result = runtime.run(x)
Code 2. Example loading a compiled model into a rebel.Runtime object and running inference

Profiling RBLN Runtime

So what actually happens when we create a runtime and execute the compiled binary? To find out, we can use the profiler provided by the RBLN SDK. The profiler allows us to observe and analyze what is happening inside the NPU during execution.
Enabling profiling is straightforward. We can turn it on by passing a single profiling-related option when creating the runtime, as shown below:
import rebel import torch # Load the compiled model and create a runtime with profiler enabled module = rebel.Runtime("compiled.rbln", tensor_type="pt", activate_profiler=True) # Run inference on the NPU x = torch.randn(16, 16) result = module.run(x)
Code 3. Example enabling the built-in profiler initializing the runtime with activate_profiler=True
A trace file is generated automatically when this code runs. We can open and visualize this file with Perfetto, which allows us to explore the detailed execution flow and performance behavior of the model on the NPU.
Figure 4. Captured device‐level execution trace recorded during a simple linear operation
Figure 4. Captured device‐level execution trace recorded during a simple linear operation
By examining the resulting trace, we can clearly understand how the execution unfolds on the NPU. In the trace above, four categories of hardware activity are captured: HostExternal HDMANeural Engine Clusters, and Task DMA. Each category represents a distinct type of activity as follows:
  • Host: Represents operations that are executed on the host CPU when they are either more efficient than running on the NPU or not supported by it. It also includes operations that adjust input shapes to better optimize execution on ATOM. This category also captures the act of launching computations to the device as well as waiting for those device executions to finish.
  • External HDMA: Represents data transfers between the host DRAM and the device DRAM.
  • Neural Engine Cluster: Represents operations that are running on the Neural Engines in ATOM.
  • Task DMA: Represents data transfers between the device DRAM and the Shared Memory in the SoC, including input tensors, intermediate tensors, and kernel weights of the target model.

Profiling Basic RBLN Runtime Flow

Figure 5. Captured device‐level trace showing input data movement during a simple linear operation
Figure 5. Captured device‐level trace showing input data movement during a simple linear operation
With the device activity categories in mind, let us revisit the trace shown above and walk through the execution flow in order. When the host launches an NPU execution (Host), it first copies the input data from host memory into the device DRAM (External HDMA). Once the data is on the device, it is transferred from DRAM into on-chip SRAM (Task DMA), after which the computation cores begin executing the workload (Neural Engine Clusters).
 
Figure 6. Captured device‐level trace showing movement of linear layer weight data during execution
Figure 6. Captured device‐level trace showing movement of linear layer weight data during execution
Since the weights of the linear layer are already loaded into the device DRAM during runtime creation, they are copied directly from DRAM to on-chip SRAM via Task DMA, without requiring any External HDMA transfers.
 
Figure 7. Captured execution trace highlighting first activity on the neural engine cluster
Figure 7. Captured execution trace highlighting first activity on the neural engine cluster
Looking closely at the Neural Engine Cluster activity, the first operation is labeled 0_input_1. This corresponds to the data preprocessing step, which includes the type casting required because ATOM does not natively support FP32 arithmetic. The input tensor is cast to one of ATOM’s supported 16bit formats before any computation can proceed.
 
Figure 8. Captured device-level trace showing ongoing matrix multiplication without additional DMAs
Figure 8. Captured device-level trace showing ongoing matrix multiplication without additional DMAs
The subsequent trace segment captures the actual linear operation the matrix multiplication. Since all data required for the computation has already been placed in the device’s SRAM through the preceding Task DMAs, the operation proceeds without any additional DMA transfers, and only the computation itself is executed. Among the activities labeled as linear, some are not actual matrix multiplications but data-handling operations. These include steps such as up-casting the result back to FP32. This aligns with what we observe when inspecting the dtype of the output tensor returned to the host (result.dtype).
 
Figure 9. Captured device‐level trace showing output data movement during execution
Figure 9. Captured device‐level trace showing output data movement during execution
Once all preprocessing, computation, and postprocessing steps are complete, the resulting tensor is copied from the device’s SRAM back to DRAM (Task DMA). Finally, the data is transferred from device memory back to host memory (External HDMA), completing the full NPU execution flow.

Profiling Fused Operations

When the computation becomes more complex, the execution trace can exhibit different patterns. Consider the common case in AI models where a linear layer is immediately followed by an activation function. These two operations can often be fused so that the activation is applied as part of the linear computation rather than as a separate step. The RBLN SDK also applies this kind of optimization, enabling the NPU to process such sequences more efficiently.
import rebel import torch # Compile the model compiled_model = rebel.compile_from_torch( torch.nn.Sequential( torch.nn.Linear(16, 16), torch.nn.ReLU() ), example_inputs=[torch.randn(16, 16)] ) # Create a runtime with profiler enabled module = rebel.Runtime(compiled_model, tensor_type="pt", activate_profiler=True) # Run inference on the NPU x = torch.randn(16, 16) output = module.run(x)
Code 4. Example profiling execution of a linear layer followed by a ReLU operation
Figure 10. Captured device‐level trace showing execution of a linear layer followed by ReLU
Figure 10. Captured device‐level trace showing execution of a linear layer followed by ReLU
It may appear similar to the previous example at first glance, but a closer look at the second operation within the Neural Engine Cluster reveals that the linear and ReLU operations have been fused.
Figure 11. Captured device‐level trace showing execution of the fused operation
Figure 11. Captured device‐level trace showing execution of the fused operation

Profiling Host Fallbacks

This time, let’s examine how the RBLN SDK handles workloads that include operations not supported by the NPU. To do so, we will run and analyze a simple model composed of two linear layers. However, between these two layers, we insert torch.atan an operation that the current RBLN SDK does not support for device side acceleration. For reference, the list of supported Torch operations can be found here.
import rebel import torch class Foo(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(16, 16) self.linear2 = torch.nn.Linear(16, 16) def forward(self, x): return self.linear2(torch.atan(self.linear1(x))) # Compile the model compiled_model = rebel.compile_from_torch(Foo(), example_inputs=[torch.randn(16, 16)]) # Create a runtime with profiler enabled module = rebel.Runtime(compiled_model, tensor_type="pt", activate_profiler=True) # Run inference on the NPU x = torch.randn(16, 16) output = module.run(x)
Code 5. Example profiling execution behavior when running an unaccelerated operation
Figure 12. Captured device‐level trace illustrating a workload that involves host fallback during execution
Figure 12. Captured device‐level trace illustrating a workload that involves host fallback during execution
Unlike the previous examples, we can now observe that additional computation occurs on the host, in this case, the execution of torch.atan. Note that the segment labeled profiling overhead represents postprocessing steps that organize device information for the profiler trace and therefore does not appear in actual scenarios. However besides the profiling overhead, the system must still perform extra data copies and synchronizations between the host and the device to the host fallback. This additional interaction introduces performance penalties, confirming the negative impact on end-to-end execution efficiency. In real inference deployments, it is important to carefully inspect the operations and handle unsupported operations appropriately to ensure optimal performance.

Support For LLMs: optimum-rbln, vllm-rbln

Building on this foundation, rebellions is extending its software stack beyond basic compilation and execution to support large-scale language model (LLM) workloads. To ensure broad model compatibility and a familiar development experience, rebellions have developed optimum-rbln, a HuggingFace integration layer that connects standard transformer workflows to RBLN hardware.
In addition to LLM modeling, optimum-rbln also implements key features required for efficient LLM serving. These include support for chunked prefill and paged attention, which are essential for high throughput and memory efficient inference.
To ensure that these capabilities carry over cleanly into real serving environments, rebellions has also developed vllm-rbln, a plugin that extends vLLM with ATOM support. Through this plugin, core vLLM serving features such as continuous batching and request scheduling can be used seamlessly on the ATOM NPU. It also enables advanced features like structured output and prefix caching, allowing developers to take full advantage of vLLM’s serving stack on ATOM without additional integration work.
Figure 13. High-level architecture for running vLLM on ATOM through the optimum-rbln integration layer (rebellions)
Figure 13. High-level architecture for running vLLM on ATOM through the optimum-rbln integration layer (rebellions)
With this integration in place, developers can compile and deploy LLMs onto the NPU and build full LLM applications. This makes it possible to run these applications efficiently on the ATOM NPU stack while maintaining both scalability and stable performance.
from optimum.rbln import RBLNLlamaForCausalLM from vllm import LLM, SamplingParams # Compile model model = RBLNLlamaForCausalLM.from_pretrained( model_id="meta-llama/Meta-Llama-3-8B-Instruct", export=True, rbln_batch_size=4, rbln_max_seq_len=8192, rbln_kvcache_partition_len=4096, # for paged flash attention rbln_prefill_chunk_size=256, # for chunked prefill rbln_tensor_parallel_size=4, # for tensor parallelism rbln_create_runtimes=False, ) model.save_pretrained("compiled_llama_3_8b_instruct") # Create vLLM engine llm = LLM( model="compiled_llama_3_8b_instruct", device="rbln", max_num_seqs=4, max_num_batched_tokens=8192, max_model_len=8192, block_size=4096, ) # Generate text print(llm.generate([ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is" ], SamplingParams(temperature=0.0)))
Code 6. Example compiling an LLM for vLLM using optimum-rbln

Ongoing Collaboration

Constraints of The Current Usage Flow of RBLN SDK

However, the current user experience design of the RBLN SDK also presents several practical challenges across the broader development workflow.
  • Additional compilation step: Like most NPUs on the market, users are required to perform an explicit compilation step before running their models, exporting them outside the PyTorch ecosystem. This extra stage can introduce unexpected friction, particularly for developers who are already accustomed to the dynamic and flexible workflow that PyTorch provides.
  • Limited profiling visibility: The profiler currently provides trace information only from the device side, making it difficult to correlate NPU activity with host-side behavior during execution.
  • Slower development iteration cycle: In an iterative development workflow, having to recompile the model after every small modification slows down experimentation and debugging.
  • Limited debugging flexibility: Since the compiled graph must always be executed as a whole, fine-grained debugging or partial execution is not straightforward.
These challenges also extend to developers building open-source frameworks or applications on top of the RBLN NPU stack. For instance, projects like optimum-rbln and vllm-rbln are somewhat fragmented. Certain functionalities that conceptually belong to the inference layer end up being implemented in the compiler layer instead. Take features like tensor parallelism or chunked prefill in previous code snippet above as examples: although they are used during model inference in vllm-rbln, they must currently be reflected at export time within optimum-rbln.
This separation introduced by the explicit “export” step to the RBLN runtime leads to a non-intuitive and fragmented development experience, making it harder for contributors to understand where specific logic should reside and ultimately reducing usability for developers working on or with the stack.
Ultimately, developers want the same experience they enjoy on GPUs: simply installing PyTorch and having the model run seamlessly with hardware acceleration. The challenges discussed above are not fundamental limitations, but rather integration gaps that can be solved by bringing the already solid RBLN stack closer to that familiar user experience.

Shift Towards PyTorch Native UX

Recognizing these challenges, rebellions is evolving its SDK toward a more seamless and developer friendly experience. Our collaboration with rebellions is built around the same vision, which is to make advanced NPU acceleration feel as effortless and accessible as GPU execution, empowering developers to focus on innovation rather than infrastructure.
As a concrete example of our collaboration, we are working together on the vllm-rbln project. Among the many areas of collaboration, we’d like to highlight two key aspects in this post:
  • supporting vLLM codes without any external compilation or code modification through the torch.compile integration
  • enabling eager mode execution through torch-rbln
These two directions share a common goal: bringing ATOM execution deeper into the PyTorch ecosystem and making it feel native to developers already familiar with the PyTorch workflow.

TorchDynamo Integration: torch.compile Integration

TorchDynamo integration is already integrated into the RBLN compiler, allowing PyTorch models to be compiled and executed natively on the ATOM NPU. Through this integration, models can run directly from PyTorch code with no separate export or compilation step. A minimal example is shown below:
import rebel import torch @torch.compile(backend="rbln") class Foo(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(16, 16) self.gelu = torch.nn.GELU() def forward(self, x): x = self.linear(x) return self.gelu(x) model = Foo() # TorchDynamo frontend traces the model and # triggers compilation using the rebellions compiler output = model(torch.randn(16, 16))
Code 7. Example compiling and executing a simple PyTorch module on ATOM using torch.compile with the RBLN backend
While this PyTorch-native path is already available, the current vllm-rbln still relies primarily on optimum-rbln as its bridge to the hardware. This dependency introduces an additional export step and limits how naturally vLLM code can run on the device.
To address this, we are shifting the vLLM workflow toward the TorchDynamo-based integration. By leveraging the RBLN compiler’s support for torch.compile, vLLM models can be executed on the NPU directly from PyTorch code, without any external conversion pipeline. This transition unifies the modeling and compilation stages and provides a smoother and more consistent developer experience on the RBLN platform.
Figure 14. High-level architecture for running vLLM through the torch.compile integration path (rebellions)
Figure 14. High-level architecture for running vLLM through the torch.compile integration path (rebellions)
This shift toward a PyTorch-native compilation flow brings several immediate benefits:
  • It eliminates the disconnect between the model code and the compiled graph, improving usability and iteration speed.
  • It allows dynamic and partially evaluated graphs to be compiled on the fly, aligning the RBLN stack more closely with the PyTorch execution model.
  • It makes integration with higher-level frameworks like vLLM cleaner and easier to maintain, since much of the logic such as tensor parallelism and chunked prefill can now reside within the same runtime context.
import os # Enable torch.compile integration with environmental variable # as optimum based path is the default for now os.environ["VLLM_RBLN_USE_VLLM_MODEL"] = "1" from vllm import LLM, SamplingParams # Create vLLM engine without external compilation process llm = LLM( model="meta-llama/Meta-Llama-3-8B-Instruct", device="rbln", max_num_seqs=4, max_model_len=8192, block_size=4096, max_num_batched_tokens=256, # chunked prefill is now handled on vLLM level tensor_parallel_size=4, # tensor parallelism is now handled on vLLM level ) # Generate text print(llm.generate([ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is" ], SamplingParams(temperature=0.0)))
Code 8. Example running vLLM using torch.compile integration
A major advantage of this approach is that vLLM code can be supported as is. vLLM already provides mature implementations of tensor parallelism, expert parallelism, multi-LoRA support, diverse sampling strategies, and many other components that form the backbone of modern LLM inference. These features are continually evolving, and most are written in PyTorch-native code. Supporting them through a PyTorch-native compilation path creates meaningful synergy and preserves the natural flow of development for vLLM users.
vLLM has been at the center of attention in the LLM serving landscape for quite some time, and its significance continues to increase as the ecosystem matures. Its design philosophy emphasizes modularity and composability, enabling researchers and engineers to prototype new ideas quickly while maintaining performance suitable for production-scale deployments. As a result, vLLM has become a central point of innovation and a common foundation for experimentation across the community.
This influence is especially evident in emerging areas such as disaggregated inference. Many recent efforts in this space are built around vLLM’s abstractions, with external frameworks intentionally aligning their interfaces to match vLLM’s execution model. vLLM defines clear patterns for scheduling, request management, attention computation, and model partitioning, and these abstractions have started to function as informal standards across the wider ecosystem.
Strong native compatibility with vLLM enables these ecosystem-wide innovations to carry over to ATOM with minimal integration work. When vLLM introduces new features or experimental techniques, a PyTorch-native backend on the RBLN platform can adopt them immediately without requiring dedicated export pipelines or proprietary graph formats. This allows developers to benefit from vLLM’s continuous evolution while still leveraging the efficiency and scalability of ATOM.
In practical terms, this positions ATOM as a first-class hardware target within the vLLM ecosystem. The NPU can inherit advancements in areas such as scheduling algorithms, memory management techniques, speculative decoding, and other cutting-edge serving methods as soon as they land in vLLM. The combination of vLLM’s rapid innovation cycle and ATOM’s PyTorch-native execution path opens the door to fast iteration, smoother integration, and seamless deployment on specialized hardware.

Eager Mode: Towards True PyTorch Native Development

The next frontier of this effort is eager-mode execution, enabled through torch-rblntorch-rbln is a PyTorch extension that brings rebellions NPU compute directly into standard PyTorch workflows. Implemented as an out-of-tree (OOT) extension, torch-rbln integrates with PyTorch’s dispatch system through the PrivateUse1 mechanism, allowing it to register custom kernels and device behaviors without modifying PyTorch itself. This design keeps the extension lightweight, maintainable, and fully compatible with the broader PyTorch ecosystem.
Figure 15. Dispatcher-based design architecture of torch-rbln (rebellions)
Figure 15. Dispatcher-based design architecture of torch-rbln (rebellions)
While torch.compile offers graph-level acceleration through just-in-time compilation, many developers still rely on eager execution for debugging, rapid prototyping, and interactive development. By supporting eager mode in a true define-by-run fashion, torch-rbln provides a familiar and seamless user experience across the full lifecycle of AI development.
Figure 16. Screenshot of running PyTorch in eager mode through torch-rbln
Figure 16. Screenshot of running PyTorch in eager mode through torch-rbln
With eager-mode support, individual operations can now run natively on the device without any ahead-of-time compilation. torch-rbln exposes RBLN kernels through PyTorch’s dispatcher, allowing standard PyTorch operators to map transparently to their NPU implementations while preserving Python-level semantics. As a result, developers can write, debug, and benchmark models exactly as they would on CPU or GPU, without additional setup or changes to their code.
Beyond improving ergonomics, eager mode forms the foundation for advanced workflows such as hybrid execution, fine-grained profiling, and mixed backend strategies. Combined with torch.compile, it enables a fully PyTorch-native experience on the RBLN platform and brings NPU acceleration to developers without requiring them to rethink how they write or execute their models.

Making Impact on “REAL WORLD”

As we have explored throughout this post, bringing a new NPU architecture like ATOM into the modern AI software ecosystem requires effort across many layers—from compiler support, operator coverage, and eager-mode execution to profiling workflows, fallback behavior, and seamless PyTorch integration.
Our examination of the RBLN SDK’s execution flow illustrates both the strengths of the current stack and the challenges that naturally arise when introducing novel hardware into existing AI workflows. These insights make it clear that achieving a smooth developer experience is not simply a matter of hardware capability, but a combination of software maturity, ecosystem alignment, and careful engineering.
In addition to these engineering efforts, we are equally committed to ensuring that this work delivers real value in the market and becomes meaningfully connected to the broader community. In the AI world, community-driven innovation has repeatedly proven its strength, as demonstrated by projects like vLLM. To help ATOM integrate naturally into this ecosystem, we actively create opportunities for developers to engage with the hardware and provide feedback.
Figure 17. Photo of Taesu Kim, CTO of SqueezeBits, leading a session during the recent vLLM hands-on workshop
Figure 17. Photo of Taesu Kim, CTO of SqueezeBits, leading a session during the recent vLLM hands-on workshop
One example of this engagement is our regular vLLM hands-on meetups, where participants run workloads directly on ATOM systems and experience the stack firsthand. Through initiatives like these, we stay closely connected with developers and evolve the platform based on real-world needs.
We will keep pushing forward so that all of these efforts translate into meaningful, tangible impact in the field and the market. Thank you for your interest and support, and we look forward to sharing even more progress soon.
Share article

SqueezeBits