Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg``
line_names=[
"Triton",
"Torch",
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton Torch
0 256.0 470.978332 701.687163
1 384.0 657.046463 823.769029
2 512.0 807.623624 914.109169
3 640.0 882.785457 947.893530
4 768.0 966.310849 1020.015563
5 896.0 1016.609970 1062.339363
6 1024.0 1066.900700 1124.095905
7 1152.0 1108.738464 1030.960172
8 1280.0 1138.822542 1069.541463
9 1408.0 1164.249191 1105.239250
10 1536.0 1179.214278 1133.391458
11 1664.0 1206.280581 1172.741617
12 1792.0 1229.846580 1191.847293
13 1920.0 1250.989273 1193.521627
14 2048.0 1279.660788 1224.995187
15 2176.0 1246.470077 963.575395
16 2304.0 1250.444109 1002.801480
17 2432.0 1274.255923 1042.264524
18 2560.0 1287.134793 1070.865953
19 2688.0 1296.740664 1101.098457
20 2816.0 1299.414948 1121.499955
21 2944.0 1323.460189 1150.705621
22 3072.0 1328.255127 1171.606015
23 3200.0 1336.806082 1179.513950
24 3328.0 1346.874872 1203.505746
25 3456.0 1352.139083 1221.506282
26 3584.0 1348.640976 1247.681154
27 3712.0 1362.861991 1264.200202
28 3840.0 1376.467408 1283.433035
29 3968.0 1371.757957 1301.804441
30 4096.0 1385.874169 1316.756551
31 4224.0 1337.132338 1295.909618
32 4352.0 1337.910554 1319.074892
33 4480.0 1352.801533 1333.709298
34 4608.0 1360.171779 1356.057065
35 4736.0 1361.794351 1367.855576
36 4864.0 1376.078475 1384.464041
37 4992.0 1376.901102 1396.357132
38 5120.0 1373.494378 1405.912769
39 5248.0 1378.435147 1367.206376
40 5376.0 1378.126308 1382.576327
41 5504.0 1387.285763 1392.836125
42 5632.0 1388.112728 1409.269053
43 5760.0 1396.261620 1423.088800
44 5888.0 1395.452816 1425.882295
45 6016.0 1396.320079 1433.370366
46 6144.0 1401.437664 1438.605996
47 6272.0 1414.430063 1410.725117
48 6400.0 1416.906908 1419.780491
49 6528.0 1411.271244 1432.799779
50 6656.0 1413.156172 1443.615925
51 6784.0 1413.261177 1451.404443
52 6912.0 1424.422773 1452.297443
53 7040.0 1421.637812 1454.983874
54 7168.0 1417.603576 1461.404738
55 7296.0 1429.798839 1084.869735
56 7424.0 1431.221102 1100.295661
57 7552.0 1425.874625 1113.697017
58 7680.0 1435.798310 1126.375496
59 7808.0 1425.099180 1135.422377
60 7936.0 1435.708060 1144.793951
61 8064.0 1435.586981 1153.504373
62 8192.0 1436.923848 1156.269126
63 8320.0 1379.171171 1114.501104
64 8448.0 1379.030784 1123.865804
65 8576.0 1390.105003 1124.515235
66 8704.0 1380.243375 1129.949517
67 8832.0 1381.072715 1130.295460
68 8960.0 1394.414844 1136.036844
69 9088.0 1410.892251 1130.223800
70 9216.0 1397.362322 1128.744948
71 9344.0 1400.354490 1421.210205
72 9472.0 1393.530763 1433.360631
73 9600.0 1397.761406 1431.225555
74 9728.0 1405.523492 1437.142920
75 9856.0 1396.429582 1440.369504
76 9984.0 1401.238070 1444.727769
77 10112.0 1401.803841 1452.943359
78 10240.0 1423.042407 1466.950202
79 10368.0 1409.037797 1462.317338
80 10496.0 1415.249472 1463.784458
81 10624.0 1400.139297 1464.812037
82 10752.0 1404.510763 1472.090890
83 10880.0 1404.918279 1477.783502
84 11008.0 1411.919631 1479.354753
85 11136.0 1424.070611 1480.523601
86 11264.0 1421.457930 1487.843155
87 11392.0 1420.358185 1485.709549
88 11520.0 1421.296138 1496.295659
89 11648.0 1432.400225 1499.858134
90 11776.0 1435.562742 1501.171293
91 11904.0 1437.085261 1509.589623
92 12032.0 1424.839440 1510.485480
93 12160.0 1403.574468 1513.071395
94 12288.0 1440.074017 1424.375949
95 12416.0 1441.300454 1396.538354
96 12544.0 1453.757131 1389.912639
97 12672.0 1438.051017 1393.591433
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 23.372 seconds)