Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive_softmax':
ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton (GB/s) Torch (GB/s) Naive Softmax (GB/s)
0 256.0 470.649762 700.817182 206.202414
1 384.0 665.802183 824.408682 260.148190
2 512.0 801.862448 931.228685 302.507488
3 640.0 915.307256 931.417421 330.687754
4 768.0 977.271387 989.579833 350.998878
5 896.0 1040.031950 1030.076948 355.016508
6 1024.0 1071.696803 1068.521748 354.729148
7 1152.0 1093.441164 1074.746394 348.231019
8 1280.0 1131.573182 1111.253166 346.947063
9 1408.0 1161.282231 1132.801877 340.319450
10 1536.0 1186.796074 1168.022114 333.271551
11 1664.0 1214.549436 1190.689087 329.723156
12 1792.0 1233.596056 1191.868127 325.687480
13 1920.0 1265.317331 1218.887082 324.989199
14 2048.0 1270.269777 1246.482213 324.392242
15 2176.0 1239.035963 960.202709 325.310211
16 2304.0 1252.211755 1000.493773 326.253373
17 2432.0 1274.143996 1035.685042 326.855399
18 2560.0 1282.200577 1067.700715 328.022038
19 2688.0 1290.978981 1095.271461 329.596309
20 2816.0 1315.235709 1122.071462 329.133315
21 2944.0 1317.180629 1145.012259 331.550763
22 3072.0 1319.233956 1167.182948 333.001801
23 3200.0 1342.279506 1169.285843 335.116677
24 3328.0 1350.865585 1200.268281 336.388627
25 3456.0 1355.806049 1225.959260 337.489703
26 3584.0 1361.862122 1247.649271 338.532892
27 3712.0 1368.762294 1268.848621 340.884374
28 3840.0 1374.583808 1282.665277 340.752389
29 3968.0 1372.231833 1301.804441 341.054304
30 4096.0 1388.683049 1314.359800 338.761105
31 4224.0 1330.487821 1274.618820 343.511607
32 4352.0 1347.563827 1295.479070 345.695683
33 4480.0 1346.809681 1319.590236 345.818831
34 4608.0 1358.191917 1333.991521 347.080194
35 4736.0 1361.040033 1343.106037 348.184513
36 4864.0 1363.522118 1359.966632 349.167947
37 4992.0 1369.702544 1370.803051 350.456854
38 5120.0 1374.582795 1385.667875 350.887760
39 5248.0 1374.271683 1349.286562 351.359184
40 5376.0 1381.406730 1369.016430 351.586089
41 5504.0 1379.966897 1383.871244 353.329936
42 5632.0 1393.559823 1392.822484 353.259526
43 5760.0 1395.423256 1402.666349 354.744631
44 5888.0 1388.677987 1413.079302 354.702503
45 6016.0 1396.930463 1425.527471 356.857398
46 6144.0 1407.058343 1425.629150 356.866586
47 6272.0 1410.561443 1402.726720 357.890582
48 6400.0 1412.159158 1410.444825 358.608969
49 6528.0 1417.551815 1417.327986 359.257360
50 6656.0 1414.882871 1426.324873 359.195420
51 6784.0 1415.387784 1437.879807 360.371846
52 6912.0 1423.198362 1439.976732 360.794595
53 7040.0 1420.293252 1453.742787 360.699916
54 7168.0 1420.668580 1460.348105 361.495303
55 7296.0 1429.509711 1088.475858 362.761296
56 7424.0 1432.045253 1096.902001 363.012689
57 7552.0 1425.631090 1111.856195 363.792010
58 7680.0 1429.377949 1125.389247 364.238529
59 7808.0 1433.352665 1135.808385 364.239728
60 7936.0 1431.839078 1142.796993 364.887240
61 8064.0 1435.160582 1145.680746 365.003683
62 8192.0 1430.108525 1153.209998 364.236624
63 8320.0 1378.704137 1118.335948 361.938493
64 8448.0 1387.603514 1123.385591 362.765299
65 8576.0 1388.057191 1129.782728 363.227729
66 8704.0 1380.015420 1134.285404 364.699283
67 8832.0 1395.608068 1133.238153 365.182897
68 8960.0 1382.051082 1139.272133 366.208079
69 9088.0 1392.894054 1137.629419 366.840078
70 9216.0 1401.954908 1141.179532 367.945593
71 9344.0 1387.520847 1420.240647 367.559703
72 9472.0 1398.659764 1432.594786 368.914691
73 9600.0 1401.228323 1432.820464 369.084378
74 9728.0 1397.644151 1441.338598 369.656999
75 9856.0 1400.863546 1442.443755 369.801939
76 9984.0 1391.201476 1451.761345 370.389558
77 10112.0 1405.789264 1454.953735 371.251094
78 10240.0 1410.057077 1461.769519 371.918522
79 10368.0 1416.731529 1464.896550 370.001534
80 10496.0 1406.557332 1467.198439 370.790823
81 10624.0 1409.377877 1463.994107 371.201995
82 10752.0 1395.665765 1471.363991 371.260926
83 10880.0 1393.253976 1478.918165 371.626769
84 11008.0 1422.356301 1475.684654 372.899551
85 11136.0 1416.770068 1482.433896 372.819612
86 11264.0 1410.988554 1487.414068 372.598124
87 11392.0 1421.901707 1489.911729 374.107389
88 11520.0 1413.596349 1498.586935 374.116891
89 11648.0 1420.903985 1499.436604 374.564754
90 11776.0 1432.964055 1500.297313 374.677610
91 11904.0 1433.214730 1508.526163 375.329335
92 12032.0 1413.734740 1508.388290 376.199185
93 12160.0 1413.402203 1513.046680 375.624993
94 12288.0 1429.781400 1419.123522 375.285685
95 12416.0 1436.838607 1394.885639 374.672414
96 12544.0 1443.189303 1394.950661 375.546905
97 12672.0 1436.251966 1389.530552 375.246765
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 35.058 seconds)