Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive_softmax':
ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton Torch Naive Softmax
0 256.0 468.452833 689.354972 206.729612
1 384.0 664.989565 831.386445 261.157176
2 512.0 799.942147 915.989562 300.951619
3 640.0 908.348602 914.048137 328.450257
4 768.0 974.912881 986.068762 347.559270
5 896.0 1048.817418 1039.267858 354.104773
6 1024.0 1081.471589 1078.798045 352.362205
7 1152.0 1091.274992 1074.480018 348.015856
8 1280.0 1125.538595 1111.471096 348.531951
9 1408.0 1157.858831 1136.119683 340.733713
10 1536.0 1195.657043 1167.718468 332.915206
11 1664.0 1207.879399 1189.899426 329.786972
12 1792.0 1230.815870 1198.938522 325.148237
13 1920.0 1262.632066 1227.638062 324.015053
14 2048.0 1267.102964 1244.418997 323.964303
15 2176.0 1237.686487 957.730105 325.602275
16 2304.0 1259.949514 998.896887 325.942816
17 2432.0 1269.163095 1033.808767 325.907522
18 2560.0 1289.635212 1067.200670 327.538818
19 2688.0 1296.258330 1100.951208 328.620297
20 2816.0 1310.418663 1124.871670 329.249392
21 2944.0 1317.155750 1147.104803 331.351009
22 3072.0 1324.398011 1170.516868 333.100111
23 3200.0 1340.366390 1169.371643 334.841402
24 3328.0 1347.172950 1199.214184 336.316971
25 3456.0 1347.760514 1223.137307 337.037815
26 3584.0 1362.155307 1242.285015 337.959922
27 3712.0 1365.493484 1260.091486 340.365623
28 3840.0 1367.256816 1284.609492 340.067124
29 3968.0 1371.577329 1297.517040 340.548659
30 4096.0 1382.868047 1314.365476 338.537517
31 4224.0 1330.608799 1273.767340 343.406837
32 4352.0 1340.744877 1298.498835 345.221623
33 4480.0 1342.439833 1315.847594 345.315166
34 4608.0 1356.865069 1333.747696 346.584171
35 4736.0 1357.688674 1344.857231 347.934884
36 4864.0 1368.607496 1363.129797 349.020782
37 4992.0 1369.836397 1373.485521 349.970426
38 5120.0 1377.704960 1388.604593 351.143594
39 5248.0 1373.898009 1352.177877 351.133694
40 5376.0 1373.541617 1369.068368 351.567181
41 5504.0 1376.542504 1372.477797 353.371911
42 5632.0 1391.723901 1396.560963 352.813959
43 5760.0 1394.458110 1406.584675 354.851486
44 5888.0 1389.453850 1413.580976 354.443016
45 6016.0 1402.479845 1413.567158 356.792191
46 6144.0 1408.258224 1420.565745 356.806450
47 6272.0 1412.107307 1400.299806 357.567158
48 6400.0 1411.947123 1409.931759 358.338950
49 6528.0 1416.289773 1419.071086 359.447925
50 6656.0 1415.305316 1433.894799 359.311200
51 6784.0 1417.893859 1430.714533 359.591245
52 6912.0 1423.587410 1437.680395 360.780803
53 7040.0 1414.719923 1450.212370 360.603686
54 7168.0 1422.184473 1449.774478 361.810060
55 7296.0 1423.085616 1086.855909 362.552323
56 7424.0 1430.063088 1095.659103 362.886964
57 7552.0 1423.776684 1111.656088 363.550295
58 7680.0 1427.860558 1120.869133 363.645226
59 7808.0 1433.065707 1133.450126 364.289655
60 7936.0 1433.949817 1141.201886 364.928276
61 8064.0 1429.935656 1152.136903 365.234802
62 8192.0 1429.637557 1149.870765 363.981518
63 8320.0 1384.127461 1117.448146 362.266815
64 8448.0 1381.693456 1125.435675 362.402980
65 8576.0 1385.757097 1128.787876 363.405873
66 8704.0 1383.150348 1134.673944 364.360417
67 8832.0 1398.098452 1132.714809 364.831567
68 8960.0 1383.778719 1137.135833 366.074490
69 9088.0 1394.381390 1134.308542 366.701941
70 9216.0 1407.059183 1142.613601 367.604997
71 9344.0 1389.330103 1422.350931 367.411346
72 9472.0 1400.804707 1428.663727 368.461745
73 9600.0 1398.967302 1431.269603 369.102115
74 9728.0 1394.465843 1441.434411 369.513179
75 9856.0 1400.970135 1439.068066 370.102273
76 9984.0 1394.213251 1448.962947 370.427513
77 10112.0 1404.414216 1453.625817 371.459231
78 10240.0 1407.812931 1464.750700 371.586504
79 10368.0 1417.341396 1461.264155 370.112489
80 10496.0 1405.468574 1466.430691 370.697412
81 10624.0 1405.869157 1466.424267 370.779356
82 10752.0 1396.925769 1473.176760 371.452425
83 10880.0 1388.155189 1477.137031 372.058562
84 11008.0 1422.063596 1476.099900 372.886245
85 11136.0 1418.511102 1483.006054 373.059114
86 11264.0 1415.307291 1489.404519 373.205625
87 11392.0 1421.515852 1486.380987 374.083120
88 11520.0 1407.813784 1494.151362 374.015381
89 11648.0 1420.312897 1500.019746 374.874991
90 11776.0 1433.157377 1499.672759 374.912809
91 11904.0 1433.460806 1505.101197 375.646318
92 12032.0 1414.247040 1509.848282 375.752422
93 12160.0 1409.691067 1517.099969 375.691335
94 12288.0 1429.344513 1420.003995 376.902178
95 12416.0 1438.311463 1397.545112 374.770257
96 12544.0 1442.326405 1397.559343 375.749041
97 12672.0 1433.880325 1395.162573 375.308313
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 35.069 seconds)