Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg``
line_names=[
"Triton",
"Torch",
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton Torch
0 256.0 466.364565 697.334601
1 384.0 650.818893 817.308710
2 512.0 800.797842 911.754582
3 640.0 802.169752 959.639932
4 768.0 882.094429 1017.303340
5 896.0 937.070856 1074.489355
6 1024.0 1009.039332 1107.706231
7 1152.0 1104.272729 1034.414615
8 1280.0 1142.811903 1078.972327
9 1408.0 1166.462589 1113.068593
10 1536.0 1184.357741 1139.453928
11 1664.0 1209.753880 1159.962015
12 1792.0 1238.147763 1195.260716
13 1920.0 1251.237712 1197.260076
14 2048.0 1270.146597 1227.116459
15 2176.0 1243.839878 958.276887
16 2304.0 1240.547606 1000.883769
17 2432.0 1266.856304 1032.416770
18 2560.0 1285.268236 1063.677866
19 2688.0 1289.105210 1096.634250
20 2816.0 1292.994410 1120.523157
21 2944.0 1307.737326 1146.880018
22 3072.0 1325.606229 1167.525615
23 3200.0 1325.029087 1175.538624
24 3328.0 1342.840232 1202.726888
25 3456.0 1352.594982 1224.965486
26 3584.0 1346.724265 1245.554565
27 3712.0 1366.991436 1264.789115
28 3840.0 1369.289856 1285.408614
29 3968.0 1372.231831 1295.605875
30 4096.0 1372.481667 1317.170795
31 4224.0 1334.927257 1290.115436
32 4352.0 1335.387621 1315.784686
33 4480.0 1349.596728 1333.533314
34 4608.0 1359.269568 1351.233157
35 4736.0 1358.152306 1366.726616
36 4864.0 1378.601053 1382.478386
37 4992.0 1368.874500 1392.486813
38 5120.0 1374.322355 1406.676849
39 5248.0 1376.228596 1357.665507
40 5376.0 1379.030189 1384.667838
41 5504.0 1381.412849 1393.023666
42 5632.0 1384.807598 1409.037237
43 5760.0 1390.220001 1422.183301
44 5888.0 1389.242411 1434.552412
45 6016.0 1396.127682 1433.920012
46 6144.0 1413.288506 1447.100857
47 6272.0 1411.016265 1407.094404
48 6400.0 1415.103857 1421.164796
49 6528.0 1413.242316 1427.466102
50 6656.0 1422.046727 1433.858745
51 6784.0 1414.958636 1455.001110
52 6912.0 1429.146988 1458.376153
53 7040.0 1418.504192 1460.792465
54 7168.0 1428.515456 1467.977055
55 7296.0 1430.681425 1085.766664
56 7424.0 1427.078482 1101.834555
57 7552.0 1426.245465 1113.776095
58 7680.0 1437.368097 1127.952688
59 7808.0 1432.683862 1136.157330
60 7936.0 1436.643328 1145.693467
61 8064.0 1439.171712 1151.663838
62 8192.0 1438.339987 1156.014178
63 8320.0 1397.822447 1113.864384
64 8448.0 1382.990073 1123.100066
65 8576.0 1398.372849 1125.691611
66 8704.0 1395.704378 1126.867959
67 8832.0 1386.248120 1128.248706
68 8960.0 1401.473910 1135.203613
69 9088.0 1413.768491 1131.505463
70 9216.0 1407.978843 1127.626824
71 9344.0 1411.909147 1423.134257
72 9472.0 1406.283856 1430.488357
73 9600.0 1401.760796 1432.104370
74 9728.0 1407.178381 1440.650606
75 9856.0 1421.384574 1441.313389
76 9984.0 1402.296627 1448.714441
77 10112.0 1412.858505 1451.885339
78 10240.0 1419.102508 1462.548677
79 10368.0 1416.232210 1463.111199
80 10496.0 1419.267579 1463.917231
81 10624.0 1417.292944 1464.452488
82 10752.0 1410.445695 1473.632699
83 10880.0 1408.355557 1477.357192
84 11008.0 1426.457127 1478.589244
85 11136.0 1427.076958 1481.097270
86 11264.0 1432.259854 1488.126310
87 11392.0 1419.375491 1489.392703
88 11520.0 1428.307676 1497.033987
89 11648.0 1429.634207 1500.164526
90 11776.0 1436.568935 1503.278345
91 11904.0 1444.700931 1509.006937
92 12032.0 1429.864182 1508.444236
93 12160.0 1427.230047 1516.440878
94 12288.0 1440.351650 1422.074616
95 12416.0 1457.126376 1397.884923
96 12544.0 1445.445651 1394.588480
97 12672.0 1457.101641 1393.874664
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 23.353 seconds)