Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](
y,
x,
x.stride(0),
y.stride(0),
n_rows,
n_cols,
)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg``
line_names=[
"Triton",
"Torch",
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)
softmax-performance:
N Triton Torch
0 256.0 479.770468 706.593695
1 384.0 615.194080 802.647814
2 512.0 749.793535 913.933299
3 640.0 818.737095 945.501049
4 768.0 886.409087 1024.809768
5 896.0 942.563271 1073.070865
6 1024.0 1009.972601 1110.411882
7 1152.0 1109.736593 611.205032
8 1280.0 1149.246622 666.327066
9 1408.0 1157.357624 726.057777
10 1536.0 1190.335967 779.857145
11 1664.0 1221.566308 813.264965
12 1792.0 1238.883517 855.313901
13 1920.0 1254.335500 907.627576
14 2048.0 1271.282413 958.374183
15 2176.0 1257.844709 978.227191
16 2304.0 1271.626941 1006.718727
17 2432.0 1292.767475 1052.760118
18 2560.0 1303.757909 1084.434702
19 2688.0 1312.450787 1103.231752
20 2816.0 1328.292570 1132.581209
21 2944.0 1327.909870 1164.526753
22 3072.0 1347.945645 1181.496918
23 3200.0 1348.779582 1190.596340
24 3328.0 1360.227059 1223.920806
25 3456.0 1375.430883 1248.010834
26 3584.0 1378.427657 1264.011794
27 3712.0 1388.429405 1268.832684
28 3840.0 1382.800219 1296.894993
29 3968.0 1391.941397 1313.533013
30 4096.0 1393.499116 1323.582592
31 4224.0 1335.517467 1161.617607
32 4352.0 1334.017367 1172.854445
33 4480.0 1351.983075 1183.345587
34 4608.0 1360.569315 1192.949501
35 4736.0 1358.615548 1196.273097
36 4864.0 1376.250879 1220.194342
37 4992.0 1369.804740 1239.342870
38 5120.0 1377.846766 1252.855515
39 5248.0 1372.630100 1257.878069
40 5376.0 1377.091013 1285.385943
41 5504.0 1374.677851 1300.987536
42 5632.0 1388.559586 1312.186073
43 5760.0 1395.323014 1321.232161
44 5888.0 1385.634133 1343.129357
45 6016.0 1400.878236 1351.783092
46 6144.0 1409.504116 1374.706138
47 6272.0 1412.582123 1374.830129
48 6400.0 1416.308105 1387.112226
49 6528.0 1414.357232 1392.843700
50 6656.0 1421.300347 1403.243546
51 6784.0 1410.686937 1414.572790
52 6912.0 1425.062816 1422.551799
53 7040.0 1420.179480 1432.863440
54 7168.0 1424.120120 1432.807212
55 7296.0 1433.424482 1445.159856
56 7424.0 1428.817493 1445.484714
57 7552.0 1425.279046 1457.395667
58 7680.0 1435.609372 1462.367202
59 7808.0 1435.521339 1464.095503
60 7936.0 1432.813211 1468.242783
61 8064.0 1436.259778 1474.360094
62 8192.0 1436.594199 1485.192349
63 8320.0 1387.616716 1401.658167
64 8448.0 1375.007926 1404.878882
65 8576.0 1396.826860 1395.295374
66 8704.0 1393.167748 1400.171516
67 8832.0 1383.429424 1406.578434
68 8960.0 1394.484567 1410.950535
69 9088.0 1408.025848 1414.321893
70 9216.0 1403.744547 1423.150707
71 9344.0 1396.534523 1425.285846
72 9472.0 1400.110110 1434.254558
73 9600.0 1397.109482 1434.560084
74 9728.0 1402.462574 1443.775893
75 9856.0 1417.608266 1443.670254
76 9984.0 1395.868631 1452.966168
77 10112.0 1411.732670 1456.107865
78 10240.0 1421.074575 1464.200791
79 10368.0 1414.810323 1465.693470
80 10496.0 1413.422695 1468.669258
81 10624.0 1416.583788 1467.789451
82 10752.0 1403.638865 1474.370820
83 10880.0 1396.411411 1483.058775
84 11008.0 1416.968751 1476.007517
85 11136.0 1423.401402 1485.623799
86 11264.0 1427.614400 1485.060508
87 11392.0 1417.122457 1491.116666
88 11520.0 1420.941401 1496.517085
89 11648.0 1423.856096 1496.308209
90 11776.0 1431.926034 1503.709021
91 11904.0 1445.548237 1504.740838
92 12032.0 1422.829078 1508.462405
93 12160.0 1418.697132 1510.021879
94 12288.0 1433.864941 1393.731558
95 12416.0 1447.475169 1392.292481
96 12544.0 1443.415647 1393.140260
97 12672.0 1447.414912 1392.943009
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 23.113 seconds)