Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](
        y,
        x,
        x.stride(0),
        y.stride(0),
        n_rows,
        n_cols,
    )
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch'],  # possible values for `line_arg``
        line_names=[
            "Triton",
            "Torch",
        ],  # label name for the lines
        styles=[('blue', '-'), ('green', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device='cuda', dtype=torch.float32)
    stream = torch.cuda.Stream()
    torch.cuda.set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch
0     256.0   479.770468   706.593695
1     384.0   615.194080   802.647814
2     512.0   749.793535   913.933299
3     640.0   818.737095   945.501049
4     768.0   886.409087  1024.809768
5     896.0   942.563271  1073.070865
6    1024.0  1009.972601  1110.411882
7    1152.0  1109.736593   611.205032
8    1280.0  1149.246622   666.327066
9    1408.0  1157.357624   726.057777
10   1536.0  1190.335967   779.857145
11   1664.0  1221.566308   813.264965
12   1792.0  1238.883517   855.313901
13   1920.0  1254.335500   907.627576
14   2048.0  1271.282413   958.374183
15   2176.0  1257.844709   978.227191
16   2304.0  1271.626941  1006.718727
17   2432.0  1292.767475  1052.760118
18   2560.0  1303.757909  1084.434702
19   2688.0  1312.450787  1103.231752
20   2816.0  1328.292570  1132.581209
21   2944.0  1327.909870  1164.526753
22   3072.0  1347.945645  1181.496918
23   3200.0  1348.779582  1190.596340
24   3328.0  1360.227059  1223.920806
25   3456.0  1375.430883  1248.010834
26   3584.0  1378.427657  1264.011794
27   3712.0  1388.429405  1268.832684
28   3840.0  1382.800219  1296.894993
29   3968.0  1391.941397  1313.533013
30   4096.0  1393.499116  1323.582592
31   4224.0  1335.517467  1161.617607
32   4352.0  1334.017367  1172.854445
33   4480.0  1351.983075  1183.345587
34   4608.0  1360.569315  1192.949501
35   4736.0  1358.615548  1196.273097
36   4864.0  1376.250879  1220.194342
37   4992.0  1369.804740  1239.342870
38   5120.0  1377.846766  1252.855515
39   5248.0  1372.630100  1257.878069
40   5376.0  1377.091013  1285.385943
41   5504.0  1374.677851  1300.987536
42   5632.0  1388.559586  1312.186073
43   5760.0  1395.323014  1321.232161
44   5888.0  1385.634133  1343.129357
45   6016.0  1400.878236  1351.783092
46   6144.0  1409.504116  1374.706138
47   6272.0  1412.582123  1374.830129
48   6400.0  1416.308105  1387.112226
49   6528.0  1414.357232  1392.843700
50   6656.0  1421.300347  1403.243546
51   6784.0  1410.686937  1414.572790
52   6912.0  1425.062816  1422.551799
53   7040.0  1420.179480  1432.863440
54   7168.0  1424.120120  1432.807212
55   7296.0  1433.424482  1445.159856
56   7424.0  1428.817493  1445.484714
57   7552.0  1425.279046  1457.395667
58   7680.0  1435.609372  1462.367202
59   7808.0  1435.521339  1464.095503
60   7936.0  1432.813211  1468.242783
61   8064.0  1436.259778  1474.360094
62   8192.0  1436.594199  1485.192349
63   8320.0  1387.616716  1401.658167
64   8448.0  1375.007926  1404.878882
65   8576.0  1396.826860  1395.295374
66   8704.0  1393.167748  1400.171516
67   8832.0  1383.429424  1406.578434
68   8960.0  1394.484567  1410.950535
69   9088.0  1408.025848  1414.321893
70   9216.0  1403.744547  1423.150707
71   9344.0  1396.534523  1425.285846
72   9472.0  1400.110110  1434.254558
73   9600.0  1397.109482  1434.560084
74   9728.0  1402.462574  1443.775893
75   9856.0  1417.608266  1443.670254
76   9984.0  1395.868631  1452.966168
77  10112.0  1411.732670  1456.107865
78  10240.0  1421.074575  1464.200791
79  10368.0  1414.810323  1465.693470
80  10496.0  1413.422695  1468.669258
81  10624.0  1416.583788  1467.789451
82  10752.0  1403.638865  1474.370820
83  10880.0  1396.411411  1483.058775
84  11008.0  1416.968751  1476.007517
85  11136.0  1423.401402  1485.623799
86  11264.0  1427.614400  1485.060508
87  11392.0  1417.122457  1491.116666
88  11520.0  1420.941401  1496.517085
89  11648.0  1423.856096  1496.308209
90  11776.0  1431.926034  1503.709021
91  11904.0  1445.548237  1504.740838
92  12032.0  1422.829078  1508.462405
93  12160.0  1418.697132  1510.021879
94  12288.0  1433.864941  1393.731558
95  12416.0  1447.475169  1392.292481
96  12544.0  1443.415647  1393.140260
97  12672.0  1447.414912  1392.943009
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 23.113 seconds)

Gallery generated by Sphinx-Gallery