Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive_softmax':
ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton (GB/s) Torch (GB/s) Naive Softmax (GB/s)
0 256.0 513.553953 705.993678 207.259011
1 384.0 711.604994 830.548541 263.723631
2 512.0 842.043427 916.395801 301.183126
3 640.0 838.766311 927.336907 328.716621
4 768.0 910.869857 979.515101 349.812091
5 896.0 972.666872 1038.733805 356.388785
6 1024.0 1021.513300 1069.659145 355.017166
7 1152.0 1024.376508 1075.076007 348.141300
8 1280.0 1069.730486 1110.850165 348.073208
9 1408.0 1121.114768 1137.373739 340.074290
10 1536.0 1147.178163 1157.774964 333.914521
11 1664.0 1187.443983 1182.807076 330.396613
12 1792.0 1210.502978 1194.465971 325.844729
13 1920.0 1234.170462 1218.306945 324.507039
14 2048.0 1257.687069 1242.098709 323.774930
15 2176.0 1176.845044 963.429848 326.136419
16 2304.0 1194.298064 1004.818403 326.520182
17 2432.0 1227.351631 1034.710142 327.607106
18 2560.0 1241.169757 1067.485789 328.485503
19 2688.0 1257.869624 1100.523848 328.387807
20 2816.0 1272.016245 1120.599277 329.565511
21 2944.0 1295.121887 1143.962413 331.823523
22 3072.0 1306.791078 1175.146043 333.343602
23 3200.0 1322.202140 1170.060696 334.661969
24 3328.0 1328.910243 1194.121214 335.772987
25 3456.0 1331.739690 1221.699503 336.697365
26 3584.0 1338.116324 1247.386341 338.115283
27 3712.0 1349.510696 1260.140971 340.637710
28 3840.0 1357.174795 1284.717300 340.742391
29 3968.0 1356.389428 1295.368470 340.889928
30 4096.0 1373.454430 1320.221631 338.657676
31 4224.0 1330.751294 1278.602932 343.015958
32 4352.0 1350.090620 1298.312223 345.563391
33 4480.0 1352.041381 1317.487721 345.673393
34 4608.0 1364.491052 1333.230326 346.926652
35 4736.0 1359.037612 1348.012812 347.663178
36 4864.0 1374.349472 1356.746699 349.247568
37 4992.0 1367.503962 1369.028388 350.196541
38 5120.0 1376.006968 1386.421589 350.825667
39 5248.0 1380.752448 1352.453686 351.566280
40 5376.0 1382.644766 1371.448199 351.448990
41 5504.0 1390.484205 1378.812858 353.882453
42 5632.0 1397.499495 1392.493828 352.936183
43 5760.0 1400.511173 1397.317610 354.846839
44 5888.0 1397.451305 1416.813857 354.795270
45 6016.0 1404.273940 1417.410346 356.629280
46 6144.0 1416.886287 1433.578367 356.519914
47 6272.0 1416.437632 1390.474368 358.047883
48 6400.0 1420.292465 1418.246493 358.599735
49 6528.0 1417.582256 1415.694763 359.536375
50 6656.0 1422.489486 1431.438261 359.351516
51 6784.0 1427.462827 1426.877772 360.367245
52 6912.0 1426.843636 1443.795049 360.665902
53 7040.0 1425.011078 1445.656679 360.811650
54 7168.0 1428.392323 1457.057281 361.755280
55 7296.0 1431.692932 1088.400403 362.415205
56 7424.0 1434.203804 1096.197116 362.909872
57 7552.0 1432.674934 1111.565164 363.331661
58 7680.0 1436.100058 1123.027464 363.869607
59 7808.0 1436.656123 1130.664539 364.671382
60 7936.0 1437.007513 1142.121399 364.878083
61 8064.0 1437.087527 1145.515947 364.682418
62 8192.0 1436.585011 1150.821016 363.156678
63 8320.0 1388.842378 1118.598388 361.466780
64 8448.0 1386.392122 1123.187191 362.287133
65 8576.0 1388.683250 1124.589088 363.321232
66 8704.0 1388.983401 1132.689932 363.995505
67 8832.0 1389.791363 1131.955380 364.609554
68 8960.0 1391.618639 1141.339135 365.909865
69 9088.0 1398.414433 1136.018886 366.751111
70 9216.0 1404.923505 1143.332667 367.413801
71 9344.0 1393.327123 1417.835236 367.530891
72 9472.0 1402.775369 1428.558483 368.424586
73 9600.0 1401.352026 1431.586387 369.199393
74 9728.0 1398.919026 1441.056898 369.750134
75 9856.0 1401.571893 1439.870095 369.970147
76 9984.0 1393.862685 1447.428722 370.322476
77 10112.0 1406.218622 1455.015349 371.140813
78 10240.0 1405.085033 1465.524481 371.585796
79 10368.0 1419.283313 1461.285038 369.921687
80 10496.0 1407.678119 1465.779164 370.213310
81 10624.0 1406.154573 1464.702864 370.563669
82 10752.0 1396.644629 1472.364425 371.305444
83 10880.0 1395.607874 1478.924326 370.852025
84 11008.0 1416.729017 1476.664784 372.961657
85 11136.0 1419.823756 1482.877193 373.396713
86 11264.0 1410.170710 1486.523407 373.050237
87 11392.0 1421.706148 1490.990049 373.463702
88 11520.0 1417.822847 1495.347789 373.614283
89 11648.0 1420.872249 1498.742335 374.003206
90 11776.0 1434.239668 1503.089772 374.873840
91 11904.0 1432.133239 1510.243590 374.893994
92 12032.0 1418.522865 1511.354202 375.664079
93 12160.0 1413.719567 1516.596745 375.801953
94 12288.0 1425.761990 1420.258762 376.132164
95 12416.0 1433.500151 1396.736398 374.350286
96 12544.0 1443.365889 1396.245500 375.388862
97 12672.0 1435.577076 1394.351791 375.237976
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 34.943 seconds)