Recurrent neural networks (RNN) are a type of neural network that are very good for working with sequential data or data that can be expressed in a sequence. They work by processing multiple inputs sequentially in time, and have an internal state that allows them to encode information about the inputs they have previously seen. Importantly, all the weights of the RNN are shared between each step. Some popular use cases for recurrent neural networks include: machine translation, speech synthesis and time series prediction. Many of these use cases are suitable for small embedded platforms but can also benefit by running on specialist hardware such as a Neural Processing Unit (NPU)
Figure 1: Layout of a simple Recurrent Neural Network for t timesteps.
In this guide, we show you how to use TensorFlow to train and quantize a simple recurrent neural network compatible with Arm's embedded NPUs such as the Arm Ethos-U55 and Arm Ethos-U65.
The complete code sample, which is runnable from the command line, is available to download from here: https://github.com/ARM-software/ML-examples/tree/master/rnn-unrolling-tflite
You should ensure you have Python3 installed on your machine.
Running the following commands create a Python environment and install the libraries necessary to run the code sample.
$ python3 -m venv env $ source env/bin/activate $ pip3 install –-upgrade pip3 $ pip3 install tensorflow == 2.5.0 $ pip3 install numpy == 1.19.5 $ pip3 install ethos-u-vela == 3.0.0
To start with, we are going to train a type of RNN called a Gated Recurrent Unit or GRU for short. This is a slightly more complex type of RNN that has features that allow it to hold on to important information from the past longer than a standard RNN.
The task that we use our GRU for is classifying handwritten digits using the MNIST dataset.Normally RNNs are not used for image classification tasks, however we choose to do so here for simplicity. To classify an image using our GRU, we treat each row of the 28x28 input images as one time steps worth of input to the GRU. This way we have a sequence of 28 image rows to input to our GRU. After performing 28 time steps worth of calculations the GRU has 'seen' the whole image and we can pass the final output to a fully connected layer that will do the classification for us.
Figure 2: Layout of the GRU model we train.
1. Define a model consisting of a single GRU layer followed by a fully connected layer in TensorFlow. The following code shows how to do this:
def rnn_model(time_steps): model_input = tf.keras.Input(shape=(time_steps, 28), name='input') gru_out = tf.keras.layers.GRU(units=256)(model_input) model_output = tf.keras.layers.Dense(10, activation='softmax')(gru_out) model = tf.keras.Model([model_input], [model_output]) return model
2. Next, load the data for training and testing the RNN model and normalize it with the following code:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() # Normalize between [0, 1] to help training. x_train = x_train / 255.0 x_test = x_test / 255.0
3. Create an instance of our previously defined model using the following code:
model = rnn_model(time_steps=28)
The time steps is set to be twenty-eight, which is the number of rows in each image.
4. Compile the model and then train it using the standard 'fit' method call. To save time we train only for one epoch, training for longer will of course give a more accurate model. This is shown in the following code:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy']) model.fit(x=x_train, y=y_train, epochs=1, validation_data=(x_test, y_test))
5. After training for only one epoch we reach an accuracy of almost 98% on the validation dataset, not a bad result considering this is not an ideal task for an RNN. Once training is finished, we save the weights of our model out to file so we can use them again later with the following code:
model.save_weights('gru_mnist')
6. Now we have trained model weights we will define a slightly different graph to use in deployment. This graph will be more optimal for inference and conversion to TFLite format - and ultimately allows the model to be quantized and run on Arm Ethos-U55. or Arm Ethos-U65 The following code example shows the new graph definition in TensorFlow:
def rnn_model_tflite(time_steps): # We need to specify batch_size=1 to get an optimal graph for inference. input_node = tf.keras.Input(shape=(time_steps, 28), batch_size=1, name='input') gru_out = tf.keras.layers.GRU(units=256, unroll=True)(input_node) prediction = tf.keras.layers.Dense(10, activation='softmax')(gru_out) model = tf.keras.Model([input_node], [prediction]) return model
It looks nearly identical to the previous definition but has two important changes. The first is setting the batch_size parameter on the input node to be one. As we are only planning to do single inferences at a time with our trained model, we can safely do this. Doing so will also have the benefit of removing any extra operations that may have to be added automatically to read the batch size or deal with batch sizes greater than one.
The second change is in the parameters used when we define the GRU layer. This is the most important change and it involves setting unroll to be True.
Normally an RNN has no defined end to the number of timesteps that it can loop for. As such, it can work with any length of input given to it. However, for a lot of real life use cases we know ahead of time the number of time steps we will need our RNN to loop for. In our example, we know that we need to have 28 time steps to account for the 28 rows in each input image. Setting unroll to be True will remove the loop from the RNN, like in figure 3, and unroll it to the number of time steps 't' set by the input to the layer. The result will be a strictly feedforward neural network that can then be quantized and converted easily to TensorFlow Lite.
Figure 3: Unrolling of an RNN.
7. We create an instance of our new 'deployment ready' model and load the previously saved weights into this model with the following code:
model_for_tflite = rnn_model_tflite(time_steps=28) model_for_tflite.load_weights('gru_mnist')
Now that we have our optimal model graph created, and the weights from the previous training have been loaded we want to go ahead and quantize it. Quantization is an essential step when deploying our model on small edge devices as it reduces model size and can speed up inference times. Moreover, NPUs like the Arm Ethos-U55 only support inference on models that have been quantized to 8-bits.
The following steps on quantizing the model follow from the public TensorFlow tutorials on post-training quantization. #tensorflow #tensorflowlite #quantization
8. As required for post-training quantization we define a generator function that will yield images from the training dataset. The following code shows how this is done:
def rep_dataset(): for i in range(50): # Only need a few examples. img = x_train[i].astype(np.float32) img = np.expand_dims(img, 0) yield [img]
Then we load the new keras model into the TFLiteConverter, set all the attributes for post-training quantization and convert the model to TensorFlow Lite format. Finally we save the model to file ready for optimization with Vela and deployment on Arm Ethos-U55 or Arm Ethos-U65. These steps are shown in the code below:
converter_quant = tf.lite.TFLiteConverter.from_keras_model(model_for_tflite) converter_quant.optimizations = [tf.lite.Optimize.DEFAULT] converter_quant.representative_dataset = rep_dataset converter_quant.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter_quant.inference_input_type = tf.int8 converter_quant.inference_output_type = tf.int8 tflite_model = converter_quant.convert() open('gru_mnist.tflite', 'wb').write(tflite_model)
As always, when quantizing your neural network model there can be some accuracy loss experienced. You should therefore check the accuracy of your model after quantization. Doing this for our GRU model we see that the accuracy has stayed virtually the same which is great news.
$ Testing quantized TFLite model... $ Accuracy of quantized TFLite model on test set: 97.57%
It might be the case for your specific model you want to manually set the initial hidden states or retrieve the final hidden states from your RNN layers. Luckily this is simple enough to do and just needs an alteration to how you define the model when exporting to TensorFlow Lite. For instance, to get a TensorFlow Lite model that has access to these input and output states we will simply define our model like in the following code snippet:
def rnn_model_tflite_with_state_transfer(time_steps): # We need to specify batch_size=1 to get an optimal graph for inference. input_node = tf.keras.Input(shape=(time_steps, 28), batch_size=1, name='input') # Input tensor for the initial GRU hidden state. initial_gru_state = tf.keras.Input(shape=256, batch_size=1, name='initial_gru_state') # Return the final GRU hidden state. gru_out, final_gru_state = tf.keras.layers.GRU(units=256, unroll=True, return_state=True)(input_node, initial_state=initial_gru_state) prediction = tf.keras.layers.Dense(10, activation='softmax')(gru_out) model = tf.keras.Model([input_node, initial_gru_state], [prediction, final_gru_state]) return model
In this model defintion we have now defined a new Input tensor that is the size of the GRU hidden layer, this is how we will feed an initial GRU state. In the GRU layer we set return_state to be True so that we can capture the final hidden state. When we call the layer we provide the initial state that we defined in the previous line. Finally, when making the keras Model we supply this additional input and output to the corresponding list of input and output nodes.
If you were using an LSTM layer instead of a basic RNN or GRU layer then the code would need only a small modification from above. You would need an additional Input tensor for the cell state and you would also need to capture the additional final cell state. These would also need to be added to the list of inputs and outputs when making the TensorFlow keras Model.
Note: When doing post-training quantization on a model with this state transfer enabled you will need to supply example data for these input hidden states in your representative data generator function. If your use case involves feeding final states as new initial states to your RNN, then it is advised to collect examples of these final states and iteratively add them to your representative dataset of input states.
The final step in preparing the model for deployment on Arm Ethos-U55 or Arm Ethos-U65 is to run the quantized TensorFlow Lite file through the Vela compiler. #compile NN #optimize #Arm Ethos NPU
Please see instruction on the Vela page on how to run your model through the Vela optimizer. After following the steps in this guide and running your model through Vela you should find yourself with a model that can be completely run on Arm Ethos-U55 or Arm Ethos-U65.
By following this guide, you successfully trained a simple RNN based model in TensorFlow. You have then seen how to unroll your RNN model so that is can be quantized and converted easily to TensorFlow Lite format ready for deploying on Arm Ethos-U55. You have also seen how it is possible to create a TFLite model where an initial RNN state can be provided as input to the model and how to capture final RNN states from the model.
Where can I buy one of these Ethos npus?