Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch', 'naive_softmax'],  # possible values for `line_arg``
        line_names=["Triton", "Torch", "Naive Softmax"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'naive_softmax':
        ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch  Naive Softmax
0     256.0   476.804403   687.969985     205.930110
1     384.0   662.670682   829.193537     261.757997
2     512.0   808.033722   919.400934     303.900680
3     640.0   888.341542   920.608567     332.271392
4     768.0   954.119287   987.451652     348.446607
5     896.0  1023.806188  1037.667343     352.495114
6    1024.0  1048.242990  1066.704213     353.397081
7    1152.0  1094.665835  1074.035779     347.082055
8    1280.0  1124.055581  1101.221035     349.420908
9    1408.0  1169.273383  1131.067580     340.899233
10   1536.0  1195.552279  1163.666096     333.180475
11   1664.0  1215.910110  1181.097305     328.762315
12   1792.0  1226.884782  1192.264547     325.771241
13   1920.0  1264.653860  1226.129247     325.122991
14   2048.0  1268.681502  1241.906960     324.411494
15   2176.0  1238.507833   962.538377     325.509433
16   2304.0  1255.158313  1005.811349     325.721077
17   2432.0  1270.893003  1033.507096     326.466147
18   2560.0  1279.614577  1065.187282     327.636496
19   2688.0  1290.785388  1095.859151     329.518529
20   2816.0  1311.566864  1125.159044     329.534915
21   2944.0  1310.162737  1143.963961     331.324179
22   3072.0  1312.527536  1167.995535     332.852419
23   3200.0  1337.116341  1172.826700     334.788098
24   3328.0  1346.586148  1199.001468     336.636970
25   3456.0  1346.676162  1220.157198     336.702317
26   3584.0  1364.103092  1242.841693     338.083304
27   3712.0  1362.545426  1264.904260     339.966692
28   3840.0  1373.942879  1279.181508     339.964787
29   3968.0  1368.919799  1300.300416     341.138220
30   4096.0  1388.865886  1313.394213     338.692256
31   4224.0  1331.704457  1276.336083     342.825111
32   4352.0  1340.376297  1300.092201     345.067957
33   4480.0  1344.071802  1319.910247     345.603724
34   4608.0  1354.015043  1336.353790     347.230797
35   4736.0  1353.346496  1343.917595     348.036578
36   4864.0  1362.262450  1357.989710     348.806429
37   4992.0  1363.396634  1371.548881     350.055164
38   5120.0  1375.202582  1387.763751     350.669824
39   5248.0  1366.159743  1357.733046     351.448374
40   5376.0  1368.595054  1362.176400     351.862417
41   5504.0  1376.655806  1378.700354     353.535239
42   5632.0  1383.544387  1396.457752     353.052972
43   5760.0  1388.236925  1399.013311     354.972358
44   5888.0  1386.077117  1416.166428     354.762310
45   6016.0  1394.702569  1424.574443     356.753184
46   6144.0  1404.625920  1424.060157     357.037863
47   6272.0  1405.415672  1401.067819     357.581008
48   6400.0  1410.072659  1399.259809     358.479735
49   6528.0  1406.401551  1417.757846     359.125846
50   6656.0  1408.531117  1424.125215     359.705531
51   6784.0  1414.092203  1437.120904     360.233846
52   6912.0  1418.601549  1439.060272     360.702662
53   7040.0  1413.901247  1449.855982     360.842062
54   7168.0  1414.536903  1455.311409     361.385952
55   7296.0  1420.665040  1085.782388     362.602826
56   7424.0  1425.042099  1097.941083     363.062709
57   7552.0  1423.359376  1111.645080     363.673392
58   7680.0  1428.690973  1121.946940     363.604994
59   7808.0  1429.748230  1131.377427     364.253615
60   7936.0  1427.677584  1143.830720     364.436456
61   8064.0  1431.443880  1148.489850     365.107877
62   8192.0  1428.136078  1151.297830     364.276937
63   8320.0  1386.578230  1116.797272     361.596207
64   8448.0  1387.161314  1125.208706     362.175811
65   8576.0  1389.714296  1126.433901     363.269797
66   8704.0  1385.456160  1132.979949     364.311422
67   8832.0  1393.997222  1132.447857     365.316491
68   8960.0  1389.622228  1140.095951     365.976587
69   9088.0  1396.294872  1138.921960     365.806288
70   9216.0  1397.913058  1142.548783     366.956630
71   9344.0  1386.410653  1421.447516     367.548607
72   9472.0  1400.005997  1433.539448     367.969698
73   9600.0  1402.187341  1431.859742     369.109778
74   9728.0  1398.747708  1439.656739     370.074259
75   9856.0  1399.900574  1436.280938     370.160670
76   9984.0  1386.425965  1445.165496     370.848624
77  10112.0  1405.950309  1451.225216     371.118407
78  10240.0  1403.642239  1464.517191     371.293283
79  10368.0  1412.309445  1462.836826     370.090294
80  10496.0  1406.176637  1466.335621     370.884285
81  10624.0  1403.565608  1462.936749     370.175987
82  10752.0  1391.931663  1472.926151     370.890380
83  10880.0  1388.920109  1477.640662     372.235355
84  11008.0  1417.285442  1474.023103     372.620328
85  11136.0  1417.869958  1483.418101     372.664543
86  11264.0  1411.596121  1486.976921     373.067988
87  11392.0  1415.276256  1488.629105     373.829584
88  11520.0  1408.548880  1495.046341     373.896283
89  11648.0  1414.934492  1501.105326     374.458507
90  11776.0  1430.041625  1503.253870     374.704219
91  11904.0  1426.228070  1507.221215     375.490174
92  12032.0  1405.693560  1508.463198     375.721497
93  12160.0  1405.857783  1513.659558     375.673640
94  12288.0  1419.781173  1415.224099     375.875309
95  12416.0  1431.049630  1394.635227     375.065895
96  12544.0  1440.051762  1391.502811     375.095080
97  12672.0  1430.071531  1393.572006     375.312714
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 35.159 seconds)

Gallery generated by Sphinx-Gallery