CodeGen Performant Convolution Kernels for Mobile GPUs

This blog post talks about how to generate performant code for convolution ops using MLIR’s multiple levels of abstractions and transformations. I initially created it for targeting ARM Mali GPUs in IREE. But given it is just direct tiling and vectorization, it should be widely applicable.

I will walk through the lowering steps, so if you are interested to know how to organize MLIR’s various dialects/patterns together to achieve similar tasks, this blog post might also be useful.

Background and Scope

The goal is to generate performant code for convolution ops to target mobile GPUs. The problem spaces on both sides are huge, therefore it’s worth limiting our scope so that we can have the best solution for each well-defined subproblem and bite off the big one piece by piece.

Convolution ops

There are many flavors of convolution ops. They can have vastly different data access patterns and some variants (e.g., depthwise/dilated convolution) are inherently adversarial for GPU utilization (without layout adjustment and other optimizations).

In this blog post, I will be focusing on the basic version of convolution, which is still widely used in various mobile vision models. To make sure we are on the same page regarding terminology, specifically, we apply a 4-D filter tensor Filter on a 4-D input tensor Input to generate the 4-D output tensor Output. Input follows layout (N, IH, IW, IC). Filter follows layout (FH, FW, IC, OC). Output follows layout (N, OH, OW, OC). For now we assume there is no padding and dilation. (Handling padding is complicated enough to merit its own blog posts.)

Mobile GPUs

In the mobile world, GPUs can come from several vendors. Apple has its own solution for iOS devices. For Android, Qualcomm Adreno and ARM Mali are notable ones. There are also others like Imagination PowerVR, and AMD is entering the mobile world too. Although they are all tiled architectures so we can target all of them uniformly, they differ on important characteristics to require special treatment here and there.

For example, here we will focus on Mali GPUs, which do not have dedicated on-chip shared memory. You can still use shared memory for sure, but it’s just plain normal system memory. So optimizations like promoting to shared memory (which we commonly see when targeting NVIDIA/AMD GPUs) actually won’t be helpful or even can be harmful to performance.

General GPU kernel optimization

The goal of GPU kernel optimization is to approach theoretical peak performance as much as possible. Pure theoretical peak performance is hard to achieve, if ever possible. In reality, arguably 50%-80% is already good enough. And there are diminishing returns as we approach closer and closer to the peak.

In general, approaching the theoretical peak means to saturate computation units, which requires to 1) prepare data fast and 2) issue enough computation.

For the first, commonly we want to 1) exploit data locality, which typically means tiling. On top of that we want to have 1b) good load/store patterns. For global memory, that is to perform 4-element load/store cyclically for adjacent GPU threads. This is particularly important to Mali GPUs. Then we want to 1c) reuse data in fast memory. For Mali GPUs we don’t have dedicated shared memory so that leaves us to only registers. We can preload shareable data and cache intermediate data in registers. But we also want to control the number of registers each thread uses to avoid spilling. There is a trade off here (as always).

For the second, lots of tricks. But in general we prefer streamlined vectorized code with as little flow control as possible. This way we can fill the GPU compute pipelines as much as possible without many bubbles. GPUs are parallel machines with thousands of threads easily, but each one of them is just a very naive slow in-order “CPU thread”1.

Overall CodeGen Strategy

Based on the previous section, clearly there is a large gap between what we see as the source convolution op and what the target hardware excels at executing. The former is quite high-level, abstract, and dynamic. The latter wants low-level, concrete, and static. Our task here is to bridge the gap. It’s doable in one step; we can hand write whatever GPU kernel code for whatever source convolution ops. The difficulty is composability, extensibility, and maintainability. It is painful for others to understand or modify.

Here are where compiler code generation and particularly the multiple-level nature of MLIR are great. Like how we typically approach all computer problems, we can create different abstractions to break down the problem and create solutions for isolated tasks to make them composible, extensible, and easy to maintain.

Looking at the problem we have at hand again, a convolution op is a standalone entity computing on n-D tensors. We want to break it down to the level that GPU favors—computation over 4-element 1-D vectors. If it’s just one thread, we need to wrap the 1-D vector computation inside some loops so that we can solve the whole original workload. But GPUs provide so many threads. We can then divide and conquer by partitioning the original workload and distribute them to different workgroups/workitems.

Tiling and distribution

Therefore, the general idea is to tile and distribute the original convolution op. We have three levels in the compute hierarchy: workgroup, subgroup/warp, and workitem. As said before for Mali GPUs there is no dedicated on-chip shared memory so it’s not helpful to utilize that as a way to exchange data among different subgroups. So we will just tile and distribute to workgroups and workitems.

This tiling and distribution should partition the convolution Output, as all its dimensions are parallel ones—each output element can be computed independently from others. This is great for GPU, which is just a massive parallel machine that was initially designed to shade pixels on screens.

Vectorization and unrolling

After tiling and distribution, each GPU thread just handles a subset of the Output elements, effectively a much smaller scale convolution here. Considering the general kernel optimizations, tiling helps to exploit data locality. At the workitem level, we need to consider the rest. We’d want vectorized code here, both for load/store and computation. We’d want to use registers to cache commonly used inputs and intermediate results. We’d want to unroll loops to generate enough streamlined code. Let’s see how we can achieve these goals.

  • Each workitem should fully compute all convolution Output elements it is responsible for in one run. All Output elements are independent so the workitem doesn’t need to wait for anything. Fully computing allows us to use registers to hold intermediate results and only write out once finally.
  • Convolution Filter is needed for computing all Output elements. We can preload it to registers to reduce memory requests and boost reuse.
  • In order to have a good vectorized memory load/store pattern, the distribution should be cyclic and each thread should be in charge of 4 consecutive elements in memory. That is, thread (0/1/2/3/etc., y, z) should handle elements 0-3/4-7/8-11/12-15/etc. This is actually where the convolution layout matters. As said before, our convolution Input/Filter/Output follows the (N, IH, IW, IC)/(FH, FW, IC, OC)/(N, OH, OW, OC) layout. We partition the Output. So we can achieve the nice access pattern for Output and therefore Filter (which also has OC as its innermost dimension), but not Input. Input has IC as its innermost dimension. That’s a reduction dimension; a thread needs to read through the full memory span consecutively. Often it’s not as small as 4. (For vision models, it’s typically larger and larger as we extract more and more high-level features.) So we don’t have consecutive threads always reading consecutive 4-element chunks. That breaks the pattern.
  • Vectorized computation is sort of natural if we can distribute properly (making sure each thread handles 4x elements for the innermost dimension) and perform vectorized load/store.
  • For unrolling, among all the dimensions, OH, OW, and OC are sufficient. We partitioned along these dimensions so they have known small static values. Therefore, both feasible and controllable. We can materialize loops for other dimensions (FH, FW, and IC).

