Arm Community
Arm Community
  • Site
  • User
  • Site
  • Search
  • User
  • Groups
    • Research Collaboration and Enablement
    • DesignStart
    • Education Hub
    • Innovation
    • Open Source Software and Platforms
  • Forums
    • AI and ML forum
    • Architectures and Processors forum
    • Arm Development Platforms forum
    • Arm Development Studio forum
    • Arm Virtual Hardware forum
    • Automotive forum
    • Compilers and Libraries forum
    • Graphics, Gaming, and VR forum
    • High Performance Computing (HPC) forum
    • Infrastructure Solutions forum
    • Internet of Things (IoT) forum
    • Keil forum
    • Morello Forum
    • Operating Systems forum
    • SoC Design and Simulation forum
    • 中文社区论区
  • Blogs
    • AI and ML blog
    • Announcements
    • Architectures and Processors blog
    • Automotive blog
    • Graphics, Gaming, and VR blog
    • High Performance Computing (HPC) blog
    • Infrastructure Solutions blog
    • Innovation blog
    • Internet of Things (IoT) blog
    • Operating Systems blog
    • Research Articles
    • SoC Design and Simulation blog
    • Tools, Software and IDEs blog
    • 中文社区博客
  • Support
    • Arm Support Services
    • Documentation
    • Downloads
    • Training
    • Arm Approved program
    • Arm Design Reviews
  • Community Help
  • More
  • Cancel
Arm Community blogs
Arm Community blogs
Graphics, Gaming, and VR blog Style transfer for graphics post-processing on mobile
  • Blogs
  • Mentions
  • Sub-Groups
  • Tags
  • Jump...
  • Cancel
More blogs in Arm Community blogs
  • AI and ML blog

  • Announcements

  • Architectures and Processors blog

  • Automotive blog

  • Embedded blog

  • Graphics, Gaming, and VR blog

  • High Performance Computing (HPC) blog

  • Infrastructure Solutions blog

  • Internet of Things (IoT) blog

  • Operating Systems blog

  • SoC Design and Simulation blog

  • Tools, Software and IDEs blog

Tell us what you think
Tags
  • vulkan
  • Neural Network
  • Machine Learning (ML)
  • Mali GPUs
  • Graphics and Gaming
  • Arm Compute Library (ACL)
Actions
  • RSS
  • More
  • Cancel
Related blog posts
Related forum threads

Style transfer for graphics post-processing on mobile

Pavel Rudko
Pavel Rudko
March 24, 2022
7 minute read time.

Overview

Using Machine Learning (ML) in graphics is a very promising area of study. A neural network (NN) can be used in various ways to modify or improve the quality of rendered frames. For example:

  • Supersampling - Rendering the scene in lower resolution and then upscaling it
  • Denoising – Removing noise e.g., when raytracing is used
  • Style Transfer – Adding interesting visual effects

In this blog, we are going to take a closer look at neural style transfer. We will explain how to use Keras to build a lightweight model for mobile and how to use Arm Compute Library (ACL) to get this model running on a device within a graphics application. This blog is mainly targeted at people who are familiar with ML (including inference on mobile) and working with Vulkan.

Style transfer

Style transfer is a technique that copies style features, such as colours and texture, from the reference image and adds them to the content image. The result looks like the content image, but “painted” in a certain way.

Style transfer processing for images

As you can see above, neural style transfer is a way to achieve artistic effects that would not be possible using classical post-processing algorithms. Today, it is already widely used on mobile for video chat or photo editing apps. However, game developers can also benefit from it significantly by adding certain atmosphere to their game or making the game look like a cartoon.

Neural network model

The first thing you need to apply ML to a photo, video or graphics on mobile is a NN model.

Architecture for mobile

Running an ML workload on mobile, especially in real-time, requires a light-weight architecture. In the case of style transfer, the task is usually split into two steps:

  • Style prediction
  • Style transformation

This gives you arbitrary style transfer, which allows you to use any style reference image, so this is a flexible solution.

But for mobile, it makes sense to use a less flexible, but more robust option, which takes only the content image as an input. In this case, the style information is “embedded” into the model.

