Triton provides an elegant solution to program GPU kernels in Python, positioning itself as a critical component in the modern AI software stack. To deliver performance and portability, it leverages a compiler, the capability of which determines the potential. Hacking the compiler internals is not a simple task. Here are some tips hopefully useful to folks. I’ll try to keep this blog post updated periodically.
Building and Installation
Triton itself focuses on the GPU kernel; Runtime bits like tensor allocation and resource management are handled by PyTorch. Also PyTorch itself uses Triton for TorchInductor. So Triton is typically installed together with PyTorch.
Using wheels from PyTorch
PyTorch has a clearly defined releasing scheme.
Triton’s versioning plan is effectively managed by PyTorch.
For every two minor releases, PyTorch would choose a recent commit from
Triton’s main
branch, fix all PyTorch usage breakages, and run extensive
regression tests in PyTorch for quality control.
In the meanwhile it may cherry pick some further commits and push to a
release/M.N.x
branch in the Triton repo.
For release/M.N+1.x
, the base commit remains the same, only important
fixes are cherry picked in.
PyTorch names the Triton wheel differently depending on the build type (stable
or nightly) and GPU vendor (NVIDIA or AMD).
The canonical pypi wheel triton
is only used for stable PyTorch releases
targeting NVIDIA CUDA.
We can find other combinations on this page.
To find which triton commit a specific PyTorch version is depending on, for
example, v2.5.1, see PyTorch’s source code
.ci/docker/ci_commit_pins/triton.txt
.
Building from source via Python
Though if interested in working on the Triton compiler itself, we need to
install from the source.
Triton’s README.md
has clear steps to follow.
I’d typically use a Python virtualenv there for better environment isolation.
pip install
builds the whole Triton project from Triton’s
setup.py
. The script first downloads build dependencies like
LLVM, NVIDIA toolchain, and pybind11 under $HOME/.triton
.
We can use the TRITON_HOME
environment variable to redirect it though.
It then invokes CMake to build the whole C++ project from the top level
CMakeLists.txt
and packages it up as a Python wheel.
The packaged Python wheel has the name of triton
. It can cause problem for
pip dependency management if we have another wheel installed from PyTorch
under a different name, i.e., pytorch-triton
or pytorch-triton-rocm
–given
the overlapping codebase, the locally built triton
will overwrite contents
but may not cover all of them due to version differences.
So, it’s suggested to purge all existing triton
, pytorch-triton
, and
pytorch-rocm-rocm
wheels beforehand.
Starting this patch, we can query the installed
triton
commit with pip show triton
, which will show something like
Version: 3.2.0+gitf8b5301a
.
There are a few nice tips for building mentioned in
Triton’s README.md
.
A shell function to bake them in would come quite convenient:
triton-pip-install () {
REPO_BASE_DIR=$(git rev-parse --show-toplevel)
TRITON_BUILD_WITH_CCACHE=true TRITON_BUILD_WITH_CLANG_LLD=true \
pip install --no-build-isolation ${REPO_BASE_DIR}/python
}
Building from source via CMake
Developing the Triton compiler mostly involves touching the C++ codebase for
various MLIR passes and patterns.
So I also directly build a Debug
version from the top level
CMakeLists.txt
.
It involves building the LLVM/MLIR in Debug
too; there are steps in the
README.md
, which I also put in a shell function:
# <source-dir> should be the local checkout directory for
# https://github.com/llvm/llvm-project/tree/main/llvm
# <target-dir> is where to put the compiled LLVM/MLIR artifacts
triton-configure-mlir() {
if (( $# < 3 ))
then
echo "usage: $0 <source-dir> <target-dir> <build-type>"
return 1
fi
SOURCE_DIR=$1; shift
TARGET_DIR=$1; shift
BUILD_TYPE=$1; shift
cmake -GNinja \
-S ${SOURCE_DIR} -B ${TARGET_DIR} \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_C_COMPILER=$(which clang) -DCMAKE_CXX_COMPILER=$(which clang++) \
-DLLVM_ENABLE_PROJECTS="llvm;mlir" \
-DLLVM_TARGETS_TO_BUILD="AMDGPU;NVPTX;X86;AArch64"
}
Then the CMake configuration step for Triton can be captured with a shell function:
# <source-dir> should be the local checkout directory for
# https://github.com/triton-lang/triton
# <target-dir> is where to put the compiled Triton artifacts
# <mlir-dir> should be the LLVM/MLIR artifacts directory
triton-cmake() {
if (( $# < 4 ))
then
echo "usage: $0 <source-dir> <target-dir> <build-type> <mlir-dir>"
return 1
fi
SOURCE_DIR=$1; shift
TARGET_DIR=$1; shift
BUILD_TYPE=$1; shift
MLIR_DIR=$1; shift
if [[ "$(uname)" == "Darwin" ]]; then
LINKER_FLAGS=()
else
LINKER_FLAGS=(
"-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld"
"-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld"
"-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld"
)
fi
REPO_BASE_DIR=$(git rev-parse --show-toplevel)
cmake -GNinja \
-S ${SOURCE_DIR} -B ${TARGET_DIR} \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DTRITON_CODEGEN_BACKENDS="amd;nvidia" \
-DLLVM_INCLUDE_DIRS=${MLIR_DIR}/include \
-DLLVM_LIBRARY_DIR=${MLIR_DIR}/lib \
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \
-DCMAKE_LINKER=lld ${LINKER_FLAGS[@]} \
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-DTRITON_BUILD_PYTHON_MODULE=ON \
-DTRITON_BUILD_PROTON=ON \
-DCUPTI_INCLUDE_DIR=${REPO_BASE_DIR}/third_party/nvidia/backend/include \
-DROCTRACER_INCLUDE_DIR=${REPO_BASE_DIR}/third_party/amd/backend/include \
-DJSON_INCLUDE_DIR=$HOME/.triton/json/include
}
With the above two scripts, setting up a new environment is:
git clone git@github.com:llvm/llvm-project.git
triton-configure-mlir llvm-project/llvm build/mlir-debug Debug
cmake --build build/mlir-debug
git clone git@github.com:triton-lang/triton.git
# Use triton-pip-install to download dependencies like NIVIDA toolchain
cd triton && triton-pip-install && cd ..
triton-cmake triton build/triton-debug Debug build/mlir-debug
cmake --build build/triton-debug
We can also symlink the compile_commands.json
file under build/mlir-debug
to the Triton repo root directory to enable better code completion, if using
Vim and YouCompleteMe for example.
Developing
Triton uses MLIR for its internal compiler passes and patterns. So it follows MLIR source code structure and adopts naming conventions there.
Source code structure
The major components are:
python/
: Python frontend codebasepython/triton
: Triton’s Python API and wheel package source codepython/src
: pybind11 stubs to connect with the C++ codebase
include/triton/
: main C++ declaration codebaseinclude/triton/Dialect/
: Triton MLIR dialect declarationsinclude/triton/Conversion/
: Triton dialect conversion declarationsinclude/triton/Analysis/
: Triton analysis utility declarations
lib/
: main C++ definition codebaselib/Dialect/
: Triton MLIR dialect definitionslib/Conversion/
: Triton dialect conversion definitionslib/Analysis/
: Triton analysis utility definitions
third_party/nvidia
: NVIDIA backend- Similar nested directory structure for
python/
,include/
, andlib/
third_party/nvidia/backend/
: entry points for the NVIDIA backend
- Similar nested directory structure for
third_party/amd
: AMD backend- Similar nested directory structure for
python/
,include/
, andlib/
third_party/amd/backend/
: entry points for the AMD backend
- Similar nested directory structure for
If you are just getting started with the codebase, third_party/*/backend/
might be the directory you’d like to check out first.
Particularly, the compiler.py
file is the backend compiler entry point,
which contains all compilation stages and pass invocations, for
NVIDIA and AMD specifically.
IR printing
IR printing is the basic yet powerful developing technique for compilers.
Generally for development or debugging, we first try to inspect the full
compilation flow and check what each pass did to isolate a specific problematic
pass.
There is a bunch of TRITON_*
environment variables to help
IR printing, a good initial combination is
# For NVIDIA CUDA:
TRITON_ALWAYS_COMPILE=1 MLIR_ENABLE_DUMP=1 TRITON_DISABLE_LINE_INFO=1 NVPTX_ENABLE_DUMP=1
# For AMD HIP:
TRITON_ALWAYS_COMPILE=1 MLIR_ENABLE_DUMP=1 TRITON_DISABLE_LINE_INFO=1 AMDGCN_ENABLE_DUMP=1
The above would print the input IR before each pass, for example,
// -----// IR Dump Before TritonCombineOps (triton-combine) ('builtin.module' operation) //----- //
...
// -----// IR Dump Before ConvertTritonToTritonGPU (convert-triton-to-tritongpu) ('builtin.module' operation) //----- //
...
// -----// IR Dump Before TritonGPUCoalesce (tritongpu-coalesce) ('builtin.module' operation) //----- //
...
Once isolated a certain pass, common MLIR tricks apply here.
We can then use the triton-opt
binary under build/triton-debug/bin/
to invoke
the pass, like triton-opt -tritongpu-coalesce
, on the input IR, and iterate
from there on:
- Adding
-debug
totriton-opt
would enable logs from all sorts, from the MLIR infra itself and from all passes. - To filter logs from a specific pass, we can find the source code of that pass
and see if it has
#define DEBUG_TYPE
label and use-debug-only=
with it. TheDEBUG_TYPE
label typically is the same as the pass name, for example, for the coalescing pass, but it’s not guaranteed. - Sometimes there might be issues in the MLIR infra. Some useful options here
are
-debug-only=dialect-conversion
to print dialect conversion logs,-debug-only=greedy-rewriter
to print greedy rewriter logs, and-debug-only=pattern-application
to print pattern application details. -debug-only=
accepts multipleDEBUG_TYPE
labels concatenated with,
.
JIT compilation artifacts
Triton compiles and caches kernels under $HOME/.triton/cache
.
For each kernel, there is a directory with a hex string name computed
from the kernel source code, various parameters, and environment
details.
This directory contains various artifacts:
# For NVIDIA CUDA
> tree ~/.triton/cache/2DU3OI4VZNGLI3YQB3XY4QSTCMCBAFXIMJYYTJHQWOXONXVQ4QTQ
├── __grp__matmul_kernel.json
├── matmul_kernel.cubin
├── matmul_kernel.json
├── matmul_kernel.llir
├── matmul_kernel.ptx
├── matmul_kernel.ttgir
└── matmul_kernel.ttir
# For AMD HIP
> tree ~/.triton/cache/D2VUWRJXYTN4VEIO2EUQMCQFXPQKOLTOR2OGZQ5553XLRWTEN6DQ
├── __grp__matmul_kernel.json
├── matmul_kernel.amdgcn
├── matmul_kernel.hsaco
├── matmul_kernel.llir
├── matmul_kernel.ptx
├── matmul_kernel.ttgir
└── matmul_kernel.ttir
The file suffix is self-explanatory:
matmul_kernel.json
contains metadata for the kernel compilation, e.g., compilation target and the associatedCUDAOptions
values orHIPOptions
values.*.ttir
,*.ttgir
,*.llir
are input IR to the Triton dialect, TritonGPU dialect, LLVM conversion stages, respectively.*.ptx
and*.cubin
are NVIDIA PTX assembly and final binary blob.*.amdgcn
and*.hsaco
are AMD GCN assembly and final binary blob.
AMD GCN assembly
For AMD, after this patch, the *.amdgcn
file contains
useful kernel resource usage information for performance debugging in broad
strokes:
; Kernel info:
; codeLenInByte = 7732
; NumSgprs: 24
; NumVgprs: 154
; NumAgprs: 128
; TotalNumVgprs: 284
; ScratchSize: 0
; MemoryBound: 1
; FloatMode: 240
; IeeeMode: 1
; LDSByteSize: 0 bytes/workgroup (compile time only)
; SGPRBlocks: 2
; VGPRBlocks: 35
; NumSGPRsForWavesPerEU: 24
; NumVGPRsForWavesPerEU: 284
; AccumOffset: 156
; Occupancy: 1
; WaveLimiterHint : 0
We can also find other useful metadata like .sgpr_spill_count
and
.vgpr_spill_count
from the *.amdgcn
file.
Such information can help to identify performance issues due to register
pressure/spill quickly without resorting to profilers.
Cross compilation
Triton normally assumes JIT compilation–it would read the current active GPU’s architecture information and JIT compile at runtime. Sometimes we may want to compile towards different GPU architectures; trying to find a machine containing the corresponding GPU may be difficult.
We can actually AOT compile a specific kernel if we only want to inspect the compilation flow or artifact.
import torch
import triton
import triton.language as tl
from triton.backends.compiler import GPUTarget
@triton.jit
def kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
# A kernel to compile
...
src = triton.compiler.ASTSource(
fn=kernel,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32", 3: "i32"},
constants={"BLOCK_SIZE": 64}
)
target = GPUTarget("cuda", 80, 32) # For NVIDIA CUDA SM_80 (e.g., A100 GPU)
# target = GPUTarget("hip", 'gfx942', 64) # For AMD HIP gfx942 (e.g., MI300X GPU)
# Or other targets..
output = triton.compile(src, target=target)
Then output.asm
is a dictionary containing various IRs, i.e., ttir, ttgir,
llir, etc., which we can print out with print(output.asm['ttgir'])
.
Debugging
The difficulties associated with debugging typically arise from isolating the problematic component and pinpointing the culprit. Once done, the solution typically derives naturally.
Once we have a more isolated case, the general methodology to pinpoint the exact culprit is to
- collect and inspect the symptoms,
- form hypothesis and run experiments to prove/refute the hypothesis, and
- iterate.
Sanitizing development environment
One important step before engaging in full blown debugging is to sanitize the development environment. It can be quite frustrating to learn that the “bug” is actually due to environmental issues after several hours of effort!
- Purge all existing Triton installations and reinstall from scratch.
- Purge
$HOME/.triton/cache
and recompile the kernel. - Double check
TRITON_*
environment variable set. - Are others able to reproduce the same issue?
- What versions are the driver and other components in the stack at?
- Did the driver stack get updated recently?
- Does the issue persist after resetting the GPU / rebooting the machine?
- …
Functionality issues
If you hit functionality issues in the Triton compiler codebase itself like segfault, it’s typically easier to pinpoint the exact place and figure out. Various general software and MLIR debugging tips apply here:
- Turn on debugging build to get asserts and other additional checks.
If compiling from Python with
pip
we can exportDEBUG=1
. - Use Clang sanitizers to help catch memory or threading issues.
- Use a general debugger to step through the codebase.
- …
Correctness issues
If the compiled kernel has correctness issues, there are a few tips to collect symptoms and form hypothesis:
- Mutate the kernel source code to get a sense of which parts are likely causing issues.
- Disable features in the compiler to use a simple compilation flow.
For example, set
num_stages=1
to disable software pipelining to see if it’s causing the problem. Comment out non essential passes incompiler.py
. - Use strict math like disabling flushing to zeros with
allow_flush_denorm=False
. - …
Performance issues
Performance issues typically would need using a profiler to see instruction timings and identify the bottlenecks. But by reading the assembly we can sometimes immediately identify obvious issues.