Putting the above together, here is a sketch of the vectorized kernel for one workitem (taken from here):

// Each thread/invocation calculates (IVC_OH * IVC_OW * IVC_OC * 4) output elements.
VEC4TYPE O[IVC_OH][IVC_OW][IVC_OC];

// Use registers to keep the filter for this tile to increase data reuse.
VEC4TYPE F[4][IVC_OC];

uvec3 wgID = gl_WorkGroupID;
uvec3 threadID = gl_LocalInvocationID;
uvec3 threadCount = gl_WorkGroupSize;

uint wgBaseOC = wgID.x * WG_TILE_OC; // Workgroup base output channel
uint wgBaseOW = wgID.y * WG_TILE_OW; // Workgroup base output width
uint wgBaseOH = wgID.z * WG_TILE_OH; // Workgroup base output height

// Initialize the output for this batch to zero.
[[unroll]] for (uint i = 0; i < IVC_OH; ++i)
  [[unroll]] for (uint j = 0; j < IVC_OW; ++j)
    [[unroll]] for (uint k = 0; k < IVC_OC; ++k)
      O[i][j][k] = VEC4TYPE(0.f, 0.f, 0.f, 0.f);

for (uint fh = 0; fh < FH; ++fh) {
  for (uint fw = 0; fw < FW; ++fw) {
    // Tile input channel with each tile having 4 elements.
    for (uint ic = 0; ic < IC; ic += 4) {
      // Load the filter for this input channel tile.
      [[unroll]] for (uint i = 0; i < 4; ++i)
        [[unroll]] for (uint j = 0; j < IVC_OC; ++j)
          uint oc = (threadID.x + threadCount.x * j) * 4 + wgBaseOC;
          F[i][j] = Filter.data[filterCoordToOffset(fh, fw, ic + i, oc)];

      // Load this input channel tile and perform dot product with filters
      // for different output elements.
      [[unroll]] for (uint i = 0; i < IVC_OH; ++i) {
        uint oh = i + threadID.z * IVC_OH + wgBaseOH;
        [[unroll]] for (uint j = 0; j < IVC_OW; ++j) {
          uint ow = j + threadID.y * IVC_OW + wgBaseOW;
          VEC4TYPE feature = Input.data[inputCoordToOffset(oh * SH + fh, ow * SW + fw, ic)];
          [[unroll]] for (uint k = 0; k < IVC_OC; ++k) {
            O[i][j][k] += VEC4TYPE(feature.x, feature.x, feature.x, feature.x) * F[0][k];
            O[i][j][k] += VEC4TYPE(feature.y, feature.y, feature.y, feature.y) * F[1][k];
            O[i][j][k] += VEC4TYPE(feature.z, feature.z, feature.z, feature.z) * F[2][k];
            O[i][j][k] += VEC4TYPE(feature.w, feature.w, feature.w, feature.w) * F[3][k];
          }
        }
      }
    }
  }
}

// Write out the computed output elements.
[[unroll]] for (uint i = 0; i < IVC_OH; ++i) {
  uint oh = i + threadID.z * IVC_OH + wgBaseOH;
  [[unroll]] for (uint j = 0; j < IVC_OW; ++j) {
    uint ow = j + threadID.y * IVC_OW + wgBaseOW;
    [[unroll]] for (uint k = 0; k < IVC_OC; ++k) {
      uint oc = (threadID.x + threadCount.x * k) * 4 + wgBaseOC;
      Output.data[outputCoordToOffset(oh, ow, oc)] = O[i][j][k];
    }
  }
}

In the above, we keep the current Output batch in O and preload the Filter batch in F. We loop over FH and FW, and perform FMA along IC. For IC, although we cannot achieve the perfect memory load pattern, we still try to handle 4 elements each time. This isn’t strictly required though. Effectively we can read one scalar element each time and FMA it to each Output element. It supports cases like IC == 3, which happens for the initial image where we typically have 3 channels (RGB).

The above code is taken from the µVkCompute, where one can directly write kernels with simple Vulkan compute pipelines to try out different CodeGen strategies. You can find the runnable code there, which also has a packed fp16 version. Invoking it on Samsung Galaxy S21 (Exynos 2100, Mali G78 MP14):

> adb shell /data/local/tmp/conv2d_mali_valhall --latency_measure_mode=gpu_timestamp

2021-09-12T16:03:26-04:00
Running /data/local/tmp/conv2d_mali_valhall
Run on (8 X 2210 MHz CPU s)
***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead.
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                         Time             CPU   Iterations UserCounters...
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[1x16x16]/WGSize[4x4x1]/f32/manual_time       13431 us          837 us           53 FLOps=359.761G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[2x8x16]/WGSize[4x4x1]/f32/manual_time        13489 us          711 us           52 FLOps=358.216G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[4x4x16]/WGSize[4x4x1]/f32/manual_time        14216 us          757 us           49 FLOps=339.894G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[2x8x16]/WGSize[4x2x2]/f32/manual_time        13337 us          433 us           53 FLOps=362.281G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[4x4x16]/WGSize[4x2x2]/f32/manual_time        13681 us          678 us           51 FLOps=353.184G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[8x2x16]/WGSize[4x2x2]/f32/manual_time        13544 us          550 us           52 FLOps=356.754G/s
...
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[2x4x32]/WGSize[4x2x2]/f16/manual_time         8405 us          520 us           83 FLOps=574.908G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[2x8x32]/WGSize[4x2x2]/f16/manual_time         6019 us          412 us          119 FLOps=802.734G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[4x4x32]/WGSize[4x2x2]/f16/manual_time         6243 us          846 us          117 FLOps=773.904G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[8x2x32]/WGSize[4x2x2]/f16/manual_time         6169 us          794 us          119 FLOps=783.25G/s
...

The theoretical peak is roughly 7602/1520 GFLOps for fp32/fp16. So this is achieving 50% utilization. Not so bad with straightforward tiling and direct vectorization.

Okay, thus far we have a clear overall CodeGen strategy and know what the inner kernel should be. Let’s see how we can generate in such a way with MLIR.

The MLIR/IREE CodeGen Flow

In MLIR we have many dialects, modelling abstractions at different levels. The ones involved here are mhlo, linalg, vector, scf, and spirv. I won’t go into details about them; you can reference their documentation by following the embedded links. Just mentioning how they are used in our flow:

  • mhlo: input dialect for models authored in TensorFlow.
  • linalg: core dialect that models linalg algebra computations on n-D tensors or buffers. We use to it perform tiling and distribution.
  • vector: core dialect for modelling n-D static-shaped vectors. We use it to perform vectorization and vector level optimizations like load-store forwarding.
  • scf: for loops around computation.
  • spirv: final output dialect for generated kernels.

