Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg``
line_names=[
"Triton",
"Torch",
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton Torch
0 256.0 464.688720 692.938580
1 384.0 654.126790 825.132325
2 512.0 818.300029 924.469771
3 640.0 871.229190 957.468235
4 768.0 962.878166 1018.703574
5 896.0 1011.995003 1073.750907
6 1024.0 1056.050178 1108.067636
7 1152.0 1105.463698 1031.631006
8 1280.0 1146.149830 1070.598375
9 1408.0 1159.351170 1105.330973
10 1536.0 1191.687465 1132.787967
11 1664.0 1217.097136 1173.909889
12 1792.0 1240.603501 1195.301412
13 1920.0 1253.470703 1203.031238
14 2048.0 1283.268784 1226.030734
15 2176.0 1242.632716 962.070986
16 2304.0 1248.049885 999.609239
17 2432.0 1274.926347 1041.340876
18 2560.0 1281.169611 1072.370670
19 2688.0 1296.814748 1094.364012
20 2816.0 1298.172169 1121.491896
21 2944.0 1319.227565 1146.803676
22 3072.0 1330.097481 1171.360905
23 3200.0 1332.876867 1171.775419
24 3328.0 1339.956344 1203.119119
25 3456.0 1358.062302 1221.891395
26 3584.0 1354.210716 1246.541922
27 3712.0 1360.723908 1268.156550
28 3840.0 1373.000826 1285.773025
29 3968.0 1372.416746 1300.309059
30 4096.0 1378.063787 1317.890421
31 4224.0 1330.286526 1292.293692
32 4352.0 1334.506708 1318.067018
33 4480.0 1339.301499 1335.662865
34 4608.0 1350.503850 1354.550351
35 4736.0 1350.860060 1365.431206
36 4864.0 1363.951501 1378.884124
37 4992.0 1362.394672 1391.994112
38 5120.0 1359.223254 1409.694996
39 5248.0 1370.180858 1363.435643
40 5376.0 1368.095426 1376.192149
41 5504.0 1377.760742 1391.979957
42 5632.0 1385.059956 1400.521743
43 5760.0 1392.100964 1420.562245
44 5888.0 1391.466418 1425.303077
45 6016.0 1392.184857 1442.840343
46 6144.0 1396.006571 1450.822235
47 6272.0 1402.748404 1413.408445
48 6400.0 1413.048644 1423.461539
49 6528.0 1398.573334 1432.581200
50 6656.0 1411.208438 1438.838713
51 6784.0 1412.941262 1432.763064
52 6912.0 1418.667368 1447.672571
53 7040.0 1413.166605 1465.075031
54 7168.0 1411.101564 1469.341518
55 7296.0 1422.599712 1085.623578
56 7424.0 1423.367495 1100.738953
57 7552.0 1420.856437 1112.232121
58 7680.0 1429.991917 1125.328773
59 7808.0 1420.944846 1137.191547
60 7936.0 1432.304065 1146.355675
61 8064.0 1429.253738 1152.925304
62 8192.0 1436.852184 1156.043964
63 8320.0 1382.700179 1111.722960
64 8448.0 1381.481276 1124.085935
65 8576.0 1385.031773 1122.551912
66 8704.0 1383.165577 1128.864932
67 8832.0 1387.880932 1129.187059
68 8960.0 1393.675899 1136.138317
69 9088.0 1406.009666 1131.415370
70 9216.0 1388.189750 1130.316407
71 9344.0 1397.507824 1420.154356
72 9472.0 1386.354354 1432.468978
73 9600.0 1392.742611 1430.037456
74 9728.0 1400.073400 1439.349554
75 9856.0 1396.557020 1441.031911
76 9984.0 1389.653683 1447.830828
77 10112.0 1400.360057 1455.332909
78 10240.0 1419.363776 1461.132698
79 10368.0 1405.464398 1464.213902
80 10496.0 1411.916134 1460.689565
81 10624.0 1396.051014 1467.058987
82 10752.0 1401.623079 1469.733368
83 10880.0 1398.659544 1482.285795
84 11008.0 1408.951122 1474.462044
85 11136.0 1418.395161 1486.416762
86 11264.0 1414.135126 1483.125812
87 11392.0 1412.345093 1487.860328
88 11520.0 1411.162452 1495.056745
89 11648.0 1419.753509 1498.729929
90 11776.0 1426.601391 1503.525685
91 11904.0 1426.265459 1509.781408
92 12032.0 1415.529415 1512.161779
93 12160.0 1401.842903 1516.788589
94 12288.0 1431.658606 1424.244959
95 12416.0 1432.576769 1394.620911
96 12544.0 1445.248479 1392.938551
97 12672.0 1430.663156 1395.926252
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 31.879 seconds)