Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive_softmax':
ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton Torch Naive Softmax
0 256.0 476.804403 687.969985 205.930110
1 384.0 662.670682 829.193537 261.757997
2 512.0 808.033722 919.400934 303.900680
3 640.0 888.341542 920.608567 332.271392
4 768.0 954.119287 987.451652 348.446607
5 896.0 1023.806188 1037.667343 352.495114
6 1024.0 1048.242990 1066.704213 353.397081
7 1152.0 1094.665835 1074.035779 347.082055
8 1280.0 1124.055581 1101.221035 349.420908
9 1408.0 1169.273383 1131.067580 340.899233
10 1536.0 1195.552279 1163.666096 333.180475
11 1664.0 1215.910110 1181.097305 328.762315
12 1792.0 1226.884782 1192.264547 325.771241
13 1920.0 1264.653860 1226.129247 325.122991
14 2048.0 1268.681502 1241.906960 324.411494
15 2176.0 1238.507833 962.538377 325.509433
16 2304.0 1255.158313 1005.811349 325.721077
17 2432.0 1270.893003 1033.507096 326.466147
18 2560.0 1279.614577 1065.187282 327.636496
19 2688.0 1290.785388 1095.859151 329.518529
20 2816.0 1311.566864 1125.159044 329.534915
21 2944.0 1310.162737 1143.963961 331.324179
22 3072.0 1312.527536 1167.995535 332.852419
23 3200.0 1337.116341 1172.826700 334.788098
24 3328.0 1346.586148 1199.001468 336.636970
25 3456.0 1346.676162 1220.157198 336.702317
26 3584.0 1364.103092 1242.841693 338.083304
27 3712.0 1362.545426 1264.904260 339.966692
28 3840.0 1373.942879 1279.181508 339.964787
29 3968.0 1368.919799 1300.300416 341.138220
30 4096.0 1388.865886 1313.394213 338.692256
31 4224.0 1331.704457 1276.336083 342.825111
32 4352.0 1340.376297 1300.092201 345.067957
33 4480.0 1344.071802 1319.910247 345.603724
34 4608.0 1354.015043 1336.353790 347.230797
35 4736.0 1353.346496 1343.917595 348.036578
36 4864.0 1362.262450 1357.989710 348.806429
37 4992.0 1363.396634 1371.548881 350.055164
38 5120.0 1375.202582 1387.763751 350.669824
39 5248.0 1366.159743 1357.733046 351.448374
40 5376.0 1368.595054 1362.176400 351.862417
41 5504.0 1376.655806 1378.700354 353.535239
42 5632.0 1383.544387 1396.457752 353.052972
43 5760.0 1388.236925 1399.013311 354.972358
44 5888.0 1386.077117 1416.166428 354.762310
45 6016.0 1394.702569 1424.574443 356.753184
46 6144.0 1404.625920 1424.060157 357.037863
47 6272.0 1405.415672 1401.067819 357.581008
48 6400.0 1410.072659 1399.259809 358.479735
49 6528.0 1406.401551 1417.757846 359.125846
50 6656.0 1408.531117 1424.125215 359.705531
51 6784.0 1414.092203 1437.120904 360.233846
52 6912.0 1418.601549 1439.060272 360.702662
53 7040.0 1413.901247 1449.855982 360.842062
54 7168.0 1414.536903 1455.311409 361.385952
55 7296.0 1420.665040 1085.782388 362.602826
56 7424.0 1425.042099 1097.941083 363.062709
57 7552.0 1423.359376 1111.645080 363.673392
58 7680.0 1428.690973 1121.946940 363.604994
59 7808.0 1429.748230 1131.377427 364.253615
60 7936.0 1427.677584 1143.830720 364.436456
61 8064.0 1431.443880 1148.489850 365.107877
62 8192.0 1428.136078 1151.297830 364.276937
63 8320.0 1386.578230 1116.797272 361.596207
64 8448.0 1387.161314 1125.208706 362.175811
65 8576.0 1389.714296 1126.433901 363.269797
66 8704.0 1385.456160 1132.979949 364.311422
67 8832.0 1393.997222 1132.447857 365.316491
68 8960.0 1389.622228 1140.095951 365.976587
69 9088.0 1396.294872 1138.921960 365.806288
70 9216.0 1397.913058 1142.548783 366.956630
71 9344.0 1386.410653 1421.447516 367.548607
72 9472.0 1400.005997 1433.539448 367.969698
73 9600.0 1402.187341 1431.859742 369.109778
74 9728.0 1398.747708 1439.656739 370.074259
75 9856.0 1399.900574 1436.280938 370.160670
76 9984.0 1386.425965 1445.165496 370.848624
77 10112.0 1405.950309 1451.225216 371.118407
78 10240.0 1403.642239 1464.517191 371.293283
79 10368.0 1412.309445 1462.836826 370.090294
80 10496.0 1406.176637 1466.335621 370.884285
81 10624.0 1403.565608 1462.936749 370.175987
82 10752.0 1391.931663 1472.926151 370.890380
83 10880.0 1388.920109 1477.640662 372.235355
84 11008.0 1417.285442 1474.023103 372.620328
85 11136.0 1417.869958 1483.418101 372.664543
86 11264.0 1411.596121 1486.976921 373.067988
87 11392.0 1415.276256 1488.629105 373.829584
88 11520.0 1408.548880 1495.046341 373.896283
89 11648.0 1414.934492 1501.105326 374.458507
90 11776.0 1430.041625 1503.253870 374.704219
91 11904.0 1426.228070 1507.221215 375.490174
92 12032.0 1405.693560 1508.463198 375.721497
93 12160.0 1405.857783 1513.659558 375.673640
94 12288.0 1419.781173 1415.224099 375.875309
95 12416.0 1431.049630 1394.635227 375.065895
96 12544.0 1440.051762 1391.502811 375.095080
97 12672.0 1430.071531 1393.572006 375.312714
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 35.159 seconds)