The journey starts with the following source convolution:

func @conv(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32xf32>)
          -> tensor<1x112x112x32xf32> {
  %0 = mhlo.convolution(%input, %filter)
            dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
            window = {stride = [2, 2], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]}
            {batch_group_count = 1 : i64, feature_group_count = 1 : i64}
            : (tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>) -> tensor<1x112x112x32xf32>
  return %0 : tensor<1x112x112x32xf32>
}

If you’d like to see the flow by yourself, you can compile iree-translate and invoke it on the above function with the following command:

iree/tools/iree-translate \
  -iree-input-type=mhlo \
  -iree-mlir-to-vm-bytecode-module \
  -iree-hal-target-backends=vulkan-spirv \
  -iree-vulkan-target-triple=valhall-unknown-android11 \
  -print-ir-after-all \
  -mlir-print-local-scope \
  conv.mlir -o iree.vmfb &> conv-conversion.mlir

The IREE codebase changes very quickly. I use IREE@080cbc46 in this blog post (for both IR snippets and code pointers). If you’d like to see the exact same result, you can check that commit out. I also put the full dump in a Gist that you can use as a reference.

mhlo to linalg conversion

The first major step is to convert mhlo.convolution into a linalg.conv op. This is done via the ConvertMHLOToLinalgOnTensors pass.

// -----// IR Dump After ConvertMHLOToLinalgOnTensors //----- //
func @conv(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view)
          -> !hal.buffer_view attributes {iree.abi.stub} {
  %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<1x225x225x3xf32>
  %1 = hal.tensor.cast %arg1 : !hal.buffer_view -> tensor<3x3x3x32xf32>
  %2 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
  %cst = constant 0.000000e+00 : f32
  %3 = linalg.fill(%cst, %2) : f32, tensor<1x112x112x32xf32> -> tensor<1x112x112x32xf32>
  %4 = linalg.conv_2d_nhwc_hwcf {
         dilations = dense<1> : tensor<2xi64>,
         strides = dense<2> : tensor<2xi64>
       }
       ins(%0, %1 : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>)
       outs(%3 : tensor<1x112x112x32xf32>)
       -> tensor<1x112x112x32xf32>
  %5 = hal.tensor.cast %4 : tensor<1x112x112x32xf32> -> !hal.buffer_view
  return %5 : !hal.buffer_view
}

There isn’t much to say here. It’s basically pattern matching against the source mhlo convolution op, check its various attributes, and emit the suitable linalg convolution op.

Unlike mhlo, where we have one single convolution op supporting many different configurations, in linalg, there are many different ops for different flavors of convolution, e.g., linalg.conv_1d_nwc_wcf, linalg.conv_2d_nhwc_hwcf. They are all called named ops. It’s simple to define a new named op in linalg. They serve as the anchor for pattern matching and specialization for CodeGen or library calls.

But named ops aren’t core to linalg; actually they are just named versions of the linalg.generic op of particular forms. linalg.generic has an implicit loop nest in it. It carries an affine map for each operand/result. The affine map defines the access pattern to the corresponding operand/result in the implicit loop nest. The payload in the loop nest is explicitly captured as a MLIR region to the linalg.generic op.

Due to this structured representation, loop transformations are quite straightforward because by definition the op is a perfect loop nest and there is no explicit loop nest to match/manipulate for transformations. You can read more about this in the documentation.

Tiling and distributing to workgroups

Now we can start the first level tiling and distribution. In IREE this is done via the DispatchLinalgOnTensors pass. What we need to do is to utilize the LinalgBaseTilingPattern with a proper LinalgTilingOptions, which controls aspects like

  • whether we want to tile/partition along each loop and what the tile size is,
  • if we want to distribute, what the processor IDs are and how to distribute,
  • what kind of loops to generate,
  • etc.

You can find the example configuration for IREE’s workgroup tiling here. In IREE tiling to workgroup is initially done in an abstract way where we use symbolic values like flow.dispatch.workgroup.size for tiling sizes, and flow.dispatch.workgroup.id/flow.dispatch.workgroup.count for processor IDs/counts. Right now we generate scf.for loops. So with the linked configuration, the tiled code looks like:

// -----// IR Dump After DispatchLinalgOnTensors //----- //
func @conv(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view)
          -> !hal.buffer_view attributes {iree.abi.stub} {
  %c32 = constant 32 : index
  %c112 = constant 112 : index
  %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<1x225x225x3xf32>
  %1 = hal.tensor.cast %arg1 : !hal.buffer_view -> tensor<3x3x3x32xf32>
  %2 = flow.dispatch.workgroups[%c32, %c112, %c112](%0, %1)
       : (tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>) -> tensor<1x112x112x32xf32> =
       (%arg2: !flow.dispatch.tensor<readonly:1x225x225x3xf32>,
        %arg3: !flow.dispatch.tensor<readonly:3x3x3x32xf32>,
        %arg4: !flow.dispatch.tensor<writeonly:1x112x112x32xf32>) {
    %cst = constant 0.000000e+00 : f32
    %c112_0 = constant 112 : index
    %c32_1 = constant 32 : index
    %4 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
    %workgroup_size_0 = flow.dispatch.workgroup.size[0] : index
    %workgroup_size_1 = flow.dispatch.workgroup.size[1] : index
    %workgroup_size_2 = flow.dispatch.workgroup.size[2] : index
    %workgroup_id_0 = flow.dispatch.workgroup.id[0] : index
    %workgroup_count_0 = flow.dispatch.workgroup.count[0] : index
    %workgroup_id_1 = flow.dispatch.workgroup.id[1] : index
    %workgroup_count_1 = flow.dispatch.workgroup.count[1] : index
    %workgroup_id_2 = flow.dispatch.workgroup.id[2] : index
    %workgroup_count_2 = flow.dispatch.workgroup.count[2] : index
    %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_2, %workgroup_size_2]
    %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_2, %workgroup_size_2]
    scf.for %arg5 = %5 to %c112_0 step %6 {
      %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_1, %workgroup_size_1]
      %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_1, %workgroup_size_1]
      scf.for %arg6 = %7 to %c112_0 step %8 {
        %9 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_0, %workgroup_size_0]
        %10 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_0, %workgroup_size_0]
        scf.for %arg7 = %9 to %c32_1 step %10 {
          %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg5)
          %12 = affine.min affine_map<(d0, d1) -> (d0 * 2 + 1, d1 * -2 + 227)>(%workgroup_size_2, %arg5)
          %13 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg6)
          %14 = affine.min affine_map<(d0, d1) -> (d0 * 2 + 1, d1 * -2 + 227)>(%workgroup_size_1, %arg6)
          %15 = flow.dispatch.tensor.load %arg2,
                  offsets = [0, %11, %13, 0], sizes = [1, %12, %14, 3], strides = [1, 1, 1, 1]
          %16 = affine.min affine_map<(d0, d1) -> (d0, -d1 + 32)>(%workgroup_size_0, %arg7)
          %17 = flow.dispatch.tensor.load %arg3,
                  offsets = [0, 0, 0, %arg7], sizes = [3, 3, 3, %16], strides = [1, 1, 1, 1]
          %18 = affine.min affine_map<(d0, d1) -> (d0, -d1 + 112)>(%workgroup_size_2, %arg5)
          %19 = affine.min affine_map<(d0, d1) -> (d0, -d1 + 112)>(%workgroup_size_1, %arg6)
          %20 = affine.min affine_map<(d0, d1) -> (d0, -d1 + 32)>(%workgroup_size_0, %arg7)
          %21 = affine.min affine_map<(d0, d1) -> (-d0 + 112, d1)>(%arg5, %workgroup_size_2)
          %22 = affine.min affine_map<(d0, d1) -> (-d0 + 112, d1)>(%arg6, %workgroup_size_1)
          %23 = affine.min affine_map<(d0, d1) -> (-d0 + 32, d1)>(%arg7, %workgroup_size_0)
          %24 = tensor.extract_slice %4[0, %arg5, %arg6, %arg7] [1, %21, %22, %23] [1, 1, 1, 1]
                : tensor<1x112x112x32xf32> to tensor<1x?x?x?xf32>
          %25 = linalg.fill(%cst, %24) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
          %26 = linalg.conv_2d_nhwc_hwcf {
                  __internal_linalg_transform__ = "workgroup",
                  dilations = dense<1> : tensor<2xi64>,
                  strides = dense<2> : tensor<2xi64>
                }
                ins(%15, %17 : tensor<1x?x?x3xf32>, tensor<3x3x3x?xf32>)
                outs(%25 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
          flow.dispatch.tensor.store %26, %arg4,
            offsets = [0, %arg5, %arg6, %arg7], sizes = [1, %18, %19, %20], strides = [1, 1, 1, 1]
        }
      }
    }
    flow.return
  }
  %3 = hal.tensor.cast %2 : tensor<1x112x112x32xf32> -> !hal.buffer_view
  return %3 : !hal.buffer_view
}

