Debugging Triton¶
This tutorial provides guidance for debugging Triton programs. It is mostly documented for Triton users. Developers interested in exploring Triton’s backend, including MLIR code transformation and LLVM code generation, can refer to this section to explore debugging options.
Using Triton’s Debugging Operations¶
Triton includes four debugging operators that allow users to check and inspect tensor values:
static_print
andstatic_assert
are intended for compile-time debugging.device_print
anddevice_assert
are used for runtime debugging.
device_assert
executes only when TRITON_DEBUG
is set to 1
.
Other debugging operators execute regardless of the value of TRITON_DEBUG
.
Using the Interpreter¶
The interpreter is a straightforward and helpful tool for debugging Triton programs.
It allows Triton users to run Triton programs on the CPU and inspect the intermediate results of each operation.
To enable the interpreter mode, set the environment variable TRITON_INTERPRET
to 1
.
This setting causes all Triton kernels to bypass compilation and be simulated by the interpreter using numpy equivalents of Triton operations.
The interpreter processes each Triton program instance sequentially, executing operations one at a time.
There are three primary ways to use the interpreter:
Print the intermediate results of each operation using the Python
print
function. To inspect an entire tensor, useprint(tensor)
. To examine individual tensor values atidx
, useprint(tensor.handle.data[idx])
.Attach
pdb
for step-by-step debugging of the Triton program:TRITON_INTERPRET=1 pdb main.py b main.py:<line number> r
Import the
pdb
package and set breakpoints in the Triton program:import triton import triton.language as tl import pdb @triton.jit def kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr): pdb.set_trace() offs = tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs) tl.store(y_ptr + offs, x)
Limitations¶
The interpreter has several known limitations:
It does not support operations on
bfloat16
numeric types. To perform operations onbfloat16
tensors, usetl.cast(tensor)
to convert the tensor tofloat32
.It does not support indirect memory access patterns such as:
ptr = tl.load(ptr) x = tl.load(ptr)
Using Third-party Tools¶
For debugging on NVIDIA GPUs, compute-sanitizer is an effective tool for checking data races and memory access issues.
To use it, prepend compute-sanitizer
to your command to run the Triton program.
For debugging on AMD GPUs, you may want to try the LLVM AddressSanitizer for ROCm.
For detailed visualization of memory access in Triton programs, consider using the triton-viz tool, which is agnostic to the underlying GPUs.