Disaggregated Inference on Apple Silicon: NPU prefill and GPU decode

In this article, we introduce how to run LLMs efficiently on Apple Silicon with disaggregated inference technique.
Jiwoong Choi's avatar
Aug 26, 2025
Disaggregated Inference on Apple Silicon: NPU prefill and GPU decode

1. Introduction

As demand increases for running large language models (LLMs) on personal devices to improve privacy and offline use, Apple Silicon has emerged as a robust hardware platform with its unified CPU, GPU, and Neural Engine (ANE) architecture. While standard frameworks like PyTorch are compatible, fully leveraging Apple's hardware requires specialized tools like MLX and Core ML. MLX is a flexible, developer-friendly framework inspired by PyTorch and NumPy, featuring a unified memory model that closely aligns with Apple Silicon's design, enabling operations on the CPU or GPU without data transfer. However, MLX's main limitation is the lack of support for the Apple Neural Engine (ANE).
To harness the power-efficient, high-throughput performance of the ANE, developers need to use Apple's Core ML framework. Core ML is less flexible than MLX and requires models to have fixed input shapes for ANE compatibility. The multifunction model feature is crucial for deploying language models within this constraint; it enables you to combine separate fixed-shape functions, such as one for prefill and another for decode, into a single package that shares weights, thereby keeping the model file size small. Additionally, the introduction of stateful models on iOS 18 and macOS 15 provides an efficient way to manage the key-value (KV) cache by allowing the model to maintain a persistent, updatable buffers during runtime. Together, these specialized features facilitate deploying language models on the ANE, overcoming the framework's inherent rigidity.
In this post, we will first explore the many undocumented challenges of running LLMs on the ANE and share practical solutions. We then provide a detailed performance comparison between MLX (GPU) and Core ML (ANE) to establish a baseline and highlight their unique benefits. Building on these insights, we introduce our Yetter Inference Engine, a new system that employs a disaggregated inference approach to leverage the strengths of both the ANE and GPU for better performance. Finally, we share comprehensive benchmarks for our engine and discuss future directions for this work.

2. Hidden Challenges Facing the ANE

Despite Core ML’s exclusive features mentioned earlier, many undocumented challenges still exist that must be addressed before fully utilizing the ANE. Here are some of these challenges along with workaround solutions identified through painstaking debugging.

Too Many States Will Kill You!

Even if Core ML’s state makes it easier to handle the KV cache efficiently, using it for the ANE is more complicated than you might think. For example, if your model has too many states, it could fail to compile for the ANE. To fix this, you need to implement custom Hugging Face cache interfaces (Cache and CacheLayerMixin) to manage key/value caches internally for all layers as a single concatenated tensor with shape (num_hidden_layers, num_key_value_heads, max_cache_len, head_dim), assuming batch_size=1. For instance, in the case of Qwen3-0.6B, the custom cache interface results in only two states in the converted Core ML model—key_cache and value_cache as shown in Figure 1. and Figure 2.—compared to 56 states when each layer-wise cache was treated as a separate state.
📎
Apple’s official blog post, “On Device Llama 3.1 with Core ML,” provides an example implementation of the cache interface. However, it relies on outdated transformer APIs and is therefore incompatible with the most recent versions. Most critically, it is designed for GPU acceleration rather than ANE.
Figure 1. The input, output, and state for the Core ML package representing the prefill stage of Qwen3-0.6B with max_seq_len=512. Only two states are included by using the custom cache interface.
Figure 1. The input, output, and state for the Core ML package representing the prefill stage of Qwen3-0.6B with max_seq_len=512. Only two states are included by using the custom cache interface.
Figure 2. The input, output, and state for the Core ML package representing the decode stage of Qwen3-0.6B with max_seq_len=512. Only two states are included by using the custom cache interface.
Figure 2. The input, output, and state for the Core ML package representing the decode stage of Qwen3-0.6B with max_seq_len=512. Only two states are included by using the custom cache interface.
However, this custom implementation conflicts with the recent change in Hugging Face’s per-layer cache interface. As a result, this approach needs further adjustments for advanced caching mechanisms with heterogeneous cache layers, such as sliding window attention used in google/gemma-3-1b-it.

Make State Dimension Sizes Powers of Two