I’ve trimmed the raw output down a bit (by removing some type annotations), but still the IR snippet becomes lengthy from now on, given we are going to lower abstractions.

We can see in the above, there are three loops that partition along OH, OW, and OC dimensions, and distribute to workgroup z, y, x dimensions. We have a linalg.conv_2d_nhwc_hwcf op working on a smaller scale tile inside the loop nest, together with the linalg.fill for its output initialization. The affine.apply/affine.min ops inside the loop nest is for calculating the tile indices and sizes to make sure we don’t go out of bound.

This performs the first level tiling. The abstract tiling is meant for handling both CPU and GPU in a uniform way; under such circumstances, we cannot determine the concrete configuration.

Afterwards, we are down to the CodeGen path for GPU specifically. We need to start injecting static information to simplify the IR (particularly those affine.apply/affine.min ops). We can find a concrete tiling and workgroup size scheme to perfectly partition the convolution Output. Here, that’s using workgroup size (8, 2, 1) and letting each workgroup process a 1x8x32 (OHxOWxOC) Output patch.3.

Performing bufferization and injecting such information, we can get:

// -----// IR Dump After SPIRVRemoveOneTripTiledLoop //----- //
func @conv_dispatch_0() {
  %c0 = constant 0 : index
  %c112 = constant 112 : index
  %cst = constant 0.000000e+00 : f32
  %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<1x225x225x3xf32>
  %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<3x3x3x32xf32>
  %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<1x112x112x32xf32>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_id_z = hal.interface.workgroup.id[2] : index
  scf.for %arg0 = %workgroup_id_z to %c112 step %c112 {
    %3 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
    %4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
    %5 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
    %6 = affine.min affine_map<(d0) -> (3, d0 * -2 + 227)>(%arg0)
    %7 = affine.apply affine_map<(d0) -> (d0 * 2)>(%3)
    %8 = affine.min affine_map<(d0) -> (17, d0 * -2 + 227)>(%3)
    %9 = memref.subview %0[0, %5, %7, 0] [1, %6, %8, 3] [1, 1, 1, 1] : ...
    %10 = affine.min affine_map<(d0) -> (32, -d0 + 32)>(%4)
    %11 = memref.subview %1[0, 0, 0, %4] [3, 3, 3, %10] [1, 1, 1, 1] : ...
    %12 = affine.min affine_map<(d0) -> (1, -d0 + 112)>(%arg0)
    %13 = affine.min affine_map<(d0) -> (8, -d0 + 112)>(%3)
    %14 = memref.subview %2[0, %arg0, %3, %4] [1, %12, %13, %10] [1, 1, 1, 1] : ...
    linalg.fill(%cst, %14) ...
    linalg.conv_2d_nhwc_hwcf {
      __internal_linalg_transform__ = "workgroup",
      dilations = dense<1> : tensor<2xi64>,
      lowering.config = {tileSizes = [[0, 1, 8, 32], [], [0, 1, 4, 4]]},
      strides = dense<2> : tensor<2xi64>
    } ins(%9, %11 : ...) outs(%14 : ...>)
  }
  return
}

It’s much cleaner right now. The outer 1-trip loops are basically gone. Next up to the second level tiling and distributing to workitems.

Tiling and distributing to workitems

This is done by the SPIRVTileAndDistribute pass. It’s quite similar to the first level tiling and distribution. It’s just a matter of setting the proper LinalgTilingOptions.

In the above step we choose workgroup size (8, 2, 1) and let each workgroup handle a 1x8x32 (OHxOWxOC) Output patch. So naturally, a thread handles a 1x4x4 Output patch. Use such static information we can fold quite a few affine ops away and generate the following IR:

