Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
    if kernel is None:
        kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                       num_stages=num_stages, num_warps=num_warps, grid=(1, ))
        kernel._init_handles()
        n_regs = kernel.n_regs
        size_smem = kernel.metadata.shared
        if is_hip():
            # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
            # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
            # ISA SECTION (3.6.4 for CDNA3)
            # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
            # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
            # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
            # not required to be equal numbers of both types.
            if is_cdna():
                NUM_GPRS = NUM_REGS * 2

            # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
            # When we divide this number with WARP_SIZE we get maximum number of waves that can
            # execute on a CU (multi-processor)  in parallel.
            MAX_NUM_THREADS = properties["max_threads_per_sm"]
            max_num_waves = MAX_NUM_THREADS // WARP_SIZE
            occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
        else:
            occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
        occupancy = min(occupancy, SIZE_SMEM // size_smem)
        num_programs = NUM_SM * occupancy
        kernels[BLOCK_SIZE] = (kernel, num_programs)

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](
        y,
        x,
        x.stride(0),
        y.stride(0),
        n_rows,
        n_cols,
    )
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch'],  # possible values for `line_arg``
        line_names=[
            "Triton",
            "Torch",
        ],  # label name for the lines
        styles=[('blue', '-'), ('green', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device='cuda', dtype=torch.float32)
    stream = torch.cuda.Stream()
    torch.cuda.set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch
0     256.0   482.498358   711.511792
1     384.0   612.573735   808.472705
2     512.0   758.495050   923.641285
3     640.0   790.603195   959.212554
4     768.0   871.435116  1031.371211
5     896.0   940.857273  1070.966370
6    1024.0   989.547335  1122.259997
7    1152.0  1096.626914   616.809145
8    1280.0  1138.241978   669.410766
9    1408.0  1169.432455   728.789692
10   1536.0  1196.443618   784.047309
11   1664.0  1211.041930   814.812512
12   1792.0  1240.661445   865.213254
13   1920.0  1257.413544   909.936735
14   2048.0  1273.694749   964.366886
15   2176.0  1259.845494   979.716721
16   2304.0  1269.187113  1018.440821
17   2432.0  1292.623351  1055.037690
18   2560.0  1301.035556  1088.540121
19   2688.0  1317.901162  1101.029024
20   2816.0  1319.432039  1132.792641
21   2944.0  1323.424460  1167.601032
22   3072.0  1346.964776  1180.961447
23   3200.0  1349.377686  1193.279129
24   3328.0  1356.050334  1222.964736
25   3456.0  1374.131214  1247.809509
26   3584.0  1379.397621  1264.466261
27   3712.0  1381.918164  1270.351388
28   3840.0  1388.263698  1304.229462
29   3968.0  1391.378186  1314.281169
30   4096.0  1398.618657  1327.886263
31   4224.0  1336.338279  1158.148950
32   4352.0  1335.215145  1174.842390
33   4480.0  1351.342805  1181.971415
34   4608.0  1361.834294  1197.288075
35   4736.0  1361.353857  1200.956438
36   4864.0  1374.497889  1220.513747
37   4992.0  1368.643927  1236.516519
38   5120.0  1372.897146  1253.604190
39   5248.0  1376.623384  1257.893014
40   5376.0  1378.088534  1284.726063
41   5504.0  1380.019527  1302.427824
42   5632.0  1388.763714  1315.360764
43   5760.0  1394.556124  1326.662645
44   5888.0  1394.383023  1341.235638
45   6016.0  1403.664729  1353.180854
46   6144.0  1407.436019  1376.194895
47   6272.0  1413.194071  1372.058381
48   6400.0  1415.938479  1389.893773
49   6528.0  1418.507806  1395.789421
50   6656.0  1421.536117  1402.685194
51   6784.0  1415.175974  1413.458355
52   6912.0  1424.702832  1425.027544
53   7040.0  1424.315563  1432.601697
54   7168.0  1427.985948  1432.608231
55   7296.0  1430.571219  1444.109612
56   7424.0  1432.281141  1447.565928
57   7552.0  1428.105407  1452.964603
58   7680.0  1435.930638  1460.016635
59   7808.0  1433.476972  1463.888915
60   7936.0  1436.901944  1466.849972
61   8064.0  1437.661622  1476.050106
62   8192.0  1439.143527  1486.742421
63   8320.0  1390.721845  1403.209510
64   8448.0  1384.042960  1404.619595
65   8576.0  1393.183027  1394.430819
66   8704.0  1391.872494  1398.063166
67   8832.0  1382.844820  1403.086522
68   8960.0  1395.034310  1411.388368
69   9088.0  1408.734709  1416.608377
70   9216.0  1408.445694  1426.463240
71   9344.0  1401.157296  1424.317926
72   9472.0  1400.749716  1433.610932
73   9600.0  1396.245319  1434.113743
74   9728.0  1403.927134  1443.520543
75   9856.0  1414.144503  1440.300241
76   9984.0  1399.650339  1453.908146
77  10112.0  1412.445473  1453.935987
78  10240.0  1420.807912  1469.366252
79  10368.0  1414.477218  1464.718900
80  10496.0  1410.961565  1467.770788
81  10624.0  1415.024626  1471.596971
82  10752.0  1404.459460  1472.029760
83  10880.0  1402.048748  1480.074671
84  11008.0  1420.339111  1478.079955
85  11136.0  1423.227181  1487.202802
86  11264.0  1429.997443  1486.543637
87  11392.0  1415.355421  1491.343381
88  11520.0  1420.799226  1496.770146
89  11648.0  1428.473755  1499.078455
90  11776.0  1432.495562  1501.496171
91  11904.0  1445.336890  1506.354697
92  12032.0  1421.376653  1509.076700
93  12160.0  1415.346263  1512.239286
94  12288.0  1434.695978  1393.017085
95  12416.0  1446.796202  1390.780577
96  12544.0  1442.861623  1393.799419
97  12672.0  1447.425889  1392.905132
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 23.439 seconds)

Gallery generated by Sphinx-Gallery