Furthermore, if any of the states has a dimension (except the very first one) whose size is not a power of 2, then such a model might encounter a runtime error. For example, when LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct is packaged as a stateful model, its state shape is (20, 8, 512, 80) - note that 80 is not a power of two. When you run this model with a Core ML Swift API, specifically MLModel's prediction method, you’ll see an error as follows:
Error Domain=com.apple.CoreML Code=0 "Unable to compute the prediction using ML Program. It can be an invalid input data or broken/unsupported model."
The runtime error can be fixed by padding the states to have shape (20, 8, 512, 128). To prevent similar problems in the future, the cache interface should internally pad the non-first dimensions of cache tensors to the nearest powers of two and then only use the unpadded parts of these tensors.

Update States with Care

Even models that already meet the power-of-two constraint mentioned above may still face compilation failures with the following unhelpful error message.
Error Domain=com.apple.CoreML Code=0 "Failed to build the model execution plan using a model architecture file"
It turns out that adding a small constant value right before updating the value caches, as shown in Figure 3., can effectively address this issue. Theoretically, zero is the ideal choice for this small value to preserve numerical correctness. However, the addition by zero is automatically eliminated by the PyTorch JIT tracer’s graph optimizer. Hence, for example, torch.finfo(torch.float32).smallest_normal is a sensible choice for the small number, which is truncated to zero when it is casted to float16 precision by Core ML’s graph optimization. Notably, this trick is only required for value cache updating in the prefill graph.
Figure 3. [Left] The value cache updating subgraph visualized by Netron. The add node (highlighted by the red box) is inserted to prevent the compilation failure. [Right] The key cache updating subgraph. In most models, a rotary positional embedding (RoPE) is applied to the key states before updating the key caches. The add node created by the RoPE (highlighted by the orange box) is likely preventing the compilation failure.
Figure 3. [Left] The value cache updating subgraph visualized by Netron. The add node (highlighted by the red box) is inserted to prevent the compilation failure. [Right] The key cache updating subgraph. In most models, a rotary positional embedding (RoPE) is applied to the key states before updating the key caches. The add node created by the RoPE (highlighted by the orange box) is likely preventing the compilation failure.

Double-check The ANE Delegation

After applying a few more adjustments—beyond what could be included in this post—the model's prefill and decode phases can finally be compiled for the ANE. Nonetheless, it remains essential to verify that each operation within the model is delegated to the ANE. Such verification can be effectively done by generating a performance report via XCode that shows operation-wise device delegation status and performance metrics. As illustrated in Figure 4. and Figure 5., nearly all operations within both the prefill and decode phases are executed on the ANE with only a few exceptions, namely the token embedding layer (represented as a gather operation), the LM head (represented as a linear operation), and a few minor element-wise operations.
Figure 4. Core ML performance report for the prefill phase of Llama-3.2-1B-Instruct generated by XCode. The model weights are quantized channel-wise with INT4, while activations are kept in FP16.
Figure 4. Core ML performance report for the prefill phase of Llama-3.2-1B-Instruct generated by XCode. The model weights are quantized channel-wise with INT4, while activations are kept in FP16.
Figure 5. Core ML performance report for the decode phase of Llama-3.2-1B-Instruct generated by XCode. The model weights are quantized channel-wise with INT4, while activations are kept in FP16.
Figure 5. Core ML performance report for the decode phase of Llama-3.2-1B-Instruct generated by XCode. The model weights are quantized channel-wise with INT4, while activations are kept in FP16.
📎
Notice that the XCode performance reports presented above are comparable to those presented in Apple’s official blog post “Deploying Transformers on the Apple Neural Engine”, despite the fact that the models discussed therein are predominantly traditional sequence classification models, such as distilbert/distilbert-base-uncased-finetuned-sst-2-english.

Don’t Let The pow In