// -----// IR Dump After SPIRVTileAndDistribute //----- //
func @conv_dispatch_0() {
  %c0 = constant 0 : index
  %c112 = constant 112 : index
  %cst = constant 0.000000e+00 : f32
  %c3 = constant 3 : index
  %c1 = constant 1 : index
  %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<1x225x225x3xf32>
  %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<3x3x3x32xf32>
  %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<1x112x112x32xf32>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_id_z = hal.interface.workgroup.id[2] : index
  scf.for %arg0 = %workgroup_id_z to %c112 step %c112 {
    %3 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
    %4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
    %5 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
    %6 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_y]
    %7 = memref.subview %0[0, %5, %6, 0] [1, 3, 17, 3] [1, 1, 1, 1]
         : memref<1x225x225x3xf32> to memref<1x3x17x3xf32, ...>
    %8 = memref.subview %1[0, 0, 0, %4] [3, 3, 3, 32] [1, 1, 1, 1]
         : memref<3x3x3x32xf32> to memref<3x3x3x32xf32, ...>
    %9 = memref.subview %2[0, %arg0, %3, %4] [1, 1, 8, 32] [1, 1, 1, 1]
         : memref<1x112x112x32xf32> to memref<1x1x8x32xf32, ...>
    %10 = "gpu.thread_id"() {dimension = "x"} : () -> index
    %11 = "gpu.thread_id"() {dimension = "y"} : () -> index
    %12 = "gpu.thread_id"() {dimension = "z"} : () -> index
    %13 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%11]
    %14 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%10]
    %15 = memref.subview %9[0, %12, %13, %14] [1, 1, 4, 4] [1, 1, 1, 1]
          : memref<1x1x8x32xf32, ...> to memref<1x1x4x4xf32, ...>
    linalg.fill(%cst, %15) ...
    %16 = "gpu.thread_id"() {dimension = "x"} : () -> index
    %17 = "gpu.thread_id"() {dimension = "y"} : () -> index
    %18 = "gpu.thread_id"() {dimension = "z"} : () -> index
    %19 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%17]
    %20 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%16]
    %21 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%18]
    %22 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%17]
    %23 = memref.subview %7[0, %21, %22, 0] [1, 3, 9, 3] [1, 1, 1, 1]
          : memref<1x3x17x3xf32, ...> to memref<1x3x9x3xf32, ...>
    %24 = memref.subview %8[0, 0, 0, %20] [3, 3, 3, 4] [1, 1, 1, 1]
          : memref<3x3x3x32xf32, ...> to memref<3x3x3x4xf32, ...>
    %25 = memref.subview %9[0, %18, %19, %20] [1, 1, 4, 4] [1, 1, 1, 1]
          : memref<1x1x8x32xf32, ...> to memref<1x1x4x4xf32, ...>
    scf.for %arg1 = %c0 to %c3 step %c1 {
      scf.for %arg2 = %c0 to %c3 step %c1 {
        %26 = memref.subview %23[0, %arg1, %arg2, 0] [1, 1, 7, 3] [1, 1, 1, 1]
              : memref<1x3x9x3xf32, ...> to memref<1x1x7x3xf32, ...>
        %27 = memref.subview %24[%arg1, %arg2, 0, 0] [1, 1, 3, 4] [1, 1, 1, 1]
              : memref<3x3x3x4xf32, ...> to memref<1x1x3x4xf32, ...>
        linalg.conv_2d_nhwc_hwcf {
          __internal_linalg_transform__ = "vectorize",
          dilations = dense<1> : tensor<2xi64>,
          lowering.config = {tileSizes = [[0, 1, 8, 32], [], [0, 1, 4, 4]]},
          strides = dense<2> : tensor<2xi64>
        }
        ins(%26, %27 : memref<1x1x7x3xf32, ...>, memref<1x1x3x4xf32, ...>)
        outs(%25 : memref<1x1x4x4xf32, ...>)
      }
    }
  }
  return
}

Unlike the first level, there are no loops generated as we already distributed them to workitems.

Okay now in the innermost level we have a linalg.conv_2d_nhwc_hwcf op working on a 1x1x4x4 Output batch. That’s exactly the smaller scale case we are looking for to vectorize.

Final vectorization and unrolling

We can directly vectorize the inner linalg.conv_2d_nhwc_hwcf op now following the logic in vectorization and unrolling section.

Right now the pattern is implemented in IREE; I’ll upstream it to the MLIR repo later. It’s relatively straightforward given we are handling a well-defined small-scale problem here. With it and a bunch of other patterns pulled in the SPIRVVectorize pass, we have:

