This blog post is about sharing our experience in running PyTorch Mobile with NNAPI on various mobile devices. I hope that this will provide developers with a sense of how these models are executed on mobile devices through PyTorch with NNAPI.
On-device machine learning (ML) enables low latency, better power efficiency, robust security, and new use cases for the end user. Currently, there are several ways to run inference on mobile, with many developers wondering which one they should use.
The Android Neural Networks API (NNAPI) is designed by Google for running computationally intensive operations for ML on Android mobile devices. It provides a single set of APIs to benefit from available hardware accelerators, including GPUs, DSPs, and NPUs. At Arm, we fully support this development. It provides the ability to target a wide array of accelerators and IP available from the Arm ecosystem. For example, the Arm Cortex-A CPUs and Mali GPUs, are available via our inference engine called Arm NN. Arm NN translates networks to the internal Arm NN format and then deploys them efficiently on the underlying IP.
Figure 1. System architecture for Android Neural Networks API from Android Developer blog on Neural Networks
The NNAPI can be accessed directly via an Android C API or higher-level frameworks such as TensorFlow Lite. Recently, PyTorch Mobile announced a new prototype feature supporting NNAPI that enables developers to use hardware accelerated inference with the PyTorch framework. This will make it easier than ever for PyTorch developers to benefit from significant performance improvements. (GPU support on Android via Vulkan has been announced at the same time. Also, runtime binary size reduction via Lite interpreter has been added.)
Figure 2. PyTorch Mobile supports Android NNAPI (From: https://medium.com/pytorch/pytorch-mobile-now-supports-android-nnapi-e2a2aeb74534)
PyTorch is a very popular ML framework, especially among researchers due to its wide range of supported operators and ease of writing. But not only that, they have been working to bridge the gap between research and production as well. As part of an effort to make a seamless transition from training a model to deploying it in a mobile environment, they introduce PyTorch Mobile. PyTorch Mobile provides an end-to-end workflow that simplifies the research to production environment, while staying entirely within the PyTorch ecosystem.
One essential step before the model can be used on mobile devices is to convert the Python-dependent model to TorchScript format. TorchScript is an intermediate representation of a PyTorch model that can then be run in a high-performance environment, such as C++. TorchScript format includes code, parameters, attributes, and debug information.
PyTorch Mobile used to run only on CPU, but now using the NNAPI makes it easier to utilize hardware acceleration. In order to run a PyTorch model with NNAPI support, you need to convert an ordinary TorchScript into NNAPI-compatible TorchScript. Then the model can be loaded and run in an application using PyTorch Mobile’s Java API or libtorch C++ API. For applications already using PyTorch Mobile, no code changes are required. Developers simply need to replace their TorchScript model with an NNAPI-compatible model.
PyTorch provides a tutorial for converting the well-known classification model MobileNetV2 to use the Android NNAPI. We followed this tutorial and executed the models for our experiments. The following figure summarizes the flow of creating an ordinary TorchScript models (CPU models) and an NNAPI-compatible TorchScript models (NNAPI models).
Figure 3. PyTorch Mobile conversion flow
First, we need to prepare a normal Python-dependent PyTorch model. The sample code uses MobileNetV2 model from torchvision. Then we generate a non-quantized model (Float32) and a quantized model (Int8) to create several types of models. Following this, we convert those models to TorchScript format. We end up with four models listed in the flow. For more information, please refer to Appendix of this blog post.
The next step is to benchmark these models. PyTorch also provides a benchmarking script to measure your model’s performance. You can easily measure the execution speed of your model by using this script.
The following graph shows the speed increase of the NNAPI models on one mobile device. This result is the average time for 200 runs. As you can see, the models using NNAPI run about 25-30% faster for both Float32 and Int8 compared with the CPU models.
Figure 4. MobileNetV2 computational speed-up in different models from one mobile device
Next, let us look at how the hardware in the mobile device works when running the NNAPI models. Arm Streamline allows you to see CPU and GPU activity inside the device.
The screenshots below show the CPU/GPU activity on another mobile. The screenshot on the left is the one running the NNAPI model with Float32, and the one on the right is with Int8. You can see that GPU is used for the model with Float32, while the model with Int8 uses a multi-core CPU instead. With the NNAPI, ML frameworks such as PyTorch query a mobile device for available hardware and select the most performant hardware for each operation. It can fall back to the default CPU implementation from Google for operations that are not supported. The support status of NNAPI varies from one mobile device to another, but it is expected that mobile manufacturers will support it more and more in the future.
Figure 5. Differences in CPUs/GPU activity depending on model accuracy when running via NNAPI
Let us look at another example. The following is an example of the same models running on a different mobile device. Here, the model with Float32 uses GPU as before, but the model with Int8 does not show any CPU or GPU activity. This suggests that another hardware accelerator which is not visible with Streamline. As you can see, the same model can work differently on different devices by using the NNAPI.
Figure 6. Different NNAPI behavior on another mobile device
Finally, I would like to share the insights we gained by running these models on various mobile devices. The following table summarizes which hardware is selected when the models are executed on the mobile devices we experimented with (8 in total). This shows that the GPU is used in all mobile devices to accelerate the models with Float32. On the other hand, the hardware used depends on the mobile device for the models with Int8. Strictly speaking, the hardware is selected for each operation, but all operations can be executed on the listed hardware for MobileNetV2.
Table 1. Hardware selected via NNAPI
With this in mind, let us look at the graph below. This graph shows how much faster the NNAPI models can be for Float32 and Int8. The speed increase varies from device to device. We can also see that Int8 tends to have higher speed increases since quantized models are generally more likely to benefit from hardware accelerators. Please note that these mobile devices include various generations. This means there are some differences in their device hardware and its performance. , I would like to emphasize that many of these mobile devices have the models accelerated by underlying Arm NN.
Figure 7. Speed-up via NNAPI on various mobile devices
In this blog post, we have seen how PyTorch with support for NNAPI works. Also, we have seen how the hardware used via NNAPI changes depending on the mobile device. Support for NNAPI from both the mobile device side and the ML framework side will progress in the future. This will enable developers to run models without having to consider differences between mobile devices. From a developer's point of view, this means that a single model is all that is needed to bring out the full potential of each mobile device. It makes it easier than ever to deploy PyTorch models to mobile devices with high performance.
On the other hand, converting ML models between frameworks or APIs is not always easy. We will follow the developments and continue to work with our partners to optimize performance on devices that support NNAPI and Arm NN.
I hope this gives you an idea of how ML models work with the NNAPI and you will give this feature a try! Also, you can learn more about underlying Arm NN here: https://developer.arm.com/ip-products/processors/machine-learning/arm-nn.
In this appendix, the detailed steps to prepare PyTorch models with NNAPI support and benchmark them will be explained.
The tutorial provided by PyTorch is currently incompatible with the latest trunk, and also the version specified in the tutorial (torch==1.8.0.dev20201106+cpu, torchvision==0.9.0.dev20201107+cpu) is not available from https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html. So, you need to revert the PyTorch github repo to the commit specified below, and build from source until the official release is made public.
Code 1. Install procedure of PyTorch and torchvision
# Setup virtual env and install dependencies python3 -m venv .venv source .venv/bin/activate pip3 install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses export USE_CUDA=0 # Install PyTorch git clone https://github.com/pytorch/pytorch cd pytorch # revert to 201106 version git reset --hard c19eb4ad73ebf16c7dc73229729ed95692472f6e git submodule sync git submodule update --init --recursive python3 setup.py install # Install torchvision git clone https://github.com/pytorch/vision # revert to 20201107 version git reset --hard 052edcecef3eb0ae9fe9e4b256fa2a488f9f395b python3 setup.py install
After installing PyTorch, the models can be generated by running the model preparation script. When you run that script, the models are created under $HOME/mobilenetv2-nnapi directory.
Code 2. Python script for model preparation
#!/usr/bin/env python import sys import os import torch import torch.utils.bundled_inputs import torch.utils.mobile_optimizer import torch.backends._nnapi.prepare import torchvision.models.quantization.mobilenet from pathlib import Path # This script supports 3 modes of quantization: # - "none": Fully floating-point model. # - "core": Quantize the core of the model, but wrap it a # quantizer/dequantizer pair, so the interface uses floating point. # - "full": Quantize the model, and use quantized tensors # for input and output. # # "none" maintains maximum accuracy # "core" sacrifices some accuracy for performance, # but maintains the same interface. # "full" maximized performance (with the same accuracy as "core"), # but requires the application to use quantized tensors. # # There is a fourth option, not supported by this script, # where we include the quant/dequant steps as NNAPI operators. def make_mobilenetv2_nnapi(output_dir_path, quantize_mode): quantize_core, quantize_iface = { "none": (False, False), "core": (True, False), "full": (True, True), }[quantize_mode] model = torchvision.models.quantization.mobilenet.mobilenet_v2(pretrained=True, quantize=quantize_core) model.eval() # Fuse BatchNorm operators in the floating point model. # (Quantized models already have this done.) # Remove dropout for this inference-only use case. if not quantize_core: model.fuse_model() assert type(model.classifier[0]) == torch.nn.Dropout model.classifier[0] = torch.nn.Identity() input_float = torch.zeros(1, 3, 224, 224) input_tensor = input_float # If we're doing a quantized model, we need to trace only the quantized core. # So capture the quantizer and dequantizer, use them to prepare the input, # and replace them with identity modules so we can trace without them. if quantize_core: quantizer = model.quant dequantizer = model.dequant model.quant = torch.nn.Identity() model.dequant = torch.nn.Identity() input_tensor = quantizer(input_float) # Many NNAPI backends prefer NHWC tensors, so convert our input to channels_last, # and set the "nnapi_nhwc" attribute for the converter. input_tensor = input_tensor.contiguous(memory_format=torch.channels_last) input_tensor.nnapi_nhwc = True # Trace the model. NNAPI conversion only works with TorchScript models, # and traced models are more likely to convert successfully than scripted. with torch.no_grad(): traced = torch.jit.trace(model, input_tensor) nnapi_model = torch.backends._nnapi.prepare.convert_model_to_nnapi(traced, input_tensor) # If we're not using a quantized interface, wrap a quant/dequant around the core. if quantize_core and not quantize_iface: nnapi_model = torch.nn.Sequential(quantizer, nnapi_model, dequantizer) model.quant = quantizer model.dequant = dequantizer # Switch back to float input for benchmarking. input_tensor = input_float.contiguous(memory_format=torch.channels_last) # Optimize the CPU model to make CPU-vs-NNAPI benchmarks fair. model = torch.utils.mobile_optimizer.optimize_for_mobile(torch.jit.script(model)) # Bundle sample inputs with the models for easier benchmarking. # This step is optional. class BundleWrapper(torch.nn.Module): def __init__(self, mod): super().__init__() self.mod = mod def forward(self, arg): return self.mod(arg) nnapi_model = torch.jit.script(BundleWrapper(nnapi_model)) torch.utils.bundled_inputs.augment_model_with_bundled_inputs( model, [(torch.utils.bundled_inputs.bundle_large_tensor(input_tensor),)]) torch.utils.bundled_inputs.augment_model_with_bundled_inputs( nnapi_model, [(torch.utils.bundled_inputs.bundle_large_tensor(input_tensor),)]) # Save both models. model.save(output_dir_path / ("mobilenetv2-quant_{}-cpu.pt".format(quantize_mode))) nnapi_model.save(output_dir_path / ("mobilenetv2-quant_{}-nnapi.pt".format(quantize_mode))) if __name__ == "__main__": for quantize_mode in ["none", "core", "full"]: make_mobilenetv2_nnapi(Path(os.environ["HOME"]) / "mobilenetv2-nnapi", quantize_mode)
Code 3. Run the model preparation script
mkdir ~/mobilenetv2-nnapi python3 prepare_model.py
The next step is to build a benchmarking program to measure the performance of the model.
Code 4.Build benchmarking program
mv <your-root-pytorch-dir> rm -rf build_android ANDROID_NDK=$NDK ANDROID_NATIVE_API_LEVEL=29 BUILD_PYTORCH_MOBILE=1 \ ANDROID_ABI=arm64-v8a ./scripts/build_android.sh -DBUILD_BINARY=ON
Once you have done that, you can try to run the built program on your mobile device.
Code 5. Run benchmarking program on a mobile device
adb connect <your-mobile-device> adb push <pytorch-dir>/build_android/bin/speed_benchmark_torch \ /data/local/tmp adb push $HOME/mobilenetv2-nnapi/mobilenetv2-quant_* /data/local/tmp adb shell /data/local/tmp/speed_benchmark_torch --pthreadpool_size=1 \ --model=/data/local/tmp/mobilenetv2-quant_full-nnapi.pt \ --use_bundled_input=0 --warmup=5 --iter=200
If you get a message like the one below, you have succeeded. When you run the script, it measures and presents the execution time required to run the model and how many times per second the model can be executed. You can refer to this page for more information about the benchmarking program.
Output 1. Benchmark console output example
Starting benchmark. Running warmup runs. Main runs. Main run finished. Microseconds per iter: 36012.7. Iters per second: 27.768
To profile a mobile device’s internal behavior in more detail, you can use Streamline.
Code 6. Profile with Streamline
# Push gatord to a mobile device adb push <armds-dir>/streamline/bin/arm64/gatord /data/local/tmp # Run gatord with --app option for command line programs adb shell /data/local/tmp/gatord \ --app /data/local/tmp/speed_benchmark_torch --pthreadpool_size=1 \ --model=/data/local/tmp/mobilenetv2-quant_full-nnapi.pt \ --use_bundled_input=0 --warmup=5 --iter=200 # Launch Streamline and follow below: # https://developer.arm.com/documentation/101813/0702/Application-profiling-on-an-Android-device/Profile-your-application