Triton Compiler Development Tips

2024-12-25
10 min read

Triton provides an elegant solution to program GPU kernels in Python, positioning itself as a critical component in the modern AI software stack. To deliver performance and portability, it leverages a compiler, the capability of which determines the potential. Hacking the compiler internals is not a simple task. Here are some tips hopefully useful to folks. I’ll try to keep this blog post updated periodically.

Building and Installation

Triton itself focuses on the GPU kernel; Runtime bits like tensor allocation and resource management are handled by PyTorch. Also PyTorch itself uses Triton for TorchInductor. So Triton is typically installed together with PyTorch.

Using wheels from PyTorch

PyTorch has a clearly defined releasing scheme. Triton’s versioning plan is effectively managed by PyTorch. For every two minor releases, PyTorch would choose a recent commit from Triton’s main branch, fix all PyTorch usage breakages, and run extensive regression tests in PyTorch for quality control. In the meanwhile it may cherry pick some further commits and push to a release/M.N.x branch in the Triton repo. For release/M.N+1.x, the base commit remains the same, only important fixes are cherry picked in.

PyTorch names the Triton wheel differently depending on the build type (stable or nightly) and GPU vendor (NVIDIA or AMD). The canonical pypi wheel triton is only used for stable PyTorch releases targeting NVIDIA CUDA. We can find other combinations on this page.

To find which triton commit a specific PyTorch version is depending on, for example, v2.5.1, see PyTorch’s source code .ci/docker/ci_commit_pins/triton.txt.

Building from source via Python

Though if interested in working on the Triton compiler itself, we need to install from the source. Triton’s README.md has clear steps to follow. I’d typically use a Python virtualenv there for better environment isolation.

pip install builds the whole Triton project from Triton’s setup.py. The script first downloads build dependencies like LLVM, NVIDIA toolchain, and pybind11 under $HOME/.triton. We can use the TRITON_HOME environment variable to redirect it though. It then invokes CMake to build the whole C++ project from the top level CMakeLists.txt and packages it up as a Python wheel.

The packaged Python wheel has the name of triton. It can cause problem for pip dependency management if we have another wheel installed from PyTorch under a different name, i.e., pytorch-triton or pytorch-triton-rocm–given the overlapping codebase, the locally built triton will overwrite contents but may not cover all of them due to version differences. So, it’s suggested to purge all existing triton, pytorch-triton, and pytorch-rocm-rocm wheels beforehand.

Starting this patch, we can query the installed triton commit with pip show triton, which will show something like Version: 3.2.0+gitf8b5301a.

There are a few nice tips for building mentioned in Triton’s README.md. A shell function to bake them in would come quite convenient:

triton-pip-install () {
  REPO_BASE_DIR=$(git rev-parse --show-toplevel)
  TRITON_BUILD_WITH_CCACHE=true TRITON_BUILD_WITH_CLANG_LLD=true \
    pip install --no-build-isolation ${REPO_BASE_DIR}/python
}

Building from source via CMake

Developing the Triton compiler mostly involves touching the C++ codebase for various MLIR passes and patterns. So I also directly build a Debug version from the top level CMakeLists.txt.

It involves building the LLVM/MLIR in Debug too; there are steps in the README.md, which I also put in a shell function:

# <source-dir> should be the local checkout directory for
#   https://github.com/llvm/llvm-project/tree/main/llvm
# <target-dir> is where to put the compiled LLVM/MLIR artifacts
triton-configure-mlir() {
  if (( $# < 3 ))
  then
    echo "usage: $0 <source-dir> <target-dir> <build-type>"
    return 1
  fi

  SOURCE_DIR=$1; shift
  TARGET_DIR=$1; shift
  BUILD_TYPE=$1; shift

  cmake -GNinja \
    -S ${SOURCE_DIR} -B ${TARGET_DIR} \
    -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
    -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
    -DCMAKE_C_COMPILER=$(which clang) -DCMAKE_CXX_COMPILER=$(which clang++) \
    -DLLVM_ENABLE_PROJECTS="llvm;mlir" \
    -DLLVM_TARGETS_TO_BUILD="AMDGPU;NVPTX;X86;AArch64"
}

Then the CMake configuration step for Triton can be captured with a shell function:

# <source-dir> should be the local checkout directory for
#   https://github.com/triton-lang/triton
# <target-dir> is where to put the compiled Triton artifacts
# <mlir-dir> should be the LLVM/MLIR artifacts directory
triton-cmake() {
  if (( $# < 4 ))
  then
    echo "usage: $0 <source-dir> <target-dir> <build-type> <mlir-dir>"
    return 1
  fi

  SOURCE_DIR=$1; shift
  TARGET_DIR=$1; shift
  BUILD_TYPE=$1; shift
  MLIR_DIR=$1;   shift

  if [[ "$(uname)" == "Darwin" ]]; then
    LINKER_FLAGS=()
  else
    LINKER_FLAGS=(
      "-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld"
      "-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld"
      "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld"
    )
  fi

  REPO_BASE_DIR=$(git rev-parse --show-toplevel)

  cmake -GNinja \
    -S ${SOURCE_DIR} -B ${TARGET_DIR} \
    -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
    -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
    -DTRITON_CODEGEN_BACKENDS="amd;nvidia" \
    -DLLVM_INCLUDE_DIRS=${MLIR_DIR}/include \
    -DLLVM_LIBRARY_DIR=${MLIR_DIR}/lib \
    -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \
    -DCMAKE_LINKER=lld ${LINKER_FLAGS[@]} \
    -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
    -DTRITON_BUILD_PYTHON_MODULE=ON \
    -DTRITON_BUILD_PROTON=ON \
    -DCUPTI_INCLUDE_DIR=${REPO_BASE_DIR}/third_party/nvidia/backend/include \
    -DROCTRACER_INCLUDE_DIR=${REPO_BASE_DIR}/third_party/amd/backend/include \
    -DJSON_INCLUDE_DIR=$HOME/.triton/json/include
}

With the above two scripts, setting up a new environment is:

git clone git@github.com:llvm/llvm-project.git
triton-configure-mlir llvm-project/llvm build/mlir-debug Debug
cmake --build build/mlir-debug

git clone git@github.com:triton-lang/triton.git
# Use triton-pip-install to download dependencies like NIVIDA toolchain
cd triton && triton-pip-install && cd ..
triton-cmake triton build/triton-debug Debug build/mlir-debug
cmake --build build/triton-debug

We can also symlink the compile_commands.json file under build/mlir-debug to the Triton repo root directory to enable better code completion, if using Vim and YouCompleteMe for example.

Developing

Triton uses MLIR for its internal compiler passes and patterns. So it follows MLIR source code structure and adopts naming conventions there.

Source code structure

The major components are:

  • python/: Python frontend codebase
    • python/triton: Triton’s Python API and wheel package source code
    • python/src: pybind11 stubs to connect with the C++ codebase
  • include/triton/: main C++ declaration codebase
    • include/triton/Dialect/: Triton MLIR dialect declarations
    • include/triton/Conversion/: Triton dialect conversion declarations
    • include/triton/Analysis/: Triton analysis utility declarations
  • lib/: main C++ definition codebase
    • lib/Dialect/: Triton MLIR dialect definitions
    • lib/Conversion/: Triton dialect conversion definitions
    • lib/Analysis/: Triton analysis utility definitions
  • third_party/nvidia: NVIDIA backend
    • Similar nested directory structure for python/, include/, and lib/
    • third_party/nvidia/backend/: entry points for the NVIDIA backend
  • third_party/amd: AMD backend
    • Similar nested directory structure for python/, include/, and lib/
    • third_party/amd/backend/: entry points for the AMD backend

If you are just getting started with the codebase, third_party/*/backend/ might be the directory you’d like to check out first. Particularly, the compiler.py file is the backend compiler entry point, which contains all compilation stages and pass invocations, for NVIDIA and AMD specifically.

IR printing

IR printing is the basic yet powerful developing technique for compilers. Generally for development or debugging, we first try to inspect the full compilation flow and check what each pass did to isolate a specific problematic pass. There is a bunch of TRITON_* environment variables to help IR printing, a good initial combination is

# For NVIDIA CUDA:
TRITON_ALWAYS_COMPILE=1 MLIR_ENABLE_DUMP=1 TRITON_DISABLE_LINE_INFO=1 NVPTX_ENABLE_DUMP=1

# For AMD HIP:
TRITON_ALWAYS_COMPILE=1 MLIR_ENABLE_DUMP=1 TRITON_DISABLE_LINE_INFO=1 AMDGCN_ENABLE_DUMP=1

The above would print the input IR before each pass, for example,

// -----// IR Dump Before TritonCombineOps (triton-combine) ('builtin.module' operation) //----- //
...
// -----// IR Dump Before ConvertTritonToTritonGPU (convert-triton-to-tritongpu) ('builtin.module' operation) //----- //
...
// -----// IR Dump Before TritonGPUCoalesce (tritongpu-coalesce) ('builtin.module' operation) //----- //
...

Once isolated a certain pass, common MLIR tricks apply here. We can then use the triton-opt binary under build/triton-debug/bin/ to invoke the pass, like triton-opt -tritongpu-coalesce, on the input IR, and iterate from there on:

  • Adding -debug to triton-opt would enable logs from all sorts, from the MLIR infra itself and from all passes.
  • To filter logs from a specific pass, we can find the source code of that pass and see if it has #define DEBUG_TYPE label and use -debug-only= with it. The DEBUG_TYPE label typically is the same as the pass name, for example, for the coalescing pass, but it’s not guaranteed.
  • Sometimes there might be issues in the MLIR infra. Some useful options here are -debug-only=dialect-conversion to print dialect conversion logs, -debug-only=greedy-rewriter to print greedy rewriter logs, and -debug-only=pattern-application to print pattern application details.
  • -debug-only= accepts multiple DEBUG_TYPE labels concatenated with ,.

JIT compilation artifacts

Triton compiles and caches kernels under $HOME/.triton/cache. For each kernel, there is a directory with a hex string name computed from the kernel source code, various parameters, and environment details. This directory contains various artifacts:

# For NVIDIA CUDA
> tree ~/.triton/cache/2DU3OI4VZNGLI3YQB3XY4QSTCMCBAFXIMJYYTJHQWOXONXVQ4QTQ
├── __grp__matmul_kernel.json
├── matmul_kernel.cubin
├── matmul_kernel.json
├── matmul_kernel.llir
├── matmul_kernel.ptx
├── matmul_kernel.ttgir
└── matmul_kernel.ttir

# For AMD HIP
> tree ~/.triton/cache/D2VUWRJXYTN4VEIO2EUQMCQFXPQKOLTOR2OGZQ5553XLRWTEN6DQ
├── __grp__matmul_kernel.json
├── matmul_kernel.amdgcn
├── matmul_kernel.hsaco
├── matmul_kernel.llir
├── matmul_kernel.ptx
├── matmul_kernel.ttgir
└── matmul_kernel.ttir

The file suffix is self-explanatory:

  • matmul_kernel.json contains metadata for the kernel compilation, e.g., compilation target and the associated CUDAOptions values or HIPOptions values.
  • *.ttir, *.ttgir, *.llir are input IR to the Triton dialect, TritonGPU dialect, LLVM conversion stages, respectively.
  • *.ptx and *.cubin are NVIDIA PTX assembly and final binary blob.
  • *.amdgcn and *.hsaco are AMD GCN assembly and final binary blob.

AMD GCN assembly

For AMD, after this patch, the *.amdgcn file contains useful kernel resource usage information for performance debugging in broad strokes:

; Kernel info:
; codeLenInByte = 7732
; NumSgprs: 24
; NumVgprs: 154
; NumAgprs: 128
; TotalNumVgprs: 284
; ScratchSize: 0
; MemoryBound: 1
; FloatMode: 240
; IeeeMode: 1
; LDSByteSize: 0 bytes/workgroup (compile time only)
; SGPRBlocks: 2
; VGPRBlocks: 35
; NumSGPRsForWavesPerEU: 24
; NumVGPRsForWavesPerEU: 284
; AccumOffset: 156
; Occupancy: 1
; WaveLimiterHint : 0

We can also find other useful metadata like .sgpr_spill_count and .vgpr_spill_count from the *.amdgcn file. Such information can help to identify performance issues due to register pressure/spill quickly without resorting to profilers.

Cross compilation

Triton normally assumes JIT compilation–it would read the current active GPU’s architecture information and JIT compile at runtime. Sometimes we may want to compile towards different GPU architectures; trying to find a machine containing the corresponding GPU may be difficult.

We can actually AOT compile a specific kernel if we only want to inspect the compilation flow or artifact.

import torch
import triton
import triton.language as tl
from triton.backends.compiler import GPUTarget

@triton.jit
def kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
  # A kernel to compile
  ...

src = triton.compiler.ASTSource(
  fn=kernel,
  signature={0: "*fp32", 1: "*fp32", 2: "*fp32", 3: "i32"},
  constants={"BLOCK_SIZE": 64}
)

target = GPUTarget("cuda", 80, 32)        # For NVIDIA CUDA SM_80 (e.g., A100 GPU)
# target = GPUTarget("hip", 'gfx942', 64) # For AMD HIP gfx942 (e.g., MI300X GPU)
# Or other targets..
output = triton.compile(src, target=target)

Then output.asm is a dictionary containing various IRs, i.e., ttir, ttgir, llir, etc., which we can print out with print(output.asm['ttgir']).

Debugging

The difficulties associated with debugging typically arise from isolating the problematic component and pinpointing the culprit. Once done, the solution typically derives naturally.

Once we have a more isolated case, the general methodology to pinpoint the exact culprit is to

  1. collect and inspect the symptoms,
  2. form hypothesis and run experiments to prove/refute the hypothesis, and
  3. iterate.

Sanitizing development environment

One important step before engaging in full blown debugging is to sanitize the development environment. It can be quite frustrating to learn that the “bug” is actually due to environmental issues after several hours of effort!

  • Purge all existing Triton installations and reinstall from scratch.
  • Purge $HOME/.triton/cache and recompile the kernel.
  • Double check TRITON_* environment variable set.
  • Are others able to reproduce the same issue?
  • What versions are the driver and other components in the stack at?
  • Did the driver stack get updated recently?
  • Does the issue persist after resetting the GPU / rebooting the machine?

Functionality issues

If you hit functionality issues in the Triton compiler codebase itself like segfault, it’s typically easier to pinpoint the exact place and figure out. Various general software and MLIR debugging tips apply here:

  • Turn on debugging build to get asserts and other additional checks. If compiling from Python with pip we can export DEBUG=1.
  • Use Clang sanitizers to help catch memory or threading issues.
  • Use a general debugger to step through the codebase.

Correctness issues

If the compiled kernel has correctness issues, there are a few tips to collect symptoms and form hypothesis:

  • Mutate the kernel source code to get a sense of which parts are likely causing issues.
  • Disable features in the compiler to use a simple compilation flow. For example, set num_stages=1 to disable software pipelining to see if it’s causing the problem. Comment out non essential passes in compiler.py.
  • Use strict math like disabling flushing to zeros with allow_flush_denorm=False.

Performance issues

Performance issues typically would need using a profiler to see instruction timings and identify the bottlenecks. But by reading the assembly we can sometimes immediately identify obvious issues.