Hold on! It’s not over yet. Even if the model is up and running on the ANE, it might generate a broken text as a result of numerical instability as described in Figure 6.
Figure 6. [Left] Llama-3.2-1B-Instruct produces only new line characters before applying the graph optimizations. [Right] The model generates valid sentences after the graph optimizations (right).
Figure 6. [Left] Llama-3.2-1B-Instruct produces only new line characters before applying the graph optimizations. [Right] The model generates valid sentences after the graph optimizations (right).
Surprisingly, this numerical instability is primarily caused by pow operations derived from the torch.Tensor.pow method used in RMSNorm or LayerNorm layers as shown in Figure 7. Instead of manually amending Hugging Face’s model implementations, it is more plausible to incorporate a custom Model Intermediate Language (MIL) graph pass into the Core ML converter that rewrites pow(x, 2) as mul(x, x). Furthermore, an additional graph pass can be implemented to fold the add operation within the RMSNorm subgraph into the epsilon parameter of the subsequent rsqrt operation.
Figure 7. An RMSNorm subgraph in Llama-3.2-1B-Instruct model. [Left] The original computation graph involves a numerically unstable pow node and a redundant add node. [Right] The resulting computation graph after applying the custom MIL graph pass that replaces the pow node with the equivalent mul node and fuses the add node into its subsequent rsqrt node’s parameter.
Figure 7. An RMSNorm subgraph in Llama-3.2-1B-Instruct model. [Left] The original computation graph involves a numerically unstable pow node and a redundant add node. [Right] The resulting computation graph after applying the custom MIL graph pass that replaces the pow node with the equivalent mul node and fuses the add node into its subsequent rsqrt node’s parameter.

Break Down SDPA for Long Sequences

Finally, the model is running on the ANE and generating valid outputs. However, if the model is compiled with an increased maximum sequence length (e.g., 1024 or 2048), you may encounter the familiar compilation error: “Failed to build the model execution plan using a model architecture file”. Given the context, it is reasonable to deduce that the cause of this failure is the significant memory requirement associated with executing the scaled_dot_product_attention operation with extensive sequence lengths.
Unexpectedly, the unused MIL graph pass scaled_dot_product_attention_sliced_q is a hidden gem in Core ML Tools capable of addressing this issue. Its docstring describes its functionality as follows:
@register_pass(namespace="common") class scaled_dot_product_attention_sliced_q(AbstractGraphPass): """ Replace the ios18.scaled_dot_product_attention operation with a memory efficient implementation of attention calculation based on slicing Q. The benefits are clearly visible for higher Q sequence lengths, though. Graph pass options: - min_seq_length: int Only operations working with Q of sequence length greater or equal to this value will be transformed. - seq_length_divider: int Defines the size of the chunks of Q being processed in SDPA (chunk_size = seq_length / seq_length_divider) """
Adding this pass (with a few modifications) to the Core ML conversion pipeline should eliminate the compilation error for long sequence lengths.

3. MLX vs. Core ML: When to Use Each

Is it worth the enormous effort to run language models on the ANE with Core ML? While MLX is certainly a better choice for a handy deployment on Apple silicon, the following benchmark results demonstrate that each has its own strengths and weaknesses.

Benchmarks

We measured the Time To First Token (TTFT) and Time Per Output Token (TPOT) for several models utilizing both MLX (GPU) and Core ML (ANE) under two different scenarios.
  1. Prefill-heavy scenario (448+64): the model generates 64 tokens given 448 input tokens
  1. Decode-heavy scenario (64+448): the model generates 448 tokens given 64 input tokens
For each scenario, the time taken by the tokenizer and the model is measured separately to understand their contributions to the TTFT since the tokenizer’s encoding overhead is quite significant compared to the model’s prefill latency in most cases. Meanwhile, only the overall TPOT is measured because the tokenizer’s decoding overhead is negligible.
Regarding the Core ML models, input sequences are padded to the maximum sequence length of 512 and the model weights are quantized to INT4 precision with per-channel granularity, while activations are maintained in FP16. Similarly, for the MLX models, the weights are quantized to INT4 precision with per-group granularity, using a group size of 64, whereas activations are preserved in either FP16 or BF16, depending on the models’ configurations.
In all cases, the chat templates are applied by the tokenizers to better simulate realistic deployment settings. The performances are measured on iPhone 15 Pro with iOS 18.6.1.

Qwen3-0.6B

Figure 8. Comparison of TTFT and TPOT for Qwen3-0.6B between MLX (GPU) and Core ML (ANE).
Figure 8. Comparison of TTFT and TPOT for Qwen3-0.6B between MLX (GPU) and Core ML (ANE).

Llama-3.2-1B-Instruct

Figure 9. Comparison of TTFT and TPOT for Llama-3.2-1B-Instruct between MLX (GPU) and Core ML (ANE).
Figure 9. Comparison of TTFT and TPOT for Llama-3.2-1B-Instruct between MLX (GPU) and Core ML (ANE).

Key Takeaways

In prefill-heavy scenarios, Core ML using the ANE can substantially enhance the TTFT with minimal drawbacks in decode-heavy scenarios. However, MLX running on the GPU consistently outperforms in terms of TPOT across both prefill-heavy and decode-heavy scenarios. It is also worth noting that the tokenizer encoding latency varies among different models, representing another factor in overall performance.

