Arm Community
Arm Community
  • Site
  • User
  • Site
  • Search
  • User
  • Groups
    • Research Collaboration and Enablement
    • DesignStart
    • Education Hub
    • Innovation
    • Open Source Software and Platforms
  • Forums
    • AI and ML forum
    • Architectures and Processors forum
    • Arm Development Platforms forum
    • Arm Development Studio forum
    • Arm Virtual Hardware forum
    • Automotive forum
    • Compilers and Libraries forum
    • Graphics, Gaming, and VR forum
    • High Performance Computing (HPC) forum
    • Infrastructure Solutions forum
    • Internet of Things (IoT) forum
    • Keil forum
    • Morello Forum
    • Operating Systems forum
    • SoC Design and Simulation forum
    • 中文社区论区
  • Blogs
    • AI and ML blog
    • Announcements
    • Architectures and Processors blog
    • Automotive blog
    • Graphics, Gaming, and VR blog
    • High Performance Computing (HPC) blog
    • Infrastructure Solutions blog
    • Innovation blog
    • Internet of Things (IoT) blog
    • Operating Systems blog
    • Research Articles
    • SoC Design and Simulation blog
    • Tools, Software and IDEs blog
    • 中文社区博客
  • Support
    • Arm Support Services
    • Documentation
    • Downloads
    • Training
    • Arm Approved program
    • Arm Design Reviews
  • Community Help
  • More
  • Cancel
Arm Community blogs
Arm Community blogs
AI and ML blog Analyzing Machine Learning models on a layer-by-layer basis
  • Blogs
  • Mentions
  • Sub-Groups
  • Tags
  • Jump...
  • Cancel
More blogs in Arm Community blogs
  • AI and ML blog

  • Announcements

  • Architectures and Processors blog

  • Automotive blog

  • Embedded blog

  • Graphics, Gaming, and VR blog

  • High Performance Computing (HPC) blog

  • Infrastructure Solutions blog

  • Internet of Things (IoT) blog

  • Operating Systems blog

  • SoC Design and Simulation blog

  • Tools, Software and IDEs blog

Tags
  • Artificial Intelligence (AI)
  • Machine Learning (ML)
  • Arm Ethos-U processor
Actions
  • RSS
  • More
  • Cancel
Related blog posts
Related forum threads

Analyzing Machine Learning models on a layer-by-layer basis

George Gekov
George Gekov
October 31, 2022
7 minute read time.

Overview

When you are deploying a Machine Learning model, you may want to know how well your neural network is using the capabilities of the hardware during inference. If you target an Arm Ethos-U55 or Ethos-U65 Machine Learning accelerator, you have to optimize your model with the Vela compiler. This blog explains how to analyze a neural network on a layer-by-layer basis and builds on top of the blog post explaining how to use Vela. Let us define two example deep learning models and analyze both of them on a layer-by-layer basis. 

Model definition

Consider a Machine Learning model that takes as input a 284x284 image with 3 channels. The model has four convolutions followed by a fully connected layer in the end. 

import tensorflow as tf
import numpy as np
 
def define_model(num_chan):
        model = tf.keras.Sequential()
        model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu',
                 input_shape=(284, 284, 3)))
        model.add(tf.keras.layers.Conv2D(64,(3,3), activation='relu',strides=2))
        model.add(tf.keras.layers.Conv2D(num_chan,(3,3), activation='relu',strides=2))
        model.add(tf.keras.layers.Conv2D(32,(3,3), activation='relu',strides=2))
        model.add(tf.keras.layers.Flatten())
        model.add(tf.keras.layers.Dense(10, activation='softmax'))
        print(model.summary())
        return model

This is a simplified version of a model that can be doing image classification and the reasoning relative to per-layer profiling can be applied to any other model as well.

We are using two variants of the same model. In one case we have 90 channels in the third convolution and in the other case - 96 channels on the same convolution. The Ethos-U works on quantized input, so we will also use the TFLite converter to quantize the two networks to int8. 