// -----// IR Dump After SPIRVVectorize //----- //
func @conv_dispatch_0() {
  %c0 = constant 0 : index
  %c112 = constant 112 : index
  %c3 = constant 3 : index
  %c1 = constant 1 : index
  %cst = constant dense<0.000000e+00> : vector<1x1x4x4xf32>
  %c4 = constant 4 : index
  %c2 = constant 2 : index
  %c6 = constant 6 : index
  %cst_0 = constant 0.000000e+00 : f32
  %cst_1 = constant dense<0.000000e+00> : vector<1x4xf32>
  %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<1x225x225x3xf32>
  %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<3x3x3x32xf32>
  %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<1x112x112x32xf32>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_id_z = hal.interface.workgroup.id[2] : index
  %3 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
  %4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
  %5 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_y]
  %6 = memref.subview %1[0, 0, 0, %4] [3, 3, 3, 32] [1, 1, 1, 1] : memref<3x3x3x32xf32> to memref<3x3x3x32xf32, ...>
  %7 = "gpu.thread_id"() {dimension = "x"} : () -> index
  %8 = "gpu.thread_id"() {dimension = "y"} : () -> index
  %9 = "gpu.thread_id"() {dimension = "z"} : () -> index
  %10 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%8]
  %11 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%7]
  %12 = vector.extract_strided_slice %cst {offsets = [0, 0, 0, 0], sizes = [1, 1, 1, 4], strides = [1, 1, 1, 1]}
        : vector<1x1x4x4xf32> to vector<1x1x1x4xf32>
  %13 = vector.extract_strided_slice %cst {offsets = [0, 0, 1, 0], sizes = [1, 1, 1, 4], strides = [1, 1, 1, 1]}
        : vector<1x1x4x4xf32> to vector<1x1x1x4xf32>
  %14 = vector.extract_strided_slice %cst {offsets = [0, 0, 2, 0], sizes = [1, 1, 1, 4], strides = [1, 1, 1, 1]}
        : vector<1x1x4x4xf32> to vector<1x1x1x4xf32>
  %15 = vector.extract_strided_slice %cst {offsets = [0, 0, 3, 0], sizes = [1, 1, 1, 4], strides = [1, 1, 1, 1]}
        : vector<1x1x4x4xf32> to vector<1x1x1x4xf32>
  %16 = "gpu.thread_id"() {dimension = "x"} : () -> index
  %17 = "gpu.thread_id"() {dimension = "y"} : () -> index
  %18 = "gpu.thread_id"() {dimension = "z"} : () -> index
  %19 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%17]
  %20 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%16]
  %21 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%18]
  %22 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%17]
  %23 = memref.subview %6[0, 0, 0, %20] [3, 3, 3, 4] [1, 1, 1, 1] : memref<3x3x3x32xf32, ...> to memref<3x3x3x4xf32, ...>
  scf.for %arg0 = %workgroup_id_z to %c112 step %c112 {
    %24 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
    %25 = memref.subview %0[0, %24, %5, 0] [1, 3, 17, 3] [1, 1, 1, 1] : memref<1x225x225x3xf32> to memref<1x3x17x3xf32, ...>
    %26 = memref.subview %2[0, %arg0, %3, %4] [1, 1, 8, 32] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x1x8x32xf32, ...>
    %27 = memref.subview %26[0, %9, %10, %11] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x8x32xf32, ...> to memref<1x1x4x4xf32, ...>
    vector.transfer_write %12, %27[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x1x4x4xf32, ...>
    vector.transfer_write %13, %27[%c0, %c0, %c1, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x1x4x4xf32, ...>
    vector.transfer_write %14, %27[%c0, %c0, %c2, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x1x4x4xf32, ...>
    vector.transfer_write %15, %27[%c0, %c0, %c3, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x1x4x4xf32, ...>
    %28 = memref.subview %25[0, %21, %22, 0] [1, 3, 9, 3] [1, 1, 1, 1] : memref<1x3x17x3xf32, ...> to memref<1x3x9x3xf32, ...>
    %29 = memref.subview %26[0, %18, %19, %20] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x8x32xf32, ...> to memref<1x1x4x4xf32, ...>
    %30 = vector.transfer_read %29[%c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x4x4xf32, ...>, vector<1x4xf32>
    %31 = vector.transfer_read %29[%c0, %c0, %c1, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x4x4xf32, ...>, vector<1x4xf32>
    %32 = vector.transfer_read %29[%c0, %c0, %c2, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x4x4xf32, ...>, vector<1x4xf32>
    %33 = vector.transfer_read %29[%c0, %c0, %c3, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x4x4xf32, ...>, vector<1x4xf32>
    %34:4 = scf.for %arg1 = %c0 to %c3 step %c1 iter_args(%arg2 = %30, %arg3 = %31, %arg4 = %32, %arg5 = %33)
            -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) {
      %35:4 = scf.for %arg6 = %c0 to %c3 step %c1 iter_args(%arg7 = %arg2, %arg8 = %arg3, %arg9 = %arg4, %arg10 = %arg5)
              -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) {
        %36 = memref.subview %28[0, %arg1, %arg6, 0] [1, 1, 7, 3] [1, 1, 1, 1] : memref<1x3x9x3xf32, ...> to memref<1x1x7x3xf32, ...>
        %37 = memref.subview %23[%arg1, %arg6, 0, 0] [1, 1, 3, 4] [1, 1, 1, 1] : memref<3x3x3x4xf32, ...> to memref<1x1x3x4xf32, ...>
        %38 = vector.transfer_read %37[%c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x3x4xf32, ...>, vector<1x4xf32>
        %39 = vector.transfer_read %37[%c0, %c0, %c1, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x3x4xf32, ...>, vector<1x4xf32>
        %40 = vector.transfer_read %37[%c0, %c0, %c2, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x3x4xf32, ...>, vector<1x4xf32>
        %41 = vector.transfer_read %36[%c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x7x3xf32, ...>, vector<1x3xf32>
        %42 = vector.extract_strided_slice %41 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
        %43 = vector.extract %38[0] : vector<1x4xf32>
        %44 = vector.extract %42[0, 0] : vector<1x1xf32>
        %45 = splat %44 : vector<4xf32>
        %46 = vector.extract %arg7[0] : vector<1x4xf32>
        %47 = vector.fma %45, %43, %46 : vector<4xf32>
        %48 = vector.extract_strided_slice %41 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
        %49 = vector.extract %39[0] : vector<1x4xf32>
        %50 = vector.extract %48[0, 0] : vector<1x1xf32>
        %51 = splat %50 : vector<4xf32>
        %52 = vector.fma %51, %49, %47 : vector<4xf32>
        %53 = vector.extract_strided_slice %41 {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
        %54 = vector.extract %40[0] : vector<1x4xf32>
        %55 = vector.extract %53[0, 0] : vector<1x1xf32>
        %56 = splat %55 : vector<4xf32>
        %57 = vector.fma %56, %54, %52 : vector<4xf32>
        %58 = vector.insert %57, %cst_1 [0] : vector<4xf32> into vector<1x4xf32>
        %59 = vector.transfer_read %36[%c0, %c0, %c2, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x7x3xf32, ...>, vector<1x3xf32>
        %60 = vector.extract_strided_slice %59 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
        %61 = vector.extract %38[0] : vector<1x4xf32>
        %62 = vector.extract %60[0, 0] : vector<1x1xf32>
        %63 = splat %62 : vector<4xf32>
        %64 = vector.extract %arg8[0] : vector<1x4xf32>
        %65 = vector.fma %63, %61, %64 : vector<4xf32>
        %66 = vector.extract_strided_slice %59 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
        %67 = vector.extract %39[0] : vector<1x4xf32>
        %68 = vector.extract %66[0, 0] : vector<1x1xf32>
        %69 = splat %68 : vector<4xf32>
        %70 = vector.fma %69, %67, %65 : vector<4xf32>
        %71 = vector.extract_strided_slice %59 {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
        %72 = vector.extract %40[0] : vector<1x4xf32>
        %73 = vector.extract %71[0, 0] : vector<1x1xf32>
        %74 = splat %73 : vector<4xf32>
        %75 = vector.fma %74, %72, %70 : vector<4xf32>
        %76 = vector.insert %75, %cst_1 [0] : vector<4xf32> into vector<1x4xf32>
        %77 = vector.transfer_read %36[%c0, %c0, %c4, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x7x3xf32, ...>, vector<1x3xf32>
        %78 = vector.extract_strided_slice %77 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
        %79 = vector.extract %38[0] : vector<1x4xf32>
        %80 = vector.extract %78[0, 0] : vector<1x1xf32>
        %81 = splat %80 : vector<4xf32>
        %82 = vector.extract %arg9[0] : vector<1x4xf32>
        %83 = vector.fma %81, %79, %82 : vector<4xf32>
        %84 = vector.extract_strided_slice %77 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
        %85 = vector.extract %39[0] : vector<1x4xf32>
        %86 = vector.extract %84[0, 0] : vector<1x1xf32>
        %87 = splat %86 : vector<4xf32>
        %88 = vector.fma %87, %85, %83 : vector<4xf32>
        %89 = vector.extract_strided_slice %77 {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
        %90 = vector.extract %40[0] : vector<1x4xf32>
        %91 = vector.extract %89[0, 0] : vector<1x1xf32>
        %92 = splat %91 : vector<4xf32>
        %93 = vector.fma %92, %90, %88 : vector<4xf32>
        %94 = vector.insert %93, %cst_1 [0] : vector<4xf32> into vector<1x4xf32>
        %95 = vector.transfer_read %36[%c0, %c0, %c6, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x7x3xf32, ...>, vector<1x3xf32>
        %96 = vector.extract_strided_slice %95 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
        %97 = vector.extract %38[0] : vector<1x4xf32>
        %98 = vector.extract %96[0, 0] : vector<1x1xf32>
        %99 = splat %98 : vector<4xf32>
        %100 = vector.extract %arg10[0] : vector<1x4xf32>
        %101 = vector.fma %99, %97, %100 : vector<4xf32>
        %102 = vector.extract_strided_slice %95 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
        %103 = vector.extract %39[0] : vector<1x4xf32>
        %104 = vector.extract %102[0, 0] : vector<1x1xf32>
        %105 = splat %104 : vector<4xf32>
        %106 = vector.fma %105, %103, %101 : vector<4xf32>
        %107 = vector.extract_strided_slice %95 {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
        %108 = vector.extract %40[0] : vector<1x4xf32>
        %109 = vector.extract %107[0, 0] : vector<1x1xf32>
        %110 = splat %109 : vector<4xf32>
        %111 = vector.fma %110, %108, %106 : vector<4xf32>
        %112 = vector.insert %111, %cst_1 [0] : vector<4xf32> into vector<1x4xf32>
        scf.yield %58, %76, %94, %112 : vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>
      }
      scf.yield %35#0, %35#1, %35#2, %35#3 : vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>
    }
    vector.transfer_write %34#3, %29[%c0, %c0, %c3, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x1x4x4xf32, ...>
    vector.transfer_write %34#2, %29[%c0, %c0, %c2, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x1x4x4xf32, ...>
    vector.transfer_write %34#1, %29[%c0, %c0, %c1, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x1x4x4xf32, ...>
    vector.transfer_write %34#0, %29[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x1x4x4xf32, ...>
  }
  return
}

This is almost it! But the inputs/outputs are still not explicitly vectorized. We just need one additional step, done via SPIRVVectorizeLoadStore pass:

// -----// IR Dump After SPIRVVectorizeLoadStore //----- //
module  {
  func @conv_dispatch_0() {
    %cst = constant dense<0.000000e+00> : vector<1x4xf32>
    %c1 = constant 1 : index
    %c3 = constant 3 : index
    %c112 = constant 112 : index
    %c0 = constant 0 : index
    %c32 = constant 32 : index
    %c16 = constant 16 : index
    %c4 = constant 4 : index
    %c2 = constant 2 : index
    %c8 = constant 8 : index
    %cst_0 = constant dense<0.000000e+00> : vector<3xf32>
    %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<1x225x225x3xf32>
    %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<3x3x3x8xvector<4xf32>>
    %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<1x112x112x8xvector<4xf32>>
    %workgroup_id_x = hal.interface.workgroup.id[0] : index
    %workgroup_id_y = hal.interface.workgroup.id[1] : index
    %workgroup_id_z = hal.interface.workgroup.id[2] : index
    %3 = muli %workgroup_id_y, %c8 : index
    %4 = muli %workgroup_id_x, %c32 : index
    %5 = muli %workgroup_id_y, %c16 : index
    %6 = "gpu.thread_id"() {dimension = "x"} : () -> index
    %7 = "gpu.thread_id"() {dimension = "y"} : () -> index
    %8 = "gpu.thread_id"() {dimension = "z"} : () -> index
    %9 = muli %7, %c4 : index
    %10 = muli %6, %c4 : index
    %11 = muli %8, %c2 : index
    %12 = muli %7, %c8 : index
    scf.for %arg0 = %workgroup_id_z to %c112 step %c112 {
      %13 = muli %arg0, %c2 : index
      %14:4 = scf.for %arg1 = %c0 to %c3 step %c1 iter_args(%arg2 = %cst, %arg3 = %cst, %arg4 = %cst, %arg5 = %cst)
              -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) {
        %29:4 = scf.for %arg6 = %c0 to %c3 step %c1 iter_args(%arg7 = %arg2, %arg8 = %arg3, %arg9 = %arg4, %arg10 = %arg5)
                -> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) {
          %30 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%4, %10]
          %31 = divi_signed %30, %c4 : index
          %32 = memref.load %1[%arg1, %arg6, %c0, %31] : memref<3x3x3x8xvector<4xf32>>
          %33 = divi_signed %30, %c4 : index
          %34 = memref.load %1[%arg1, %arg6, %c1, %33] : memref<3x3x3x8xvector<4xf32>>
          %35 = divi_signed %30, %c4 : index
          %36 = memref.load %1[%arg1, %arg6, %c2, %35] : memref<3x3x3x8xvector<4xf32>>
          %37 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%arg1)[%13, %11]
          %38 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%arg6)[%5, %12]
          %39 = memref.load %0[%c0, %37, %38, %c0] : memref<1x225x225x3xf32>
          %40 = vector.insert %39, %cst_0 [0] : f32 into vector<3xf32>
          %41 = memref.load %0[%c0, %37, %38, %c1] : memref<1x225x225x3xf32>
          %42 = vector.insert %41, %40 [1] : f32 into vector<3xf32>
          %43 = memref.load %0[%c0, %37, %38, %c2] : memref<1x225x225x3xf32>
          %44 = vector.insert %43, %42 [2] : f32 into vector<3xf32>
          %45 = vector.extract_strided_slice %44 {offsets = [0], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
          %46 = vector.extract %45[0] : vector<1xf32>
          %47 = splat %46 : vector<4xf32>
          %48 = vector.shape_cast %arg7 : vector<1x4xf32> to vector<4xf32>
          %49 = vector.fma %47, %32, %48 : vector<4xf32>
          %50 = vector.extract_strided_slice %44 {offsets = [1], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
          %51 = vector.extract %50[0] : vector<1xf32>
          %52 = splat %51 : vector<4xf32>
          %53 = vector.fma %52, %34, %49 : vector<4xf32>
          %54 = vector.extract_strided_slice %44 {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
          %55 = vector.extract %54[0] : vector<1xf32>
          %56 = splat %55 : vector<4xf32>
          %57 = vector.fma %56, %36, %53 : vector<4xf32>
          %58 = vector.shape_cast %57 : vector<4xf32> to vector<1x4xf32>
          %59 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 + s2 + 2)>()[%5, %12, %arg6]
          %60 = memref.load %0[%c0, %37, %59, %c0] : memref<1x225x225x3xf32>
          %61 = vector.insert %60, %cst_0 [0] : f32 into vector<3xf32>
          %62 = memref.load %0[%c0, %37, %59, %c1] : memref<1x225x225x3xf32>
          %63 = vector.insert %62, %61 [1] : f32 into vector<3xf32>
          %64 = memref.load %0[%c0, %37, %59, %c2] : memref<1x225x225x3xf32>
          %65 = vector.insert %64, %63 [2] : f32 into vector<3xf32>
          %66 = vector.extract_strided_slice %65 {offsets = [0], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
          %67 = vector.extract %66[0] : vector<1xf32>
          %68 = splat %67 : vector<4xf32>
          %69 = vector.shape_cast %arg8 : vector<1x4xf32> to vector<4xf32>
          %70 = vector.fma %68, %32, %69 : vector<4xf32>
          %71 = vector.extract_strided_slice %65 {offsets = [1], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
          %72 = vector.extract %71[0] : vector<1xf32>
          %73 = splat %72 : vector<4xf32>
          %74 = vector.fma %73, %34, %70 : vector<4xf32>
          %75 = vector.extract_strided_slice %65 {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
          %76 = vector.extract %75[0] : vector<1xf32>
          %77 = splat %76 : vector<4xf32>
          %78 = vector.fma %77, %36, %74 : vector<4xf32>
          %79 = vector.shape_cast %78 : vector<4xf32> to vector<1x4xf32>
          %80 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 + s2 + 4)>()[%5, %12, %arg6]
          %81 = memref.load %0[%c0, %37, %80, %c0] : memref<1x225x225x3xf32>
          %82 = vector.insert %81, %cst_0 [0] : f32 into vector<3xf32>
          %83 = memref.load %0[%c0, %37, %80, %c1] : memref<1x225x225x3xf32>
          %84 = vector.insert %83, %82 [1] : f32 into vector<3xf32>
          %85 = memref.load %0[%c0, %37, %80, %c2] : memref<1x225x225x3xf32>
          %86 = vector.insert %85, %84 [2] : f32 into vector<3xf32>
          %87 = vector.extract_strided_slice %86 {offsets = [0], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
          %88 = vector.extract %87[0] : vector<1xf32>
          %89 = splat %88 : vector<4xf32>
          %90 = vector.shape_cast %arg9 : vector<1x4xf32> to vector<4xf32>
          %91 = vector.fma %89, %32, %90 : vector<4xf32>
          %92 = vector.extract_strided_slice %86 {offsets = [1], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
          %93 = vector.extract %92[0] : vector<1xf32>
          %94 = splat %93 : vector<4xf32>
          %95 = vector.fma %94, %34, %91 : vector<4xf32>
          %96 = vector.extract_strided_slice %86 {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
          %97 = vector.extract %96[0] : vector<1xf32>
          %98 = splat %97 : vector<4xf32>
          %99 = vector.fma %98, %36, %95 : vector<4xf32>
          %100 = vector.shape_cast %99 : vector<4xf32> to vector<1x4xf32>
          %101 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 + s2 + 6)>()[%5, %12, %arg6]
          %102 = memref.load %0[%c0, %37, %101, %c0] : memref<1x225x225x3xf32>
          %103 = vector.insert %102, %cst_0 [0] : f32 into vector<3xf32>
          %104 = memref.load %0[%c0, %37, %101, %c1] : memref<1x225x225x3xf32>
          %105 = vector.insert %104, %103 [1] : f32 into vector<3xf32>
          %106 = memref.load %0[%c0, %37, %101, %c2] : memref<1x225x225x3xf32>
          %107 = vector.insert %106, %105 [2] : f32 into vector<3xf32>
          %108 = vector.extract_strided_slice %107 {offsets = [0], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
          %109 = vector.extract %108[0] : vector<1xf32>
          %110 = splat %109 : vector<4xf32>
          %111 = vector.shape_cast %arg10 : vector<1x4xf32> to vector<4xf32>
          %112 = vector.fma %110, %32, %111 : vector<4xf32>
          %113 = vector.extract_strided_slice %107 {offsets = [1], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
          %114 = vector.extract %113[0] : vector<1xf32>
          %115 = splat %114 : vector<4xf32>
          %116 = vector.fma %115, %34, %112 : vector<4xf32>
          %117 = vector.extract_strided_slice %107 {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
          %118 = vector.extract %117[0] : vector<1xf32>
          %119 = splat %118 : vector<4xf32>
          %120 = vector.fma %119, %36, %116 : vector<4xf32>
          %121 = vector.shape_cast %120 : vector<4xf32> to vector<1x4xf32>
          scf.yield %58, %79, %100, %121 : vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>
        }
        scf.yield %29#0, %29#1, %29#2, %29#3 : vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>
      }
      %15 = vector.shape_cast %14#3 : vector<1x4xf32> to vector<4xf32>
      %16 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %8]
      %17 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 + 3)>()[%3, %9]
      %18 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%4, %10]
      %19 = divi_signed %18, %c4 : index
      memref.store %15, %2[%c0, %16, %17, %19] : memref<1x112x112x8xvector<4xf32>>
      %20 = vector.shape_cast %14#2 : vector<1x4xf32> to vector<4xf32>
      %21 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 + 2)>()[%3, %9]
      %22 = divi_signed %18, %c4 : index
      memref.store %20, %2[%c0, %16, %21, %22] : memref<1x112x112x8xvector<4xf32>>
      %23 = vector.shape_cast %14#1 : vector<1x4xf32> to vector<4xf32>
      %24 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 + 1)>()[%3, %9]
      %25 = divi_signed %18, %c4 : index
      memref.store %23, %2[%c0, %16, %24, %25] : memref<1x112x112x8xvector<4xf32>>
      %26 = vector.shape_cast %14#0 : vector<1x4xf32> to vector<4xf32>
      %27 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%3, %9]
      %28 = divi_signed %18, %c4 : index
      memref.store %26, %2[%c0, %16, %27, %28] : memref<1x112x112x8xvector<4xf32>>
    }
    return
  }
  hal.interface private @io {
    hal.interface.binding public @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
    hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
    hal.interface.binding public @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
  }
}

Now all inputs/outputs load/store and computation are fully vectorized!

The final step is to convert it to the spirv dialect so that we can serialize the kernel and send it to the GPU for execution. It’s mostly mechanical so I’ll omit that here for simplicity.

Summary

In this blog post I introduced an approach to direct tile and vectorize convolution ops. It suits GPU architectures, particularly for ARM Mali GPUs. Hopefully the IR snippet walkthrough can serve as a good example of how one can leverage various dialects/patterns/transforms/utilities in MLIR to implement CodeGen flows for high-level ML ops.


  1. I know I’m dancing around the cliff here by making this analogy. 😊 But I’m assuming familiarity with CPU/GPU thread differences and nuances. If that’s not the case, please feel free to search the Internet as there are lots of great articles about this topic. ↩︎

  2. 14 (MP) * 2 (SIMD) * 16 (SIMD width) * 2 (FMA) * 850M (Hz) = 761600M ~= 760GFLops ↩︎

  3. The configuration are selected and annotated as attributes to the IR by the SPIRVLowerExecutableTargetPass pass if your are curious. I’ve omitted it here for simplicity. ↩︎