Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software piepling stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
if kernel is None:
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
kernels[BLOCK_SIZE] = (kernel, num_programs)
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](
y,
x,
x.stride(0),
y.stride(0),
n_rows,
n_cols,
)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg``
line_names=[
"Triton",
"Torch",
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)
![02 fused softmax](../../_images/sphx_glr_02-fused-softmax_001.png)
softmax-performance:
N Triton Torch
0 256.0 479.971415 702.139663
1 384.0 613.489806 804.742475
2 512.0 757.308365 923.473008
3 640.0 792.051811 961.830972
4 768.0 872.367065 1033.113093
5 896.0 932.895877 1067.405188
6 1024.0 989.438959 1129.162127
7 1152.0 1112.417652 614.940583
8 1280.0 1156.447348 669.547786
9 1408.0 1156.167100 725.404231
10 1536.0 1187.712733 782.654510
11 1664.0 1218.862046 815.667263
12 1792.0 1233.590604 854.532308
13 1920.0 1255.456581 908.113861
14 2048.0 1270.281625 958.344204
15 2176.0 1257.045464 976.763217
16 2304.0 1266.731133 1011.394436
17 2432.0 1290.127580 1053.298897
18 2560.0 1296.583449 1082.990627
19 2688.0 1314.693382 1104.309499
20 2816.0 1332.546636 1133.581467
21 2944.0 1327.624345 1169.870031
22 3072.0 1353.640841 1187.736957
23 3200.0 1353.220618 1192.174610
24 3328.0 1363.641333 1227.636505
25 3456.0 1373.281889 1247.479233
26 3584.0 1375.770007 1258.537236
27 3712.0 1378.913165 1270.766004
28 3840.0 1389.251747 1298.249544
29 3968.0 1390.697837 1313.064552
30 4096.0 1402.824581 1325.247436
31 4224.0 1330.811484 1160.612852
32 4352.0 1334.043866 1173.129910
33 4480.0 1349.110557 1184.574665
34 4608.0 1359.838974 1191.423570
35 4736.0 1363.032668 1197.348060
36 4864.0 1379.382583 1218.583994
37 4992.0 1375.980047 1237.385693
38 5120.0 1371.203812 1250.252107
39 5248.0 1375.119936 1257.147176
40 5376.0 1376.502713 1287.848724
41 5504.0 1381.458560 1298.379544
42 5632.0 1390.390942 1315.963848
43 5760.0 1391.360785 1321.663121
44 5888.0 1386.578913 1342.204717
45 6016.0 1399.833504 1354.020567
46 6144.0 1411.381737 1374.478180
47 6272.0 1417.616876 1376.650571
48 6400.0 1415.051396 1385.060043
49 6528.0 1412.316190 1392.812077
50 6656.0 1423.865574 1403.358983
51 6784.0 1412.134443 1417.108951
52 6912.0 1431.047456 1426.964999
53 7040.0 1424.970812 1430.834852
54 7168.0 1428.934183 1436.517200
55 7296.0 1435.660930 1442.566029
56 7424.0 1433.571928 1446.973560
57 7552.0 1428.555765 1456.004828
58 7680.0 1435.309231 1461.972743
59 7808.0 1434.510651 1466.803134
60 7936.0 1437.678716 1465.534748
61 8064.0 1436.084725 1473.054894
62 8192.0 1434.386853 1482.412155
63 8320.0 1391.891694 1399.981688
64 8448.0 1380.358548 1406.076925
65 8576.0 1394.888859 1396.690410
66 8704.0 1390.768797 1396.130286
67 8832.0 1382.413603 1404.592061
68 8960.0 1395.243342 1412.769341
69 9088.0 1411.527259 1417.466346
70 9216.0 1400.909976 1426.596086
71 9344.0 1402.486354 1426.510050
72 9472.0 1398.834902 1430.743271
73 9600.0 1393.100801 1433.705458
74 9728.0 1399.536009 1444.842892
75 9856.0 1415.974631 1442.096158
76 9984.0 1400.929888 1451.247395
77 10112.0 1410.581457 1452.466523
78 10240.0 1416.177667 1469.316875
79 10368.0 1417.322739 1460.591910
80 10496.0 1412.850281 1468.017150
81 10624.0 1412.680008 1468.076393
82 10752.0 1404.906930 1471.349144
83 10880.0 1398.881206 1481.058594
84 11008.0 1421.418331 1476.153870
85 11136.0 1422.667248 1483.317770
86 11264.0 1429.545510 1489.454579
87 11392.0 1414.557658 1488.380936
88 11520.0 1420.718429 1497.344137
89 11648.0 1428.797265 1500.208617
90 11776.0 1433.342691 1500.593549
91 11904.0 1444.444772 1507.989222
92 12032.0 1422.775520 1505.912506
93 12160.0 1416.091598 1512.305483
94 12288.0 1433.751886 1389.649613
95 12416.0 1449.014173 1390.461567
96 12544.0 1438.416199 1393.102407
97 12672.0 1446.080469 1391.976252
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 23.440 seconds)