Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive_softmax':
ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton (GB/s) Torch (GB/s) Naive Softmax (GB/s)
0 256.0 479.781182 688.625207 208.987586
1 384.0 657.673842 815.728814 263.644206
2 512.0 805.414257 927.425586 299.593141
3 640.0 911.149048 918.399745 329.779601
4 768.0 988.283253 989.660727 348.472431
5 896.0 1053.691792 1037.975260 353.985226
6 1024.0 1080.694808 1067.017140 352.434521
7 1152.0 1098.452974 1075.291077 348.620721
8 1280.0 1137.718677 1099.305531 348.750081
9 1408.0 1169.334802 1137.996755 339.954948
10 1536.0 1185.699169 1157.250752 333.555329
11 1664.0 1212.157901 1193.285535 330.303689
12 1792.0 1228.453606 1193.109986 326.204874
13 1920.0 1252.607669 1225.363032 324.517779
14 2048.0 1276.729177 1243.581848 324.087924
15 2176.0 1236.741339 959.130061 325.771214
16 2304.0 1252.735515 1004.425475 326.114892
17 2432.0 1265.407257 1034.088036 325.984479
18 2560.0 1281.347128 1071.737649 328.055721
19 2688.0 1295.767387 1100.693101 329.109278
20 2816.0 1314.557607 1126.034001 329.184528
21 2944.0 1317.774173 1143.547856 331.630184
22 3072.0 1325.597177 1170.120676 333.249958
23 3200.0 1341.818311 1174.498366 334.323162
24 3328.0 1344.092006 1196.551544 336.119174
25 3456.0 1349.377115 1223.790312 337.011476
26 3584.0 1358.690087 1239.514034 337.985459
27 3712.0 1369.597754 1261.849190 340.105162
28 3840.0 1372.722466 1282.793638 340.247436
29 3968.0 1377.408638 1296.502510 340.819870
30 4096.0 1390.341948 1315.666402 338.885231
31 4224.0 1326.133645 1279.086809 343.011886
32 4352.0 1340.593447 1295.210272 345.225582
33 4480.0 1348.797177 1318.509138 346.260399
34 4608.0 1356.320739 1336.182709 346.586377
35 4736.0 1361.628777 1343.964607 347.939725
36 4864.0 1365.299096 1356.739776 348.838700
37 4992.0 1370.849053 1373.882160 349.871840
38 5120.0 1372.581946 1384.780686 350.334995
39 5248.0 1376.758297 1353.924955 351.914773
40 5376.0 1380.467876 1368.374138 351.418178
41 5504.0 1383.996045 1379.726486 353.600149
42 5632.0 1393.321318 1393.712962 353.208412
43 5760.0 1395.118355 1405.432369 354.953758
44 5888.0 1391.266428 1417.560204 354.915940
45 6016.0 1402.692226 1419.285124 356.508356
46 6144.0 1409.565803 1427.441797 357.176855
47 6272.0 1408.761964 1397.427764 357.493315
48 6400.0 1410.296723 1411.735196 357.995939
49 6528.0 1412.017019 1412.889957 358.872617
50 6656.0 1413.653525 1431.494328 359.323961
51 6784.0 1418.629389 1433.680566 360.233845
52 6912.0 1421.956900 1440.619506 360.633741
53 7040.0 1420.684692 1449.915226 360.798001
54 7168.0 1422.670018 1455.174501 362.035869
55 7296.0 1423.580122 1082.049192 362.296326
56 7424.0 1427.298193 1096.763620 362.772475
57 7552.0 1426.841788 1108.034556 363.213348
58 7680.0 1430.103091 1122.907699 364.016380
59 7808.0 1431.374146 1128.539145 364.298735
60 7936.0 1432.964422 1139.890770 364.209214
61 8064.0 1434.165083 1147.524841 364.610096
62 8192.0 1433.908065 1148.726968 363.842927
63 8320.0 1380.826916 1115.978689 361.770409
64 8448.0 1381.316460 1123.545476 362.889457
65 8576.0 1389.308485 1123.863512 363.468264
66 8704.0 1381.412793 1132.992060 364.302519
67 8832.0 1396.614561 1131.225215 364.853782
68 8960.0 1388.630961 1138.040113 365.598816
69 9088.0 1396.935100 1135.232854 366.608855
70 9216.0 1405.272267 1141.438830 367.532000
71 9344.0 1387.280022 1421.428375 367.681536
72 9472.0 1396.486184 1427.347615 368.815180
73 9600.0 1400.616524 1431.691084 368.787515
74 9728.0 1398.398902 1443.438327 369.563910
75 9856.0 1399.888402 1440.888556 370.435724
76 9984.0 1393.050699 1451.403084 370.835312
77 10112.0 1402.376974 1456.504987 371.891105
78 10240.0 1409.858981 1463.827179 371.555127
79 10368.0 1415.596316 1461.388737 370.085854
80 10496.0 1413.322769 1468.165544 370.497401
81 10624.0 1408.137115 1463.134142 370.774913
82 10752.0 1397.303252 1472.617419 371.418349
83 10880.0 1394.759547 1481.661099 371.881936
84 11008.0 1422.446060 1479.103316 372.589329
85 11136.0 1419.561328 1486.240756 373.401159
86 11264.0 1410.283620 1485.205713 373.147894
87 11392.0 1422.751370 1490.516449 374.385608
88 11520.0 1410.305717 1499.644971 374.333342
89 11648.0 1422.099167 1499.055070 375.003669
90 11776.0 1438.241289 1501.677608 375.638014
91 11904.0 1432.181417 1508.180933 375.156017
92 12032.0 1417.078289 1508.179311 375.920386
93 12160.0 1413.788364 1516.435822 376.404866
94 12288.0 1426.985668 1421.984003 375.822737
95 12416.0 1436.504123 1400.556055 374.880106
96 12544.0 1444.269509 1394.610410 375.441527
97 12672.0 1437.226001 1394.385728 375.158875
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 35.006 seconds)