Arm Community
Arm Community
  • Site
  • User
  • Site
  • Search
  • User
Arm Community blogs
Arm Community blogs
AI blog Ethos-U and Beyond: How ExecuTorch 1.0 powers AI at the edge
  • Blogs
  • Mentions
  • Sub-Groups
  • Tags
  • Jump...
  • Cancel
More blogs in Arm Community blogs
  • AI blog

  • Announcements

  • Architectures and Processors blog

  • Automotive blog

  • Embedded and Microcontrollers blog

  • Internet of Things (IoT) blog

  • Laptops and Desktops blog

  • Mobile, Graphics, and Gaming blog

  • Operating Systems blog

  • Servers and Cloud Computing blog

  • SoC Design and Simulation blog

  • Tools, Software and IDEs blog

Tags
  • PyTorch
  • Artificial Intelligence (AI)
  • Cortex-A
  • KleidiAI
  • Cortex-M
  • ASR
  • Arm Ethos-U processor
  • Edge Computing
Actions
  • RSS
  • More
  • Cancel
Related blog posts
Related forum threads

Ethos-U and Beyond: How ExecuTorch 1.0 powers AI at the edge

Per Åstrand
Per Åstrand
October 22, 2025
8 minute read time.

This blog post is published on behalf of Per Åstrand and Fredrik Knutsson


Introduction: A new era for embedded AI

AI is getting leaner. No longer confined to the cloud or powerful smartphones, the next generation of intelligence is moving into the tiniest devices: smart sensors, wearables, and industrial systems that run on milliwatts and kilobytes. These environments are unforgiving because every cycle and every byte counts. Bringing modern AI into this space has long required trade-offs between accuracy, efficiency, and developer productivity.

With the General Availability (GA) release of ExecuTorch 1.0, those trade-offs are beginning to disappear. As part of the PyTorch ecosystem, ExecuTorch bridges the gap between innovation and embedded deployment, empowering developers to run state-of-the-art models on Arm-based edge devices, from power-efficient microcontrollers paired with Arm Ethos-U NPUs to high-performance industrial solutions built on Arm CPUs.

ExecuTorch: Portable, performant, productive

ExecuTorch brings the PyTorch strengths directly to the edge, built on three guiding principles:

  • Portable: A unified execution stack that targets Arm CPUs, GPUs, and NPUs without rewriting models.
  • Performant: Optimized runtimes and backends ensure models fit and run within the strict limits of embedded devices.
  • Productive: Developers stay within the familiar PyTorch flow, with native lowering and quantization tooling that makes deployment straightforward.

With ExecuTorch 1.0, deploying AI at the edge is not just possible, it is PyTorch end-to-end. No new frameworks, no conversion, just a direct, optimized path from research to production.

The Arm advantage: Seamless edge IP support

From PyTorch to the edge: The whole flow

ExecuTorch streamlines the journey from PyTorch models to efficient embedded execution. As shown in the diagram below, models can be exported and lowered through different backend delegates that map onto the full spectrum of Arm edge IPs, from NPUs to Cortex-A and Cortex-M CPUs. This unified stack means developers can begin in PyTorch and seamlessly deploy across diverse Arm-based edge devices with consistent tooling, predictable performance, and a clear path from research to production.

Diagram showing how Arm's edge IPs are supported via ExecuTorch backend delegation.

Figure 1. How Arm's edge IPs are supported via ExecuTorch backend delegation.

TOSA: A unified foundation for Arm NPUs

The TOSA (Tensor Operator Set Architecture) 1.0 specification and tooling, released earlier this year, establishes a consistent foundation for deploying AI workloads across Arm NPUs and neural technology. Within ExecuTorch, the TOSA backend ensures predictable behavior by lowering the majority of edge operators (int8 and float32) into a common, portable form. Any operators not yet covered fall back to CPU execution through reference kernels. This unified path makes TOSA the backbone for scalable embedded AI, enabling seamless acceleration across Arm’s NPU family.

Ethos-U: Production-ready AI acceleration

ExecuTorch 1.0 delivers production-quality support for the Ethos-U family of NPUs, designed specifically for ultra-low-power AI acceleration. Key highlights include:

  • ~80% operator coverage for edge AI workloads.
  • Support for int8, with int16 experimental.
  • 100+ models from popular sources such as TorchVision, TorchAudio, and HuggingFace running end-to-end on Ethos-U.
  • Performance on par with other AI frameworks across representative edge workloads, validating the efficiency of the TOSA legalization and execution flow.

