How to use the recent MMLA Instructions Efficiently [Required Data Layout]?

What is the appropriate way to load and store data for the recent Arm Neon MMLA instructions?

For example, the description of SMMLA is:

> Signed 8-bit integer matrix multiply-accumulate. This instruction multiplies the 2x8 matrix of signed 8-bit integer values in the first source vector by the 8x2 matrix of signed 8-bit integer values in the second source vector. The resulting 2x2 32-bit integer matrix [...]

They require inputs sized 2x8 and 8x2 and produce outputs of the form 2x2. How can I efficiently load and store data for these functions? With matrices that are either row- or column-major, I see three broad possibilities to use these instructions, none appealing:

  • [Split and Combine] I load two consecutive 128-bit vectors and split them. Storing is equally "complicated" since I need to extract the relevant parts of the 2x2 matrix before storing it.
  • [Smaller Vectors] Alternatively, I could load 64-bit vectors, but that seems inefficient too.
  • [Reorder Input Data] Of course, I could re-pack the input data so the vectorized loads already span multiple rows or columns. Is that the intended use?

An example code for the inner loop (reduction over K) of a small 4xK input matrix A (row-major) and a Kx4 matrix B (column-major) is as follows (loading 64-bit vectors and combining them):

```

for (size_t k = 0; k < 64; k += 8) {

    uint8x8_t low = vld1_u8(row0);

    uint8x8_t high = vld1_u8(row1);

    uint8x16_t row01x01234567 = vcombine_u8(low, high);

    row0 += 8;

    row1 += 8;

    low = vld1_u8(row2);

    high = vld1_u8(row3);

    uint8x16_t row23x01234567 = vcombine_u8(low, high);

    row2 += 8;

    row3 += 8;

    low = vld1_u8(col0);

    high = vld1_u8(col1);

    uint8x16_t col01x01234567 = vcombine_u8(low, high);

    col0 += 8;

    col1 += 8;

    low = vld1_u8(col2);

    high = vld1_u8(col3);

    uint8x16_t col23x01234567 = vcombine_u8(low, high);

    col2 += 8;

    col3 += 8;

    out01x01 = vmmlaq_u32(out01x01, row01x01234567, col01x01234567);

    out01x23 = vmmlaq_u32(out01x23, row01x01234567, col23x01234567);

    out23x01 = vmmlaq_u32(out23x01, row23x01234567, col01x01234567);

    out23x23 = vmmlaq_u32(out23x23, row23x01234567, col23x01234567);

}

```

The result is correct, but seems terribly inefficient. The code above is just an example. I actually would use larger tile sizes to maximize register usage.