Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software piepling stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
    if kernel is None:
        kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                       num_stages=num_stages, num_warps=num_warps, grid=(1, ))
        kernel._init_handles()
        n_regs = kernel.n_regs
        size_smem = kernel.metadata.shared
        if is_hip():
            # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
            # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
            # ISA SECTION (3.6.4 for CDNA3)
            # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
            # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
            # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
            # not required to be equal numbers of both types.
            if is_cdna():
                NUM_GPRS = NUM_REGS * 2

            # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
            # When we divide this number with WARP_SIZE we get maximum number of waves that can
            # execute on a CU (multi-processor)  in parallel.
            MAX_NUM_THREADS = properties["max_threads_per_sm"]
            max_num_waves = MAX_NUM_THREADS // WARP_SIZE
            occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
        else:
            occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
        occupancy = min(occupancy, SIZE_SMEM // size_smem)
        num_programs = NUM_SM * occupancy
        kernels[BLOCK_SIZE] = (kernel, num_programs)

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](
        y,
        x,
        x.stride(0),
        y.stride(0),
        n_rows,
        n_cols,
    )
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch'],  # possible values for `line_arg``
        line_names=[
            "Triton",
            "Torch",
        ],  # label name for the lines
        styles=[('blue', '-'), ('green', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device='cuda', dtype=torch.float32)
    stream = torch.cuda.Stream()
    torch.cuda.set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch
0     256.0   475.867106   685.041521
1     384.0   613.032958   822.409391
2     512.0   750.296568   930.035577
3     640.0   790.475418   960.227448
4     768.0   883.304113  1036.036596
5     896.0   941.239815  1074.005203
6    1024.0   998.411163  1127.358004
7    1152.0  1107.388312   614.316841
8    1280.0  1152.085426   669.896840
9    1408.0  1165.018145   725.193381
10   1536.0  1191.426267   787.302034
11   1664.0  1219.815572   817.880919
12   1792.0  1241.430732   855.826960
13   1920.0  1244.872450   909.264196
14   2048.0  1282.354351   959.172880
15   2176.0  1257.858266   977.060489
16   2304.0  1261.732018  1011.007650
17   2432.0  1296.925154  1052.185869
18   2560.0  1300.174369  1080.226421
19   2688.0  1306.572215  1104.797021
20   2816.0  1323.858092  1130.708708
21   2944.0  1324.525524  1164.800810
22   3072.0  1353.594189  1186.090959
23   3200.0  1356.083467  1196.341652
24   3328.0  1355.409148  1228.936082
25   3456.0  1373.449908  1244.932516
26   3584.0  1374.465036  1262.930784
27   3712.0  1384.367178  1271.709169
28   3840.0  1390.638549  1298.586952
29   3968.0  1395.001557  1313.332740
30   4096.0  1395.242811  1327.553451
31   4224.0  1338.042971  1159.426661
32   4352.0  1337.867830  1173.499677
33   4480.0  1351.256710  1186.240790
34   4608.0  1362.245580  1193.560744
35   4736.0  1361.895043  1200.267993
36   4864.0  1375.085700  1224.542615
37   4992.0  1374.815808  1233.950428
38   5120.0  1371.161147  1252.235018
39   5248.0  1373.591566  1260.546573
40   5376.0  1372.060019  1289.972913
41   5504.0  1382.841033  1298.268685
42   5632.0  1388.820712  1314.290414
43   5760.0  1394.451111  1323.098922
44   5888.0  1393.764471  1345.659711
45   6016.0  1403.136516  1352.609039
46   6144.0  1413.038519  1376.960691
47   6272.0  1415.663449  1377.669242
48   6400.0  1420.420451  1388.910568
49   6528.0  1414.501754  1392.639934
50   6656.0  1421.213816  1404.372093
51   6784.0  1409.944841  1413.202611
52   6912.0  1425.560949  1426.204882
53   7040.0  1418.873044  1432.946465
54   7168.0  1428.092036  1436.630975
55   7296.0  1432.287131  1442.002162
56   7424.0  1430.511057  1448.902940
57   7552.0  1427.121774  1453.109409
58   7680.0  1429.168391  1459.199951
59   7808.0  1430.275708  1463.962681
60   7936.0  1438.461546  1468.927692
61   8064.0  1439.384861  1473.002401
62   8192.0  1437.314252  1485.506245
63   8320.0  1387.137297  1400.745800
64   8448.0  1382.388319  1402.241194
65   8576.0  1392.371869  1396.019308
66   8704.0  1389.304526  1399.251918
67   8832.0  1383.881687  1406.039248
68   8960.0  1396.330310  1410.869304
69   9088.0  1409.346182  1414.738999
70   9216.0  1401.166839  1425.284474
71   9344.0  1401.304834  1421.682265
72   9472.0  1400.880383  1433.803861
73   9600.0  1393.515811  1432.734961
74   9728.0  1398.206330  1444.328796
75   9856.0  1416.068255  1441.556529
76   9984.0  1400.436350  1452.193004
77  10112.0  1411.846864  1457.736544
78  10240.0  1417.073333  1464.301337
79  10368.0  1410.655528  1465.980532
80  10496.0  1415.096305  1467.142139
81  10624.0  1409.992543  1465.182604
82  10752.0  1407.951976  1468.473886
83  10880.0  1397.792114  1481.210527
84  11008.0  1420.737915  1479.885215
85  11136.0  1421.767254  1485.501686
86  11264.0  1427.069414  1489.076358
87  11392.0  1416.867578  1488.931526
88  11520.0  1422.872495  1495.385639
89  11648.0  1424.532358  1498.630444
90  11776.0  1429.061198  1500.870467
91  11904.0  1446.545948  1505.074098
92  12032.0  1423.495363  1508.793236
93  12160.0  1418.899333  1510.691556
94  12288.0  1435.003027  1395.010009
95  12416.0  1448.855159  1388.660747
96  12544.0  1441.054534  1395.206547
97  12672.0  1445.622702  1391.422853
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 24.675 seconds)

Gallery generated by Sphinx-Gallery