Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch'],  # possible values for `line_arg``
        line_names=[
            "Triton",
            "Torch",
        ],  # label name for the lines
        styles=[('blue', '-'), ('green', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch
0     256.0   472.862396   699.756800
1     384.0   651.484514   818.132292
2     512.0   811.372182   910.981374
3     640.0   814.815852   951.199175
4     768.0   887.696040  1013.714039
5     896.0   944.024510  1070.809111
6    1024.0  1009.288834  1124.303264
7    1152.0  1098.129482   612.333855
8    1280.0  1149.415129   670.749920
9    1408.0  1157.755917   723.087728
10   1536.0  1186.521658   779.984566
11   1664.0  1220.031133   813.103697
12   1792.0  1233.592125   860.443754
13   1920.0  1257.807086   909.920411
14   2048.0  1279.293340   960.264478
15   2176.0  1262.012475   976.516538
16   2304.0  1257.687651  1012.200991
17   2432.0  1295.234727  1053.631793
18   2560.0  1306.036295  1084.478062
19   2688.0  1305.711150  1101.632895
20   2816.0  1325.406297  1132.903000
21   2944.0  1320.628778  1168.861528
22   3072.0  1345.113165  1182.321087
23   3200.0  1352.509539  1193.519784
24   3328.0  1358.600741  1228.811358
25   3456.0  1371.922444  1247.347143
26   3584.0  1377.727442  1264.102519
27   3712.0  1378.660793  1268.178949
28   3840.0  1381.682227  1298.776892
29   3968.0  1391.266351  1318.796461
30   4096.0  1392.908893  1328.805598
31   4224.0  1335.949028  1157.249456
32   4352.0  1332.794908  1172.282930
33   4480.0  1344.923715  1180.426870
34   4608.0  1356.642316  1194.205931
35   4736.0  1353.108713  1199.312644
36   4864.0  1371.082307  1221.315990
37   4992.0  1369.167399  1232.013491
38   5120.0  1365.956607  1248.304787
39   5248.0  1375.794898  1261.895356
40   5376.0  1372.388729  1285.443032
41   5504.0  1373.963112  1300.965509
42   5632.0  1383.782508  1315.589144
43   5760.0  1385.675823  1326.471747
44   5888.0  1387.650753  1346.726776
45   6016.0  1390.567121  1353.657159
46   6144.0  1405.255358  1373.408063
47   6272.0  1408.341351  1375.927978
48   6400.0  1407.039883  1385.626975
49   6528.0  1408.358404  1392.005945
50   6656.0  1418.426375  1402.485653
51   6784.0  1407.833992  1414.641674
52   6912.0  1419.522391  1420.735984
53   7040.0  1414.992699  1432.565566
54   7168.0  1422.078532  1433.541635
55   7296.0  1423.228611  1441.673207
56   7424.0  1421.330637  1445.291406
57   7552.0  1425.615342  1455.520673
58   7680.0  1432.973900  1459.093589
59   7808.0  1431.048238  1463.433996
60   7936.0  1431.169756  1467.436816
61   8064.0  1434.767175  1472.245079
62   8192.0  1431.303468  1486.995142
63   8320.0  1381.802364  1404.224648
64   8448.0  1377.469291  1407.548449
65   8576.0  1392.370158  1395.305976
66   8704.0  1384.866712  1400.226696
67   8832.0  1380.795627  1403.900363
68   8960.0  1392.689736  1410.657442
69   9088.0  1405.334189  1416.571863
70   9216.0  1391.399603  1422.585025
71   9344.0  1394.117981  1425.715247
72   9472.0  1396.154348  1433.834850
73   9600.0  1383.956960  1433.952205
74   9728.0  1393.178018  1441.260455
75   9856.0  1407.849769  1444.561983
76   9984.0  1396.719730  1453.737252
77  10112.0  1402.232239  1457.585595
78  10240.0  1414.240082  1467.189401
79  10368.0  1406.514416  1464.605995
80  10496.0  1403.841382  1463.282013
81  10624.0  1405.385315  1469.296199
82  10752.0  1398.634149  1469.379374
83  10880.0  1392.687410  1481.280773
84  11008.0  1414.261085  1474.270856
85  11136.0  1410.275136  1483.614131
86  11264.0  1420.611967  1488.153231
87  11392.0  1408.010939  1488.529851
88  11520.0  1413.224806  1496.158861
89  11648.0  1419.125024  1499.139291
90  11776.0  1420.978199  1501.699769
91  11904.0  1434.961249  1509.173618
92  12032.0  1415.075986  1508.763999
93  12160.0  1411.006496  1512.082697
94  12288.0  1424.858431  1393.581615
95  12416.0  1443.339491  1388.857475
96  12544.0  1432.562594  1394.972408
97  12672.0  1438.338791  1391.244828
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 23.234 seconds)

Gallery generated by Sphinx-Gallery