4. Yetter Inference Engine - Disaggregated Inference for Apple Silicon

Looking at the performance of MLX and Core ML shows a clear chance to improve. MLX and the GPU excel during decoding, while Core ML and the ANE are better for prefill. This suggests a strong idea: why not use both? This is the core concept of our Yetter Inference Engine for Apple Silicon. We disaggregate the inference process, designating each stage to the hardware best suited for it.
However, building a foundation for integrating Core ML and MLX involves costs. Although some popular open-source projects for running LLMs on Apple Silicon exist, most focus either on CPU/GPU usage (e.g., llama.cpp) or solely on the ANE (e.g., Anemll). Therefore, besides supporting pure ANE inference with Core ML and GPU inference with MLX, we have added support for disaggregated inference that uses both ANE and GPU in our Yetter Inference Engine. This method maximizes LLM performance on Apple Silicon by assigning the prefill stage to Core ML, which handles it on the ANE to utilize its high-throughput capacity, and then directing the decode stage to MLX, managing the iterative generation of new tokens on the GPU to take advantage of its fast decoding speed.
To support various language models with minimal effort, the Yetter Inference Engine features a streamlined conversion tool that can transform any Hugging Face language model into a single, multifunction Core ML package. The conversion tool can package the model as either stateful (for pure ANE inference) or stateless (for disaggregated prefill, where the KV cache tensors are treated as outputs).

End-to-end Latency Comparisons

To thoroughly evaluate our approach alongside existing ones—Core ML and MLX—we measured end-to-end latencies for various models using each approach—MLX, Core ML, and our Yetter Inference Engine (i.e., disaggregated inference)—under both prefill-heavy and decode-heavy scenarios. The specific settings remain the same as in the previous experiments.

Qwen3-0.6B

Figure 10. Comparison of end-to-end latency for Qwen3-0.6B across MLX, Core ML, and Yetter.
Figure 10. Comparison of end-to-end latency for Qwen3-0.6B across MLX, Core ML, and Yetter.

Llama-3.2-1B-Instruct

Figure 11. Comparison of end-to-end latency for Llama-3.2-1B-Instruct across MLX, Core ML, and Yetter.
Figure 11. Comparison of end-to-end latency for Llama-3.2-1B-Instruct across MLX, Core ML, and Yetter.

EXAONE-4.0-1.2B

Figure 12. Comparison of end-to-end latency for EXAONE-4.0-1.2B across MLX, Core ML, and Yetter.
Figure 12. Comparison of end-to-end latency for EXAONE-4.0-1.2B across MLX, Core ML, and Yetter.

HyperCLOVAX-SEED-Text-Instruct-1.5B

Figure 13. Comparison of end-to-end latency for HyperCLOVAX-SEED-Text-Instruct-1.5B across MLX, Core ML, and Yetter.
Figure 13. Comparison of end-to-end latency for HyperCLOVAX-SEED-Text-Instruct-1.5B across MLX, Core ML, and Yetter.

kanana-1.5-2.1b-instruct-2505

Figure 14. Comparison of end-to-end latency for kanana-1.5-2.1b-instruct-2505 across MLX, Core ML, and Yetter.
Figure 14. Comparison of end-to-end latency for kanana-1.5-2.1b-instruct-2505 across MLX, Core ML, and Yetter.

Key Takeaways

As shown above, Yetter Inference Engine can support various models on Apple Silicon with consistent performance profiles. Compared to MLX, it significantly reduces prefill latency, which is only slightly higher than Core ML’s because Yetter uses stateless prefill models, whereas Core ML can utilize stateful ones. Its decode latencies are nearly identical to MLX's. This allows Yetter to consistently outperform both MLX and Core ML in prefill-heavy scenarios. In decode-heavy scenarios, Yetter performs similarly to MLX while notably surpassing Core ML.

5. Future Works

The success of this disaggregated inference model opens up exciting possibilities for future work. We are exploring ways to extend this approach by using different levels of numerical precision for the prefill and decode stages, which could provide additional performance improvements. Another direction involves developing a system that dynamically routes workloads between the GPU and ANE based on the specific task and device condition. We also plan to expand support for more model architectures and intend to open-source the engine to encourage community collaboration and feedback.

 
Share article
Join the SqueezeBits newsletter today!

SqueezeBits