This makes Ethos-U the most complete path today for running advanced PyTorch models at the microcontroller level. For example, transformer-based architectures like the Conformer model can now run on Ethos-U85-based edge devices, something unthinkable just a few years ago.

Neural technology with the VGF backend

ExecuTorch 1.0 also introduces support for Arm’s upcoming neural technology that will feature in 2026 Arm GPUs through the new VGF backend. This enables ahead-of-time export and execution of neural networks that will power use cases like Neural Super Sampling (NSS), denoising, and ML-driven rendering on future Arm GPUs. While this is covered in detail in this blog post, the important takeaway is that the same ExecuTorch and TOSA infrastructure supporting Ethos-U today also extends to next-generation neural graphics acceleration. This ensures that developers can start experimenting now and be ready for what is coming next, all in the same Python and PyTorch-based development flow.

Cortex-M + CMSIS-NN: Efficient everywhere

Arm’s Cortex-M CPUs are the backbone of the embedded world, shipping in billions of devices annually. With ExecuTorch integrating CMSIS-NN, even CPU-only inference benefits from optimized kernels. Accelerated support for Cortex-M is already available, with further improvements underway to expand efficiency across the family.

Cortex-A + XNNPACK: Scaling performance with Arm KleidiAI

For higher-performance Linux-based platforms, ExecuTorch 1.0 brings the same seamless flow to Cortex-A CPUs. Using XNNPACK, tuned by Arm KleidiAI, developers can achieve peak performance for edge workloads. This allows models to scale naturally from microcontrollers up to Cortex-A–based edge compute without changing the workflow.

Evaluate without hardware

Not every developer has hardware on hand, and getting started with edge AI should not depend on waiting for silicon. Across different Arm backends, developers can evaluate and validate their models even before hardware is available:

  • Ethos-U: Through Arm Corstone subsystems, developers can use the Fixed Virtual Platform (FVP) to emulate Ethos-U targets. From there, the path to deployment is straightforward, moving from FVP to FPGA prototyping, and eventually onto silicon implementations built on Arm Corstone.
  • Neural technology–enabled GPUs: For future GPU acceleration, ExecuTorch provides an emulation path that mimics the behavior of the hardware, enabling early development and testing long before first silicon arrives.

This layered approach makes it possible to prototype, validate, and refine AI workloads today, without needing physical hardware in hand. This is well aligned with the productivity goals of ExecuTorch.

Python-native backends: Hackable and extensible

One of the strengths of the Arm integration in ExecuTorch 1.0 is that the entire flow – from model lowering down to the backend compiler – is implemented in Python. This makes the backends transparent and hackable:

  • Developers can easily inspect and modify the lowering path.
  • Missing functionality can be added without diving into opaque C++ or proprietary toolchains.
  • Contributions back to the project are straightforward, empowering the community to expand operator support and optimize performance.

This design choice keeps the developer experience close to PyTorch, while still unlocking the efficiency of Arm hardware. It means that if you need to adapt ExecuTorch for your own model, workflow, or hardware configuration, the tools are right there at your fingertips.

For deployment, ExecuTorch provides a lightweight C++ runtime that is easy to integrate into any application. The runtime is designed for portability and efficiency, delivering predictable performance on resource-constrained devices while keeping the deployment footprint small. Together, the Python development flow and the C++ runtime create a seamless bridge from experimentation to production.

Diagram showing that ExecuTorch has an efficient C++ based runtime and a hackable AoT comilation flow.

Figure 2. ExecuTorch has an efficient C++-based runtime and a hackable AoT compilation flow.

A developer journey: From PyTorch to Ethos-U

Here is what an Ethos-U–focused workflow looks like with ExecuTorch:

  1. Start in PyTorch with a model from HuggingFace, TorchAudio, or TorchVision.
  2. Quantize and lower directly in Python using ExecuTorch’s native tooling.
  3. Deploy on Ethos-U to achieve milliwatt-scale inference performance.
  4. Validate on Corstone FVP without hardware or run fallback paths on Cortex-M with CMSIS-NN.

Turning a PyTorch model into a .pte artifact ready for deployment is straightforward. The snippet below captures the representative steps involved.

# Conformer model with the same hyper parameters as how we have trained it.
model = Conformer(num_classes=vocab_size)

dataset = torchaudio.datasets.LIBRISPEECH()
# Pick 100 random indexes for calibration
calibration_set = torch.utils.data.Subset(dataset, random.sample(range(len(dataset)), 100))
calibration_loader = torch.utils.data.DataLoader(
    calibration_set, batch_size=1, shuffle=False, collate_fn=collate_fn
)