model1 = define_model(90)
model2 = define_model(96)

def rep_dataset():
    for i in range(50):
        img =  np.random.rand(1,284,284,3)
        yield [img.astype(np.float32)]
 
# Model 1
converter_quant = tf.lite.TFLiteConverter.from_keras_model(model1)
converter_quant.optimizations = [tf.lite.Optimize.DEFAULT]
converter_quant.representative_dataset = rep_dataset
converter_quant.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter_quant.inference_input_type = tf.int8
converter_quant.inference_output_type = tf.int8
tflite_model = converter_quant.convert()
open('Model1.tflite', 'wb').write(tflite_model)
 
# Model 2
converter_quant2 = tf.lite.TFLiteConverter.from_keras_model(model2)
converter_quant2.optimizations = [tf.lite.Optimize.DEFAULT]
converter_quant2.representative_dataset = rep_dataset
converter_quant2.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter_quant2.inference_input_type = tf.int8
converter_quant2.inference_output_type = tf.int8
tflite_model = converter_quant2.convert()
open('Model2.tflite', 'wb').write(tflite_model)

Below you can see a visualization of Model1 after it has been quantized.

Remainder that the biggest optimization you can make to your model is to design the network with the operators supported by the Vela compiler. In this way, the microNPU accelerates the integrity of your workload. When your model maps fully to the Ethos-U, you can use Vela's per-layer estimate to find additional bottlenecks in the workload. Both Model1 and Model2 map fully to the Ethos-U(in other words, there are no fallbacks to the CPU), so let us compile them with per-layer estimation. 

Compile a network with per-layer estimate 

To analyze a network on a per-layer basis, you would need to include the --verbose-performance CLI option to your Vela command during compilation. Let us compile Model1 from the previous paragraph:

vela Model1.tflite --accelerator-config=ethos-u65-256 --config vela.ini --memory-mode=Dedicated_Sram --system-config=Ethos_U65_High_End --verbose-performance

With the previous command, Vela will map the operators of the Model1.tflite to the block-based MAC engines of an Ethos-U65 configured with 256 Multiply Accumulates per clock cycle and you obtain a per-layer estimate report. Note that Vela does not have a cycle accurate model of the microNPU under the hood. So, the numbers reported from the compiler are an estimate and they will not match exactly the performance when running your workload in silicon. Having said that, you can still use the report to identify the most compute-intensive layers and understand if your model is making good use of the capabilities of the hardware. 

Now, let us examine the per-layer report you obtain: 

Performance for NPU Subgraph main_split_1
 TFLite Operator:     NNG Operator:        SRAM Usage (  Peak%): Op Cycles (Netwrk%) [        NPU    SRAM AC    DRAM AC OnFlash AC OffFlashAC ]: MAC Count (Netwrk% /   Util%):Name:
 CONV_2D              Conv2DBias               326336 (100.00%)    2687386 ( 43.91%) [    1387150     159048    2687386          0          0 ]   68708736 (  9.72% /   9.99%) sequential/conv2d/Relu;sequential/conv2d/BiasAdd;sequential/conv2d_3/BiasAdd/ReadVariableOp;sequential/conv2d/Conv2D
 CONV_2D              Conv2DBias               326336 (100.00%)    1489121 ( 24.33%) [    1489121     419904     334507          0          0 ]  361267200 ( 51.09% /  94.77%) sequential/conv2d_1/Relu;sequential/conv2d_1/BiasAdd;sequential/conv2d_1/BiasAdd/ReadVariableOp;sequential/conv2d_1/Conv2D
 CONV_2D              Conv2DBias               155280 ( 47.58%)    1566720 ( 25.60%) [    1230143      28566    1566720          0          0 ]  246810240 ( 34.90% /  61.54%) sequential/conv2d_2/Relu;sequential/conv2d_2/BiasAdd;sequential/conv2d_2/BiasAdd/ReadVariableOp;sequential/conv2d_2/Conv2D
 CONV_2D              Conv2DBias               155280 ( 47.58%)     215668 (  3.52%) [     215668      75888      78916          0          0 ]   29963520 (  4.24% /  54.27%) sequential/conv2d_3/Relu;sequential/conv2d_3/BiasAdd;sequential/conv2d_3/BiasAdd/ReadVariableOp;sequential/conv2d_3/Conv2D
 FULLY_CONNECTED      FullyConnected                0 (  0.00%)     157491 (  2.57%) [      92730          5     157491          0          0 ]     369920 (  0.05% /   0.92%) sequential/dense/MatMul;sequential/dense/BiasAdd

