Note
Go to the end to download the full example code.
Fused Softmax
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive_softmax':
ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton (GB/s) Torch (GB/s) Naive Softmax (GB/s)
0 256.0 514.264267 710.308844 205.684566
1 384.0 709.130739 820.099233 262.047238
2 512.0 831.113104 926.637333 302.423697
3 640.0 845.012367 910.118495 331.453903
4 768.0 920.260997 984.541553 351.905361
5 896.0 968.800342 1038.174820 355.361785
6 1024.0 1023.589910 1075.672814 354.399517
7 1152.0 1027.876081 1068.003854 348.679859
8 1280.0 1082.334979 1109.563610 348.234217
9 1408.0 1124.233870 1136.594676 343.090526
10 1536.0 1146.434784 1165.091116 333.637247
11 1664.0 1180.542121 1178.782176 329.831516
12 1792.0 1202.236290 1200.369240 326.294104
13 1920.0 1234.823545 1226.698613 326.392908
14 2048.0 1250.537876 1246.703012 325.807920
15 2176.0 1185.344056 964.733820 326.001449
16 2304.0 1197.742271 1004.568615 326.107284
17 2432.0 1220.500784 1035.553574 327.270171
18 2560.0 1245.425385 1071.372786 328.644020
19 2688.0 1268.226957 1097.909776 329.540370
20 2816.0 1276.486904 1127.178672 329.376307
21 2944.0 1293.509830 1144.780139 331.862932
22 3072.0 1308.364934 1171.347845 334.144493
23 3200.0 1323.065987 1170.899844 334.975634
24 3328.0 1321.977898 1199.030711 336.250247
25 3456.0 1344.006019 1220.282338 337.369985
26 3584.0 1348.584291 1244.280350 338.318047
27 3712.0 1350.214641 1265.563111 340.154016
28 3840.0 1363.937470 1281.479263 341.097767
29 3968.0 1369.911621 1296.414309 341.116968
30 4096.0 1372.520927 1320.005111 339.147794
31 4224.0 1337.937950 1275.396837 342.882778
32 4352.0 1344.612825 1298.449218 344.933851
33 4480.0 1356.305577 1314.354152 346.108963
34 4608.0 1368.563420 1335.764560 346.683385
35 4736.0 1363.832377 1344.591131 347.952697
36 4864.0 1381.945394 1356.335887 349.543536
37 4992.0 1377.411409 1368.175999 350.511724
38 5120.0 1386.624547 1384.651024 351.021800
39 5248.0 1382.458606 1357.826632 352.076863
40 5376.0 1383.900399 1368.195770 351.716787
41 5504.0 1390.987944 1379.897300 353.628359
42 5632.0 1398.559211 1399.547777 353.706218
43 5760.0 1405.556787 1405.382160 354.606333
44 5888.0 1392.290556 1412.974571 354.832389
45 6016.0 1405.294169 1423.245445 356.396805
46 6144.0 1417.238568 1432.749175 356.889724
47 6272.0 1416.502069 1398.574706 357.848966
48 6400.0 1417.938216 1411.810886 358.821483
49 6528.0 1424.625112 1424.500724 359.517905
50 6656.0 1421.624639 1436.761401 359.348221
51 6784.0 1423.619997 1443.753304 360.206259
52 6912.0 1433.749533 1447.611932 360.675091
53 7040.0 1423.930225 1443.871189 361.002863
54 7168.0 1431.701915 1460.786881 361.755280
55 7296.0 1433.029000 1086.976665 362.956860
56 7424.0 1436.842862 1096.046149 362.599411
57 7552.0 1436.033943 1111.040374 363.573085
58 7680.0 1433.409213 1119.168489 363.586455
59 7808.0 1429.387914 1134.193956 364.789697
60 7936.0 1441.907121 1144.670112 364.986581
61 8064.0 1439.699457 1151.941119 365.048978
62 8192.0 1429.125916 1148.652331 363.692001
63 8320.0 1389.156504 1114.892884 361.292870
64 8448.0 1385.946510 1122.729839 362.158007
65 8576.0 1391.605442 1126.640472 363.071994
66 8704.0 1388.162918 1129.806384 364.084440
67 8832.0 1394.561608 1127.800743 364.693888
68 8960.0 1385.069692 1136.568658 365.749829
69 9088.0 1401.081398 1134.905847 366.804488
70 9216.0 1402.843751 1138.959660 367.303125
71 9344.0 1393.951105 1418.160746 367.653672
72 9472.0 1406.857297 1433.454705 368.515502
73 9600.0 1393.235044 1429.553272 368.760685
74 9728.0 1401.458065 1441.288254 369.683333
75 9856.0 1396.620020 1441.643412 369.945633
76 9984.0 1391.710160 1447.342120 370.201790
77 10112.0 1408.177239 1454.496136 371.450252
78 10240.0 1408.013746 1463.856631 371.765898
79 10368.0 1412.831752 1464.424151 370.076976
80 10496.0 1410.730869 1464.739874 370.413013
81 10624.0 1407.995487 1467.148953 370.872699
82 10752.0 1401.726694 1473.363476 371.501446
83 10880.0 1394.905228 1480.363461 371.858680
84 11008.0 1416.782862 1477.417634 372.979405
85 11136.0 1424.858054 1486.373701 372.988118
86 11264.0 1413.743334 1485.504816 373.405603
87 11392.0 1424.490259 1488.262365 374.089737
88 11520.0 1415.587495 1494.503457 373.667128
89 11648.0 1417.579578 1498.980549 374.409829
90 11776.0 1433.665511 1500.585282 374.637699
91 11904.0 1432.593872 1509.243634 375.294065
92 12032.0 1429.695313 1508.840525 376.420749
93 12160.0 1422.011454 1517.582988 375.855074
94 12288.0 1431.659257 1418.086413 375.914640
95 12416.0 1433.010428 1393.989466 374.006705
96 12544.0 1443.285842 1393.493167 375.195879
97 12672.0 1436.689774 1392.959954 375.022730
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 37.341 seconds)