Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive_softmax':
ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton (GB/s) Torch (GB/s) Naive Softmax (GB/s)
0 256.0 479.392738 688.644471 208.291026
1 384.0 666.701587 819.647163 265.336130
2 512.0 814.381594 911.821022 300.529921
3 640.0 918.974351 915.359612 328.833999
4 768.0 988.334036 974.393732 349.393180
5 896.0 1052.354057 1040.946241 356.511876
6 1024.0 1083.969623 1079.789974 355.806450
7 1152.0 1090.015481 1068.152457 349.438616
8 1280.0 1130.875041 1114.771225 349.595371
9 1408.0 1172.829092 1130.575448 341.536844
10 1536.0 1189.198121 1158.037248 332.885244
11 1664.0 1216.394887 1192.698726 330.042768
12 1792.0 1238.646653 1192.269470 325.922148
13 1920.0 1253.133199 1218.957510 324.107610
14 2048.0 1267.869246 1242.002648 323.808636
15 2176.0 1229.521425 962.382388 325.760834
16 2304.0 1258.626822 1004.627491 326.386060
17 2432.0 1268.672580 1034.881447 327.632528
18 2560.0 1283.067203 1068.026464 328.269200
19 2688.0 1295.403693 1097.625663 329.638462
20 2816.0 1312.918348 1126.770677 329.714936
21 2944.0 1321.249303 1147.475445 331.509081
22 3072.0 1321.962839 1169.654304 334.329260
23 3200.0 1337.790602 1175.557858 334.799862
24 3328.0 1344.825668 1198.484977 336.495734
25 3456.0 1347.881116 1225.761493 336.590062
26 3584.0 1364.855453 1248.444230 338.768782
27 3712.0 1362.249774 1263.894840 340.475992
28 3840.0 1368.319726 1285.649284 340.682410
29 3968.0 1374.910593 1297.324039 341.353283
30 4096.0 1387.623536 1315.103714 339.172059
31 4224.0 1332.575837 1278.375278 343.225698
32 4352.0 1346.440717 1300.796895 345.284246
33 4480.0 1349.759194 1316.351223 346.088020
34 4608.0 1358.267471 1336.244469 346.559639
35 4736.0 1361.059207 1345.500888 348.218048
36 4864.0 1366.759780 1359.102100 349.544293
37 4992.0 1367.526915 1372.345818 350.040602
38 5120.0 1376.970152 1384.510671 351.151968
39 5248.0 1374.523076 1352.943718 351.769267
40 5376.0 1382.906553 1369.164980 352.029528
41 5504.0 1378.327418 1380.601194 354.040175
42 5632.0 1394.310535 1394.658896 353.027318
43 5760.0 1393.677258 1400.080528 355.088660
44 5888.0 1396.089025 1420.363069 355.002988
45 6016.0 1400.837311 1420.025852 356.277630
46 6144.0 1410.312431 1422.664145 356.843455
47 6272.0 1408.401746 1400.618076 358.108065
48 6400.0 1412.474849 1404.704408 358.469040
49 6528.0 1414.434732 1423.880117 359.374819
50 6656.0 1413.966761 1424.257280 359.411236
51 6784.0 1415.473668 1438.572041 360.073442
52 6912.0 1419.738428 1439.471078 360.812988
53 7040.0 1420.920065 1460.479008 361.016527
54 7168.0 1424.308423 1456.577167 361.664017
55 7296.0 1422.336124 1081.533201 362.191231
56 7424.0 1431.017576 1097.218633 362.608483
57 7552.0 1429.993367 1111.797656 363.267944
58 7680.0 1433.983217 1121.882492 363.773756
59 7808.0 1427.894284 1131.042270 364.521322
60 7936.0 1436.483822 1144.706250 364.819339
61 8064.0 1434.042612 1147.319259 365.144132
62 8192.0 1431.391406 1148.758413 363.552395
63 8320.0 1380.422222 1116.801198 361.582812
64 8448.0 1384.886388 1125.678422 362.446951
65 8576.0 1388.245213 1129.163929 363.392506
66 8704.0 1380.456433 1132.834062 364.396059
67 8832.0 1393.213229 1131.710511 364.782700
68 8960.0 1386.567015 1139.804692 365.718730
69 9088.0 1399.766941 1137.059306 366.337985
70 9216.0 1405.970683 1146.345696 367.356037
71 9344.0 1392.151314 1419.416623 367.318290
72 9472.0 1396.291894 1431.596048 368.397985
73 9600.0 1404.701711 1432.142547 368.738328
74 9728.0 1396.771624 1439.911869 369.807814
75 9856.0 1400.908786 1442.263883 369.686348
76 9984.0 1394.004239 1452.396622 370.318004
77 10112.0 1403.358341 1452.022288 371.468629
78 10240.0 1409.026240 1463.730080 371.576925
79 10368.0 1413.428249 1460.566847 370.121370
80 10496.0 1409.937957 1468.626491 370.559601
81 10624.0 1408.787923 1466.562983 371.144105
82 10752.0 1392.216982 1474.168759 371.594674
83 10880.0 1395.957592 1481.201501 372.166852
84 11008.0 1420.387097 1476.240577 372.567191
85 11136.0 1413.925945 1482.428129 373.365592
86 11264.0 1413.352153 1484.408905 373.258931
87 11392.0 1421.908114 1491.101463 374.553622
88 11520.0 1410.223970 1496.615377 374.010968
89 11648.0 1422.333078 1498.795433 375.070260
90 11776.0 1430.976099 1501.197578 374.921690
91 11904.0 1432.826248 1507.580245 375.178277
92 12032.0 1415.458450 1510.852878 375.858488
93 12160.0 1415.097196 1516.700226 376.089869
94 12288.0 1428.788911 1417.846148 375.896896
95 12416.0 1440.006274 1392.656451 374.681248
96 12544.0 1447.985834 1394.750678 375.274802
97 12672.0 1432.899657 1392.639999 375.141306
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 35.022 seconds)