Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch', 'naive_softmax'],  # possible values for `line_arg``
        line_names=["Triton", "Torch", "Naive Softmax"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'naive_softmax':
        ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch  Naive Softmax
0     256.0   468.452833   689.354972     206.729612
1     384.0   664.989565   831.386445     261.157176
2     512.0   799.942147   915.989562     300.951619
3     640.0   908.348602   914.048137     328.450257
4     768.0   974.912881   986.068762     347.559270
5     896.0  1048.817418  1039.267858     354.104773
6    1024.0  1081.471589  1078.798045     352.362205
7    1152.0  1091.274992  1074.480018     348.015856
8    1280.0  1125.538595  1111.471096     348.531951
9    1408.0  1157.858831  1136.119683     340.733713
10   1536.0  1195.657043  1167.718468     332.915206
11   1664.0  1207.879399  1189.899426     329.786972
12   1792.0  1230.815870  1198.938522     325.148237
13   1920.0  1262.632066  1227.638062     324.015053
14   2048.0  1267.102964  1244.418997     323.964303
15   2176.0  1237.686487   957.730105     325.602275
16   2304.0  1259.949514   998.896887     325.942816
17   2432.0  1269.163095  1033.808767     325.907522
18   2560.0  1289.635212  1067.200670     327.538818
19   2688.0  1296.258330  1100.951208     328.620297
20   2816.0  1310.418663  1124.871670     329.249392
21   2944.0  1317.155750  1147.104803     331.351009
22   3072.0  1324.398011  1170.516868     333.100111
23   3200.0  1340.366390  1169.371643     334.841402
24   3328.0  1347.172950  1199.214184     336.316971
25   3456.0  1347.760514  1223.137307     337.037815
26   3584.0  1362.155307  1242.285015     337.959922
27   3712.0  1365.493484  1260.091486     340.365623
28   3840.0  1367.256816  1284.609492     340.067124
29   3968.0  1371.577329  1297.517040     340.548659
30   4096.0  1382.868047  1314.365476     338.537517
31   4224.0  1330.608799  1273.767340     343.406837
32   4352.0  1340.744877  1298.498835     345.221623
33   4480.0  1342.439833  1315.847594     345.315166
34   4608.0  1356.865069  1333.747696     346.584171
35   4736.0  1357.688674  1344.857231     347.934884
36   4864.0  1368.607496  1363.129797     349.020782
37   4992.0  1369.836397  1373.485521     349.970426
38   5120.0  1377.704960  1388.604593     351.143594
39   5248.0  1373.898009  1352.177877     351.133694
40   5376.0  1373.541617  1369.068368     351.567181
41   5504.0  1376.542504  1372.477797     353.371911
42   5632.0  1391.723901  1396.560963     352.813959
43   5760.0  1394.458110  1406.584675     354.851486
44   5888.0  1389.453850  1413.580976     354.443016
45   6016.0  1402.479845  1413.567158     356.792191
46   6144.0  1408.258224  1420.565745     356.806450
47   6272.0  1412.107307  1400.299806     357.567158
48   6400.0  1411.947123  1409.931759     358.338950
49   6528.0  1416.289773  1419.071086     359.447925
50   6656.0  1415.305316  1433.894799     359.311200
51   6784.0  1417.893859  1430.714533     359.591245
52   6912.0  1423.587410  1437.680395     360.780803
53   7040.0  1414.719923  1450.212370     360.603686
54   7168.0  1422.184473  1449.774478     361.810060
55   7296.0  1423.085616  1086.855909     362.552323
56   7424.0  1430.063088  1095.659103     362.886964
57   7552.0  1423.776684  1111.656088     363.550295
58   7680.0  1427.860558  1120.869133     363.645226
59   7808.0  1433.065707  1133.450126     364.289655
60   7936.0  1433.949817  1141.201886     364.928276
61   8064.0  1429.935656  1152.136903     365.234802
62   8192.0  1429.637557  1149.870765     363.981518
63   8320.0  1384.127461  1117.448146     362.266815
64   8448.0  1381.693456  1125.435675     362.402980
65   8576.0  1385.757097  1128.787876     363.405873
66   8704.0  1383.150348  1134.673944     364.360417
67   8832.0  1398.098452  1132.714809     364.831567
68   8960.0  1383.778719  1137.135833     366.074490
69   9088.0  1394.381390  1134.308542     366.701941
70   9216.0  1407.059183  1142.613601     367.604997
71   9344.0  1389.330103  1422.350931     367.411346
72   9472.0  1400.804707  1428.663727     368.461745
73   9600.0  1398.967302  1431.269603     369.102115
74   9728.0  1394.465843  1441.434411     369.513179
75   9856.0  1400.970135  1439.068066     370.102273
76   9984.0  1394.213251  1448.962947     370.427513
77  10112.0  1404.414216  1453.625817     371.459231
78  10240.0  1407.812931  1464.750700     371.586504
79  10368.0  1417.341396  1461.264155     370.112489
80  10496.0  1405.468574  1466.430691     370.697412
81  10624.0  1405.869157  1466.424267     370.779356
82  10752.0  1396.925769  1473.176760     371.452425
83  10880.0  1388.155189  1477.137031     372.058562
84  11008.0  1422.063596  1476.099900     372.886245
85  11136.0  1418.511102  1483.006054     373.059114
86  11264.0  1415.307291  1489.404519     373.205625
87  11392.0  1421.515852  1486.380987     374.083120
88  11520.0  1407.813784  1494.151362     374.015381
89  11648.0  1420.312897  1500.019746     374.874991
90  11776.0  1433.157377  1499.672759     374.912809
91  11904.0  1433.460806  1505.101197     375.646318
92  12032.0  1414.247040  1509.848282     375.752422
93  12160.0  1409.691067  1517.099969     375.691335
94  12288.0  1429.344513  1420.003995     376.902178
95  12416.0  1438.311463  1397.545112     374.770257
96  12544.0  1442.326405  1397.559343     375.749041
97  12672.0  1433.880325  1395.162573     375.308313
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 35.069 seconds)

Gallery generated by Sphinx-Gallery