Currently, if you want to deploy your machine learning model on Arm Ethos-U55 or Arm Ethos-U65 you must use TensorFlow Lite for Microcontrollers. Doing so requires you have your model in TensorFlow Lite format first. If you have trained your model using TensorFlow, then the process to convert to TensorFlow Lite is already well documented.
However, it might be the case that you have trained your model using PyTorch instead, which does not natively allow you to export to TensorFlow Lite format. This guide shows how it is possible to convert your trained PyTorch model to TensorFlow Lite via ONNX. During conversion to TensorFlow lite we will also quantize it, before finally optimizing it with Vela ready for deploying on Arm Ethos-U55 or Arm Ethos-U65.
Note: Although this guide shows how to convert a PyTorch trained model to TensorFlow Lite, we recommend you use TensorFlow for model training if possible. This avoids any translation errors that can occur when converting from PyTorch and you should also see better performance and layer support.
In this guide we will use PyTorch to train a small convolutional neural network, perform optimizations on the graph and then export it to ONNX format. Next, the ONNX format model will be converted to TensorFlow saved model format, before finally being loaded into TensorFlow for quantization and conversion to TensorFlow Lite. Lastly Vela is used to optimize the model ready for deploying on Arm Ethos-U55 or Arm Ethos-U65.
Figure 1. Flow of operations to get your model from PyTorch and ready for Arm Ethos-U55 or Arm Ethos-U65.
The complete code sample, that you can run from the command line, is available to download from here: https://github.com/ARM-software/ML-examples/tree/master/pytorch-to-tflite.
Make sure you have Python3 installed on your machine.
Running the following commands creates a Python environment and install the libraries necessary to run the code sample.
$ python3 -m venv env $ source env/bin/activate $ pip3 install –-upgrade pip3 $ pip3 install torch==1.12.1+cpu torchvision==0.13.1+cpu torchaudio==0.12.1 -f https://download.pytorch.org/whl/torch_stable.html $ pip3 install onnx==1.9.0 $ pip3 install tensorflow==2.4.1 $ pip3 install onnx-tf==1.9.0 $ pip3 install ethos-u-vela==3.0.0
1. Use the following code to define a simple convolutional neural network in PyTorch that is going to train on the CIFAR10 dataset:
# A small convolutional network to test PyTorch to TFLite conversion. class SimpleNetwork(nn.Module): def __init__(self): super(SimpleNetwork, self).__init__() self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(3, 3)) self.bn1 = nn.BatchNorm2d(16) self.conv2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3)) self.bn2 = nn.BatchNorm2d(16) self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3)) self.bn3 = nn.BatchNorm2d(32) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=(1, 1)) # Test padding conversion. self.bn4 = nn.BatchNorm2d(64) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # An affine operation: y = Wx + b self.conv5 = nn.Conv2d(in_channels=64,out_channels=10, kernel_size=(6, 6)) # Feature size is 6*6 here. self.softmax = nn.Softmax() def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = self.pool1(x) x = F.relu(self.bn3(self.conv3(x))) x = F.relu(self.bn4(self.conv4(x))) x = self.pool2(x) x = self.conv5(x) x = x.view(-1, 10) x = self.softmax(x) return x
In this example, the model is trained on the CIFAR10 dataset, so this dataset is loaded and a basic training loop is then used to train the model. Refer to the complete code sample to see how this is implemented.
2. After the model is trained, put it into evaluation mode, so it is ready for exporting using the following code:
model.eval()
We also test the accuracy of the model at this point so we can compare later after conversion. With only two epochs of training, we obtain around 65% accuracy – much better than random chance as we can see in the following output:
$ Training finished... $ Accuracy of PyTorch model on test set: 65.97%
3. As we used batch normalization layers in our model, one optimization we can do is to fold or fuse these layers into the preceding convolution operation. Folding or fusing can be done by calling torch.quantization.fuse_modules on a list of layer names in the model that can be fused together, like in the following code:torch.quantization.fuse_modules(model, [[‘conv1’, ‘bn1’], [‘conv2’, ‘bn2’], [‘conv3’, ‘bn3’], [‘conv4’, ‘bn4’]], inplace=True)
torch.quantization.fuse_modules(model, [[‘conv1’, ‘bn1’], [‘conv2’, ‘bn2’], [‘conv3’, ‘bn3’], [‘conv4’, ‘bn4’]], inplace=True)
The batch normalization layers are now fused into their preceding convolution layer. Doing this helps to reduce the amount of unnecessary computation carried out by the network, reduces the number of weights, and helps make conversion smoother.
4. Next, we export our trained PyTorch model to ONNX format. ONNX is a model exchange format focused on inferencing that acts as an intermediate format between different neural network frameworks. PyTorch natively supports exporting to ONNX format.
We can export our model to ONNX by calling the following function:
torch.onnx.export(model, test_input, “model.onnx”, input_names=[‘input’], output_names=[‘output’])
To export, we must provide some sample input data so that PyTorch can identify exactly what parts of the graph need to be exported. The export function also allows us to change the names of input and output nodes in the model if we wish.
5. After exporting to ONNX format, we can inspect the model file with Netron and see that it is virtually identical to the original PyTorch model. Note that the identity operations in the PyTorch model are left over from fusing batch normalization layers. In the ONNX model, the input and output nodes now have the new names we assigned when exporting.
Figure 2: PyTorch model on the left and its exported ONNX version on the right
Once the model is in ONNX format, we can use ONNX and the available ONNX converters to load and convert the model to TensorFlow format.
6. Load the ONNX model, prepare it to be converted to TensorFlow and then save to it file in the TensorFlow saved model format using the following code:
onnx_model = onnx.load(“model.onnx”) tf_rep = onnx_tf.backend.prepare(onnx_model, device=’CPU’) tf_rep.export_graph(“model_tf”)
We choose to set the device to ‘CPU’ to force operations to be in NHWC format which is required by TensorFlow Lite.
7. Load our model into TensorFlow using the TFLite converter now that the model is in TensorFlow Save model format, by using the following code:
converter = tf.lite.TFLiteConverter.from_saved_model(“model_tf”)
Note: The Arm Ethos-U55 and Arm Ethos-U65 are designed to accelerate neural network inferences and support only 8-bit weights with 8-bit or 16-bit activations. Therefore, to take advantage of the Arm Ethos-U55 or Arm Ethos-U65, your model must be quantized from 32-bit floating-point to 8-bit fixed-point format.
8. After loading the converted model with the TFLiteConverter, we perform post-training quantization. The result of this process is a model with both weights and activations fully quantized, so that it can be deployed on Arm Ethos-U55 or Arm Ethos-U65. The following code shows how to perform this quantization:
def rep_dataset(): """Generator function to produce representative dataset for post-training quantization.""" # Use a few samples from the training set. for _ in range(100): img = iter(train_loader).next()[0].numpy() img = [(img.astype(np.float32))] yield img converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = rep_dataset converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_model = converter.convert()
9. Save the quantized model to file, using the following code:
open(‘model.tflite’, ‘wb’).write(tflite_model)
10. Test the accuracy of the model again because it can change slightly after quantization. We can see that in this case quantizing our model has done little to change the accuracy of our model. This test also confirms to us that the conversion process has not affected the quality of the model either. The following shows the reported accuracy after verifying the model on the test set:
$ Testing quantized TFLite model... $ Accuracy of quantized TFLite model on test set: 66.02%
If we inspect the resulting TensorFlow Lite file with Netron, we can see that a transpose operator has been added at the start of our graph. PyTorch uses a channel first (NCHW) data layout for its operations, while TensorFlow and TensorFlow Lite primarily use channel last (NHWC). The converter maintains the same input shape as the original PyTorch model so that the same input data can be reused in the converted model without alteration. As a result, the converter adds a transpose or reshape operation so that the converted convolution operations can correctly work with this data.
Figure 3. Transpose operation added when input channels is greater than one, and a reshape operation added otherwise.
Depending on the shape of the model outputs, the converter can also add a transpose operation at the end of the model. This operation returns the output shape to the original PyTorch model output shape.
Also depending on your model, you may notice the converter adds transpose operations before and after certain layers in the TensorFlow Lite model. These additions are also a result of the conversion process - and because certain operators are still expecting NCHW inputs even after conversion.
For example, these additions can happen when using some activation functions such as ReLU6. The resulting TensorFlow Lite graph looks like the following:
Figure 4. ReLU6 operation wrapped in transpose operators after conversion to TensorFlow Lite.
One notable addition to the TensorFlow Lite graph is the addition of padding operations before convolution layers. TensorFlow has the concept of ‘VALID’ and ‘SAME’ padding, while PyTorch traditionally only allows for explicit padding to be used. Newer versions of PyTorch support 'VALID' and 'SAME' padding, however your model can not currently be exported to ONNX if you use these options. When using explicit padding the conversion process maintains this seperate padding when converting to TensorFlow, even if a ‘SAME’ padding in a convolution layer might be equivalent.
Figure 5. Explicit padding operation added in before a convolution layer after conversion to TensorFlow Lite.
Fortunately, if we deploy the model on Arm Ethos-U55 or Arm Ethos-U65, these added pad operations will be fused again after running the model through Vela.
Note: Depending on the version of TensorFlow used when converting to TensorFlow Lite, you can see that the converter adds transpose operations around these pad operations as well. In this case, no optimization can occur.
Due to current limitations of the conversion process, the converted TensorFlow Lite model will likely be less than optimal. When compared to a model natively trained TensorFlow and then converted to TensorFlow Lite, it can be slower to perform inference. The main cause for this is the additional transpose operations that the converter can add to the graph when converting between ONNX and TensorFlow.
We can see this more concretely by benchmarking the model converted from PyTorch against the same model architecture trained natively in TensorFlow. One way to do this is using an MPS3 FPGA board loaded with an Arm Ethos-U55 and Arm Cortex-M55 bitfile. You can then profile inference speeds of the models using the Arm ML Embedded Evaluation Kit’s Generic Inference Runner application.
The following table compares the model converted from PyTorch to a natively trained TensorFlow one. The PyTorch model has an order of magnitude high CPU cycle counts, while they both have similar NPU cycle counts. The added transpose operator would account for this increase in CPU cycle counts as it has to run on the CPU.
This extra CPU operation also shows up in the additional wall clock inference times for the converted model. As the model used in this guide is very small, a large proportion of the total wall clock time is spent on this transpose operation. For larger and deeper models, we would hope to see the proportion of time spent on this operation to vastly decrease.
One other limitation is that the ONNX to TensorFlow converter does not focus on making your model fully compatible with TensorFlow Lite. As a result, you may come across occasions that a model can convert to TensorFlow but fails to convert to TensorFlow Lite.
Unfortunately, there seems to be no quick solution to many of these limitations in the conversion process. Until current tooling improves, they are an unavoidable consequence of converting a PyTorch model to TensorFlow Lite. The best advice to overcome these issues is to train your model natively in TensorFlow and convert to TensorFlow Lite from there.
After quantizing the TensorFlow Lite file, we can run it through the Arm Vela compiler to optimize the model for deployment on Arm Ethos-U55 or Arm Ethos-U65.
This optimization can be done from the command line by running the following command:
$ vela model.tflite
We use the default parameters for optimization and Ethos-U choice, but this is okay just to see the resulting model after Vela has optimized it. See https://pypi.org/project/ethos-u-vela/ for more information on the different command line options that you can use with Vela.
Figure 6. The converted TensorFlow Lite model after being run through the Vela compiler.
In the output folder that is created, we find the Vela optimized model in TensorFlow Lite format. Inspecting the optimized model, we can see that every node, except for the transpose layer, has been compiled into one ethos-u operator. This means that all those layers can be successfully run and accelerated on Arm Ethos-U55 or Arm Ethos-U65 NPU. The transpose layer is instead run on the accompanying Arm Cortex-M CPU.
By following this guide, you should have successfully trained and converted a PyTorch model to TensorFlow format. The model was then quantized in TensorFlow and converted to TensorFlow Lite format, before finally being run through Vela ready to deploy on Arm Ethos-U55 or Arm Ethos-U65.
We have also highlighted some of the limitations of this conversion process, the main one being the additional transpose operations introduced. Additionaly, we saw how these additional operations can affect final performance when running on Arm Ethos-U55 or Arm Ethos-U65.
Ultimately if you have the choice, we highly recommend you train your model natively using the TensorFlow framework and convert to TensorFlow Lite from there. This should result in the most optimal model when you deploy it, and present you with the least problems along the way. However, if that is not possible then, as we have shown in this guide, conversion from PyTorch may be possible to do and produce a usable model.
Try it for yourself with the Jupyter Notebook found here:
[CTAToken URL = "https://github.com/ARM-software/ML-examples/blob/master/pytorch-to-tflite/PyTorch_to_TensorFlow_Lite.ipynb" target="_blank" text="PyTorch to TenorFlow Lite Jupyter Notebook" class ="green"]
It is possible to remove the first transpose layer from your converted model by wrapping up your PyTorch model. First create a new model class that acts as the wrapper:
class NetPermuteWrapper(nn.Module): def __init__(self, model_to_wrap): super(NetPermuteWrapper, self).__init__() self.model_to_wrap = model_to_wrap def forward(self, x): x = torch.permute(x, (0, 3, 1, 2)) # Permute input from NHWC to NCHW.
Our model wrapper will expect to receive NHWC format input data which will then get permuted, before feeding it to the rest of our original model. After you have trained your model and are ready to export it to ONNX, create an instance of this new model wrapper by passing in your current PyTorch model. Then, pass this wrapped model object through to PyTorch's export functions, instead of the original model:
wrapped_model = NetPermuteWrapper(model) torch.onnx.export(wrapped_model, test_input, “model.onnx”, input_names=[‘input’], output_names=[‘output’])
Remember to also transpose your test input data, as input is now NHWC format instead of NCHW. You can now continue the rest of the steps in the guide as before. After converting and generating your TFLite file, the input will be NHWC format and the transpose operation you had before has disappeared.
Since this blog was originally posted there have been several other projects that look to convert a PyTorch model to TensorFlow Lite format. One such project that seems promising and overcomes some of the issues mentioned in the blog post is TinyNeuralNetwork by Alibaba. If the ONNX method method mentioned in this blog is not good enough for your needs then it might be worth giving this tool a try as well.
Is there future plan considering using pytorch model used on U55? Converting the model is so boring.