# Load the checkpoint data for the model weights
checkpoint = torch.load(path_to_checkpoint, weights_only=True)
model.load_state_dict(checkpoint["model"])
model.eval()

exported_program = torch.export.export(model, example_inputs, strict=True)
graph_module = exported_program.module()

compile_spec = EthosUCompileSpec("ethos-u85-256")
# Create the quantizer and use the ExecuTorch PT2E flow to quantize the model
compile_spec = EthosUCompileSpec("ethos-u85-256")
quantizer = EthosUQuantizer(compile_spec)
quantizer.set_global(get_symmetric_quantization_config(is_per_channel=True))
quantized_graph_module = prepare_pt2e(graph_module, quantizer)

# Do the post-training quantization calibration using the dataset
for feats, feat_lens, *_ in calibration_loader:
    feats, feat_lens, *_ = next(
        iter(calibration_loader)
    )
    quantized_graph_module(feats, feat_lens)

# quantization parameters are captured and the model is re-exported
quantized_exported_program = torch.export.export(
    convert_pt2e(quantized_graph_module), example_inputs, strict=True
)

# Create partitioner that delegates the parts it can accelerate to the backend
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
    quantized_exported_program,
    partitioner=[EthosUPartitioner(compile_spec)],
)
# Create the artifact representation of the quantized model
executorch_program_manager = edge_program_manager.to_executorch(
    config=executorch.exir.ExecutorchBackendConfig(extract_delegate_segments=False)
)
# And save to disk for deployment on the Ethos-U85 target
executorch.exir.save_pte_program(
    executorch_program_manager, "conformer_quantized_ethos-u85-256.pte"
)

This end-to-end flow demonstrates how ExecuTorch 1.0 makes Ethos-U deployment production-ready, while keeping developers in a familiar and flexible PyTorch environment. For complete examples, please find the ExecuTorch documentation and the PTQ part of the ASR example.

Looking ahead: The next frontier

ExecuTorch 1.0 is just the beginning. The Arm roadmap includes:

  • Smarter quantization: Dynamic Range Quantization (DRQ) and selective quantization for better accuracy control.
  • Broader CPU acceleration: Expanded CMSIS-NN support to further optimize Cortex-M.
  • Richer examples: New use cases and hands-on tutorials to showcase ExecuTorch in practical deployments.
  • Expanded datatypes: Beyond int8, float32, and experimental int16 support, to the full range of TOSA extensions.

Together, these advances will make AI on Arm not only more efficient but also more accessible to developers everywhere. ExecuTorch 1.0 already marks a milestone, delivering efficient CPU paths, a seamless PyTorch-to-Arm workflow, and simple prototyping on emulated hardware. The future of AI at the edge is here. Head over to the ExecuTorch site to start deploying your PyTorch models with ExecuTorch today.

Further reading

  • Arm Newsroom blog.
  • ExecuTorch documentation.
  • ASR training example on Ethos-U using ExecuTorch.
  • Learning paths:
    • Visualize Ethos-U NPU performance with ExecuTorch on Arm FVPs.
    • Build an Android chat app with Llama, KleidiAI, ExecuTorch, and XNNPACK.
    • Run Llama 3 on a Raspberry Pi 5 using ExecuTorch.
    • Introduction to TinyML on Arm using PyTorch and ExecuTorch.
  • Arm Community blog post on ExecuTorch support for Arm neural technology
  • ExecuTorch Ethos-U minimal example notebook.
Anonymous
AI blog
  • Ethos-U and Beyond: How ExecuTorch 1.0 powers AI at the edge

    Per Åstrand
    Per Åstrand
    AI meets the edge: ExecuTorch 1.0 brings PyTorch performance and portability to Arm’s tiniest, most efficient devices.
    • October 22, 2025
  • Arm neural technology in ExecuTorch 1.0

    Robert Elliott
    Robert Elliott
    With the announcement of Arm neural technology, Arm is enabling neural networks and a new class of neural graphics capabilities.
    • October 22, 2025
  • ExecuTorch 1.0 is here and with SME2 optimizations through KleidiAI

    Gian Marco Iodice
    Gian Marco Iodice
    Today marks an exciting milestone with the official general availability (GA) release of ExecuTorch 1.0, a lightweight, production-ready runtime from the PyTorch ecosystem.
    • October 22, 2025