What is the meaning of each column in the report? 

NNG Operator Neural network graph operator - the operator that is run on the hardware
SRAM Usage Amount of SRAM to run the operator
Peak% Percentage of the peak SRAM usage for the whole network 
Op cycles Number of estimated cycles to run the operator 
Netrwk% Percentage of Op cycles / Total Op cycle estimate
NPU Estimated number of NPU cycles
SRAM AC Estimated number of SRAM access cycles
DRAM AC Estimated number of DRAM access cycles
MAC Count Number of multiply accumulates to perform the operation
Netwrk% Percentage of MAC Count/Total number of MACs of the NN 
Util% Percentage of MAC Count/MACs if the operation used the full capability of the hardware

Analyzing a neural network on a layer-by-layer basis

First, let us focus on the per-layer report relative to Model 1. 

The first convolution is estimated to consume 43% of all cycles and is having MAC utilization of nearly 10%. This is the layer that is expected to consume the biggest number of cycles from the whole network. The Ethos-U65 hardware with 256 MAC units works efficiently when the number of output channels is a multiple of 8, as illustrated in the OFM depth column here. Note that other variants of the microNPU work efficiently with different values for the number of output channels. For example, an Ethos-U65 with 512 MAC units configured with PARALLEL_MODE = 1 would work efficiently when the number of output channels is a multiple of 16. If we were to modify the number of channels on the first convolution, the utilization of the hardware would improve. However, the first input is an image with 3 channels and we cannot change the input data to something else. 

The third convolution is estimated to consume 25% of all cycles and has utilization of 61%. As per the model definition, the third convolution is with 90 channels. The EthosU65-256 hardware works efficiently when the number of channels in a layer is a multiple of 8. In the next paragraph we will change the number of channels on that convolution to 96 and analyze the impact. 

Furthermore, note that as per the Util% column relative to MAC count, the Fully Connected layer has a low efficiency - less than 1%. That is expected because Fully Connected is highly memory bound operation. As the table shows, the expected number of DRAM access cycles is far superior to the number of NPU cycles. This is because you read each weight only once and modern neural network architectures rarely rely on Fully Connected layers. The Netwrk% column(relative to Op cycles) shows that the Fully Connected is estimated to be responsible for only 2% of all cycles of the inference. So, even if we somehow optimized the Fully Connected operation to make better use of the hardware, the net improvement will be minimal.

Note also that Model1 has a RESHAPE layer before it was compiled, but the per-layer table does not contain RESHAPE. This behavior is expected - Vela optimizes the memory-only operations such as RESHAPE by modifying the shape of the IFM/OFM of preceding or following tensors. 

Now, let us pass to the second model. We compile it with the following command.

vela Model2.tflite --accelerator-config=ethos-u65-256 --config vela.ini --memory-mode=Dedicated_Sram --system-config=Ethos_U65_High_End --verbose-performance

And obtain the per-layer report: 

