Note
Go to the end to download the full example code.
Fused Softmax¶
In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.
In doing so, you will learn about:
The benefits of kernel fusion for bandwidth-bound operations.
Reduction operators in Triton.
Motivations¶
Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
When implemented naively in PyTorch, computing y = naive_softmax(x)
for \(x \in R^{M \times N}\)
requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements.
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads
X once and does all the necessary computations on-chip.
Doing so would require reading and writing back only \(MN\) bytes, so we could
expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)).
The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically
but, as we will see later, it is still far from ideal.
Compute Kernel¶
Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
if kernel is None:
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
kernels[BLOCK_SIZE] = (kernel, num_programs)
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](
y,
x,
x.stride(0),
y.stride(0),
n_rows,
n_cols,
)
return y
Unit Test¶
We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
As expected, the results are identical.
Benchmark¶
Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows.
We will then compare its performance against (1) torch.softmax
and (2) the naive_softmax
defined above.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg``
line_names=[
"Triton",
"Torch",
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)
softmax-performance:
N Triton Torch
0 256.0 475.072072 695.065373
1 384.0 610.079401 811.471723
2 512.0 754.342040 919.712541
3 640.0 789.265309 962.246259
4 768.0 874.759796 1017.241938
5 896.0 929.108709 1076.338957
6 1024.0 998.736970 1115.438764
7 1152.0 1116.458513 610.971093
8 1280.0 1145.734286 669.488146
9 1408.0 1164.656389 725.404316
10 1536.0 1183.981189 780.113270
11 1664.0 1216.760870 813.923431
12 1792.0 1241.212053 855.583078
13 1920.0 1247.110758 908.744014
14 2048.0 1273.713210 959.642350
15 2176.0 1267.868850 976.754644
16 2304.0 1265.367204 1009.461646
17 2432.0 1296.738518 1058.765063
18 2560.0 1307.478463 1089.619289
19 2688.0 1316.725541 1108.282126
20 2816.0 1320.519372 1128.686120
21 2944.0 1328.761729 1166.223767
22 3072.0 1354.626436 1184.251325
23 3200.0 1351.696997 1195.375616
24 3328.0 1349.881360 1222.289058
25 3456.0 1373.813758 1251.674668
26 3584.0 1379.000592 1261.875056
27 3712.0 1380.726388 1270.162908
28 3840.0 1383.564658 1299.695441
29 3968.0 1386.354177 1316.159411
30 4096.0 1395.362556 1324.336015
31 4224.0 1339.441650 1161.949714
32 4352.0 1336.112162 1176.210802
33 4480.0 1352.418767 1182.050267
34 4608.0 1360.456061 1197.728907
35 4736.0 1361.613092 1197.270977
36 4864.0 1373.932199 1224.471962
37 4992.0 1372.907083 1233.295258
38 5120.0 1372.078005 1253.689897
39 5248.0 1374.711405 1258.907336
40 5376.0 1379.053395 1284.359818
41 5504.0 1379.105309 1299.663321
42 5632.0 1392.329967 1317.736600
43 5760.0 1395.696932 1322.492882
44 5888.0 1386.137747 1342.631723
45 6016.0 1398.575315 1356.818526
46 6144.0 1410.758058 1372.205820
47 6272.0 1415.495495 1375.174746
48 6400.0 1419.705022 1386.753307
49 6528.0 1410.242515 1393.505541
50 6656.0 1419.182757 1405.481808
51 6784.0 1410.043857 1415.268880
52 6912.0 1424.175406 1422.246445
53 7040.0 1421.442904 1431.216081
54 7168.0 1430.812616 1433.340617
55 7296.0 1428.225472 1439.866352
56 7424.0 1427.681062 1444.854790
57 7552.0 1425.888352 1455.406250
58 7680.0 1436.259743 1459.073758
59 7808.0 1434.987093 1467.046144
60 7936.0 1436.230019 1466.565096
61 8064.0 1433.349451 1474.872297
62 8192.0 1437.077708 1483.296854
63 8320.0 1387.622820 1402.245362
64 8448.0 1380.392190 1405.798260
65 8576.0 1394.235566 1397.321986
66 8704.0 1394.214828 1401.841116
67 8832.0 1383.898549 1403.830618
68 8960.0 1396.464308 1411.119895
69 9088.0 1407.855438 1416.562388
70 9216.0 1404.644508 1427.684388
71 9344.0 1397.506763 1424.505933
72 9472.0 1401.047450 1436.146824
73 9600.0 1393.785291 1432.004113
74 9728.0 1404.978379 1443.802885
75 9856.0 1416.323565 1441.569281
76 9984.0 1400.441084 1452.598944
77 10112.0 1416.025004 1451.695136
78 10240.0 1418.700816 1466.587721
79 10368.0 1413.426443 1461.543005
80 10496.0 1418.130103 1466.513470
81 10624.0 1410.584558 1467.802730
82 10752.0 1406.810652 1473.657636
83 10880.0 1401.680762 1477.683590
84 11008.0 1418.090370 1477.688996
85 11136.0 1422.693532 1487.034826
86 11264.0 1428.657991 1487.786980
87 11392.0 1420.074659 1491.363268
88 11520.0 1423.918123 1492.221961
89 11648.0 1424.807957 1499.729021
90 11776.0 1430.914832 1501.918359
91 11904.0 1442.172808 1507.845030
92 12032.0 1421.892405 1509.618623
93 12160.0 1417.047651 1511.446571
94 12288.0 1436.674819 1392.492332
95 12416.0 1447.834085 1388.422628
96 12544.0 1441.263398 1393.344051
97 12672.0 1447.460023 1391.969298
- In the above plot, we can see that:
Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
Triton is noticeably faster than
torch.softmax
– in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.
Total running time of the script: (0 minutes 23.307 seconds)