Note
Go to the end to download the full example code.
Fused Softmax
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive_softmax':
ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton (GB/s) Torch (GB/s) Naive Softmax (GB/s)
0 256.0 500.056313 691.626019 203.192830
1 384.0 696.062625 813.845754 260.071501
2 512.0 835.022220 918.642249 300.099694
3 640.0 837.570612 914.495068 331.498427
4 768.0 910.458857 990.967742 351.839647
5 896.0 967.244058 1029.561299 355.314277
6 1024.0 1030.878009 1081.342046 353.232439
7 1152.0 1027.968971 1073.265719 347.298383
8 1280.0 1071.140288 1101.689922 347.551781
9 1408.0 1113.470090 1134.705382 342.451459
10 1536.0 1145.985502 1154.428188 332.555958
11 1664.0 1188.469599 1181.463388 328.990865
12 1792.0 1203.988488 1200.802825 325.804163
13 1920.0 1237.792045 1219.498761 326.324876
14 2048.0 1249.280876 1244.770839 325.679763
15 2176.0 1180.409644 960.102574 325.528351
16 2304.0 1197.911336 1004.126619 326.041386
17 2432.0 1220.757470 1038.762584 327.458095
18 2560.0 1236.683111 1067.281502 327.915274
19 2688.0 1260.046411 1096.970480 329.348194
20 2816.0 1278.139925 1122.202008 328.969537
21 2944.0 1293.466512 1148.473081 331.520087
22 3072.0 1313.583386 1174.572486 333.644160
23 3200.0 1326.784009 1168.825188 335.086250
24 3328.0 1326.725616 1198.894453 336.854932
25 3456.0 1338.935214 1223.314109 337.391154
26 3584.0 1341.160594 1247.464691 338.296674
27 3712.0 1336.926377 1263.590072 340.288237
28 3840.0 1361.291851 1286.526136 340.584794
29 3968.0 1362.983610 1295.021825 341.080402
30 4096.0 1368.590779 1315.977071 338.893680
31 4224.0 1337.126543 1274.053960 342.404101
32 4352.0 1345.669905 1299.844559 344.947548
33 4480.0 1358.561032 1318.110175 345.802896
34 4608.0 1366.813941 1335.028025 347.393896
35 4736.0 1361.060790 1344.125600 348.439905
36 4864.0 1379.597313 1356.449819 349.398080
37 4992.0 1369.783555 1369.853953 350.438165
38 5120.0 1387.651729 1388.767977 351.257295
39 5248.0 1382.438090 1357.973029 351.840500
40 5376.0 1381.536173 1362.970925 351.367761
41 5504.0 1389.537390 1382.731570 353.535239
42 5632.0 1400.149133 1398.331249 353.445511
43 5760.0 1403.211451 1405.312739 355.014219
44 5888.0 1397.281267 1415.510248 354.841607
45 6016.0 1408.297269 1417.921250 356.919696
46 6144.0 1416.889903 1431.642685 357.348429
47 6272.0 1412.015045 1403.142553 357.516387
48 6400.0 1416.556478 1409.997479 358.622822
49 6528.0 1418.322293 1424.740618 359.079778
50 6656.0 1418.749015 1436.639885 359.062388
51 6784.0 1425.013066 1432.909455 360.045417
52 6912.0 1430.740630 1452.567868 360.500573
53 7040.0 1424.275770 1454.312962 361.066645
54 7168.0 1425.342721 1469.095332 361.673141
55 7296.0 1431.835618 1088.316987 362.983241
56 7424.0 1430.429088 1100.058589 362.830912
57 7552.0 1433.148163 1111.496335 363.696198
58 7680.0 1434.464478 1123.085351 363.381741
59 7808.0 1433.595854 1135.670435 364.662283
60 7936.0 1435.272831 1144.102307 364.645772
61 8064.0 1440.980920 1150.457259 365.252940
62 8192.0 1431.998891 1152.154218 363.737060
63 8320.0 1391.066545 1114.616869 361.288413
64 8448.0 1381.762661 1126.645098 362.376240
65 8576.0 1389.555569 1127.866061 363.174319
66 8704.0 1385.894967 1135.795364 364.422793
67 8832.0 1390.679493 1132.780439 364.831566
68 8960.0 1386.789222 1140.230247 365.972137
69 9088.0 1401.070158 1138.054658 366.571799
70 9216.0 1404.422469 1141.420110 367.618343
71 9344.0 1394.129021 1423.721125 367.206633
72 9472.0 1405.002507 1431.404564 368.566525
73 9600.0 1397.438979 1427.451848 368.975432
74 9728.0 1397.824134 1439.531795 369.550616
75 9856.0 1398.684989 1440.355758 369.952433
76 9984.0 1391.374582 1448.018068 370.367195
77 10112.0 1405.495863 1454.754718 371.513056
78 10240.0 1405.140854 1465.643055 371.653758
79 10368.0 1412.885493 1462.872500 369.770961
80 10496.0 1403.994888 1467.297766 370.448541
81 10624.0 1404.777868 1467.325354 370.850471
82 10752.0 1399.186122 1471.662808 371.400726
83 10880.0 1397.558453 1478.554339 371.979058
84 11008.0 1420.242113 1479.438072 371.935129
85 11136.0 1424.050606 1485.278001 373.290036
86 11264.0 1414.994760 1488.049081 373.378928
87 11392.0 1419.858756 1487.304962 374.394446
88 11520.0 1407.667639 1493.888328 374.226985
89 11648.0 1415.753170 1501.613686 374.290402
90 11776.0 1431.472624 1502.882488 375.250591
91 11904.0 1431.230552 1510.543154 376.335801
92 12032.0 1425.182627 1508.575051 376.327660
93 12160.0 1420.287355 1516.435817 376.533679
94 12288.0 1429.188170 1421.846300 376.538108
95 12416.0 1432.113344 1393.922980 374.425254
96 12544.0 1440.117980 1396.146810 375.209031
97 12672.0 1437.111999 1396.447843 375.405073
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 40.100 seconds)