Performance for NPU Subgraph main_split_1
 TFLite Operator:     NNG Operator:        SRAM Usage (  Peak%): Op Cycles (Netwrk%) [        NPU    SRAM AC    DRAM AC OnFlash AC OffFlashAC ]: MAC Count (Netwrk% /   Util%):Name:
 CONV_2D              Conv2DBias               326352 (100.00%)    2687386 ( 43.91%) [    1387150     159048    2687386          0          0 ]   68708736 (  9.47% /   9.99%) sequential_1/conv2d_4/Relu;sequential_1/conv2d_4/BiasAdd;sequential_1/conv2d_4/BiasAdd/ReadVariableOp;sequential_1/conv2d_4/Conv2D
 CONV_2D              Conv2DBias               326352 (100.00%)    1489121 ( 24.33%) [    1489121     419904     334507          0          0 ]  361267200 ( 49.79% /  94.77%) sequential_1/conv2d_5/Relu;sequential_1/conv2d_5/BiasAdd;sequential_1/conv2d_5/BiasAdd/ReadVariableOp;sequential_1/conv2d_5/Conv2D
 CONV_2D              Conv2DBias               150496 ( 46.11%)    1566720 ( 25.60%) [    1230984      28566    1566720          0          0 ]  263264256 ( 36.28% /  65.64%) sequential_1/conv2d_6/Relu;sequential_1/conv2d_6/BiasAdd;sequential_1/conv2d_6/BiasAdd/ReadVariableOp;sequential_1/conv2d_6/Conv2D
 CONV_2D              Conv2DBias               150496 ( 46.11%)     216103 (  3.53%) [     216103      75888      78916          0          0 ]   31961088 (  4.40% /  57.77%) sequential_1/conv2d_7/Relu;sequential_1/conv2d_7/BiasAdd;sequential_1/conv2d_4/BiasAdd/ReadVariableOp;sequential_1/conv2d_7/Conv2D
 FULLY_CONNECTED      FullyConnected                0 (  0.00%)     157478 (  2.57%) [      92730          5     157478          0          0 ]     369920 (  0.05% /   0.92%) sequential_1/dense_1/MatMul;sequential_1/dense_1/BiasAdd

When using 96 channels, we improve the efficiency of the third convolution to 65%. What is more, the estimated cycle count for the third convolution has remained unchanged despite increasing the number of channels of the model. Why is that ? The convolution dimensions need to get rounded up to fit the 256 MACs of the microNPU and when using 90 and 96 channels the dimensions get rounded up to the same closest multiple of 8. In other words, we are able to process more data on the Ethos-U65-256 hardware with minimal penalty on the total cycle count. Equally, as the output from the third convolution is fed into the last convolution, we also improve the efficiency of the fourth convolution. More generally, you can use the convolution performance table to check the optimal shape of the Feature Maps of your model. 

Conclusion

In this blog, we analyzed two deep learning models on a layer-by-layer basis. We identified the layer that takes the biggest number of cycles. We also made a change to the model to improve the overall use of the hardware. Vela's per-layer report is a powerful tool to analyze the MAC engine utilization of every layer of a neural network. Once your model maps fully to the Ethos-U, you can achieve a fine-grained understanding of the bottlenecks of your workload and make changes according to the specificities of your model.  

Anonymous
AI and ML blog
  • Analyzing Machine Learning models on a layer-by-layer basis

    George Gekov
    George Gekov
    In this blog, we demonstrate how to analyze a Machine Learning model on a layer-by-layer basis.
    • October 31, 2022
  • How audio development platforms can take advantage of accelerated ML processing

    Mary Bennion
    Mary Bennion
    Join DSP Concepts and Alif Semiconductor at Arm DevSummit 2022 to discuss ML techniques commonly used for audio. Discover the features and benefits of the Audio Weaver platform.
    • October 24, 2022
  • How to Deploy PaddlePaddle on Arm Cortex-M with Arm Virtual Hardware

    Liliya Wu
    Liliya Wu
    This blog introduces how to deploy a PP-OCRv3 English text recognition model on Arm Cortex-M55 processor with Arm Virtual Hardware.
    • August 31, 2022