This blog post talks about how to generate performant code for convolution ops using MLIR’s multiple levels of abstractions and transformations. I initially created it for targeting ARM Mali GPUs in IREE. But given it is just direct tiling and vectorization, it should be widely applicable.
I will walk through the lowering steps, so if you are interested to know how to organize MLIR’s various dialects/patterns together to achieve similar tasks, this blog post might also be useful.
Background and Scope
The goal is to generate performant code for convolution ops to target mobile GPUs. The problem spaces on both sides are huge, therefore it’s worth limiting our scope so that we can have the best solution for each well-defined subproblem and bite off the big one piece by piece.
Convolution ops
There are many flavors of convolution ops. They can have vastly different data access patterns and some variants (e.g., depthwise/dilated convolution) are inherently adversarial for GPU utilization (without layout adjustment and other optimizations).
In this blog post, I will be focusing on the basic version of convolution,
which is still widely used in various mobile vision models.
To make sure we are on the same page regarding terminology,
specifically, we apply a 4-D filter tensor Filter
on a 4-D input tensor
Input
to generate the 4-D output tensor Output
. Input
follows layout
(N
, IH
, IW
, IC
). Filter
follows layout (FH
, FW
, IC
, OC
).
Output
follows layout (N
, OH
, OW
, OC
). For now we assume there is
no padding and dilation. (Handling padding is complicated enough to merit its
own blog posts.)
Mobile GPUs
In the mobile world, GPUs can come from several vendors. Apple has its own solution for iOS devices. For Android, Qualcomm Adreno and ARM Mali are notable ones. There are also others like Imagination PowerVR, and AMD is entering the mobile world too. Although they are all tiled architectures so we can target all of them uniformly, they differ on important characteristics to require special treatment here and there.
For example, here we will focus on Mali GPUs, which do not have dedicated on-chip shared memory. You can still use shared memory for sure, but it’s just plain normal system memory. So optimizations like promoting to shared memory (which we commonly see when targeting NVIDIA/AMD GPUs) actually won’t be helpful or even can be harmful to performance.
General GPU kernel optimization
The goal of GPU kernel optimization is to approach theoretical peak performance as much as possible. Pure theoretical peak performance is hard to achieve, if ever possible. In reality, arguably 50%-80% is already good enough. And there are diminishing returns as we approach closer and closer to the peak.
In general, approaching the theoretical peak means to saturate computation units, which requires to 1) prepare data fast and 2) issue enough computation.
For the first, commonly we want to 1) exploit data locality, which typically means tiling. On top of that we want to have 1b) good load/store patterns. For global memory, that is to perform 4-element load/store cyclically for adjacent GPU threads. This is particularly important to Mali GPUs. Then we want to 1c) reuse data in fast memory. For Mali GPUs we don’t have dedicated shared memory so that leaves us to only registers. We can preload shareable data and cache intermediate data in registers. But we also want to control the number of registers each thread uses to avoid spilling. There is a trade off here (as always).
For the second, lots of tricks. But in general we prefer streamlined vectorized code with as little flow control as possible. This way we can fill the GPU compute pipelines as much as possible without many bubbles. GPUs are parallel machines with thousands of threads easily, but each one of them is just a very naive slow in-order “CPU thread”1.
Overall CodeGen Strategy
Based on the previous section, clearly there is a large gap between what we see as the source convolution op and what the target hardware excels at executing. The former is quite high-level, abstract, and dynamic. The latter wants low-level, concrete, and static. Our task here is to bridge the gap. It’s doable in one step; we can hand write whatever GPU kernel code for whatever source convolution ops. The difficulty is composability, extensibility, and maintainability. It is painful for others to understand or modify.
Here are where compiler code generation and particularly the multiple-level nature of MLIR are great. Like how we typically approach all computer problems, we can create different abstractions to break down the problem and create solutions for isolated tasks to make them composible, extensible, and easy to maintain.
Looking at the problem we have at hand again, a convolution op is a standalone entity computing on n-D tensors. We want to break it down to the level that GPU favors—computation over 4-element 1-D vectors. If it’s just one thread, we need to wrap the 1-D vector computation inside some loops so that we can solve the whole original workload. But GPUs provide so many threads. We can then divide and conquer by partitioning the original workload and distribute them to different workgroups/workitems.
Tiling and distribution
Therefore, the general idea is to tile and distribute the original convolution op. We have three levels in the compute hierarchy: workgroup, subgroup/warp, and workitem. As said before for Mali GPUs there is no dedicated on-chip shared memory so it’s not helpful to utilize that as a way to exchange data among different subgroups. So we will just tile and distribute to workgroups and workitems.
This tiling and distribution should partition the convolution Output
, as
all its dimensions are parallel ones—each output element can be computed
independently from others. This is great for GPU, which is just a massive
parallel machine that was initially designed to shade pixels on screens.
Vectorization and unrolling
After tiling and distribution, each GPU thread just handles a subset of the
Output
elements, effectively a much smaller scale convolution here.
Considering the general kernel optimizations, tiling helps to exploit data
locality. At the workitem level, we need to consider the rest.
We’d want vectorized code here, both for load/store and computation.
We’d want to use registers to cache commonly used inputs and intermediate
results. We’d want to unroll loops to generate enough streamlined code.
Let’s see how we can achieve these goals.
- Each workitem should fully compute all convolution
Output
elements it is responsible for in one run. AllOutput
elements are independent so the workitem doesn’t need to wait for anything. Fully computing allows us to use registers to hold intermediate results and only write out once finally. - Convolution
Filter
is needed for computing allOutput
elements. We can preload it to registers to reduce memory requests and boost reuse. - In order to have a good vectorized memory load/store pattern, the distribution
should be cyclic and each thread should be in charge of 4 consecutive
elements in memory.
That is, thread (
0
/1
/2
/3
/etc.,y
,z
) should handle elements0
-3
/4
-7
/8
-11
/12
-15
/etc. This is actually where the convolution layout matters. As said before, our convolutionInput
/Filter
/Output
follows the (N
,IH
,IW
,IC
)/(FH
,FW
,IC
,OC
)/(N
,OH
,OW
,OC
) layout. We partition theOutput
. So we can achieve the nice access pattern forOutput
and thereforeFilter
(which also hasOC
as its innermost dimension), but notInput
.Input
hasIC
as its innermost dimension. That’s a reduction dimension; a thread needs to read through the full memory span consecutively. Often it’s not as small as 4. (For vision models, it’s typically larger and larger as we extract more and more high-level features.) So we don’t have consecutive threads always reading consecutive 4-element chunks. That breaks the pattern. - Vectorized computation is sort of natural if we can distribute properly (making sure each thread handles 4x elements for the innermost dimension) and perform vectorized load/store.
- For unrolling, among all the dimensions,
OH
,OW
, andOC
are sufficient. We partitioned along these dimensions so they have known small static values. Therefore, both feasible and controllable. We can materialize loops for other dimensions (FH
,FW
, andIC
).
Putting the above together, here is a sketch of the vectorized kernel for one workitem (taken from here):
// Each thread/invocation calculates (IVC_OH * IVC_OW * IVC_OC * 4) output elements.
VEC4TYPE O[IVC_OH][IVC_OW][IVC_OC];
// Use registers to keep the filter for this tile to increase data reuse.
VEC4TYPE F[4][IVC_OC];
uvec3 wgID = gl_WorkGroupID;
uvec3 threadID = gl_LocalInvocationID;
uvec3 threadCount = gl_WorkGroupSize;
uint wgBaseOC = wgID.x * WG_TILE_OC; // Workgroup base output channel
uint wgBaseOW = wgID.y * WG_TILE_OW; // Workgroup base output width
uint wgBaseOH = wgID.z * WG_TILE_OH; // Workgroup base output height
// Initialize the output for this batch to zero.
[[unroll]] for (uint i = 0; i < IVC_OH; ++i)
[[unroll]] for (uint j = 0; j < IVC_OW; ++j)
[[unroll]] for (uint k = 0; k < IVC_OC; ++k)
O[i][j][k] = VEC4TYPE(0.f, 0.f, 0.f, 0.f);
for (uint fh = 0; fh < FH; ++fh) {
for (uint fw = 0; fw < FW; ++fw) {
// Tile input channel with each tile having 4 elements.
for (uint ic = 0; ic < IC; ic += 4) {
// Load the filter for this input channel tile.
[[unroll]] for (uint i = 0; i < 4; ++i)
[[unroll]] for (uint j = 0; j < IVC_OC; ++j)
uint oc = (threadID.x + threadCount.x * j) * 4 + wgBaseOC;
F[i][j] = Filter.data[filterCoordToOffset(fh, fw, ic + i, oc)];
// Load this input channel tile and perform dot product with filters
// for different output elements.
[[unroll]] for (uint i = 0; i < IVC_OH; ++i) {
uint oh = i + threadID.z * IVC_OH + wgBaseOH;
[[unroll]] for (uint j = 0; j < IVC_OW; ++j) {
uint ow = j + threadID.y * IVC_OW + wgBaseOW;
VEC4TYPE feature = Input.data[inputCoordToOffset(oh * SH + fh, ow * SW + fw, ic)];
[[unroll]] for (uint k = 0; k < IVC_OC; ++k) {
O[i][j][k] += VEC4TYPE(feature.x, feature.x, feature.x, feature.x) * F[0][k];
O[i][j][k] += VEC4TYPE(feature.y, feature.y, feature.y, feature.y) * F[1][k];
O[i][j][k] += VEC4TYPE(feature.z, feature.z, feature.z, feature.z) * F[2][k];
O[i][j][k] += VEC4TYPE(feature.w, feature.w, feature.w, feature.w) * F[3][k];
}
}
}
}
}
}
// Write out the computed output elements.
[[unroll]] for (uint i = 0; i < IVC_OH; ++i) {
uint oh = i + threadID.z * IVC_OH + wgBaseOH;
[[unroll]] for (uint j = 0; j < IVC_OW; ++j) {
uint ow = j + threadID.y * IVC_OW + wgBaseOW;
[[unroll]] for (uint k = 0; k < IVC_OC; ++k) {
uint oc = (threadID.x + threadCount.x * k) * 4 + wgBaseOC;
Output.data[outputCoordToOffset(oh, ow, oc)] = O[i][j][k];
}
}
}
In the above, we keep the current Output
batch in O
and preload the
Filter
batch in F
. We loop over FH
and FW
, and perform FMA along IC
.
For IC
, although we cannot achieve the perfect memory load pattern, we still
try to handle 4 elements each time. This isn’t strictly required though.
Effectively we can read one scalar element each time and FMA it to each Output
element. It supports cases like IC
== 3, which happens for the initial image
where we typically have 3 channels (RGB).
The above code is taken from the µVkCompute, where one can directly write kernels with simple Vulkan compute pipelines to try out different CodeGen strategies. You can find the runnable code there, which also has a packed fp16 version. Invoking it on Samsung Galaxy S21 (Exynos 2100, Mali G78 MP14):
> adb shell /data/local/tmp/conv2d_mali_valhall --latency_measure_mode=gpu_timestamp
2021-09-12T16:03:26-04:00
Running /data/local/tmp/conv2d_mali_valhall
Run on (8 X 2210 MHz CPU s)
***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead.
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations UserCounters...
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[1x16x16]/WGSize[4x4x1]/f32/manual_time 13431 us 837 us 53 FLOps=359.761G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[2x8x16]/WGSize[4x4x1]/f32/manual_time 13489 us 711 us 52 FLOps=358.216G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[4x4x16]/WGSize[4x4x1]/f32/manual_time 14216 us 757 us 49 FLOps=339.894G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[2x8x16]/WGSize[4x2x2]/f32/manual_time 13337 us 433 us 53 FLOps=362.281G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[4x4x16]/WGSize[4x2x2]/f32/manual_time 13681 us 678 us 51 FLOps=353.184G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[8x2x16]/WGSize[4x2x2]/f32/manual_time 13544 us 550 us 52 FLOps=356.754G/s
...
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[2x4x32]/WGSize[4x2x2]/f16/manual_time 8405 us 520 us 83 FLOps=574.908G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[2x8x32]/WGSize[4x2x2]/f16/manual_time 6019 us 412 us 119 FLOps=802.734G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[4x4x32]/WGSize[4x2x2]/f16/manual_time 6243 us 846 us 117 FLOps=773.904G/s
Mali-G78/Input[1x258x258x16]xFilter[3x3x16x256]/Stride[1x1]/Tile[8x2x32]/WGSize[4x2x2]/f16/manual_time 6169 us 794 us 119 FLOps=783.25G/s
...
The theoretical peak is roughly 7602/1520 GFLOps for fp32/fp16. So this is achieving 50% utilization. Not so bad with straightforward tiling and direct vectorization.
Okay, thus far we have a clear overall CodeGen strategy and know what the inner kernel should be. Let’s see how we can generate in such a way with MLIR.
The MLIR/IREE CodeGen Flow
In MLIR we have many dialects, modelling abstractions at different levels.
The ones involved here are mhlo
, linalg
,
vector
, scf
, and spirv
.
I won’t go into details about them; you can reference their documentation by
following the embedded links. Just mentioning how they are used in our flow:
mhlo
: input dialect for models authored in TensorFlow.linalg
: core dialect that models linalg algebra computations on n-D tensors or buffers. We use to it perform tiling and distribution.vector
: core dialect for modelling n-D static-shaped vectors. We use it to perform vectorization and vector level optimizations like load-store forwarding.scf
: for loops around computation.spirv
: final output dialect for generated kernels.
The journey starts with the following source convolution:
func @conv(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32xf32>)
-> tensor<1x112x112x32xf32> {
%0 = mhlo.convolution(%input, %filter)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [2, 2], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
: (tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>) -> tensor<1x112x112x32xf32>
return %0 : tensor<1x112x112x32xf32>
}
If you’d like to see the flow by yourself, you can compile
iree-translate
and invoke it on the above function with the
following command:
iree/tools/iree-translate \
-iree-input-type=mhlo \
-iree-mlir-to-vm-bytecode-module \
-iree-hal-target-backends=vulkan-spirv \
-iree-vulkan-target-triple=valhall-unknown-android11 \
-print-ir-after-all \
-mlir-print-local-scope \
conv.mlir -o iree.vmfb &> conv-conversion.mlir
The IREE codebase changes very quickly. I use IREE@080cbc46
in
this blog post (for both IR snippets and code pointers). If you’d like to see
the exact same result, you can check that commit out.
I also put the full dump in a Gist that you can use as a reference.
mhlo
to linalg
conversion
The first major step is to convert mhlo.convolution
into a linalg.conv
op.
This is done via the ConvertMHLOToLinalgOnTensors
pass.
// -----// IR Dump After ConvertMHLOToLinalgOnTensors //----- //
func @conv(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view)
-> !hal.buffer_view attributes {iree.abi.stub} {
%0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<1x225x225x3xf32>
%1 = hal.tensor.cast %arg1 : !hal.buffer_view -> tensor<3x3x3x32xf32>
%2 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
%cst = constant 0.000000e+00 : f32
%3 = linalg.fill(%cst, %2) : f32, tensor<1x112x112x32xf32> -> tensor<1x112x112x32xf32>
%4 = linalg.conv_2d_nhwc_hwcf {
dilations = dense<1> : tensor<2xi64>,
strides = dense<2> : tensor<2xi64>
}
ins(%0, %1 : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>)
outs(%3 : tensor<1x112x112x32xf32>)
-> tensor<1x112x112x32xf32>
%5 = hal.tensor.cast %4 : tensor<1x112x112x32xf32> -> !hal.buffer_view
return %5 : !hal.buffer_view
}
There isn’t much to say here. It’s basically pattern matching against the
source mhlo
convolution op, check its various attributes, and emit the
suitable linalg
convolution op.
Unlike mhlo
, where we have one single convolution
op supporting many
different configurations, in linalg
, there are many different ops for
different flavors of convolution, e.g., linalg.conv_1d_nwc_wcf
,
linalg.conv_2d_nhwc_hwcf
. They are all called named ops.
It’s simple to define a new named op in linalg
. They serve
as the anchor for pattern matching and specialization for CodeGen or library
calls.
But named ops aren’t core to linalg
; actually they are just named versions of
the linalg.generic
op of particular forms.
linalg.generic
has an implicit loop nest in it. It carries an affine map
for each operand/result. The affine map defines the access pattern to the
corresponding operand/result in the implicit loop nest.
The payload in the loop nest is explicitly captured as a MLIR region to the
linalg.generic
op.
Due to this structured representation, loop transformations are quite straightforward because by definition the op is a perfect loop nest and there is no explicit loop nest to match/manipulate for transformations. You can read more about this in the documentation.
Tiling and distributing to workgroups
Now we can start the first level tiling and distribution.
In IREE this is done via the DispatchLinalgOnTensors
pass. What we need to do is to utilize the
LinalgBaseTilingPattern
with a proper
LinalgTilingOptions
, which controls aspects like
- whether we want to tile/partition along each loop and what the tile size is,
- if we want to distribute, what the processor IDs are and how to distribute,
- what kind of loops to generate,
- etc.
You can find the example configuration for IREE’s workgroup tiling
here.
In IREE tiling to workgroup is initially done in an abstract way where we use
symbolic values like flow.dispatch.workgroup.size
for tiling sizes, and
flow.dispatch.workgroup.id
/flow.dispatch.workgroup.count
for processor
IDs/counts. Right now we generate scf.for
loops. So with the linked
configuration, the tiled code looks like:
// -----// IR Dump After DispatchLinalgOnTensors //----- //
func @conv(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view)
-> !hal.buffer_view attributes {iree.abi.stub} {
%c32 = constant 32 : index
%c112 = constant 112 : index
%0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<1x225x225x3xf32>
%1 = hal.tensor.cast %arg1 : !hal.buffer_view -> tensor<3x3x3x32xf32>
%2 = flow.dispatch.workgroups[%c32, %c112, %c112](%0, %1)
: (tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>) -> tensor<1x112x112x32xf32> =
(%arg2: !flow.dispatch.tensor<readonly:1x225x225x3xf32>,
%arg3: !flow.dispatch.tensor<readonly:3x3x3x32xf32>,
%arg4: !flow.dispatch.tensor<writeonly:1x112x112x32xf32>) {
%cst = constant 0.000000e+00 : f32
%c112_0 = constant 112 : index
%c32_1 = constant 32 : index
%4 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
%workgroup_size_0 = flow.dispatch.workgroup.size[0] : index
%workgroup_size_1 = flow.dispatch.workgroup.size[1] : index
%workgroup_size_2 = flow.dispatch.workgroup.size[2] : index
%workgroup_id_0 = flow.dispatch.workgroup.id[0] : index
%workgroup_count_0 = flow.dispatch.workgroup.count[0] : index
%workgroup_id_1 = flow.dispatch.workgroup.id[1] : index
%workgroup_count_1 = flow.dispatch.workgroup.count[1] : index
%workgroup_id_2 = flow.dispatch.workgroup.id[2] : index
%workgroup_count_2 = flow.dispatch.workgroup.count[2] : index
%5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_2, %workgroup_size_2]
%6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_2, %workgroup_size_2]
scf.for %arg5 = %5 to %c112_0 step %6 {
%7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_1, %workgroup_size_1]
%8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_1, %workgroup_size_1]
scf.for %arg6 = %7 to %c112_0 step %8 {
%9 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_0, %workgroup_size_0]
%10 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_0, %workgroup_size_0]
scf.for %arg7 = %9 to %c32_1 step %10 {
%11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg5)
%12 = affine.min affine_map<(d0, d1) -> (d0 * 2 + 1, d1 * -2 + 227)>(%workgroup_size_2, %arg5)
%13 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg6)
%14 = affine.min affine_map<(d0, d1) -> (d0 * 2 + 1, d1 * -2 + 227)>(%workgroup_size_1, %arg6)
%15 = flow.dispatch.tensor.load %arg2,
offsets = [0, %11, %13, 0], sizes = [1, %12, %14, 3], strides = [1, 1, 1, 1]
%16 = affine.min affine_map<(d0, d1) -> (d0, -d1 + 32)>(%workgroup_size_0, %arg7)
%17 = flow.dispatch.tensor.load %arg3,
offsets = [0, 0, 0, %arg7], sizes = [3, 3, 3, %16], strides = [1, 1, 1, 1]
%18 = affine.min affine_map<(d0, d1) -> (d0, -d1 + 112)>(%workgroup_size_2, %arg5)
%19 = affine.min affine_map<(d0, d1) -> (d0, -d1 + 112)>(%workgroup_size_1, %arg6)
%20 = affine.min affine_map<(d0, d1) -> (d0, -d1 + 32)>(%workgroup_size_0, %arg7)
%21 = affine.min affine_map<(d0, d1) -> (-d0 + 112, d1)>(%arg5, %workgroup_size_2)
%22 = affine.min affine_map<(d0, d1) -> (-d0 + 112, d1)>(%arg6, %workgroup_size_1)
%23 = affine.min affine_map<(d0, d1) -> (-d0 + 32, d1)>(%arg7, %workgroup_size_0)
%24 = tensor.extract_slice %4[0, %arg5, %arg6, %arg7] [1, %21, %22, %23] [1, 1, 1, 1]
: tensor<1x112x112x32xf32> to tensor<1x?x?x?xf32>
%25 = linalg.fill(%cst, %24) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
%26 = linalg.conv_2d_nhwc_hwcf {
__internal_linalg_transform__ = "workgroup",
dilations = dense<1> : tensor<2xi64>,
strides = dense<2> : tensor<2xi64>
}
ins(%15, %17 : tensor<1x?x?x3xf32>, tensor<3x3x3x?xf32>)
outs(%25 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
flow.dispatch.tensor.store %26, %arg4,
offsets = [0, %arg5, %arg6, %arg7], sizes = [1, %18, %19, %20], strides = [1, 1, 1, 1]
}
}
}
flow.return
}
%3 = hal.tensor.cast %2 : tensor<1x112x112x32xf32> -> !hal.buffer_view
return %3 : !hal.buffer_view
}
I’ve trimmed the raw output down a bit (by removing some type annotations), but still the IR snippet becomes lengthy from now on, given we are going to lower abstractions.
We can see in the above, there are three loops that partition along OH
, OW
,
and OC
dimensions, and distribute to workgroup z
, y
, x
dimensions.
We have a linalg.conv_2d_nhwc_hwcf
op working on a smaller scale tile inside
the loop nest, together with the linalg.fill
for its output initialization.
The affine.apply
/affine.min
ops inside the loop nest is for calculating
the tile indices and sizes to make sure we don’t go out of bound.
This performs the first level tiling. The abstract tiling is meant for handling both CPU and GPU in a uniform way; under such circumstances, we cannot determine the concrete configuration.
Afterwards, we are down to the CodeGen path for GPU specifically. We need to
start injecting static information to simplify the IR (particularly those
affine.apply
/affine.min
ops).
We can find a concrete tiling and workgroup size scheme to perfectly partition
the convolution Output
. Here, that’s using workgroup size (8, 2, 1) and
letting each workgroup process a 1x8x32
(OHxOWxOC
) Output
patch.3.
Performing bufferization and injecting such information, we can get:
// -----// IR Dump After SPIRVRemoveOneTripTiledLoop //----- //
func @conv_dispatch_0() {
%c0 = constant 0 : index
%c112 = constant 112 : index
%cst = constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<1x225x225x3xf32>
%1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<3x3x3x32xf32>
%2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<1x112x112x32xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_id_z = hal.interface.workgroup.id[2] : index
scf.for %arg0 = %workgroup_id_z to %c112 step %c112 {
%3 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%5 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
%6 = affine.min affine_map<(d0) -> (3, d0 * -2 + 227)>(%arg0)
%7 = affine.apply affine_map<(d0) -> (d0 * 2)>(%3)
%8 = affine.min affine_map<(d0) -> (17, d0 * -2 + 227)>(%3)
%9 = memref.subview %0[0, %5, %7, 0] [1, %6, %8, 3] [1, 1, 1, 1] : ...
%10 = affine.min affine_map<(d0) -> (32, -d0 + 32)>(%4)
%11 = memref.subview %1[0, 0, 0, %4] [3, 3, 3, %10] [1, 1, 1, 1] : ...
%12 = affine.min affine_map<(d0) -> (1, -d0 + 112)>(%arg0)
%13 = affine.min affine_map<(d0) -> (8, -d0 + 112)>(%3)
%14 = memref.subview %2[0, %arg0, %3, %4] [1, %12, %13, %10] [1, 1, 1, 1] : ...
linalg.fill(%cst, %14) ...
linalg.conv_2d_nhwc_hwcf {
__internal_linalg_transform__ = "workgroup",
dilations = dense<1> : tensor<2xi64>,
lowering.config = {tileSizes = [[0, 1, 8, 32], [], [0, 1, 4, 4]]},
strides = dense<2> : tensor<2xi64>
} ins(%9, %11 : ...) outs(%14 : ...>)
}
return
}
It’s much cleaner right now. The outer 1-trip loops are basically gone. Next up to the second level tiling and distributing to workitems.
Tiling and distributing to workitems
This is done by the SPIRVTileAndDistribute
pass.
It’s quite similar to the first level tiling and distribution. It’s just a
matter of setting the proper LinalgTilingOptions
.
In the above step we choose workgroup size (8, 2, 1) and let each workgroup
handle a 1x8x32
(OHxOWxOC
) Output
patch. So naturally, a thread handles a
1x4x4
Output
patch. Use such static information we can fold quite a few
affine
ops away and generate the following IR:
// -----// IR Dump After SPIRVTileAndDistribute //----- //
func @conv_dispatch_0() {
%c0 = constant 0 : index
%c112 = constant 112 : index
%cst = constant 0.000000e+00 : f32
%c3 = constant 3 : index
%c1 = constant 1 : index
%0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<1x225x225x3xf32>
%1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<3x3x3x32xf32>
%2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<1x112x112x32xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_id_z = hal.interface.workgroup.id[2] : index
scf.for %arg0 = %workgroup_id_z to %c112 step %c112 {
%3 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%5 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
%6 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_y]
%7 = memref.subview %0[0, %5, %6, 0] [1, 3, 17, 3] [1, 1, 1, 1]
: memref<1x225x225x3xf32> to memref<1x3x17x3xf32, ...>
%8 = memref.subview %1[0, 0, 0, %4] [3, 3, 3, 32] [1, 1, 1, 1]
: memref<3x3x3x32xf32> to memref<3x3x3x32xf32, ...>
%9 = memref.subview %2[0, %arg0, %3, %4] [1, 1, 8, 32] [1, 1, 1, 1]
: memref<1x112x112x32xf32> to memref<1x1x8x32xf32, ...>
%10 = "gpu.thread_id"() {dimension = "x"} : () -> index
%11 = "gpu.thread_id"() {dimension = "y"} : () -> index
%12 = "gpu.thread_id"() {dimension = "z"} : () -> index
%13 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%11]
%14 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%10]
%15 = memref.subview %9[0, %12, %13, %14] [1, 1, 4, 4] [1, 1, 1, 1]
: memref<1x1x8x32xf32, ...> to memref<1x1x4x4xf32, ...>
linalg.fill(%cst, %15) ...
%16 = "gpu.thread_id"() {dimension = "x"} : () -> index
%17 = "gpu.thread_id"() {dimension = "y"} : () -> index
%18 = "gpu.thread_id"() {dimension = "z"} : () -> index
%19 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%17]
%20 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%16]
%21 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%18]
%22 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%17]
%23 = memref.subview %7[0, %21, %22, 0] [1, 3, 9, 3] [1, 1, 1, 1]
: memref<1x3x17x3xf32, ...> to memref<1x3x9x3xf32, ...>
%24 = memref.subview %8[0, 0, 0, %20] [3, 3, 3, 4] [1, 1, 1, 1]
: memref<3x3x3x32xf32, ...> to memref<3x3x3x4xf32, ...>
%25 = memref.subview %9[0, %18, %19, %20] [1, 1, 4, 4] [1, 1, 1, 1]
: memref<1x1x8x32xf32, ...> to memref<1x1x4x4xf32, ...>
scf.for %arg1 = %c0 to %c3 step %c1 {
scf.for %arg2 = %c0 to %c3 step %c1 {
%26 = memref.subview %23[0, %arg1, %arg2, 0] [1, 1, 7, 3] [1, 1, 1, 1]
: memref<1x3x9x3xf32, ...> to memref<1x1x7x3xf32, ...>
%27 = memref.subview %24[%arg1, %arg2, 0, 0] [1, 1, 3, 4] [1, 1, 1, 1]
: memref<3x3x3x4xf32, ...> to memref<1x1x3x4xf32, ...>
linalg.conv_2d_nhwc_hwcf {
__internal_linalg_transform__ = "vectorize",
dilations = dense<1> : tensor<2xi64>,
lowering.config = {tileSizes = [[0, 1, 8, 32], [], [0, 1, 4, 4]]},
strides = dense<2> : tensor<2xi64>
}
ins(%26, %27 : memref<1x1x7x3xf32, ...>, memref<1x1x3x4xf32, ...>)
outs(%25 : memref<1x1x4x4xf32, ...>)
}
}
}
return
}
Unlike the first level, there are no loops generated as we already distributed them to workitems.
Okay now in the innermost level we have a linalg.conv_2d_nhwc_hwcf
op working
on a 1x1x4x4
Output
batch. That’s exactly the smaller scale case we are
looking for to vectorize.
Final vectorization and unrolling
We can directly vectorize the inner linalg.conv_2d_nhwc_hwcf
op now following
the logic in vectorization and unrolling section.
Right now the pattern is implemented in IREE; I’ll
upstream it to the MLIR repo later. It’s relatively straightforward given we are
handling a well-defined small-scale problem here. With it and a bunch of other
patterns pulled in the SPIRVVectorize
pass, we have:
// -----// IR Dump After SPIRVVectorize //----- //
func @conv_dispatch_0() {
%c0 = constant 0 : index
%c112 = constant 112 : index
%c3 = constant 3 : index
%c1 = constant 1 : index
%cst = constant dense<0.000000e+00> : vector<1x1x4x4xf32>
%c4 = constant 4 : index
%c2 = constant 2 : index
%c6 = constant 6 : index
%cst_0 = constant 0.000000e+00 : f32
%cst_1 = constant dense<0.000000e+00> : vector<1x4xf32>
%0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<1x225x225x3xf32>
%1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<3x3x3x32xf32>
%2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<1x112x112x32xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_id_z = hal.interface.workgroup.id[2] : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
%4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%5 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_y]
%6 = memref.subview %1[0, 0, 0, %4] [3, 3, 3, 32] [1, 1, 1, 1] : memref<3x3x3x32xf32> to memref<3x3x3x32xf32, ...>
%7 = "gpu.thread_id"() {dimension = "x"} : () -> index
%8 = "gpu.thread_id"() {dimension = "y"} : () -> index
%9 = "gpu.thread_id"() {dimension = "z"} : () -> index
%10 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%8]
%11 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%7]
%12 = vector.extract_strided_slice %cst {offsets = [0, 0, 0, 0], sizes = [1, 1, 1, 4], strides = [1, 1, 1, 1]}
: vector<1x1x4x4xf32> to vector<1x1x1x4xf32>
%13 = vector.extract_strided_slice %cst {offsets = [0, 0, 1, 0], sizes = [1, 1, 1, 4], strides = [1, 1, 1, 1]}
: vector<1x1x4x4xf32> to vector<1x1x1x4xf32>
%14 = vector.extract_strided_slice %cst {offsets = [0, 0, 2, 0], sizes = [1, 1, 1, 4], strides = [1, 1, 1, 1]}
: vector<1x1x4x4xf32> to vector<1x1x1x4xf32>
%15 = vector.extract_strided_slice %cst {offsets = [0, 0, 3, 0], sizes = [1, 1, 1, 4], strides = [1, 1, 1, 1]}
: vector<1x1x4x4xf32> to vector<1x1x1x4xf32>
%16 = "gpu.thread_id"() {dimension = "x"} : () -> index
%17 = "gpu.thread_id"() {dimension = "y"} : () -> index
%18 = "gpu.thread_id"() {dimension = "z"} : () -> index
%19 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%17]
%20 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%16]
%21 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%18]
%22 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%17]
%23 = memref.subview %6[0, 0, 0, %20] [3, 3, 3, 4] [1, 1, 1, 1] : memref<3x3x3x32xf32, ...> to memref<3x3x3x4xf32, ...>
scf.for %arg0 = %workgroup_id_z to %c112 step %c112 {
%24 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
%25 = memref.subview %0[0, %24, %5, 0] [1, 3, 17, 3] [1, 1, 1, 1] : memref<1x225x225x3xf32> to memref<1x3x17x3xf32, ...>
%26 = memref.subview %2[0, %arg0, %3, %4] [1, 1, 8, 32] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x1x8x32xf32, ...>
%27 = memref.subview %26[0, %9, %10, %11] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x8x32xf32, ...> to memref<1x1x4x4xf32, ...>
vector.transfer_write %12, %27[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x1x4x4xf32, ...>
vector.transfer_write %13, %27[%c0, %c0, %c1, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x1x4x4xf32, ...>
vector.transfer_write %14, %27[%c0, %c0, %c2, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x1x4x4xf32, ...>
vector.transfer_write %15, %27[%c0, %c0, %c3, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x1x4xf32>, memref<1x1x4x4xf32, ...>
%28 = memref.subview %25[0, %21, %22, 0] [1, 3, 9, 3] [1, 1, 1, 1] : memref<1x3x17x3xf32, ...> to memref<1x3x9x3xf32, ...>
%29 = memref.subview %26[0, %18, %19, %20] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x8x32xf32, ...> to memref<1x1x4x4xf32, ...>
%30 = vector.transfer_read %29[%c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x4x4xf32, ...>, vector<1x4xf32>
%31 = vector.transfer_read %29[%c0, %c0, %c1, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x4x4xf32, ...>, vector<1x4xf32>
%32 = vector.transfer_read %29[%c0, %c0, %c2, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x4x4xf32, ...>, vector<1x4xf32>
%33 = vector.transfer_read %29[%c0, %c0, %c3, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x4x4xf32, ...>, vector<1x4xf32>
%34:4 = scf.for %arg1 = %c0 to %c3 step %c1 iter_args(%arg2 = %30, %arg3 = %31, %arg4 = %32, %arg5 = %33)
-> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) {
%35:4 = scf.for %arg6 = %c0 to %c3 step %c1 iter_args(%arg7 = %arg2, %arg8 = %arg3, %arg9 = %arg4, %arg10 = %arg5)
-> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) {
%36 = memref.subview %28[0, %arg1, %arg6, 0] [1, 1, 7, 3] [1, 1, 1, 1] : memref<1x3x9x3xf32, ...> to memref<1x1x7x3xf32, ...>
%37 = memref.subview %23[%arg1, %arg6, 0, 0] [1, 1, 3, 4] [1, 1, 1, 1] : memref<3x3x3x4xf32, ...> to memref<1x1x3x4xf32, ...>
%38 = vector.transfer_read %37[%c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x3x4xf32, ...>, vector<1x4xf32>
%39 = vector.transfer_read %37[%c0, %c0, %c1, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x3x4xf32, ...>, vector<1x4xf32>
%40 = vector.transfer_read %37[%c0, %c0, %c2, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x3x4xf32, ...>, vector<1x4xf32>
%41 = vector.transfer_read %36[%c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x7x3xf32, ...>, vector<1x3xf32>
%42 = vector.extract_strided_slice %41 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
%43 = vector.extract %38[0] : vector<1x4xf32>
%44 = vector.extract %42[0, 0] : vector<1x1xf32>
%45 = splat %44 : vector<4xf32>
%46 = vector.extract %arg7[0] : vector<1x4xf32>
%47 = vector.fma %45, %43, %46 : vector<4xf32>
%48 = vector.extract_strided_slice %41 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
%49 = vector.extract %39[0] : vector<1x4xf32>
%50 = vector.extract %48[0, 0] : vector<1x1xf32>
%51 = splat %50 : vector<4xf32>
%52 = vector.fma %51, %49, %47 : vector<4xf32>
%53 = vector.extract_strided_slice %41 {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
%54 = vector.extract %40[0] : vector<1x4xf32>
%55 = vector.extract %53[0, 0] : vector<1x1xf32>
%56 = splat %55 : vector<4xf32>
%57 = vector.fma %56, %54, %52 : vector<4xf32>
%58 = vector.insert %57, %cst_1 [0] : vector<4xf32> into vector<1x4xf32>
%59 = vector.transfer_read %36[%c0, %c0, %c2, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x7x3xf32, ...>, vector<1x3xf32>
%60 = vector.extract_strided_slice %59 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
%61 = vector.extract %38[0] : vector<1x4xf32>
%62 = vector.extract %60[0, 0] : vector<1x1xf32>
%63 = splat %62 : vector<4xf32>
%64 = vector.extract %arg8[0] : vector<1x4xf32>
%65 = vector.fma %63, %61, %64 : vector<4xf32>
%66 = vector.extract_strided_slice %59 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
%67 = vector.extract %39[0] : vector<1x4xf32>
%68 = vector.extract %66[0, 0] : vector<1x1xf32>
%69 = splat %68 : vector<4xf32>
%70 = vector.fma %69, %67, %65 : vector<4xf32>
%71 = vector.extract_strided_slice %59 {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
%72 = vector.extract %40[0] : vector<1x4xf32>
%73 = vector.extract %71[0, 0] : vector<1x1xf32>
%74 = splat %73 : vector<4xf32>
%75 = vector.fma %74, %72, %70 : vector<4xf32>
%76 = vector.insert %75, %cst_1 [0] : vector<4xf32> into vector<1x4xf32>
%77 = vector.transfer_read %36[%c0, %c0, %c4, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x7x3xf32, ...>, vector<1x3xf32>
%78 = vector.extract_strided_slice %77 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
%79 = vector.extract %38[0] : vector<1x4xf32>
%80 = vector.extract %78[0, 0] : vector<1x1xf32>
%81 = splat %80 : vector<4xf32>
%82 = vector.extract %arg9[0] : vector<1x4xf32>
%83 = vector.fma %81, %79, %82 : vector<4xf32>
%84 = vector.extract_strided_slice %77 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
%85 = vector.extract %39[0] : vector<1x4xf32>
%86 = vector.extract %84[0, 0] : vector<1x1xf32>
%87 = splat %86 : vector<4xf32>
%88 = vector.fma %87, %85, %83 : vector<4xf32>
%89 = vector.extract_strided_slice %77 {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
%90 = vector.extract %40[0] : vector<1x4xf32>
%91 = vector.extract %89[0, 0] : vector<1x1xf32>
%92 = splat %91 : vector<4xf32>
%93 = vector.fma %92, %90, %88 : vector<4xf32>
%94 = vector.insert %93, %cst_1 [0] : vector<4xf32> into vector<1x4xf32>
%95 = vector.transfer_read %36[%c0, %c0, %c6, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1x7x3xf32, ...>, vector<1x3xf32>
%96 = vector.extract_strided_slice %95 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
%97 = vector.extract %38[0] : vector<1x4xf32>
%98 = vector.extract %96[0, 0] : vector<1x1xf32>
%99 = splat %98 : vector<4xf32>
%100 = vector.extract %arg10[0] : vector<1x4xf32>
%101 = vector.fma %99, %97, %100 : vector<4xf32>
%102 = vector.extract_strided_slice %95 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
%103 = vector.extract %39[0] : vector<1x4xf32>
%104 = vector.extract %102[0, 0] : vector<1x1xf32>
%105 = splat %104 : vector<4xf32>
%106 = vector.fma %105, %103, %101 : vector<4xf32>
%107 = vector.extract_strided_slice %95 {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32>
%108 = vector.extract %40[0] : vector<1x4xf32>
%109 = vector.extract %107[0, 0] : vector<1x1xf32>
%110 = splat %109 : vector<4xf32>
%111 = vector.fma %110, %108, %106 : vector<4xf32>
%112 = vector.insert %111, %cst_1 [0] : vector<4xf32> into vector<1x4xf32>
scf.yield %58, %76, %94, %112 : vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>
}
scf.yield %35#0, %35#1, %35#2, %35#3 : vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>
}
vector.transfer_write %34#3, %29[%c0, %c0, %c3, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x1x4x4xf32, ...>
vector.transfer_write %34#2, %29[%c0, %c0, %c2, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x1x4x4xf32, ...>
vector.transfer_write %34#1, %29[%c0, %c0, %c1, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x1x4x4xf32, ...>
vector.transfer_write %34#0, %29[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x1x4x4xf32, ...>
}
return
}
This is almost it! But the inputs/outputs are still not explicitly vectorized.
We just need one additional step, done via
SPIRVVectorizeLoadStore
pass:
// -----// IR Dump After SPIRVVectorizeLoadStore //----- //
module {
func @conv_dispatch_0() {
%cst = constant dense<0.000000e+00> : vector<1x4xf32>
%c1 = constant 1 : index
%c3 = constant 3 : index
%c112 = constant 112 : index
%c0 = constant 0 : index
%c32 = constant 32 : index
%c16 = constant 16 : index
%c4 = constant 4 : index
%c2 = constant 2 : index
%c8 = constant 8 : index
%cst_0 = constant dense<0.000000e+00> : vector<3xf32>
%0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<1x225x225x3xf32>
%1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<3x3x3x8xvector<4xf32>>
%2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<1x112x112x8xvector<4xf32>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_id_z = hal.interface.workgroup.id[2] : index
%3 = muli %workgroup_id_y, %c8 : index
%4 = muli %workgroup_id_x, %c32 : index
%5 = muli %workgroup_id_y, %c16 : index
%6 = "gpu.thread_id"() {dimension = "x"} : () -> index
%7 = "gpu.thread_id"() {dimension = "y"} : () -> index
%8 = "gpu.thread_id"() {dimension = "z"} : () -> index
%9 = muli %7, %c4 : index
%10 = muli %6, %c4 : index
%11 = muli %8, %c2 : index
%12 = muli %7, %c8 : index
scf.for %arg0 = %workgroup_id_z to %c112 step %c112 {
%13 = muli %arg0, %c2 : index
%14:4 = scf.for %arg1 = %c0 to %c3 step %c1 iter_args(%arg2 = %cst, %arg3 = %cst, %arg4 = %cst, %arg5 = %cst)
-> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) {
%29:4 = scf.for %arg6 = %c0 to %c3 step %c1 iter_args(%arg7 = %arg2, %arg8 = %arg3, %arg9 = %arg4, %arg10 = %arg5)
-> (vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>) {
%30 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%4, %10]
%31 = divi_signed %30, %c4 : index
%32 = memref.load %1[%arg1, %arg6, %c0, %31] : memref<3x3x3x8xvector<4xf32>>
%33 = divi_signed %30, %c4 : index
%34 = memref.load %1[%arg1, %arg6, %c1, %33] : memref<3x3x3x8xvector<4xf32>>
%35 = divi_signed %30, %c4 : index
%36 = memref.load %1[%arg1, %arg6, %c2, %35] : memref<3x3x3x8xvector<4xf32>>
%37 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%arg1)[%13, %11]
%38 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%arg6)[%5, %12]
%39 = memref.load %0[%c0, %37, %38, %c0] : memref<1x225x225x3xf32>
%40 = vector.insert %39, %cst_0 [0] : f32 into vector<3xf32>
%41 = memref.load %0[%c0, %37, %38, %c1] : memref<1x225x225x3xf32>
%42 = vector.insert %41, %40 [1] : f32 into vector<3xf32>
%43 = memref.load %0[%c0, %37, %38, %c2] : memref<1x225x225x3xf32>
%44 = vector.insert %43, %42 [2] : f32 into vector<3xf32>
%45 = vector.extract_strided_slice %44 {offsets = [0], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
%46 = vector.extract %45[0] : vector<1xf32>
%47 = splat %46 : vector<4xf32>
%48 = vector.shape_cast %arg7 : vector<1x4xf32> to vector<4xf32>
%49 = vector.fma %47, %32, %48 : vector<4xf32>
%50 = vector.extract_strided_slice %44 {offsets = [1], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
%51 = vector.extract %50[0] : vector<1xf32>
%52 = splat %51 : vector<4xf32>
%53 = vector.fma %52, %34, %49 : vector<4xf32>
%54 = vector.extract_strided_slice %44 {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
%55 = vector.extract %54[0] : vector<1xf32>
%56 = splat %55 : vector<4xf32>
%57 = vector.fma %56, %36, %53 : vector<4xf32>
%58 = vector.shape_cast %57 : vector<4xf32> to vector<1x4xf32>
%59 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 + s2 + 2)>()[%5, %12, %arg6]
%60 = memref.load %0[%c0, %37, %59, %c0] : memref<1x225x225x3xf32>
%61 = vector.insert %60, %cst_0 [0] : f32 into vector<3xf32>
%62 = memref.load %0[%c0, %37, %59, %c1] : memref<1x225x225x3xf32>
%63 = vector.insert %62, %61 [1] : f32 into vector<3xf32>
%64 = memref.load %0[%c0, %37, %59, %c2] : memref<1x225x225x3xf32>
%65 = vector.insert %64, %63 [2] : f32 into vector<3xf32>
%66 = vector.extract_strided_slice %65 {offsets = [0], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
%67 = vector.extract %66[0] : vector<1xf32>
%68 = splat %67 : vector<4xf32>
%69 = vector.shape_cast %arg8 : vector<1x4xf32> to vector<4xf32>
%70 = vector.fma %68, %32, %69 : vector<4xf32>
%71 = vector.extract_strided_slice %65 {offsets = [1], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
%72 = vector.extract %71[0] : vector<1xf32>
%73 = splat %72 : vector<4xf32>
%74 = vector.fma %73, %34, %70 : vector<4xf32>
%75 = vector.extract_strided_slice %65 {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
%76 = vector.extract %75[0] : vector<1xf32>
%77 = splat %76 : vector<4xf32>
%78 = vector.fma %77, %36, %74 : vector<4xf32>
%79 = vector.shape_cast %78 : vector<4xf32> to vector<1x4xf32>
%80 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 + s2 + 4)>()[%5, %12, %arg6]
%81 = memref.load %0[%c0, %37, %80, %c0] : memref<1x225x225x3xf32>
%82 = vector.insert %81, %cst_0 [0] : f32 into vector<3xf32>
%83 = memref.load %0[%c0, %37, %80, %c1] : memref<1x225x225x3xf32>
%84 = vector.insert %83, %82 [1] : f32 into vector<3xf32>
%85 = memref.load %0[%c0, %37, %80, %c2] : memref<1x225x225x3xf32>
%86 = vector.insert %85, %84 [2] : f32 into vector<3xf32>
%87 = vector.extract_strided_slice %86 {offsets = [0], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
%88 = vector.extract %87[0] : vector<1xf32>
%89 = splat %88 : vector<4xf32>
%90 = vector.shape_cast %arg9 : vector<1x4xf32> to vector<4xf32>
%91 = vector.fma %89, %32, %90 : vector<4xf32>
%92 = vector.extract_strided_slice %86 {offsets = [1], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
%93 = vector.extract %92[0] : vector<1xf32>
%94 = splat %93 : vector<4xf32>
%95 = vector.fma %94, %34, %91 : vector<4xf32>
%96 = vector.extract_strided_slice %86 {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
%97 = vector.extract %96[0] : vector<1xf32>
%98 = splat %97 : vector<4xf32>
%99 = vector.fma %98, %36, %95 : vector<4xf32>
%100 = vector.shape_cast %99 : vector<4xf32> to vector<1x4xf32>
%101 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 + s2 + 6)>()[%5, %12, %arg6]
%102 = memref.load %0[%c0, %37, %101, %c0] : memref<1x225x225x3xf32>
%103 = vector.insert %102, %cst_0 [0] : f32 into vector<3xf32>
%104 = memref.load %0[%c0, %37, %101, %c1] : memref<1x225x225x3xf32>
%105 = vector.insert %104, %103 [1] : f32 into vector<3xf32>
%106 = memref.load %0[%c0, %37, %101, %c2] : memref<1x225x225x3xf32>
%107 = vector.insert %106, %105 [2] : f32 into vector<3xf32>
%108 = vector.extract_strided_slice %107 {offsets = [0], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
%109 = vector.extract %108[0] : vector<1xf32>
%110 = splat %109 : vector<4xf32>
%111 = vector.shape_cast %arg10 : vector<1x4xf32> to vector<4xf32>
%112 = vector.fma %110, %32, %111 : vector<4xf32>
%113 = vector.extract_strided_slice %107 {offsets = [1], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
%114 = vector.extract %113[0] : vector<1xf32>
%115 = splat %114 : vector<4xf32>
%116 = vector.fma %115, %34, %112 : vector<4xf32>
%117 = vector.extract_strided_slice %107 {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32>
%118 = vector.extract %117[0] : vector<1xf32>
%119 = splat %118 : vector<4xf32>
%120 = vector.fma %119, %36, %116 : vector<4xf32>
%121 = vector.shape_cast %120 : vector<4xf32> to vector<1x4xf32>
scf.yield %58, %79, %100, %121 : vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>
}
scf.yield %29#0, %29#1, %29#2, %29#3 : vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>, vector<1x4xf32>
}
%15 = vector.shape_cast %14#3 : vector<1x4xf32> to vector<4xf32>
%16 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %8]
%17 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 + 3)>()[%3, %9]
%18 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%4, %10]
%19 = divi_signed %18, %c4 : index
memref.store %15, %2[%c0, %16, %17, %19] : memref<1x112x112x8xvector<4xf32>>
%20 = vector.shape_cast %14#2 : vector<1x4xf32> to vector<4xf32>
%21 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 + 2)>()[%3, %9]
%22 = divi_signed %18, %c4 : index
memref.store %20, %2[%c0, %16, %21, %22] : memref<1x112x112x8xvector<4xf32>>
%23 = vector.shape_cast %14#1 : vector<1x4xf32> to vector<4xf32>
%24 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 + 1)>()[%3, %9]
%25 = divi_signed %18, %c4 : index
memref.store %23, %2[%c0, %16, %24, %25] : memref<1x112x112x8xvector<4xf32>>
%26 = vector.shape_cast %14#0 : vector<1x4xf32> to vector<4xf32>
%27 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%3, %9]
%28 = divi_signed %18, %c4 : index
memref.store %26, %2[%c0, %16, %27, %28] : memref<1x112x112x8xvector<4xf32>>
}
return
}
hal.interface private @io {
hal.interface.binding public @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b2_xw_external, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
}
Now all inputs/outputs load/store and computation are fully vectorized!
The final step is to convert it to the spirv
dialect so that we can serialize
the kernel and send it to the GPU for execution. It’s mostly mechanical so I’ll
omit that here for simplicity.
Summary
In this blog post I introduced an approach to direct tile and vectorize convolution ops. It suits GPU architectures, particularly for ARM Mali GPUs. Hopefully the IR snippet walkthrough can serve as a good example of how one can leverage various dialects/patterns/transforms/utilities in MLIR to implement CodeGen flows for high-level ML ops.
-
I know I’m dancing around the cliff here by making this analogy. 😊 But I’m assuming familiarity with CPU/GPU thread differences and nuances. If that’s not the case, please feel free to search the Internet as there are lots of great articles about this topic. ↩︎
-
14 (MP) * 2 (SIMD) * 16 (SIMD width) * 2 (FMA) * 850M (Hz) = 761600M ~= 760GFLops ↩︎
-
The configuration are selected and annotated as attributes to the IR by the
SPIRVLowerExecutableTargetPass
pass if your are curious. I’ve omitted it here for simplicity. ↩︎