Deep Neural Networks (DNNs) are trained using 32-bit IEEE single-precision to represent the floating-point model weights and activation tensors. This compute budget is typically acceptable during training by taking advantage of CPU and especially GPUs with large compute capabilities. However, these models are often required to run on systems with limited memory and compute capability such as embedded devices. Therefore, running a DNN inference on resource-constrained embedded systems using 32-bit representation is not always practical due to the massive number of multiply accumulate operations (MACs) required.
TensorFlow Lite is an open-source deep learning framework enabling on-device machine learning. It allows you to convert a pre-trained TensorFlow model into a TensorFlow Lite flat buffer file (.tflite) which is optimized for speed and storage. During conversion, optimization techniques can be applied to accelerate an inference and reduce model size.
TensorFlow Model Optimization Toolkit provides optimization techniques such as quantization, pruning, and clustering that are compatible with TensorFlow Lite. Based on the optimization technique, the complexity and the size of the model can be reduced which results in less memory usage, smaller storage size and download size.
Also, optimization is necessary for some hardware accelerators such as Arm Ethos-U microNPU as it performs calculation in 8-bit integer precision. So, to deploy any model on Arm Ethos-U microNPU, it is required to first be optimized.
Ethos-U55 is a first generation microNPU designed to accelerate neural networks inference in a low-area with low-power consumption. Paired with a Cortex-M processor, Ethos-U microNPU deliver up to a 480x ML performance uplift compared to previous Cortex-M generations. Its configurability allows developers to target a wide range of AI applications with:
Quantization reduces the precision of model's parameter values (that being, weights). By default, are 32-bit floating point numbers. This results in a smaller model size and faster computation, often with minimal or no loss in accuracy. However, depending on a model architecture and a quantization method, the impact on the accuracy may be significant. Therefore, the trade-off between model accuracy and size should be considered during the application development process.
Quantization can take place during model training or after model training. Based on that, quantization is classified into two principal techniques:
You can quantize a trained 32-bit float TensorFlow model during conversion into a TensorFlow Lite model using post-training integer quantization techniques. Post-training integer quantization not only increases inferencing speed on microcontrollers but also is compatible with fixed-point hardware accelerators such as Arm Ethos-U microNPUs. It converts models’ parameters from 32-bit floating points to nearest 8-bit fixed-point numbers while getting reasonable quantized model accuracy with 3-4x reduction in model size.
There are two modes of post-training integer quantization:
Quantizing using integer-only converts weights, variables, input, and output tensors to integers. However, int16 activations could result in better accuracy at expense of slower compute times, while maintaining nearly the same model size as int8. Some examples of models that benefit from 16x8 quantization mode of the post-training quantization include:
How does full integer quantization work?
For full integer quantization, first the weights of the model are quantized to 8bit integer values. Then the variable tensors such as layer activations are quantized. To calculate the potential range of values that all these tensors can take, a small subset of training or validation data is required. This representative dataset can be made using the following representative_data_gen() generator function.
def representative_data_gen(): for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100): yield [input_value]
Model inference is then performed using this representative dataset to calculating minimum and maximum values for variable tensors.
Integer with float fallback: To convert float32 activations and model weights into int8 and use float operators for those that have not an integer implementation, use the following snipped code:
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.optimizations =[tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset tflite_quant_model = converter.convert()
Alternatively, to quantize the model to 16x8 quantization mode:
Setting the optimizations flag to use default optimization and then specify 16x8 quantization mode in the target specification as follow.
converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.target_spec.supported_ops = [tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8] converter.representative_dataset = representative_dataset tflite_quant_model = converter.convert()
As a result, optimized operators and unoptimized operators are replaced with supported quantized operators and unsupported operators respectively.
Integer only: Enforcing full integer quantization for all operations including input and output and return error for those operations that cannot quantize:
converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_data_gen # Ensure that if any ops can't be quantized, the converter throws an error converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] # Set the input and output tensors to int8 (APIs added in r2.3) converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_model_quant = converter.convert()
With these post-training optimization methods, it is important to ensure that our model design is as efficient and optimized as possible while maintaining zero or minimum accuracy loss.
You will have a fully quantized TensorFlow Lite model after post-training integer quantization which is compatible with integer-based devices such as Ethos-U. However, to deploy your model on a system using an Ethos-U NPU, the quantized TensorFlow Lite file should be compiled with Vela for further optimization. #Compile NN #Arm Ethos NPU #optimize
You can learn how to run your model through the Vela optimizer on the Vela page.
The accuracy of the model can drop as we move to lower precision (for example, 8-bit) from 32-bit float using post-training quantization. However, for minimum or even zero accuracy loss in critical applications such as security systems, quantization-aware training technique may be required.
Quantization-aware training simulates inference-time quantization errors during training, so the model learns robust parameters around that loss. Quantized error is the error associated with quantizing of the weights and activations to lower precision and then converting back to 32-bit float. Note that quantization is only simulated in the forward pass to induce the quantization error while the backward pass remains the same and only floating-point weights are updated.
Define a model and applying quantization aware training to the trained model:
quantize_model = tfmot.quantization.keras.quantize_model # q_aware stands for for quantization aware. q_aware_model = quantize_model(model) # `quantize_model` requires a recompile. q_aware_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) q_aware_model.summary()
You can train the model for an epoch with quantization aware training only on a subset of training data and evaluate the model against baseline: train_input_subset = train_input[0:1000] train_labels_subset = train_ labels[0:1000] q_aware_model.fit(train_input_subset, train_labels_subset, batch_size=500, epochs=1, validation_split=0.1) _, baseline_model_accuracy = model.evaluate( test_images, test_labels, verbose=0) _, q_aware_model_accuracy = q_aware_model.evaluate( test_images, test_labels, verbose=0) print('Baseline test accuracy:', baseline_model_accuracy) print('Quant test accuracy:', q_aware_model_accuracy
train_input_subset = train_input[0:1000] train_labels_subset = train_ labels[0:1000] q_aware_model.fit(train_input_subset, train_labels_subset, batch_size=500, epochs=1, validation_split=0.1) _, baseline_model_accuracy = model.evaluate( test_images, test_labels, verbose=0) _, q_aware_model_accuracy = q_aware_model.evaluate( test_images, test_labels, verbose=0) print('Baseline test accuracy:', baseline_model_accuracy) print('Quant test accuracy:', q_aware_model_accuracy
Finally, create fully quantized model with int8 weights and int8 activations for TFLite:
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model) converter.optimizations = [tf.lite.Optimize.DEFAULT] quantized_tflite_model = converter.convert()
You will have a fully quantized TensorFlow Lite model after quantization aware training compatible with integer only devices such as Ethos-U. However, to deploy your model on a system using an Ethos-U NPU, the quantized TensorFlow Lite file should be compiled with Vela for further optimization.
Which quantization method to choose for your ML model?
Choosing a quantized method for your application is a trade-off between model performance and size. For example, with 16x8 quantization technique, you sacrifice speed and size while getting better performance compared to full integer quantization. But if you only want to benefit from speed, full int8 quantization is the best.
Also depending on your ML application and availability of training data, you can choose a quantized method for your ML model. For example, a critical application like security systems that zero or minimum accuracy loss is required, quantization-aware training is beneficial.
After training deep learning models, it is common to see that the model is over-parameterized with parameters that have values of zero or close to zero. Setting a certain percentage of those values to zero during training and using a subset of trained parameters generates sparsity in the model. The sparse model preserves the high-dimensional features of the original network after pruning those parameters.
Training the model parameters with network pruning technique helps achieve high compression rates with minimal accuracy loss and enables execution of the model on embedded devices with only a few kilobytes of memory. Also, model sparsity can further accelerate inference within Arm Ethos-U NPU.
Note that similar to quantization there is a trade-off between the size of the model and the accuracy of the optimized model.
Define a model to fine-tune pre-trained model with pruning starting with 50% sparsity and ending with 80% sparsity.
import tensorflow_model_optimization as tfmot prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude # Compute end step to finish pruning after 2 epochs. batch_size = 128 epochs = 2 validation_split = 0.1 # 10% of training set will be used for validation set. num_images = train_images.shape[0] * (1 - validation_split) end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs # Define model for pruning. pruning_params = { 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,final_sparsity=0.80,begin_step=0, end_step=end_step) } model_for_pruning = prune_low_magnitude(model, **pruning_params) # `prune_low_magnitude` requires a recompile. model_for_pruning.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True), metrics=['accuracy']) model_for_pruning.summary()
Training and evaluating the model against baseline
logdir = tempfile.mkdtemp() callbacks = [ tfmot.sparsity.keras.UpdatePruningStep(), tfmot.sparsity.keras.PruningSummaries(log_dir=logdir), ] model_for_pruning.fit(train_images, train_labels, batch_size=batch_size, epochs=epochs, validation_split=validation_split, callbacks=callbacks) _, model_for_pruning_accuracy = model_for_pruning.evaluate( test_images, test_labels, verbose=0) print('Baseline test accuracy:', baseline_model_accuracy) print('Pruned test accuracy:', model_for_pruning_accuracy)
Weight pruning can further combine with quantization to improve memory footprint from both techniques and speed inference. Quantization then allows pruned model to be used with Ethos-U Machine Learning processors.
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export) converter.optimizations = [tf.lite.Optimize.DEFAULT] quantized_and_pruned_tflite_model = converter.convert() _, quantized_and_pruned_tflite_file = tempfile.mkstemp('.tflite') with open(quantized_and_pruned_tflite_file, 'wb') as f: f.write(quantized_and_pruned_tflite_model)
To deploy your model on a system using an Ethos-U NPU, the quantized TensorFlow Lite file should be compiled with Vela for further optimization.
Another technique of model optimization is a weight clustering which proposed and contributed by Arm ML team to TensorFlow Model Optimization Toolkit. Clustering reduces the storage and the size of the model leading to benefits for deployments on resource-constrain embedded systems. With this technique, first a fixed number of cluster centers for each layer is defined. Next, the weights of each layer are grouped into N clusters which later be replaced by their closest center.
Therefore, the size of the model will be reduced by replacing similar weights in each layer with the same value. These values are found by running a clustering algorithm over the weights of a trained model. Depending on the model and number of chosen clusters, the accuracy of the model could drop after clustering. To reduce the impact on accuracy, you must pass a pre-trained model with acceptable accuracy before clustering.
Before passing the model to the clustering API, it needs to be fully trained with acceptable accuracy.
import tensorflow_model_optimization as tfmot cluster_weights = tfmot.clustering.keras.cluster_weights CentroidInitialization = tfmot.clustering.keras.CentroidInitialization clustering_params = { 'number_of_clusters': 3, 'cluster_centroids_init': CentroidInitialization.LINEAR } # Cluster a whole model clustered_model = cluster_weights(model, **clustering_params) # Use smaller learning rate for fine-tuning clustered model opt = tf.keras.optimizers.Adam(learning_rate=1e-5) clustered_model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=opt, metrics=['accuracy']) clustered_model.summary()
Cluster sequential and functional models
Tips to get better model accuracy after clustering:
# Create a base model base_model = setup_model() base_model.load_weights(pretrained_weights) clustered_model = tf.keras.Sequential([ Dense(...), cluster_weights(Dense(..., kernel_initializer=pretrained_weights, bias_initializer=pretrained_bias), **clustering_params), Dense(...)])
Weight clustering also can further combine with quantization to improve memory footprint from both techniques and speed inference. Quantization then allows clustered model to be used with Ethos-U machine learning processors.
converter = tf.lite.TFLiteConverter.from_keras_model(final_model) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset tflite_quant_model = converter.convert() quantized_and_clustered_tflite_file = 'quantized_clustered.tflite' with open(quantized_and_clustered_tflite_file, 'wb') as f: f.write(tflite_quant_model)
However, to deploy your model on a system using an Ethos-U NPU, the quantized TensorFlow Lite file should be compiled with Vela for further optimization.
You can download the complete code sample of weight clustering combining with quantization technique from here: https://github.com/ARM-software/ML-examples/tree/master/ethos-u-microspeech
Collaborative optimization is a process of stacking different optimization techniques to improve inference speed on special hardware accelerators such as Ethos-U microNPUs.
This technique keeps the balance between compression and accuracy for deployment by taking advantages of accumulated optimization effect. Various combinations of the quantization techniques for deployment are possible such as:
Therefore, you can apply one or both of pruning and clustering following by post-training or QAT. However, combining these techniques does not preserve the results of the preceding technique. This leads to losing the overall benefits of simultaneously applying them.
For example, the sparsity of the pruned model will not preserve after deploying clustering. To address this problem the following collaborative optimization techniques can be used.
Sparsity preserving clustering example
To apply sparsity preserving clustering, first you need to prune the model using pruning API. Next, chain the model with clustering using the sparsity-preserving API. Finally, quantize the model with post-training quantization for deployment on Ethos-U microNPU.
Prune and fine-tune the model to 50% sparsity:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude pruning_params = { 'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100) } callbacks = [tfmot.sparsity.keras.UpdatePruningStep()] pruned_model = prune_low_magnitude(model, **pruning_params) # Use smaller learning rate for fine-tuning opt = tf.keras.optimizers.Adam(learning_rate=1e-5) pruned_model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=opt, metrics=['accuracy']) pruned_model.summary() # Fine-tune model pruned_model.fit( train_data, train_labels, epochs=3, validation_split=0.1, callbacks=callbacks)
To check the model kernels was correctly pruned, we need to strip the pruning wrapper first.
stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model) print_model_weights_sparsity(stripped_pruned_model)
Apply sparsity preserving clustering:
# Sparsity preserving clustering from tensorflow_model_optimization.python.core.clustering.keras.experimental import ( cluster, ) cluster_weights = cluster.cluster_weights clustering_params = { 'number_of_clusters': 8, 'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS, 'preserve_sparsity': True } sparsity_clustered_model = cluster_weights(stripped_pruned_model, **clustering_params) sparsity_clustered_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) # Train sparsity preserving clustering model sparsity_clustered_model.fit(train_data, train_labels,epochs=3, validation_split=0.1)
Create a TFLite model from combining sparsity preserving weight clustering and post-training quantization:
stripped_sparsity_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model) converter = tf.lite.TFLiteConverter.from_keras_model(stripped_sparsity_clustered_model) converter.optimizations = [tf.lite.Optimize.DEFAULT] sparsity_clustered_quant_model = converter.convert() _, pruned_and_clustered_tflite_file = tempfile.mkstemp('.tflite') with open(pruned_and_clustered_tflite_file, 'wb') as f: f.write(sparsity_clustered_quant_model)
The previous example demonstrates the process of training to get an optimized sparsity preserving model. For the other techniques, please refer to the CQAT, PQAT, and PCQAT. #tensorflow, #tensorflowlite, #collaborative optimization.
The following table shows the results of running experiments on DS-CNN-L and Mobilenet-V2, demonstrating the compression benefits vs. accuracy loss incurred. It also summarizes the number of microNPU cycles that were used for computations using Ethos-U55 NPU accelerator configured with 128 MAC units.
582,973