Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
if kernel is None:
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
kernels[BLOCK_SIZE] = (kernel, num_programs)
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](
y,
x,
x.stride(0),
y.stride(0),
n_rows,
n_cols,
)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg``
line_names=[
"Triton",
"Torch",
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)
softmax-performance:
N Triton Torch
0 256.0 482.498358 711.511792
1 384.0 612.573735 808.472705
2 512.0 758.495050 923.641285
3 640.0 790.603195 959.212554
4 768.0 871.435116 1031.371211
5 896.0 940.857273 1070.966370
6 1024.0 989.547335 1122.259997
7 1152.0 1096.626914 616.809145
8 1280.0 1138.241978 669.410766
9 1408.0 1169.432455 728.789692
10 1536.0 1196.443618 784.047309
11 1664.0 1211.041930 814.812512
12 1792.0 1240.661445 865.213254
13 1920.0 1257.413544 909.936735
14 2048.0 1273.694749 964.366886
15 2176.0 1259.845494 979.716721
16 2304.0 1269.187113 1018.440821
17 2432.0 1292.623351 1055.037690
18 2560.0 1301.035556 1088.540121
19 2688.0 1317.901162 1101.029024
20 2816.0 1319.432039 1132.792641
21 2944.0 1323.424460 1167.601032
22 3072.0 1346.964776 1180.961447
23 3200.0 1349.377686 1193.279129
24 3328.0 1356.050334 1222.964736
25 3456.0 1374.131214 1247.809509
26 3584.0 1379.397621 1264.466261
27 3712.0 1381.918164 1270.351388
28 3840.0 1388.263698 1304.229462
29 3968.0 1391.378186 1314.281169
30 4096.0 1398.618657 1327.886263
31 4224.0 1336.338279 1158.148950
32 4352.0 1335.215145 1174.842390
33 4480.0 1351.342805 1181.971415
34 4608.0 1361.834294 1197.288075
35 4736.0 1361.353857 1200.956438
36 4864.0 1374.497889 1220.513747
37 4992.0 1368.643927 1236.516519
38 5120.0 1372.897146 1253.604190
39 5248.0 1376.623384 1257.893014
40 5376.0 1378.088534 1284.726063
41 5504.0 1380.019527 1302.427824
42 5632.0 1388.763714 1315.360764
43 5760.0 1394.556124 1326.662645
44 5888.0 1394.383023 1341.235638
45 6016.0 1403.664729 1353.180854
46 6144.0 1407.436019 1376.194895
47 6272.0 1413.194071 1372.058381
48 6400.0 1415.938479 1389.893773
49 6528.0 1418.507806 1395.789421
50 6656.0 1421.536117 1402.685194
51 6784.0 1415.175974 1413.458355
52 6912.0 1424.702832 1425.027544
53 7040.0 1424.315563 1432.601697
54 7168.0 1427.985948 1432.608231
55 7296.0 1430.571219 1444.109612
56 7424.0 1432.281141 1447.565928
57 7552.0 1428.105407 1452.964603
58 7680.0 1435.930638 1460.016635
59 7808.0 1433.476972 1463.888915
60 7936.0 1436.901944 1466.849972
61 8064.0 1437.661622 1476.050106
62 8192.0 1439.143527 1486.742421
63 8320.0 1390.721845 1403.209510
64 8448.0 1384.042960 1404.619595
65 8576.0 1393.183027 1394.430819
66 8704.0 1391.872494 1398.063166
67 8832.0 1382.844820 1403.086522
68 8960.0 1395.034310 1411.388368
69 9088.0 1408.734709 1416.608377
70 9216.0 1408.445694 1426.463240
71 9344.0 1401.157296 1424.317926
72 9472.0 1400.749716 1433.610932
73 9600.0 1396.245319 1434.113743
74 9728.0 1403.927134 1443.520543
75 9856.0 1414.144503 1440.300241
76 9984.0 1399.650339 1453.908146
77 10112.0 1412.445473 1453.935987
78 10240.0 1420.807912 1469.366252
79 10368.0 1414.477218 1464.718900
80 10496.0 1410.961565 1467.770788
81 10624.0 1415.024626 1471.596971
82 10752.0 1404.459460 1472.029760
83 10880.0 1402.048748 1480.074671
84 11008.0 1420.339111 1478.079955
85 11136.0 1423.227181 1487.202802
86 11264.0 1429.997443 1486.543637
87 11392.0 1415.355421 1491.343381
88 11520.0 1420.799226 1496.770146
89 11648.0 1428.473755 1499.078455
90 11776.0 1432.495562 1501.496171
91 11904.0 1445.336890 1506.354697
92 12032.0 1421.376653 1509.076700
93 12160.0 1415.346263 1512.239286
94 12288.0 1434.695978 1393.017085
95 12416.0 1446.796202 1390.780577
96 12544.0 1442.861623 1393.799419
97 12672.0 1447.425889 1392.905132
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 23.439 seconds)