The architecture we have used is very small and utilizes separable convolution blocks. They provide better performance than normal convolutions, which is why the same solution is used in MobileNet. These layers are combined into larger residual blocks, that have transposed convolutions in the end. The whole model contains 2 residual blocks, as well as an additional separable convolution in the end.

Here is a code snippet that demonstrates how to set up the model in Keras:

def separable_conv(x, filters, dilations):
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.DepthwiseConv2D((3, 3), strides=(1, 1), dilation_rate=dilations, padding='same', activation='relu')(x)
    x = tf.keras.layers.Conv2D(filters, (1, 1), strides=(1, 1), dilation_rate=(1, 1), padding='valid')(x)
    return x
 
def residual_block(x):
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Conv2D(8, (3, 3), strides=(2, 2), dilation_rate=(1, 1), padding='same')(x)
    x1 = separable_conv(x, 8, (1, 1))
    x = tf.keras.layers.Add()([x, x1])
    x1 = separable_conv(x, 8, (2, 2))
    x = tf.keras.layers.Add()([x, x1])
    x1 = separable_conv(x, 8, (1, 1))
    x = tf.keras.layers.Add()([x, x1])
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Conv2D(16, (3, 3), strides=(1, 1), dilation_rate=(1, 1), padding='same', activation='relu')(x)
    x = tf.keras.layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), dilation_rate=(1, 1), padding='valid')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    return x
 
def create_style_transfer_model():
    input = tf.keras.Input(shape=(512, 256, 3), batch_size=1, name='input')
    
    x = tf.keras.layers.Conv2D(16, (3, 3), strides=(1, 1), dilation_rate=(1, 1), padding='same')(input)
    x1 = residual_block (x)
    x = tf.keras.layers.Add()([x, x1])
    x1 = residual_block (x)
    x = tf.keras.layers.Add()([x, x1])
    x1 = separable_conv(x, 16, (1, 1))
    x = tf.keras.layers.Add()([x, x1])
    
    output = tf.keras.layers.Conv2D(3, (3, 3), strides=(1, 1), dilation_rate=(1, 1), padding='same')(x)
 
    model = tf.keras.Model(inputs=[input], outputs=[output], name='style_transfer')
    return model
 
model = create_style_transfer_model()

Training style transfer networks

The network must produce pictures that look like the content and style reference image at the same time. This means that during the training we need to compare the output to both images and calculate total loss as a weighted sum. The weight coefficient will define how much the result is stylized.

One way of a comparison would be pixel by pixel. But with style transfer, visible image features are more important than the actual pixel colours. Consider the following: you take an image and move it by 1 pixel to the right. With pixelwise comparison, the difference can be huge while the image remains virtually the same.

For a better comparison, we can use another NN that extracts features from images to calculate perceptual loss during the training process. In this case, we calculate loss as the difference between image features.

Training using another network

Setting up training using perceptual loss can be rather complicated and the training process itself can take long time. We have used another approach to train our small NN for mobile.

If there is already a pre-trained style transfer network that produces good results, but is too big for running on mobile in real-time. We can then use it to produce dataset outputs and simplify the training.  Below you can see how dataset inputs were converted into stylized outputs using a bigger pre-trained network. With these two sets of images (original and stylized), we were able to use simple mean square error function to train a smaller model.

Combining ML and graphics

The most important part is embedding NN inference into the graphics pipeline. There are few important things to consider:

  • The choice of ML framework
  • The choice of the inference device (CPU, GPU, NPU)
  • Establishing interoperability between graphics API and ML framework

We have created a demo project, that implements the combination of graphics rendering and ML postprocessing (style transfer in this case). The result looks like this:

Performance on the Samsung Galaxy S21 smartphone that uses the Exynos 2100 chipset with the Arm Mali-G78 GPU was good: ~26 milliseconds per frame, or 38 frames per second.

In this project, Vulkan is used for graphics rendering and ACL for ML inference. ACL supports inference on GPU using OpenCL for acceleration.

Using the OpenCL backend for inference together with OpenCL and Vulkan external memory extensions allowed us to achieve zero-copy data sharing between graphics and ML.

In the first render pass, the scene is rendered into an offscreen Vulkan image. The image memory is exported as an Android Hardware Buffer using the VK_KHR_external_memory extension. The code snippet below demonstrates how the image is created.

