I explained the Vector dialect and related patterns in the previous blog post. In this one let us look at a layer higher and talk about the Linalg dialect and transformations around it.
Problem Space
Machine learning (ML), especially deep learning, actually brings workloads with quite regular patterns. The whole ML model is typically repeating certain known-good basic blocks, optionally in a nested fashion.
The primitive operations in those basic blocks include dense computation like matrix multiplication (matmul) or convolution, which, in their plain definition, are perfect loop nests over multiply-accumulation (MAC). Given the internal MAC operation, matmul and convolution are just a form of reduction. It’s also common to see direct reductions to find the maximal/minimal/average values of a tensor. Aside from reduction, another category of computation, which is more “lightweight”, is elementwise operations. They are even more well formed and easy to handle.
Computation is one way to produce tensors in ML model dataflow; we can also see pure data shuffling operations. Examples are broadcast, transpose, concatenate, which in essence are performing affine transformations over data element indexing. More sophisticated data shuffling operations include gather, scatter, topK, and others. These operations involve non-affine mapping over data element indexing, and are more “irregular” and difficult to handle.
We know that generally programmable CPUs/GPUs adopt tiled-based architectures, and have different levels of memory/cache hierarchies to feed data into SIMD/SIMT compute units. Even with the memory/cache hierarchy, data access is still much slower than compute; so we need to increase data reuse to achieve higher ops per second in reality. To best utilize the hardware, we need to tile the problem and distribute the workload to different computation units, and maintain a high data reuse rate in faster memory/cache level.
With the above in mind, each of the previous ML operation categories is essentially a different problem for code generation (CodeGen). For example, square matmul by definition is performing N2 data element access and N3 MAC; so we reuse a memory access N times. Convolution offers even higher reuse intrinsically. They are inherently easier to achieve high ops per second. On the other hand, there is no data reuse in elementwise operations; if done alone, they are always memory bound. However, given the great 1:1 mapping nature of elementwise operations, we typically fuse them with matmul or convolution to compute “for free”. Pure reduction is an even more challenging problem, especially for massively parallel GPUs, which would require dedicated approaches.
In general, it’s clear that for different categories, we need to have different strategies. Though we don’t necessarily want to write a separate compiler flow for each different ML operations and target architecture, for software complexity management concerns. We would like to unify certain problems of similar characteristics into one flow. However, we cannot go too far down this route—if we unify everything into the same flow, it would mean we can assume nothing and must do everything in the most general sense, which would make analyses and transformations extremely hard and conservative. So, practical trade-offs in the design.
That’s effectively what the Linalg dialect is trying to achieve. The Linalg dialect tries to strike a balance between the IR generality and transform simplicity/effectiveness. It focuses on addressing the problems with perfect loop nest and affine element indexing, by providing powerful tiling, fusion, and vectorization mechanisms.
This covers the CodeGen needs for compute ops like matmul/convolution/reduction, and data shuffling ops like broadcast/transpose. This is already a large portion of ML ops. It leaves problems like gather/scatter/etc. aside on purpose, because accommodating those would mean relaxing a lot on the assumptions (that is, perfect loop nest and affine element indexing) and thus forgoing some very powerful mechanisms (e.g., affine composition, back slice analysis) and bringing more complexities into the compiler.
Enough about the rationale. Now let’s move on to the design considerations.
Design Considerations
Positioning
The Linalg dialect presents one of the core abstractions for progressive MLIR CodeGen in ML compilers. Showing the CodeGen hierarchy introduced previously again and highlighting the Linalg layer:
Each layer in the above flow serves its own purpose:
- At the top level, dialects like TF, TFLite, and Torch are meant for ML framework integration; and dialects like MHLo and TOSA are meant for consolidating flexible framework op sets into (stable) input ML programs.
- The Vector dialect, which I discussed in the previous blog post, aims to compile a tile of the original problem to a single compute unit, by mapping to hardware registers and native vector compute instructions.
- Dialects like MemRef are for handling memory planning and concrete data accesses. Its position in the flow is relatively flexible as it can happen either before or after the vector abstractions.
- At the bottom of the stack is dialects like LLVM or SPIR-V to exit the MLIR system for even lower level CodeGen and/or final program serialization.
The Linalg dialect is actually the entry layer for structured MLIR CodeGen—dialects before it are for ML program representation; from Linalg, we start to perform transformations to gradually fit hardware targets.
These transformations include tiling, fusion, distribution, and vectorization. Their collective goals are to divide the original problem and assign them to different compute units, and later handle the inner tile to the Vector dialect for further CodeGen aiming at a single compute unit.
Op structure and categories
The Linalg dialect’s documentation provides high-level description of Linalg op structure and categories, in addition to detailed semantics for each op. It’s highly worth a read. I won’t elaborate what’s explained there, just a quick summary to lay down the foundation for further discussion.
Linalg ops can operate both on tensors and buffers.
In general, there are two categories of Linalg ops—structured ones and
non-structured ones.
There are very few non-structured Linalg ops, e.g., linalg.index
,
linalg.yield
. They are auxiliary; each one is distinct.
The majority Linalg ops are structured ones, including linalg.matmul
,
linalg.conv_*
, linalg.generic
and other compute ops.
Among them, linalg.generic
is really the core op; other ops are called
named ops and are effectively syntax sugar over a certain
instance of linalg.generic
ops.
They all implement the LinalgOp
interface and have
uniform representations as described in the
documentation:
- Each op itself is an implicit perfect loop nest, where each loop has an explicitly defined iterator type (parallel/reduction/etc.).
- If with tensor semantics, each op result has an associated output operand, providing the initial value of the output. If with buffer semantics, Linalg ops don’t have results; the output operand would be directly read-writable buffers.1
- Each input/output operand has an associated indexing map from the implicit loop nest to the operand’s tensor/buffer, specifying how the operand’s elements are accessed. The loop nest’s iteration space is fully derived from the op’s operands.
- The computation done is specified as a region, which allows great flexibility.
The above characteristics greatly simplify transformations over Linalg
ops.
Often we just need to write one single pattern to target the LinalgOp
interface.
Implicit loop nests mean we can typically avoid analyzing and transforming
loop nests.
As said in the Problem Space section, the Linalg dialect
purposefully only represents ML ops which by definition are perfect loop
nests with affine element indexing.
Ops not fitting this pattern would need their own special flows.
In IREE we have the LinalgExt dialect to experiment modeling those ops
(including gather
, scatter
, scan
, topk
, fft
, etc.).
They are gradually upstreamed and placed in more suitable holding dialects,
e.g., now we have tensor.gather
/tensor.scatter
ops in upstream MLIR
codebase.
Transformations
The Linalg ops are designed from the beginning to facilitate transformations. By purposefully restricting the problems to address, we can make assumptions that lead to powerful analysis and transformations.
Actually the whole structured CodeGen paradigm in MLIR, where the Linalg dialect is a core component of it, prefers this kind of co-designing op semantics and transformations and let ops structurally encode and guarantee certain properties to simplify analyses and transformations.
Important transformations happening at the Linalg level include tiling, fusion, distribution, vectorization, and lowering into plain loops.
Tiling
Tiling is the process of dividing the original problem into smaller ones. It is a key step to map the original problem into multiple compute units, which are common nowadays for CPUs/GPUs. Tiling can happen multiple times, depending on the compute hierarchy of the hardware target. With tiling, we can also turn a dynamic shaped problem into a static one, with a static tile size. This is important to enable further Vector level transformations.
Although we typically perform tiling at the Linalg level, it’s not
fundamentally limited to Linalg ops.
So in MLIR, tiling is moving to an interface, unsurprisingly, the
TilingInterface
, to allow dividing other dialects'
ops to map to compute hierarchies.
(Notably the ops in IREE LinalgExt dialect implements the tiling interface.)
Here is Linalg ops' TilingInterface
. implementation.
Tiling would materialize a loop nest.
The interface just provides op-specific information about how the particular op
should be tiled;
we would need loop ops for the materialized loop nest.
It’s pluggable here, as we can use different kinds of loop ops, e.g., scf.for
ops (and here is the code connecting them together).
It’s worth discussing the IR after tiling a bit.
For the following linalg.matmul
:
func.func @matmul(%lhs : tensor<?x?xf32>, %rhs : tensor<?x?xf32>,
%init : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul
ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
Tiling with tile sizes (M
, N
) = (16, 32):
func.func @matmul(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>,
%init: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%dimM = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dimK = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%dimN = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%0 = scf.for %arg3 = %c0 to %dimM step %c16 iter_args(%arg4 = %arg2) -> (tensor<?x?xf32>) {
%1 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 16)>(%arg3)[%dimM]
%2 = scf.for %arg5 = %c0 to %dimN step %c32 iter_args(%arg6 = %arg4) -> (tensor<?x?xf32>) {
%3 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 32)>(%arg5)[%dimN]
%sliceA = tensor.extract_slice %arg0[%arg3, 0] [%1, %dimK] [1, 1]...
%sliceB = tensor.extract_slice %arg1[0, %arg5] [%dimK, %3] [1, 1]...
%sliceC = tensor.extract_slice %arg6[%arg3, %arg5] [%1, %3] [1, 1]...
%4 = linalg.matmul
ins(%sliceA, %sliceB : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%sliceC : tensor<?x?xf32>) -> tensor<?x?xf32>
%insert = tensor.insert_slice %4 into %arg6[%arg3, %arg5] [%1, %3] [1, 1]...
scf.yield %insert : tensor<?x?xf32>
}
scf.yield %2 : tensor<?x?xf32>
}
return %0 : tensor<?x?xf32>
}
We have a materialized scf.for
loop nest, with steps equal to the tile sizes
we have specified for matmul M
and N
dimensions.
In the innermost loop representing the current tile, tensor.extract_slice
ops
define the original input tensors' slices that the current tile will read, while
the tensor.insert_slice
op defines the output tensor’s slice that the current
tile will write.
affine.min
ops are generated to compute the size bounds for each tile.
There are a few aspects worth calling out:
Firstly, tensor.extract_slice
and tensor.insert_slice
defines the input and
output slices for the tile.
They still maintain the original dimensionality and capture a 2-D (offset,
size, stride) tuple.
Such semantics of tensor.*_slice
ops makes it very straightforward to get a
tile of the original problem and compose further later.
Tiling itself generates loop nests so we face more details and it is harder to
go back to the original form; this is a form of lowering.
However, the tensor.*_slice
ops are also an instance of retaining
high-level information as much as possible and not prematurely lower (and
linearize the indexing).
Secondly, we have the same problem, linalg.matmul
, in the tiled loop nest,
just on a smaller scale defined by tensor.*_slice
ops.
If we want to tile again, it would just mean reapplying the same transformation.
Thirdly, now we have a whole structure to represent the full tiled problem—the
loo nest, the tensor.*_slice
ops, and the inner linalg
op.
They must work together to preserve the full semantics.
Such structures can be fragile and prone to breakage in IR transformations.
So we face the risk of breaking it with, for example, seemingly innocent
canonicalization patterns.
(This is right now the case for bufferization—it requires reserving such
a structure.)
This problem is why we have Linalg structured ops and want to embed structures
in op semantics implicitly from the beginning.
More needs to be done for this case though.
Lastly, the tensor.*_slice
ops can represent data movement to faster
memory/cache hierarchy.
Fusion
We have seen that tensor.*_slice
ops help to define the input/output slices
for tiling.
Fusion just goes one step further on top of that—instead of just pulling in
the consumer op’s input slice into the tiled loop nest around the consumer op,
we pull in the producer op and all its input slices for computing the consumed
output slice.
Recall that Linalg ops uses an affine indexing map to encode the access scheme
for each input/output operand, and the loop bounds are defined by operand
tensor/buffer shapes.
Calculating the producer input slices is just composing these affine maps
and deriving offsets and sizes on each dimension separately.
This can be achieved via inverse(producerIndexingMap).compose(consumerIndexingMap)
for permutation producer indexing maps, and elementwise op fusion uses this
approach.
If via the TilingInterface
, the implementation is different and goes through
the makeTiledShapes
utility.
I won’t expand on this; you can see the full scf.for
tiling and fusion
procedure here.
Distribution
Distribution is the process of assigning different tiles to different compute units. It is particularly important for GPUs, where we need to utilize all those workgroups and workitems for parallelism.
The current commonly used approach to perform distribution is providing the processor ID and count SSA values when tiling (and fusion). Then for the materialized loop nest, we update the loop ranges using those processor ID and count SSA values.
%idy = gpu.thread_id y
%dimy = gpu.block_dim y
%lby = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%idy, %c4]
%stepy = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%dimy, %c4]
scf.for %ivy = %lby to %c8 step %stepy {
%idx = gpu.thread_id x
%dimx = gpu.block_dim x
%lbx = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%idx, %c4]
%stepx = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%dimx, %c4]
scf.for %arg3 = %lbx to %c32 step %stepx {
...
}
}
Though the above faces an issue.
It’s well defined if we distribute memref
tiles—buffers allow read/write
to individual elements, so we can perform concurrent updates to the whole
memref
.
For tensor
tiles, it becomes difficult—a tensor, regardless of the
dimensionality and shape, is an integrated single value.
There is no partial update semantics; touching even a single element would
generate a whole new tensor of the input shape.
Therefore, the concurrent updates via tensor.insert_slice
ops is not well
defined.
This problem sort of gets “resolved” after we bufferize those tensors, but
still, at the tensor level, we have such a semantic gap.
So we see upstream experimenting moving to use the new scf.foreach_thread
op.
In its region, we leverage scf.foreach_thread.perform_concurrently
and
tensor.parallel_insert_slice
to address the issue.
Vectorization
Vectorization is the last step in the Linalg lowering flow. It bridges the Linalg layer and Vector layer. In MLIR, vectorization is not trying to find parallelism by turning scalar computation into vector; it’s basically mechanically generating vector ops of the same shape, and then later do in-dialect lowering to convert those high-dimension vectors into low-dimension native ones. I’ve written about this in the previous blog post, so I won’t repeat it again here.
Lowering to loops
Converting to plain loops gives us a fallback lowering path and reference implementation for Linalg ops. It can be quite helpful in certain cases. Given that Linalg ops are just an implicit loop nest, lowering to loops is trivial. We just need to materialize the loop nest and inline the compute region.
Closing Words
The Linalg dialect presents one of the core abstractions for progressive and structured MLIR CodeGen in ML compilers. Hopefully this blog post shed some light on its design and key transformations.
The Linalg dialect is the de facto “testbed” for quite a few new CodeGen techniques, including the transform dialect. It evolves fast and can often graduate features to more suitable holding dialects and directories.
-
So this unifies IR representation for tensors and buffers. More importantly, it makes bufferization easier—we can perform analysis and reuse buffers via the output operand and result binding in cases, which avoids excessive copies. ↩︎