Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg``
line_names=[
"Triton",
"Torch",
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)
softmax-performance:
N Triton Torch
0 256.0 468.525462 706.593640
1 384.0 612.033628 820.719852
2 512.0 762.467266 918.101833
3 640.0 811.147301 961.043509
4 768.0 886.393504 1020.536303
5 896.0 952.535379 1075.271109
6 1024.0 1009.066185 1115.058111
7 1152.0 1099.714890 614.148310
8 1280.0 1152.360083 667.119022
9 1408.0 1157.898410 726.046832
10 1536.0 1186.686291 779.237050
11 1664.0 1221.174560 810.666662
12 1792.0 1233.178123 860.123075
13 1920.0 1248.900661 910.654618
14 2048.0 1278.688234 960.565219
15 2176.0 1241.460066 974.425788
16 2304.0 1249.326045 1012.433184
17 2432.0 1282.541506 1058.271199
18 2560.0 1282.890237 1089.501329
19 2688.0 1297.276822 1101.514696
20 2816.0 1305.926868 1130.480108
21 2944.0 1314.874736 1169.750034
22 3072.0 1335.013006 1184.625721
23 3200.0 1333.402456 1192.356121
24 3328.0 1339.395946 1228.571170
25 3456.0 1354.827941 1246.445136
26 3584.0 1356.554660 1259.200679
27 3712.0 1367.376117 1274.364256
28 3840.0 1370.577873 1298.164587
29 3968.0 1377.633261 1314.526756
30 4096.0 1382.684029 1328.490755
31 4224.0 1340.454745 1160.213260
32 4352.0 1332.913820 1176.221686
33 4480.0 1351.998559 1185.319171
34 4608.0 1362.085333 1197.152115
35 4736.0 1364.513792 1201.355460
36 4864.0 1376.758264 1223.927407
37 4992.0 1368.632676 1235.950819
38 5120.0 1374.251484 1252.467141
39 5248.0 1375.614748 1260.808643
40 5376.0 1379.294446 1286.907848
41 5504.0 1380.427699 1300.938687
42 5632.0 1386.688725 1316.266628
43 5760.0 1395.847816 1323.983693
44 5888.0 1388.921393 1342.447678
45 6016.0 1399.401344 1354.084438
46 6144.0 1409.490053 1376.076502
47 6272.0 1416.754813 1373.426830
48 6400.0 1418.377589 1389.615622
49 6528.0 1414.806780 1397.413756
50 6656.0 1422.453318 1403.039382
51 6784.0 1412.176698 1416.345928
52 6912.0 1428.399821 1426.025129
53 7040.0 1424.151392 1432.381591
54 7168.0 1427.671537 1436.063840
55 7296.0 1433.001417 1444.299869
56 7424.0 1430.516974 1444.240501
57 7552.0 1430.051127 1453.501460
58 7680.0 1435.939215 1462.386600
59 7808.0 1436.211969 1464.431365
60 7936.0 1436.482017 1465.563204
61 8064.0 1443.378996 1472.689937
62 8192.0 1438.849835 1483.776844
63 8320.0 1383.442444 1400.922030
64 8448.0 1372.289864 1402.527918
65 8576.0 1387.675346 1397.870853
66 8704.0 1384.611369 1402.143431
67 8832.0 1378.601729 1404.727793
68 8960.0 1391.895847 1413.714786
69 9088.0 1402.579864 1417.100907
70 9216.0 1396.239680 1425.894734
71 9344.0 1395.964542 1421.808900
72 9472.0 1396.542217 1434.766971
73 9600.0 1391.339223 1432.800387
74 9728.0 1400.931261 1441.334590
75 9856.0 1407.491223 1445.324325
76 9984.0 1396.718897 1447.569104
77 10112.0 1410.516498 1458.217184
78 10240.0 1411.098022 1466.663896
79 10368.0 1409.218281 1466.624640
80 10496.0 1411.795365 1468.286548
81 10624.0 1410.777555 1466.253615
82 10752.0 1400.341882 1474.331605
83 10880.0 1397.879109 1480.252035
84 11008.0 1417.128770 1481.411705
85 11136.0 1417.145322 1483.387262
86 11264.0 1423.102333 1488.682387
87 11392.0 1410.360430 1490.513388
88 11520.0 1418.324676 1494.304419
89 11648.0 1419.401204 1501.700954
90 11776.0 1425.657314 1500.849330
91 11904.0 1436.706558 1508.994477
92 12032.0 1417.533839 1509.353186
93 12160.0 1419.888050 1513.509235
94 12288.0 1427.373825 1392.649035
95 12416.0 1448.379662 1389.611285
96 12544.0 1438.620630 1395.112607
97 12672.0 1444.140489 1391.291386
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 23.151 seconds)