VkImageCreateInfo image_create_info = {};
image_create_info.sType = VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO;
image_create_info.pNext             = &external_memory_image_create_info;
image_create_info.imageType         = VK_IMAGE_TYPE_2D;
image_create_info.format            = VK_FORMAT_R8G8B8A8_UNORM;
image_create_info.mipLevels         = 1;
image_create_info.arrayLayers       = 1;
image_create_info.samples           = VK_SAMPLE_COUNT_1_BIT;
image_create_info.tiling            = VK_IMAGE_TILING_LINEAR;
image_create_info.sharingMode       = VK_SHARING_MODE_EXCLUSIVE;
image_create_info.initialLayout     = VK_IMAGE_LAYOUT_UNDEFINED;
image_create_info.extent            = extent;
image_create_info.usage             = VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT | VK_IMAGE_USAGE_SAMPLED_BIT;

VkImage offscreen_image_handle;
auto result = vkCreateImage(device.get_handle(), &image_create_info, nullptr, &offscreen_image_handle);

The corresponding AHardwareBuffer is imported into OpenCL using cl_arm_import_memory extension as a cl_mem handle:

VkMemoryGetAndroidHardwareBufferInfoANDROID get_hardware_buffer_info = {};
get_hardware_buffer_info.sType  = VK_STRUCTURE_TYPE_MEMORY_GET_ANDROID_HARDWARE_BUFFER_INFO_ANDROID;
get_hardware_buffer_info.pNext  = nullptr;
get_hardware_buffer_info.memory = memory;
auto result = vkGetMemoryAndroidHardwareBufferANDROID(get_device().get_handle(), &get_hardware_buffer_info, &hardware_buffer);

The OpenCL handle is then imported into ACL as input and output for style transfer inference:

const cl_import_properties_arm cl_import_properties[] = { CL_IMPORT_TYPE_ARM, CL_IMPORT_TYPE_ANDROID_HARDWARE_BUFFER_ARM, 0 };
cl_mem imported_memory = clImportMemoryARM(context,
                                           CL_MEM_READ_WRITE,
                                           cl_import_properties,
                                           hardware_buffer,
                                           CL_IMPORT_MEMORY_WHOLE_ALLOCATION_ARM,
                                           &error);

auto status = input_tensor->allocator()->import_memory(cl::Buffer(imported_memory));

The output of the inference stage is sampled in the final render pass to display the result onto the screen.

You can find more info about Vulkan-OpenCL interop in the Khronos Vulkan Samples repository.

The pipeline overview is represented below:

Upcoming Khronos extensions for OpenCL

In the future, it will be even easier for developers to achieve smooth interop between Vulkan or OpenGL ES rendering and OpenCL inference, once the following Khronos extensions are supported on mobile devices:

  • cl_khr_external_memory (For data sharing)
  • cl_khr_semaphore (For synchronization)

Conclusions

ML-based post-processing is coming to mobile devices and opens many new opportunities for graphics developers.

We have covered one particular use case: using style transfer as part of the graphics pipeline. But there are other areas to explore – for example, super sampling.

In any use case, efficient data sharing between Graphics and ML is important. Check out our code sample on Vulkan-OpenCL interop.

Following on from the GDC talk about style transfer presented by Roberto Lopez Mendez, the GitHub repo can be accessed with the source code, models and data so you can build the demo that was showcased in the talk, as well as providing an opportunity to experiment and build new demos.

Anonymous
Graphics, Gaming, and VR blog
  • Arm Immortalis-G715 Developer Overview

    Peter Harris
    Peter Harris
    The new Arm®︎ Immortalis™︎ -G715 GPU is now available in consumer devices. This blog explores what is new, and how developers can get the best performance out of it.
    • March 20, 2023
  • Success in mobile games with ray tracing

    arm-phodges
    arm-phodges
    Blog provides details on how to use ray tracing techniques successfully across all mobile games on Arm-powered smartphones.
    • March 6, 2023
  • Arm at Vulkanised 2023

    Peter Harris
    Peter Harris
    A summary of the Arm talks at Khronos' Vulkanised 2023 event.
    • March 1, 2023