Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive_softmax':
ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton Torch Naive Softmax
0 256.0 475.166542 675.218075 206.112331
1 384.0 655.123930 829.318041 261.964284
2 512.0 820.240936 920.826918 300.119709
3 640.0 919.818146 910.222212 329.405789
4 768.0 988.500953 980.283489 348.762065
5 896.0 1046.180499 1034.781675 353.955498
6 1024.0 1085.341715 1082.789026 352.988781
7 1152.0 1088.758877 1077.305928 347.923159
8 1280.0 1125.320758 1109.671632 349.082970
9 1408.0 1156.547509 1140.081001 340.524440
10 1536.0 1195.061126 1159.202653 333.435594
11 1664.0 1214.181932 1183.183909 329.563218
12 1792.0 1225.788360 1191.537175 325.590999
13 1920.0 1260.929975 1224.866509 324.211874
14 2048.0 1269.980596 1243.374771 324.300319
15 2176.0 1239.481839 960.011808 325.278138
16 2304.0 1256.071190 1004.069000 325.824058
17 2432.0 1271.804802 1034.925234 326.790521
18 2560.0 1285.339836 1067.259125 327.731295
19 2688.0 1294.067511 1100.757789 328.664549
20 2816.0 1306.664337 1124.630729 329.752174
21 2944.0 1321.303487 1147.969438 331.210784
22 3072.0 1322.431384 1175.069674 332.841686
23 3200.0 1334.604891 1169.863007 334.714863
24 3328.0 1347.131181 1199.093927 336.254812
25 3456.0 1356.477369 1226.493753 336.719658
26 3584.0 1358.626995 1247.479158 337.674973
27 3712.0 1371.647801 1263.625509 340.304252
28 3840.0 1369.816780 1280.972247 340.121546
29 3968.0 1375.528414 1300.882402 340.857667
30 4096.0 1386.007011 1319.005245 338.494617
31 4224.0 1328.530186 1279.987447 343.407694
32 4352.0 1342.233320 1298.143303 345.206784
33 4480.0 1347.923755 1320.795082 345.862487
34 4608.0 1359.617221 1336.947806 346.801983
35 4736.0 1353.313026 1348.619056 347.983447
36 4864.0 1368.299874 1359.355064 348.815461
37 4992.0 1371.427217 1374.804306 349.770660
38 5120.0 1374.992917 1386.795634 350.996853
39 5248.0 1374.226440 1355.796792 351.547410
40 5376.0 1378.554977 1369.527929 351.521772
41 5504.0 1376.802714 1387.294885 353.572592
42 5632.0 1392.038419 1395.118961 353.055168
43 5760.0 1394.849569 1407.915898 354.943088
44 5888.0 1393.439800 1406.941706 354.985594
45 6016.0 1399.191286 1424.789814 356.708391
46 6144.0 1406.135293 1434.852743 357.084182
47 6272.0 1407.924231 1389.846673 357.488702
48 6400.0 1410.331764 1415.362757 358.710882
49 6528.0 1413.901202 1417.482148 358.771423
50 6656.0 1417.531940 1430.927169 359.590513
51 6784.0 1418.330524 1431.930771 360.224651
52 6912.0 1423.713121 1445.207062 360.969399
53 7040.0 1419.455565 1449.944026 360.943004
54 7168.0 1420.082358 1466.447063 361.613843
55 7296.0 1427.423121 1086.122102 362.204935
56 7424.0 1429.852124 1101.918526 362.868641
57 7552.0 1428.321119 1112.382291 363.327109
58 7680.0 1432.570851 1124.138155 363.699493
59 7808.0 1430.566466 1132.011378 364.685028
60 7936.0 1431.831105 1146.323226 364.691308
61 8064.0 1431.492392 1149.162852 364.668855
62 8192.0 1434.785207 1155.191268 363.917400
63 8320.0 1383.931619 1117.103520 362.098143
64 8448.0 1386.548368 1127.731977 362.750969
65 8576.0 1385.551962 1126.687030 363.575272
66 8704.0 1382.688836 1134.045907 364.324784
67 8832.0 1390.409971 1134.128897 365.182896
68 8960.0 1384.893552 1138.200061 365.852057
69 9088.0 1399.933157 1135.963177 367.165173
70 9216.0 1401.714158 1144.239871 367.645039
71 9344.0 1390.608061 1419.232831 367.894427
72 9472.0 1398.644747 1429.941794 368.439350
73 9600.0 1404.299488 1428.562451 369.194911
74 9728.0 1402.438849 1441.918501 370.136485
75 9856.0 1401.180528 1443.792075 370.079887
76 9984.0 1391.217459 1452.395937 370.604705
77 10112.0 1403.336103 1451.577118 370.635158
78 10240.0 1408.840210 1463.563501 371.169325
79 10368.0 1414.884589 1462.547608 369.757665
80 10496.0 1409.902423 1468.434117 370.484073
81 10624.0 1401.148458 1466.643644 370.384282
82 10752.0 1396.693696 1467.963231 370.900729
83 10880.0 1395.752229 1477.250378 371.346187
84 11008.0 1417.250086 1477.100080 372.010157
85 11136.0 1419.920844 1481.750833 372.824045
86 11264.0 1410.755379 1486.420145 372.460934
87 11392.0 1424.095915 1487.538180 373.904506
88 11520.0 1416.377672 1496.383799 373.583461
89 11648.0 1422.967528 1500.403335 374.609043
90 11776.0 1430.630993 1502.691116 374.757456
91 11904.0 1432.965945 1510.090107 375.334162
92 12032.0 1414.554312 1508.802320 375.739168
93 12160.0 1414.560959 1513.937107 375.532153
94 12288.0 1424.370203 1417.793602 375.967891
95 12416.0 1437.123832 1396.183056 374.712168
96 12544.0 1444.976847 1395.603119 375.687499
97 12672.0 1439.138774 1392.988592 375.176450
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 37.774 seconds)