Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg``
line_names=[
"Triton",
"Torch",
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton Torch
0 256.0 475.586631 692.157095
1 384.0 611.051785 808.699702
2 512.0 762.678891 925.326796
3 640.0 822.274457 965.944912
4 768.0 886.224978 1015.784524
5 896.0 945.713281 1077.909025
6 1024.0 1010.767050 1113.138008
7 1152.0 1099.146418 611.715954
8 1280.0 1153.143540 670.297276
9 1408.0 1165.429899 722.166542
10 1536.0 1191.784200 780.627506
11 1664.0 1219.958145 813.077504
12 1792.0 1233.126736 861.197467
13 1920.0 1258.013602 910.222208
14 2048.0 1275.289595 955.609685
15 2176.0 1242.407449 975.753874
16 2304.0 1252.329481 1012.569383
17 2432.0 1278.668225 1058.570388
18 2560.0 1285.639593 1087.911734
19 2688.0 1297.413186 1105.179623
20 2816.0 1300.162925 1135.020012
21 2944.0 1309.796778 1168.738688
22 3072.0 1332.224805 1182.738756
23 3200.0 1332.148775 1193.860263
24 3328.0 1346.485823 1223.470004
25 3456.0 1358.373790 1251.767583
26 3584.0 1356.498759 1263.244879
27 3712.0 1372.213712 1268.742050
28 3840.0 1377.363554 1303.907713
29 3968.0 1372.946617 1314.766749
30 4096.0 1377.409884 1325.599210
31 4224.0 1337.863399 1156.911946
32 4352.0 1335.541849 1175.155703
33 4480.0 1354.617656 1182.995987
34 4608.0 1362.844596 1197.643893
35 4736.0 1360.106666 1198.106973
36 4864.0 1375.461496 1225.108909
37 4992.0 1371.606221 1233.161330
38 5120.0 1381.242754 1253.476705
39 5248.0 1375.990692 1260.819547
40 5376.0 1378.979904 1288.587020
41 5504.0 1378.620710 1301.990365
42 5632.0 1388.118114 1311.909073
43 5760.0 1395.694858 1327.468564
44 5888.0 1395.517468 1343.523082
45 6016.0 1402.903549 1356.581561
46 6144.0 1406.995799 1377.027459
47 6272.0 1417.749381 1372.618327
48 6400.0 1419.290307 1391.055790
49 6528.0 1414.921985 1393.006569
50 6656.0 1421.444524 1404.901593
51 6784.0 1413.345453 1416.048681
52 6912.0 1431.673437 1426.846516
53 7040.0 1421.295431 1433.377945
54 7168.0 1425.225028 1435.829200
55 7296.0 1429.739075 1442.312533
56 7424.0 1431.368755 1444.667094
57 7552.0 1427.816041 1451.826502
58 7680.0 1439.304454 1458.003964
59 7808.0 1434.253588 1466.777968
60 7936.0 1436.958188 1467.959121
61 8064.0 1440.581548 1475.614033
62 8192.0 1435.831512 1485.574078
63 8320.0 1379.934034 1402.104189
64 8448.0 1377.738352 1407.554899
65 8576.0 1393.595855 1395.451743
66 8704.0 1385.066238 1398.704504
67 8832.0 1374.671337 1407.119058
68 8960.0 1396.574538 1414.371774
69 9088.0 1404.289136 1415.702828
70 9216.0 1398.652796 1424.417360
71 9344.0 1394.309383 1422.263055
72 9472.0 1396.388869 1438.260282
73 9600.0 1392.540486 1434.334182
74 9728.0 1395.628384 1445.040886
75 9856.0 1405.740776 1441.417621
76 9984.0 1395.367936 1453.700537
77 10112.0 1406.191064 1456.147611
78 10240.0 1415.767613 1468.447168
79 10368.0 1407.580369 1465.139864
80 10496.0 1410.486093 1467.379233
81 10624.0 1407.490111 1468.979738
82 10752.0 1401.359667 1474.434605
83 10880.0 1394.243678 1481.430634
84 11008.0 1414.490981 1477.102277
85 11136.0 1418.118266 1485.731108
86 11264.0 1423.721222 1487.016576
87 11392.0 1416.254858 1487.185295
88 11520.0 1420.554241 1496.610054
89 11648.0 1418.382543 1497.178144
90 11776.0 1422.178307 1501.891397
91 11904.0 1437.630712 1506.344719
92 12032.0 1417.783439 1508.794889
93 12160.0 1417.024373 1513.149493
94 12288.0 1430.544500 1393.853407
95 12416.0 1445.442923 1389.850756
96 12544.0 1438.052691 1392.989890
97 12672.0 1447.611958 1390.703994
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 23.249 seconds)