Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch', 'naive_softmax'],  # possible values for `line_arg``
        line_names=["Triton", "Torch", "Naive Softmax"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'naive_softmax':
        ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N  Triton (GB/s)  Torch (GB/s)  Naive Softmax (GB/s)
0     256.0     510.900817    684.034723            204.606980
1     384.0     697.341343    827.492956            262.267095
2     512.0     819.820050    929.864816            300.446734
3     640.0     835.708543    930.618855            331.149453
4     768.0     903.284280    975.700808            348.676197
5     896.0     970.243281   1037.334523            353.036415
6    1024.0    1029.429730   1066.317051            353.459779
7    1152.0    1021.228447   1062.397233            349.354566
8    1280.0    1074.207536   1103.025853            349.596662
9    1408.0    1118.738166   1141.892497            342.499494
10   1536.0    1156.150036   1167.388179            333.139471
11   1664.0    1177.368198   1181.346324            328.853473
12   1792.0    1201.453607   1198.966224            325.326989
13   1920.0    1222.657756   1225.247070            324.789659
14   2048.0    1251.832512   1243.778438            324.716321
15   2176.0    1187.590501    962.926221            325.437062
16   2304.0    1200.578854    998.837668            325.573179
17   2432.0    1224.217129   1035.799528            326.828673
18   2560.0    1246.987692   1067.717125            328.437933
19   2688.0    1258.053034   1097.233603            329.478688
20   2816.0    1273.428691   1126.433081            329.382636
21   2944.0    1295.477074   1143.504393            331.718311
22   3072.0    1312.095487   1168.694605            333.457785
23   3200.0    1322.751477   1172.883208            334.758503
24   3328.0    1323.886893   1204.058894            336.533953
25   3456.0    1336.665266   1220.810824            336.590263
26   3584.0    1336.650032   1246.945826            338.498935
27   3712.0    1353.907427   1267.192670            340.228367
28   3840.0    1360.666119   1283.879385            340.195283
29   3968.0    1363.450642   1298.027416            341.419063
30   4096.0    1372.404397   1321.535674            338.801099
31   4224.0    1341.455517   1275.263321            343.544313
32   4352.0    1349.559378   1297.257440            345.341969
33   4480.0    1354.570700   1320.356703            345.791746
34   4608.0    1362.976118   1329.363548            346.906994
35   4736.0    1362.540380   1344.572430            347.802967
36   4864.0    1374.739426   1359.741254            349.057695
37   4992.0    1373.412212   1372.086346            350.069549
38   5120.0    1379.941722   1388.540488            351.019484
39   5248.0    1384.835223   1357.321931            351.789325
40   5376.0    1382.788336   1360.687883            351.765816
41   5504.0    1392.651841   1378.545040            353.675385
42   5632.0    1400.844248   1401.026157            352.712010
43   5760.0    1398.069873   1406.929544            354.891598
44   5888.0    1394.843532   1413.544917            354.581980
45   6016.0    1408.350446   1412.993852            356.494473
46   6144.0    1411.796440   1433.942043            356.598434
47   6272.0    1413.284037   1386.352490            357.839720
48   6400.0    1418.551013   1415.706793            358.213597
49   6528.0    1416.136001   1409.005441            359.145905
50   6656.0    1421.964313   1428.806108            359.431544
51   6784.0    1424.313896   1438.168603            360.481321
52   6912.0    1427.551740   1444.532882            361.015430
53   7040.0    1426.426674   1444.339218            361.071563
54   7168.0    1426.845003   1463.155522            361.723334
55   7296.0    1432.915566   1084.862758            362.460948
56   7424.0    1436.745035   1101.708302            362.543707
57   7552.0    1431.827821   1109.641208            363.345322
58   7680.0    1438.284936   1120.618794            363.878739
59   7808.0    1433.864859   1129.847942            364.313138
60   7936.0    1437.386955   1139.758932            364.296003
61   8064.0    1439.381699   1150.755196            365.547956
62   8192.0    1436.434208   1150.502658            363.570523
63   8320.0    1385.396717   1118.469724            361.596208
64   8448.0    1388.395345   1123.937199            362.304952
65   8576.0    1386.844240   1129.162430            363.080889
66   8704.0    1382.325324   1131.834878            364.391603
67   8832.0    1392.464260   1134.597639            365.151737
68   8960.0    1387.346098   1138.629592            365.616575
69   9088.0    1400.731812   1136.415582            366.719898
70   9216.0    1404.470486   1144.726716            367.590386
71   9344.0    1395.009607   1417.024056            367.680527
72   9472.0    1399.315942   1432.924081            369.041928
73   9600.0    1399.850545   1432.065909            369.078421
74   9728.0    1401.354024   1437.668889            369.759007
75   9856.0    1400.713023   1440.987743            369.476505
76   9984.0    1395.607271   1450.217538            370.653465
77  10112.0    1404.332229   1453.534832            371.402008
78  10240.0    1405.756048   1465.745492            371.864641
79  10368.0    1413.872458   1458.620045            369.855177
80  10496.0    1408.593674   1465.480603            370.120186
81  10624.0    1406.706339   1466.428601            370.269026
82  10752.0    1396.862710   1474.077927            371.599526
83  10880.0    1398.381375   1479.584913            371.030487
84  11008.0    1417.070076   1476.265525            372.195648
85  11136.0    1421.887138   1483.348694            373.121257
86  11264.0    1413.711924   1487.645065            372.651256
87  11392.0    1418.718648   1487.639699            374.181346
88  11520.0    1414.981824   1496.738394            374.280315
89  11648.0    1421.404529   1499.440829            374.719809
90  11776.0    1433.882755   1500.211287            375.023855
91  11904.0    1432.183082   1510.145846            375.201705
92  12032.0    1419.039070   1508.660312            375.483108
93  12160.0    1415.781095   1517.099441            376.112036
94  12288.0    1431.410014   1423.862108            376.256578
95  12416.0    1434.436147   1397.538602            375.300615
96  12544.0    1440.064246   1396.179869            375.252875
97  12672.0    1435.177197   1393.550136            375.114948
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 35.083 seconds)

Gallery generated by Sphinx-Gallery